/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark;

import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.storage.RDDInfo;
import org.apache.spark.storage.StorageLevel;
import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
import org.apache.sysds.runtime.compress.CompressionSettingsBuilder;
import org.apache.sysds.runtime.compress.SingletonLookupHashMap;
import org.apache.sysds.runtime.compress.cost.CostEstimatorBuilder;
import org.apache.sysds.runtime.compress.cost.CostEstimatorFactory;
import org.apache.sysds.runtime.compress.workload.WTreeRoot;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import scala.Tuple2;

public class CompressionSPInstruction
extends UnarySPInstruction {
    private static final Log LOG = LogFactory.getLog((String)CompressionSPInstruction.class.getName());
    private final int _singletonLookupID;

    private CompressionSPInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String istr, int singletonLookupID) {
        super(SPInstruction.SPType.Compression, op, in, out, opcode, istr);
        this._singletonLookupID = singletonLookupID;
    }

    public static CompressionSPInstruction parseInstruction(String str) {
        InstructionUtils.checkNumFields(str, 2, 3);
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand out = new CPOperand(parts[2]);
        if (parts.length == 4) {
            int treeNodeID = Integer.parseInt(parts[3]);
            return new CompressionSPInstruction(null, in1, out, opcode, str, treeNodeID);
        }
        return new CompressionSPInstruction(null, in1, out, opcode, str, 0);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        Object mappingFunction;
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName());
        if (this._singletonLookupID == 0) {
            mappingFunction = new CompressionFunction();
        } else {
            WTreeRoot root = (WTreeRoot)SingletonLookupHashMap.getMap().get(this._singletonLookupID);
            CostEstimatorBuilder costBuilder = new CostEstimatorBuilder(root);
            mappingFunction = new CompressionWorkloadFunction(costBuilder);
        }
        JavaPairRDD out = in.mapValues((Function)mappingFunction);
        if (LOG.isTraceEnabled()) {
            in.persist(StorageLevel.MEMORY_AND_DISK());
            out.persist(StorageLevel.MEMORY_AND_DISK());
            long sparkSizeIn = 0L;
            long sparkSizeOut = 0L;
            long blockSizesIn = CompressionSPInstruction.reduceSizes(in.mapValues((Function)new SizeFunction()).collect());
            long blockSizesOut = CompressionSPInstruction.reduceSizes(out.mapValues((Function)new SizeFunction()).collect());
            for (RDDInfo info : sec.getSparkContext().sc().getRDDStorageInfo()) {
                if (info.id() == out.id()) {
                    sparkSizeOut = info.memSize();
                    continue;
                }
                if (info.id() != in.id()) continue;
                sparkSizeIn = info.memSize();
            }
            StringBuilder sb = new StringBuilder();
            sb.append("Spark Compression Instruction sizes:");
            sb.append(String.format("\nSBCompress: InSize:       %16d", sparkSizeIn));
            sb.append(String.format("\nSBCompress: InBlockSize:  %16d", blockSizesIn));
            sb.append(String.format("\nSBCompress: OutSize:      %16d", sparkSizeOut));
            sb.append(String.format("\nSBCompress: OutBlockSize: %16d", blockSizesOut));
            LOG.trace((Object)sb.toString());
        }
        sec.setRDDHandleForVariable(this.output.getName(), out);
        sec.addLineageRDD(this.input1.getName(), this.output.getName());
    }

    public static Long reduceSizes(List<Tuple2<MatrixIndexes, Long>> in) {
        long sum = 0L;
        for (Tuple2<MatrixIndexes, Long> e : in) {
            sum += ((Long)e._2()).longValue();
        }
        return sum;
    }

    public static class SizeFunction
    implements Function<MatrixBlock, Long> {
        private static final long serialVersionUID = 1L;

        public Long call(MatrixBlock arg0) throws Exception {
            return arg0.getInMemorySize();
        }
    }

    public static class CompressionWorkloadFunction
    implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = -65288330833922L;
        final CostEstimatorBuilder costBuilder;

        public CompressionWorkloadFunction(CostEstimatorBuilder costBuilder) {
            this.costBuilder = costBuilder;
        }

        public MatrixBlock call(MatrixBlock arg0) throws Exception {
            CompressionSettingsBuilder csb = new CompressionSettingsBuilder().setIsInSparkInstruction();
            return (MatrixBlock)CompressedMatrixBlockFactory.compress(arg0, InfrastructureAnalyzer.getLocalParallelism(), csb, this.costBuilder).getLeft();
        }
    }

    public static class CompressionFunction
    implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = -6528833083609423922L;

        public MatrixBlock call(MatrixBlock arg0) throws Exception {
            CompressionSettingsBuilder csb = new CompressionSettingsBuilder().setIsInSparkInstruction().setCostType(CostEstimatorFactory.CostType.MEMORY);
            return (MatrixBlock)CompressedMatrixBlockFactory.compress(arg0, csb).getLeft();
        }
    }
}

