/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory;
import org.apache.sysds.runtime.compress.lib.CLALibUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class CLALibRightMultBy {
    private static final Log LOG = LogFactory.getLog((String)CLALibRightMultBy.class.getName());

    public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
        boolean allowOverlap = ConfigurationManager.getDMLConfig().getBooleanValue("sysds.compressed.overlapping");
        return CLALibRightMultBy.rightMultByMatrix(m1, m2, ret, k, allowOverlap);
    }

    public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k, boolean allowOverlap) {
        if (m2.isEmpty()) {
            LOG.trace((Object)"Empty right multiply");
            if (ret == null) {
                ret = new MatrixBlock(m1.getNumRows(), m2.getNumColumns(), 0L);
            } else {
                ret.reset(m1.getNumRows(), m2.getNumColumns(), 0L);
            }
        } else {
            if (m2 instanceof CompressedMatrixBlock) {
                m2 = ((CompressedMatrixBlock)m2).getUncompressed("Uncompressed right side of right MM");
            }
            if ((ret = CLALibRightMultBy.rightMultByMatrixOverlapping(m1, m2, k)) instanceof CompressedMatrixBlock) {
                if (!allowOverlap) {
                    ret = ((CompressedMatrixBlock)ret).getUncompressed("Overlapping not allowed");
                } else {
                    double uncompressedSize;
                    double compressedSize = ret.getInMemorySize();
                    if (compressedSize > (uncompressedSize = (double)MatrixBlock.estimateSizeDenseInMemory(ret.getNumRows(), ret.getNumColumns()))) {
                        ret = ((CompressedMatrixBlock)ret).getUncompressed("Overlapping rep to big: " + compressedSize + " vs Uncompressed " + uncompressedSize);
                    }
                }
            }
        }
        ret.recomputeNonZeros();
        return ret;
    }

    private static MatrixBlock rightMultByMatrixOverlapping(CompressedMatrixBlock m1, MatrixBlock that, int k) {
        AColGroup cRet;
        int rl = m1.getNumRows();
        int cr = that.getNumColumns();
        int rr = that.getNumRows();
        List<AColGroup> colGroups = m1.getColGroups();
        ArrayList<AColGroup> retCg = new ArrayList<AColGroup>();
        CompressedMatrixBlock ret = new CompressedMatrixBlock(rl, cr);
        boolean containsSDC = CLALibUtils.containsSDCOrConst(colGroups);
        double[] constV = containsSDC ? new double[rr] : null;
        List<AColGroup> filteredGroups = CLALibUtils.filterGroups(colGroups, constV);
        if (colGroups == filteredGroups) {
            constV = null;
        }
        boolean containsNull = false;
        containsNull = k == 1 ? CLALibRightMultBy.rightMultByMatrixOverlappingSingleThread(filteredGroups, that, retCg) : CLALibRightMultBy.rightMultByMatrixOverlappingMultiThread(filteredGroups, that, retCg, k);
        if (constV != null && (cRet = ColGroupFactory.genColGroupConst(rr, constV).rightMultByMatrix(that)) != null) {
            retCg.add(cRet);
        }
        ret.allocateColGroupList(retCg);
        if (retCg.size() > 1) {
            ret.setOverlapping(true);
        }
        CLALibRightMultBy.addEmptyColumn(retCg, cr, rl, containsNull);
        return ret;
    }

    private static boolean rightMultByMatrixOverlappingSingleThread(List<AColGroup> filteredGroups, MatrixBlock that, List<AColGroup> retCg) {
        boolean containsNull = false;
        for (AColGroup g : filteredGroups) {
            AColGroup retG = g.rightMultByMatrix(that);
            if (retG != null) {
                retCg.add(retG);
                continue;
            }
            containsNull = true;
        }
        return containsNull;
    }

    private static boolean rightMultByMatrixOverlappingMultiThread(List<AColGroup> filteredGroups, MatrixBlock that, List<AColGroup> retCg, int k) {
        ExecutorService pool = CommonThreadPool.get(k);
        boolean containsNull = false;
        try {
            ArrayList<RightMatrixMultTask> tasks = new ArrayList<RightMatrixMultTask>(filteredGroups.size());
            for (AColGroup aColGroup : filteredGroups) {
                tasks.add(new RightMatrixMultTask(aColGroup, that));
            }
            for (Future future : pool.invokeAll(tasks)) {
                AColGroup g = (AColGroup)future.get();
                if (g != null) {
                    retCg.add(g);
                    continue;
                }
                containsNull = true;
            }
        }
        catch (InterruptedException | ExecutionException e) {
            throw new DMLRuntimeException(e);
        }
        return containsNull;
    }

    private static void addEmptyColumn(List<AColGroup> retCg, int cr, int rl, boolean containsNull) {
        ColGroupEmpty cge;
        if (containsNull && (cge = CLALibRightMultBy.findEmptyColumnsAndMakeEmptyColGroup(retCg, cr, rl)) != null) {
            retCg.add(cge);
        }
    }

    private static ColGroupEmpty findEmptyColumnsAndMakeEmptyColGroup(List<AColGroup> colGroups, int nCols, int nRows) {
        HashSet<Integer> emptyColumns = new HashSet<Integer>(nCols);
        for (int i = 0; i < nCols; ++i) {
            emptyColumns.add(i);
        }
        for (AColGroup g : colGroups) {
            for (int c : g.getColIndices()) {
                emptyColumns.remove(c);
            }
        }
        if (emptyColumns.size() != 0) {
            int[] emptyColumnsFinal = emptyColumns.stream().mapToInt(Integer::intValue).toArray();
            return new ColGroupEmpty(emptyColumnsFinal);
        }
        return null;
    }

    private static class RightMatrixMultTask
    implements Callable<AColGroup> {
        private final AColGroup _colGroup;
        private final MatrixBlock _b;

        protected RightMatrixMultTask(AColGroup colGroup, MatrixBlock b) {
            this._colGroup = colGroup;
            this._b = b;
        }

        @Override
        public AColGroup call() {
            try {
                return this._colGroup.rightMultByMatrix(this._b);
            }
            catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
        }
    }
}

