/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops;

import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.MemoTable;
import org.apache.sysds.hops.MultiThreadedHop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.PartialAggregate;
import org.apache.sysds.lops.TernaryAggregate;
import org.apache.sysds.lops.UAggOuterChain;
import org.apache.sysds.lops.UnaryCP;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

public class AggUnaryOp
extends MultiThreadedHop {
    private static final boolean ALLOW_UNARYAGG_WO_FINAL_AGG = true;
    private Types.AggOp _op;
    private Types.Direction _direction;

    private AggUnaryOp() {
    }

    public AggUnaryOp(String l, Types.DataType dt, Types.ValueType vt, Types.AggOp o, Types.Direction idx, Hop inp) {
        super(l, dt, vt);
        this._op = o;
        this._direction = idx;
        this.getInput().add(0, inp);
        inp.getParent().add(this);
    }

    @Override
    public void checkArity() {
        HopsException.check(this._input.size() == 1, this, "should have arity 1 but has arity %d", this._input.size());
    }

    public Types.AggOp getOp() {
        return this._op;
    }

    public void setOp(Types.AggOp op) {
        this._op = op;
    }

    public Types.Direction getDirection() {
        return this._direction;
    }

    public void setDirection(Types.Direction direction) {
        this._direction = direction;
    }

    @Override
    public boolean isGPUEnabled() {
        if (!DMLScript.USE_ACCELERATOR) {
            return false;
        }
        try {
            if (this.isTernaryAggregateRewriteApplicable() || this.isUnaryAggregateOuterCPRewriteApplicable()) {
                return false;
            }
            if (this._op == Types.AggOp.SUM && (this._direction == Types.Direction.RowCol || this._direction == Types.Direction.Row || this._direction == Types.Direction.Col) || this._op == Types.AggOp.SUM_SQ && (this._direction == Types.Direction.RowCol || this._direction == Types.Direction.Row || this._direction == Types.Direction.Col) || this._op == Types.AggOp.MAX && (this._direction == Types.Direction.RowCol || this._direction == Types.Direction.Row || this._direction == Types.Direction.Col) || this._op == Types.AggOp.MIN && (this._direction == Types.Direction.RowCol || this._direction == Types.Direction.Row || this._direction == Types.Direction.Col) || this._op == Types.AggOp.MEAN && (this._direction == Types.Direction.RowCol || this._direction == Types.Direction.Row || this._direction == Types.Direction.Col) || this._op == Types.AggOp.VAR && (this._direction == Types.Direction.RowCol || this._direction == Types.Direction.Row || this._direction == Types.Direction.Col) || this._op == Types.AggOp.PROD && this._direction == Types.Direction.RowCol) {
                return true;
            }
        }
        catch (HopsException e) {
            throw new RuntimeException(e);
        }
        return false;
    }

    @Override
    public Lop constructLops() {
        block17: {
            if (this.getLops() != null) {
                return this.getLops();
            }
            try {
                Types.ExecType et = this.optFindExecType();
                Hop input = this.getInput().get(0);
                if (et == Types.ExecType.CP || et == Types.ExecType.GPU || et == Types.ExecType.FED) {
                    Lop agg1 = null;
                    if (this.isTernaryAggregateRewriteApplicable()) {
                        agg1 = this.constructLopsTernaryAggregateRewrite(et);
                    } else if (this.isUnaryAggregateOuterCPRewriteApplicable()) {
                        BinaryOp binput = (BinaryOp)this.getInput().get(0);
                        agg1 = new UAggOuterChain(binput.getInput().get(0).constructLops(), binput.getInput().get(1).constructLops(), this._op, this._direction, binput.getOp(), Types.DataType.MATRIX, this.getValueType(), Types.ExecType.CP);
                        PartialAggregate.setDimensionsBasedOnDirection(agg1, this.getDim1(), this.getDim2(), input.getBlocksize(), this._direction);
                        if (this.getDataType() == Types.DataType.SCALAR) {
                            UnaryCP unary1 = new UnaryCP(agg1, Types.OpOp1.CAST_AS_SCALAR, this.getDataType(), this.getValueType());
                            unary1.getOutputParameters().setDimensions(0L, 0L, 0L, -1L);
                            this.setLineNumbers(unary1);
                            agg1 = unary1;
                        }
                    } else {
                        int k = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
                        agg1 = new PartialAggregate(input.constructLops(), this._op, this._direction, this.getDataType(), this.getValueType(), et, k);
                    }
                    this.setOutputDimensions(agg1);
                    this.setLineNumbers(agg1);
                    this.setLops(agg1);
                    if (this.getDataType() == Types.DataType.SCALAR) {
                        agg1.getOutputParameters().setDimensions(1L, 1L, this.getBlocksize(), this.getNnz());
                    }
                    break block17;
                }
                if (et == Types.ExecType.SPARK) {
                    if (this.isTernaryAggregateRewriteApplicable()) {
                        Lop aggregate = this.constructLopsTernaryAggregateRewrite(et);
                        this.setOutputDimensions(aggregate);
                        this.setLineNumbers(aggregate);
                        this.setLops(aggregate);
                    } else if (this.isUnaryAggregateOuterSPRewriteApplicable()) {
                        BinaryOp binput = (BinaryOp)this.getInput().get(0);
                        UAggOuterChain transform1 = new UAggOuterChain(binput.getInput().get(0).constructLops(), binput.getInput().get(1).constructLops(), this._op, this._direction, binput.getOp(), Types.DataType.MATRIX, this.getValueType(), Types.ExecType.SPARK);
                        PartialAggregate.setDimensionsBasedOnDirection(transform1, this.getDim1(), this.getDim2(), input.getBlocksize(), this._direction);
                        this.setLineNumbers(transform1);
                        this.setLops(transform1);
                        if (this.getDataType() == Types.DataType.SCALAR) {
                            UnaryCP unary1 = new UnaryCP(transform1, Types.OpOp1.CAST_AS_SCALAR, this.getDataType(), this.getValueType());
                            unary1.getOutputParameters().setDimensions(0L, 0L, 0L, -1L);
                            this.setLineNumbers(unary1);
                            this.setLops(unary1);
                        }
                    } else {
                        boolean needAgg = AggUnaryOp.requiresAggregation(input, this._direction);
                        AggBinaryOp.SparkAggType aggtype = this.getSparkUnaryAggregationType(needAgg);
                        PartialAggregate aggregate = new PartialAggregate(input.constructLops(), this._op, this._direction, input._dataType, this.getValueType(), aggtype, et);
                        aggregate.setDimensionsBasedOnDirection(this.getDim1(), this.getDim2(), input.getBlocksize());
                        this.setLineNumbers(aggregate);
                        this.setLops(aggregate);
                        if (this.getDataType() == Types.DataType.SCALAR) {
                            UnaryCP unary1 = new UnaryCP(aggregate, Types.OpOp1.CAST_AS_SCALAR, this.getDataType(), this.getValueType());
                            unary1.getOutputParameters().setDimensions(0L, 0L, 0L, -1L);
                            this.setLineNumbers(unary1);
                            this.setLops(unary1);
                        }
                    }
                    break block17;
                }
                throw new HopsException("ExecType " + (Object)((Object)et) + " not recognized in " + this.toString());
            }
            catch (Exception e) {
                throw new HopsException(this.printErrorLocation() + "In AggUnary Hop, error constructing Lops ", e);
            }
        }
        this.constructAndSetLopsDataFlowProperties();
        return this.getLops();
    }

    @Override
    public String getOpString() {
        return "ua(" + this._op.toString() + this._direction.toString() + ")";
    }

    @Override
    public boolean allowsAllExecTypes() {
        return true;
    }

    @Override
    protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) {
        double sparsity = -1.0;
        sparsity = this.isGPUEnabled() ? 1.0 : OptimizerUtils.getSparsity(dim1, dim2, nnz);
        return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
    }

    @Override
    protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) {
        double val = 0.0;
        double sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz);
        switch (this._op) {
            case MAX: 
            case MIN: {
                if (this._direction != Types.Direction.Col) break;
                val = dim2 * 4L;
                break;
            }
            case SUM: 
            case SUM_SQ: {
                if (this._direction == Types.Direction.Col) {
                    val = OptimizerUtils.estimateSizeExactSparsity(2L, dim2, sparsity);
                    break;
                }
                if (this._direction != Types.Direction.Row) break;
                val = OptimizerUtils.estimateSizeExactSparsity(dim1, 2L, 1.0);
                break;
            }
            case MEAN: {
                if (this._direction == Types.Direction.Col) {
                    val = OptimizerUtils.estimateSizeExactSparsity(3L, dim2, sparsity);
                    break;
                }
                if (this._direction != Types.Direction.Row) break;
                val = OptimizerUtils.estimateSizeExactSparsity(dim1, 3L, 1.0);
                break;
            }
            case VAR: {
                if (this.isGPUEnabled()) {
                    long in1dim1 = this.getInput().get(0).getDim1();
                    long in1dim2 = this.getInput().get(0).getDim2();
                    val = 2L * OptimizerUtils.estimateSize(in1dim1, in1dim2);
                    if (this._direction == Types.Direction.Col) {
                        val += (double)OptimizerUtils.estimateSize(in1dim1, 1L);
                        break;
                    }
                    if (this._direction != Types.Direction.Row) break;
                    val += (double)OptimizerUtils.estimateSize(1L, in1dim2);
                    break;
                }
                if (this._direction == Types.Direction.Col) {
                    val = OptimizerUtils.estimateSizeExactSparsity(5L, dim2, sparsity);
                    break;
                }
                if (this._direction != Types.Direction.Row) break;
                val = OptimizerUtils.estimateSizeExactSparsity(dim1, 5L, 1.0);
                break;
            }
            case MAXINDEX: 
            case MININDEX: {
                Hop hop = this.getInput().get(0);
                if (this.isUnaryAggregateOuterCPRewriteApplicable()) {
                    val = 3L * OptimizerUtils.estimateSizeExactSparsity(1L, hop.getDim2(), 1.0);
                    break;
                }
                val = OptimizerUtils.estimateSizeExactSparsity(dim1, 2L, 1.0);
                break;
            }
            default: {
                val = 0.0;
            }
        }
        return val;
    }

    @Override
    protected DataCharacteristics inferOutputCharacteristics(MemoTable memo) {
        MatrixCharacteristics ret = null;
        Hop input = this.getInput().get(0);
        DataCharacteristics dc = memo.getAllInputStats(input);
        if (this._direction == Types.Direction.Col && dc.colsKnown()) {
            ret = new MatrixCharacteristics(1L, dc.getCols(), -1, -1L);
        } else if (this._direction == Types.Direction.Row && dc.rowsKnown()) {
            ret = new MatrixCharacteristics(dc.getRows(), 1L, -2);
        }
        return ret;
    }

    @Override
    protected Types.ExecType optFindExecType(boolean transitive) {
        this.checkAndSetForcedPlatform();
        Types.ExecType REMOTE = Types.ExecType.SPARK;
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            this._etype = OptimizerUtils.isMemoryBasedOptLevel() ? this.findExecTypeByMemEstimate() : (this.getInput().get(0).areDimsBelowThreshold() || this.getInput().get(0).isVector() ? Types.ExecType.CP : REMOTE);
            this.checkAndSetInvalidCPDimsAndSize();
        }
        if (transitive && this._etype == Types.ExecType.CP && this._etypeForced != Types.ExecType.CP && (!(this.getInput(0) instanceof DataOp) && this.getInput(0).optFindExecType() == Types.ExecType.SPARK || this.getInput(0) instanceof DataOp && ((DataOp)this.getInput(0)).hasOnlyRDD()) && (this.getInput(0).getParent().size() == 1 || this.getInput(0).getParent().stream().filter(h -> h != this).allMatch(h -> h.optFindExecType(false) == Types.ExecType.SPARK) || !AggUnaryOp.requiresAggregation(this.getInput(0), this._direction))) {
            this._etype = Types.ExecType.SPARK;
        }
        this.updateETFed();
        this.setRequiresRecompileIfNecessary();
        return this._etype;
    }

    private static boolean requiresAggregation(Hop input, Types.Direction dir) {
        boolean noAggRequired = input.getDim1() > 1L && input.getDim1() <= (long)input.getBlocksize() && dir == Types.Direction.Col || input.getDim2() > 1L && input.getDim2() <= (long)input.getBlocksize() && dir == Types.Direction.Row;
        return !noAggRequired;
    }

    private AggBinaryOp.SparkAggType getSparkUnaryAggregationType(boolean agg) {
        if (!agg) {
            return AggBinaryOp.SparkAggType.NONE;
        }
        if (this.getDataType() == Types.DataType.SCALAR || this.dimsKnown() && this.getDim1() <= (long)this.getBlocksize() && this.getDim2() <= (long)this.getBlocksize()) {
            return AggBinaryOp.SparkAggType.SINGLE_BLOCK;
        }
        return AggBinaryOp.SparkAggType.MULTI_BLOCK;
    }

    private boolean isTernaryAggregateRewriteApplicable() {
        Hop input1;
        boolean ret = false;
        if (DMLScript.USE_ACCELERATOR) {
            return false;
        }
        if (OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES && this._op == Types.AggOp.SUM && (this._direction == Types.Direction.RowCol || this._direction == Types.Direction.Col) && (input1 = this.getInput().get(0)).getParent().size() == 1 && input1 instanceof BinaryOp) {
            BinaryOp binput1 = (BinaryOp)input1;
            if (binput1.getOp() == Types.OpOp2.POW && binput1.getInput().get(1) instanceof LiteralOp) {
                LiteralOp lit = (LiteralOp)binput1.getInput().get(1);
                ret = HopRewriteUtils.getIntValueSafe(lit) == 3L;
            } else if (binput1.getOp() == Types.OpOp2.MULT) {
                Hop input11 = input1.getInput().get(0);
                Hop input12 = input1.getInput().get(1);
                ret = input11 instanceof BinaryOp && ((BinaryOp)input11).getOp() == Types.OpOp2.MULT ? HopRewriteUtils.isEqualSize(input11.getInput().get(0), input1) && HopRewriteUtils.isEqualSize(input11.getInput().get(1), input1) && HopRewriteUtils.isEqualSize(input12, input1) : (input12 instanceof BinaryOp && ((BinaryOp)input12).getOp() == Types.OpOp2.MULT ? HopRewriteUtils.isEqualSize(input12.getInput().get(0), input1) && HopRewriteUtils.isEqualSize(input12.getInput().get(1), input1) && HopRewriteUtils.isEqualSize(input11, input1) : HopRewriteUtils.isEqualSize(input11, input12));
            }
        }
        return ret;
    }

    private static boolean isCompareOperator(Types.OpOp2 opOp2) {
        return opOp2 == Types.OpOp2.LESS || opOp2 == Types.OpOp2.LESSEQUAL || opOp2 == Types.OpOp2.GREATER || opOp2 == Types.OpOp2.GREATEREQUAL || opOp2 == Types.OpOp2.EQUAL || opOp2 == Types.OpOp2.NOTEQUAL;
    }

    @Override
    public boolean isMultiThreadedOpType() {
        return true;
    }

    private boolean isUnaryAggregateOuterSPRewriteApplicable() {
        boolean ret = false;
        Hop input = this.getInput().get(0);
        if (input instanceof BinaryOp && ((BinaryOp)input).isOuter()) {
            double size;
            Hop right = input.getInput().get(1);
            double d = size = right.dimsKnown() ? (double)OptimizerUtils.estimateSize(right.getDim1(), right.getDim2()) : right.getOutputMemEstimate();
            if (this._op == Types.AggOp.MAXINDEX || this._op == Types.AggOp.MININDEX) {
                double memBudgetExec = SparkExecutionContext.getBroadcastMemoryBudget();
                double memBudgetLocal = OptimizerUtils.getLocalMemBudget();
                ret = 2.0 * size < memBudgetExec && 2.0 * size < memBudgetLocal;
            } else if (OptimizerUtils.checkSparkBroadcastMemoryBudget(size)) {
                ret = true;
            }
        }
        return ret;
    }

    private boolean isUnaryAggregateOuterCPRewriteApplicable() {
        boolean ret = false;
        Hop input = this.getInput().get(0);
        if (input instanceof BinaryOp && ((BinaryOp)input).isOuter() && (this._op == Types.AggOp.MAXINDEX || this._op == Types.AggOp.MININDEX || this._op == Types.AggOp.SUM) && AggUnaryOp.isCompareOperator(((BinaryOp)input).getOp())) {
            ret = true;
        }
        return ret;
    }

    private Lop constructLopsTernaryAggregateRewrite(Types.ExecType et) {
        Hop b112;
        BinaryOp input1 = (BinaryOp)this.getInput().get(0);
        Hop input11 = input1.getInput().get(0);
        Hop input12 = input1.getInput().get(1);
        Lop in1 = null;
        Lop in2 = null;
        Lop in3 = null;
        boolean handled = false;
        if (input1.getOp() == Types.OpOp2.POW) {
            assert (HopRewriteUtils.isLiteralOfValue(input12, 3.0)) : "this case can only occur with a power of 3";
            in2 = in1 = input11.constructLops();
            in3 = in1;
            handled = true;
        } else if (input11 instanceof BinaryOp) {
            BinaryOp b11 = (BinaryOp)input11;
            switch (b11.getOp()) {
                case MULT: {
                    in1 = input11.getInput().get(0).constructLops();
                    in2 = input11.getInput().get(1).constructLops();
                    in3 = input12.constructLops();
                    handled = true;
                    break;
                }
                case POW: {
                    b112 = b11.getInput().get(1);
                    if (input12 instanceof BinaryOp && ((BinaryOp)input12).getOp() == Types.OpOp2.MULT || !HopRewriteUtils.isLiteralOfValue(b112, 2.0)) break;
                    in2 = in1 = b11.getInput().get(0).constructLops();
                    in3 = input12.constructLops();
                    handled = true;
                    break;
                }
            }
        } else if (input12 instanceof BinaryOp) {
            BinaryOp b12 = (BinaryOp)input12;
            switch (b12.getOp()) {
                case MULT: {
                    in1 = input11.constructLops();
                    in2 = input12.getInput().get(0).constructLops();
                    in3 = input12.getInput().get(1).constructLops();
                    handled = true;
                    break;
                }
                case POW: {
                    b112 = b12.getInput().get(1);
                    if (!HopRewriteUtils.isLiteralOfValue(b112, 2.0)) break;
                    in2 = in1 = b12.getInput().get(0).constructLops();
                    in3 = input11.constructLops();
                    handled = true;
                    break;
                }
            }
        }
        if (!handled) {
            in1 = input11.constructLops();
            in2 = input12.constructLops();
            in3 = new LiteralOp(1L).constructLops();
        }
        int k = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
        Types.ExecType et_input = input1.optFindExecType();
        et_input = et_input == Types.ExecType.GPU ? Types.ExecType.CP : et_input;
        return new TernaryAggregate(in1, in2, in3, Types.AggOp.SUM, Types.OpOp2.MULT, this._direction, this.getDataType(), Types.ValueType.FP64, et_input, k);
    }

    @Override
    public void refreshSizeInformation() {
        if (this.getDataType() != Types.DataType.SCALAR) {
            Hop input = this.getInput().get(0);
            if (this._direction == Types.Direction.Col) {
                this.setDim1(1L);
                this.setDim2(input.getDim2());
            } else if (this._direction == Types.Direction.Row) {
                this.setDim1(input.getDim1());
                this.setDim2(1L);
            }
        }
    }

    @Override
    public boolean isTransposeSafe() {
        boolean ret = this._direction == Types.Direction.RowCol && (this._op == Types.AggOp.SUM || this._op == Types.AggOp.SUM_SQ || this._op == Types.AggOp.MIN || this._op == Types.AggOp.MAX || this._op == Types.AggOp.PROD || this._op == Types.AggOp.MEAN || this._op == Types.AggOp.VAR);
        return ret;
    }

    @Override
    public Object clone() throws CloneNotSupportedException {
        AggUnaryOp ret = new AggUnaryOp();
        ret.clone(this, false);
        ret._op = this._op;
        ret._direction = this._direction;
        ret._maxNumThreads = this._maxNumThreads;
        return ret;
    }

    @Override
    public boolean compare(Hop that) {
        if (!(that instanceof AggUnaryOp)) {
            return false;
        }
        AggUnaryOp that2 = (AggUnaryOp)that;
        return this._op == that2._op && this._direction == that2._direction && this._maxNumThreads == that2._maxNumThreads && this.getInput().get(0) == that2.getInput().get(0);
    }
}

