/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.cost;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.cost.ComputeCost;
import org.apache.sysds.hops.cost.FederatedCost;
import org.apache.sysds.hops.cost.HopRel;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;

public class FederatedCostEstimator {
    public int DEFAULT_MEMORY_ESTIMATE = 8;
    public int DEFAULT_ITERATION_NUMBER = 15;
    public double WORKER_NETWORK_BANDWIDTH_BYTES_PS = 1.073741824E9;
    public double WORKER_COMPUTE_BANDWIDTH_FLOPS = 2.68435456E9;
    public double WORKER_DEGREE_OF_PARALLELISM = 8.0;
    public double WORKER_READ_BANDWIDTH_BYTES_PS = 3.758096384E9;
    public boolean printCosts = false;

    public FederatedCost costEstimate(DMLProgram dmlProgram) {
        FederatedCost programTotalCost = new FederatedCost();
        for (StatementBlock stmBlock : dmlProgram.getStatementBlocks()) {
            programTotalCost.addInputTotalCost(this.costEstimate(stmBlock).getTotal());
        }
        return programTotalCost;
    }

    private FederatedCost costEstimate(StatementBlock sb) {
        if (sb instanceof WhileStatementBlock) {
            WhileStatementBlock whileSB = (WhileStatementBlock)sb;
            FederatedCost whileSBCost = this.costEstimate(whileSB.getPredicateHops());
            for (Statement statement : whileSB.getStatements()) {
                WhileStatement whileStatement = (WhileStatement)statement;
                for (StatementBlock bodyBlock : whileStatement.getBody()) {
                    whileSBCost.addInputTotalCost(this.costEstimate(bodyBlock));
                }
            }
            whileSBCost.addRepetitionCost(this.DEFAULT_ITERATION_NUMBER);
            return whileSBCost;
        }
        if (sb instanceof IfStatementBlock) {
            IfStatementBlock ifSB = (IfStatementBlock)sb;
            FederatedCost ifSBCost = new FederatedCost();
            for (Statement statement : ifSB.getStatements()) {
                IfStatement ifStatement = (IfStatement)statement;
                for (StatementBlock ifBodySB : ifStatement.getIfBody()) {
                    ifSBCost.addInputTotalCost(this.costEstimate(ifBodySB));
                }
                for (StatementBlock elseBodySB : ifStatement.getElseBody()) {
                    ifSBCost.addInputTotalCost(this.costEstimate(elseBodySB));
                }
            }
            ifSBCost.setInputTotalCost(ifSBCost.getInputTotalCost() / 2.0);
            ifSBCost.addInputTotalCost(this.costEstimate(ifSB.getPredicateHops()));
            return ifSBCost;
        }
        if (sb instanceof ForStatementBlock) {
            ForStatementBlock forSB = (ForStatementBlock)sb;
            ArrayList<Hop> predicateHops = new ArrayList<Hop>();
            predicateHops.add(forSB.getFromHops());
            predicateHops.add(forSB.getToHops());
            predicateHops.add(forSB.getIncrementHops());
            FederatedCost forSBCost = this.costEstimate(predicateHops);
            for (Statement statement : forSB.getStatements()) {
                ForStatement forStatement = (ForStatement)statement;
                for (StatementBlock forStatementBlockBody : forStatement.getBody()) {
                    forSBCost.addInputTotalCost(this.costEstimate(forStatementBlockBody));
                }
            }
            forSBCost.addRepetitionCost(forSB.getEstimateReps());
            return forSBCost;
        }
        if (sb instanceof FunctionStatementBlock) {
            FederatedCost funcCost = this.addInitialInputCost(sb);
            FunctionStatementBlock funcSB = (FunctionStatementBlock)sb;
            for (Statement statement : funcSB.getStatements()) {
                FunctionStatement funcStatement = (FunctionStatement)statement;
                for (StatementBlock funcStatementBody : funcStatement.getBody()) {
                    funcCost.addInputTotalCost(this.costEstimate(funcStatementBody));
                }
            }
            return funcCost;
        }
        return this.costEstimate(sb.getHops());
    }

    private FederatedCost addInitialInputCost(StatementBlock sb) {
        FederatedCost basicCost = new FederatedCost();
        for (StatementBlock childSB : sb.getDMLProg().getStatementBlocks()) {
            basicCost.addInputTotalCost(this.costEstimate(childSB).getTotal());
        }
        return basicCost;
    }

    private FederatedCost costEstimate(ArrayList<Hop> roots) {
        FederatedCost basicCost = new FederatedCost();
        for (Hop root : roots) {
            basicCost.addInputTotalCost(this.costEstimate(root));
        }
        return basicCost;
    }

    public FederatedCost costEstimate(Hop root) {
        if (root.federatedCostInitialized()) {
            return root.getFederatedCost();
        }
        boolean hasFederatedInput = root.someInputFederated();
        double inputCosts = root.getInput().stream().mapToDouble(in -> in.federatedCostInitialized() ? 0.0 : this.costEstimate((Hop)in).getTotal()).sum();
        double inputTransferCost = this.inputTransferCostEstimate(hasFederatedInput, root);
        double computingCost = ComputeCost.getHOPComputeCost(root);
        if (hasFederatedInput) {
            int numWorkers = (int)root.getInput().stream().filter(Hop::hasFederatedOutput).count();
            computingCost /= (double)numWorkers * this.WORKER_DEGREE_OF_PARALLELISM * this.WORKER_COMPUTE_BANDWIDTH_FLOPS;
        } else {
            computingCost /= this.WORKER_DEGREE_OF_PARALLELISM * this.WORKER_COMPUTE_BANDWIDTH_FLOPS;
        }
        double outputTransferCost = root.hasLocalOutput() && (hasFederatedInput || root.isFederatedDataOp()) ? root.getOutputMemEstimate(this.DEFAULT_MEMORY_ESTIMATE) / this.WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0.0;
        double readCost = root.getInputMemEstimate(this.DEFAULT_MEMORY_ESTIMATE) / this.WORKER_READ_BANDWIDTH_BYTES_PS;
        FederatedCost rootFedCost = new FederatedCost(readCost, inputTransferCost, outputTransferCost, computingCost, inputCosts);
        root.setFederatedCost(rootFedCost);
        if (this.printCosts) {
            FederatedCostEstimator.printCosts(root);
        }
        return rootFedCost;
    }

    public FederatedCost costEstimate(HopRel root, Map<Long, List<HopRel>> hopRelMemo) {
        if (hopRelMemo.containsKey(root.hopRef.getHopID()) && hopRelMemo.get(root.hopRef.getHopID()).stream().anyMatch(h -> h.fedOut == root.fedOut)) {
            return root.getCostObject();
        }
        boolean hasFederatedInput = root.inputDependency.stream().anyMatch(in -> in.hopRef.hasFederatedOutput());
        double inputCosts = root.inputDependency.stream().mapToDouble(in -> {
            double inCost = in.existingCostPointer(root.hopRef.getHopID()) ? 0.0 : this.costEstimate((HopRel)in, hopRelMemo).getTotal();
            in.addCostPointer(root.hopRef.getHopID());
            return inCost;
        }).sum();
        double inputTransferCost = this.inputTransferCostEstimate(hasFederatedInput, root);
        double computingCost = ComputeCost.getHOPComputeCost(root.hopRef);
        if (hasFederatedInput) {
            int numWorkers = (int)root.inputDependency.stream().filter(HopRel::hasFederatedOutput).count();
            computingCost /= (double)numWorkers * this.WORKER_DEGREE_OF_PARALLELISM * this.WORKER_COMPUTE_BANDWIDTH_FLOPS;
        } else {
            computingCost /= this.WORKER_DEGREE_OF_PARALLELISM * this.WORKER_COMPUTE_BANDWIDTH_FLOPS;
        }
        double outputTransferCost = root.hasLocalOutput() && (hasFederatedInput || root.hopRef.isFederatedDataOp()) ? root.hopRef.getOutputMemEstimate(this.DEFAULT_MEMORY_ESTIMATE) / this.WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0.0;
        double readCost = root.hopRef.getInputMemEstimate(this.DEFAULT_MEMORY_ESTIMATE) / this.WORKER_READ_BANDWIDTH_BYTES_PS;
        return new FederatedCost(readCost, inputTransferCost, outputTransferCost, computingCost, inputCosts);
    }

    private double inputTransferCostEstimate(boolean hasFederatedInput, HopRel root) {
        if (hasFederatedInput) {
            return root.inputDependency.stream().filter(input -> root.hopRef.isFederatedDataOp() ? input.hasFederatedOutput() : input.hasLocalOutput()).mapToDouble(in -> in.hopRef.getOutputMemEstimate(this.DEFAULT_MEMORY_ESTIMATE)).sum() / this.WORKER_NETWORK_BANDWIDTH_BYTES_PS;
        }
        return 0.0;
    }

    private double inputTransferCostEstimate(boolean hasFederatedInput, Hop root) {
        if (hasFederatedInput) {
            return root.getInput().stream().filter(input -> root.isFederatedDataOp() ? input.hasFederatedOutput() : input.hasLocalOutput()).mapToDouble(in -> in.getOutputMemEstimate(this.DEFAULT_MEMORY_ESTIMATE)).sum() / this.WORKER_NETWORK_BANDWIDTH_BYTES_PS;
        }
        return 0.0;
    }

    private static void printCosts(Hop root) {
        System.out.println("===============================");
        System.out.println(root);
        System.out.println("Is federated: " + root.isFederated());
        System.out.println("Has federated output: " + root.hasFederatedOutput());
        System.out.println(root.getText());
        System.out.println("Pure computeCost: " + ComputeCost.getHOPComputeCost(root));
        System.out.println("Dim1: " + root.getDim1() + " Dim2: " + root.getDim2());
        System.out.println(root.getFederatedCost().toString());
        System.out.println("===============================");
    }
}

