/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark.data;

import java.io.Serializable;
import java.util.ArrayList;
import org.apache.spark.broadcast.Broadcast;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlockFactory;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBlock;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.matrix.data.Pair;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.UtilFunctions;

public class PartitionedBroadcast<T extends CacheBlock>
implements Serializable {
    private static final long serialVersionUID = 7041959166079438401L;
    protected static final long BROADCAST_PARTSIZE = 0xF000000L;
    private Broadcast<PartitionedBlock<T>>[] _pbc = null;
    private DataCharacteristics _dc;

    public PartitionedBroadcast() {
    }

    public PartitionedBroadcast(Broadcast<PartitionedBlock<T>>[] broadcasts, DataCharacteristics dc) {
        this._pbc = broadcasts;
        this._dc = dc;
    }

    public Broadcast<PartitionedBlock<T>>[] getBroadcasts() {
        return this._pbc;
    }

    public long getNumRows() {
        return this._dc.getRows();
    }

    public long getNumCols() {
        return this._dc.getCols();
    }

    public int getNumRowBlocks() {
        return (int)this._dc.getNumRowBlocks();
    }

    public int getNumColumnBlocks() {
        return (int)this._dc.getNumColBlocks();
    }

    public DataCharacteristics getDataCharacteristics() {
        return this._dc;
    }

    public static int computeBlocksPerPartition(long rlen, long clen, long blen) {
        return (int)(0xF000000L / Math.min(rlen, blen) / Math.min(clen, blen));
    }

    public static int computeBlocksPerPartition(long[] dims, int blen) {
        long blocksPerPartition = 0xF000000L;
        for (int i = 0; i < dims.length; ++i) {
            blocksPerPartition /= Math.min(dims[i], (long)blen);
        }
        return (int)blocksPerPartition;
    }

    public T getBlock(int rowIndex, int colIndex) {
        int pix = 0;
        if (this._pbc.length > 1) {
            int numPerPart = PartitionedBroadcast.computeBlocksPerPartition(this._dc.getRows(), this._dc.getCols(), this._dc.getBlocksize());
            int ix = (rowIndex - 1) * this.getNumColumnBlocks() + (colIndex - 1);
            pix = ix / numPerPart;
        }
        return this._pbc[pix].value().getBlock(rowIndex, colIndex);
    }

    public T getBlock(int[] ix) {
        int pix = 0;
        if (this._pbc.length > 1) {
            long[] dims = this._dc.getDims();
            int blen = this._dc.getBlocksize();
            int numPerPart = PartitionedBroadcast.computeBlocksPerPartition(dims, this._dc.getBlocksize());
            pix = (int)(UtilFunctions.computeBlockNumber(ix, dims, blen) / (long)numPerPart);
        }
        return this._pbc[pix].value().getBlock(ix);
    }

    public T slice(long rl, long ru, long cl, long cu, T block) {
        int lrl = (int)rl;
        int lru = (int)ru;
        int lcl = (int)cl;
        int lcu = (int)cu;
        ArrayList<?> allBlks = CacheBlockFactory.getPairList(block);
        int start_iix = (lrl - 1) / this._dc.getBlocksize() + 1;
        int end_iix = (lru - 1) / this._dc.getBlocksize() + 1;
        int start_jix = (lcl - 1) / this._dc.getBlocksize() + 1;
        int end_jix = (lcu - 1) / this._dc.getBlocksize() + 1;
        for (int iix = start_iix; iix <= end_iix; ++iix) {
            for (int jix = start_jix; jix <= end_jix; ++jix) {
                IndexRange ixrange = new IndexRange(rl, ru, cl, cu);
                allBlks.addAll(OperationsOnMatrixValues.performSlice(ixrange, this._dc.getBlocksize(), iix, jix, this.getBlock(iix, jix)));
            }
        }
        CacheBlock ret = (CacheBlock)((Pair)allBlks.get(0)).getValue();
        for (int i = 1; i < allBlks.size(); ++i) {
            ret = ret.merge(((Pair)allBlks.get(i)).getValue(), false);
        }
        return (T)ret;
    }

    public void destroy() {
        for (Broadcast<PartitionedBlock<T>> bvar : this._pbc) {
            SparkExecutionContext.cleanupBroadcastVariable(bvar);
        }
    }
}

