/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.lops.compile.linearization;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.OperatorOrderingUtils;
import org.apache.sysds.lops.compile.linearization.IDagLinearizer;
import org.apache.sysds.lops.compile.linearization.LinearizerDepthFirst;

public class LinearizerPipelineAware
extends IDagLinearizer {
    private static final int IGNORE_LIMIT = 0;
    private static final int HARD_LIMIT = 4;
    private static final int UPPER_BOUND = 10;

    @Override
    public List<Lop> linearize(List<Lop> v) {
        if (v.size() <= 0) {
            v.forEach(l -> l.setPipelineID(1));
            return new LinearizerDepthFirst().linearize(v);
        }
        List<Lop> roots = v.stream().filter(OperatorOrderingUtils::isLopRoot).collect(Collectors.toList());
        Integer pipelineId = 0;
        ArrayList<Lop> opList = new ArrayList<Lop>();
        HashMap<Integer, List<Lop>> pipelineMap = new HashMap<Integer, List<Lop>>();
        for (Lop r : roots) {
            pipelineId = LinearizerPipelineAware.depthFirst(r, pipelineId, opList, pipelineMap) + 1;
        }
        LinearizerPipelineAware.mergeSingleNodePipelines(pipelineMap);
        LinearizerPipelineAware.mergeSmallPipelines(pipelineMap);
        roots.forEach(Lop::resetVisitStatus);
        return opList;
    }

    private static int depthFirst(Lop root, int pipelineId, List<Lop> opList, Map<Integer, List<Lop>> pipelineMap) {
        if (root.isVisited()) {
            return root.getPipelineID();
        }
        root.setPipelineID(pipelineId);
        root.setVisited();
        if (pipelineMap.containsKey(pipelineId)) {
            pipelineMap.get(pipelineId).add(root);
        } else {
            ArrayList<Lop> lopList = new ArrayList<Lop>();
            lopList.add(root);
            pipelineMap.put(pipelineId, lopList);
        }
        ArrayList<Lop> children = root.getInputs();
        if (children.size() == 1) {
            Lop child = (Lop)children.get(0);
            pipelineId = Math.max(pipelineId, LinearizerPipelineAware.depthFirst(child, pipelineId, opList, pipelineMap));
        } else {
            for (int i = 0; i < children.size(); ++i) {
                Lop child = (Lop)children.get(i);
                if (child.getOutputs().size() == 1 || child.getOutputs().size() > 1 && child.getOutputs().stream().allMatch(o -> o == root)) {
                    LinearizerPipelineAware.depthFirst(child, root.getPipelineID(), opList, pipelineMap);
                    continue;
                }
                pipelineId = Math.max(pipelineId, LinearizerPipelineAware.depthFirst(child, pipelineId + 1, opList, pipelineMap));
            }
        }
        opList.add(root);
        return pipelineId;
    }

    private static void mergeSingleNodePipelines(Map<Integer, List<Lop>> map) {
        Map<Integer, List> pipelinesWithOneNode = map.entrySet().stream().filter(e -> ((List)e.getValue()).size() == 1).collect(Collectors.toMap(e -> (Integer)e.getKey(), e -> (List)e.getValue()));
        if (pipelinesWithOneNode.size() == 0) {
            return;
        }
        pipelinesWithOneNode.entrySet().stream().forEach(e -> {
            Lop lop = (Lop)((List)e.getValue()).get(0);
            if (lop.getOutputs().size() > 0) {
                lop.setPipelineID(lop.getOutputs().get(0).getPipelineID());
            } else if (lop.getInputs().size() > 0) {
                lop.setPipelineID(lop.getInputs().get(0).getPipelineID());
            }
            if (lop.getOutputs().size() > 0 || lop.getInputs().size() > 0) {
                ((List)map.get(lop.getPipelineID())).add(lop);
                map.remove(e.getKey());
            }
        });
    }

    private static void mergeSmallPipelines(Map<Integer, List<Lop>> map) {
        if (map.size() < 2) {
            return;
        }
        List<Map.Entry<Integer, Integer>> sortedPipelineSizes = LinearizerPipelineAware.getPipelinesSortedBySize(map);
        Map.Entry<Integer, Integer> sm0 = sortedPipelineSizes.get(0);
        Map.Entry<Integer, Integer> sm1 = sortedPipelineSizes.get(1);
        while (sm0 != null && sm1 != null && (sm0.getValue() < 4 || sm0.getValue() + sm1.getValue() < 10)) {
            int mergeIntoId = sm1.getKey();
            map.get(sm0.getKey()).forEach(l -> l.setPipelineID(mergeIntoId));
            map.get(mergeIntoId).addAll((Collection<Lop>)map.get(sm0.getKey()));
            map.remove(sm0.getKey());
            sortedPipelineSizes = LinearizerPipelineAware.getPipelinesSortedBySize(map);
            if (sortedPipelineSizes.size() < 2) {
                sm0 = null;
                sm1 = null;
                continue;
            }
            sm0 = sortedPipelineSizes.get(0);
            sm1 = sortedPipelineSizes.get(1);
        }
    }

    private static List<Map.Entry<Integer, Integer>> getPipelinesSortedBySize(Map<Integer, List<Lop>> map) {
        return map.entrySet().stream().sorted(Map.Entry.comparingByValue(Comparator.comparingInt(List::size))).map(e -> Map.entry((Integer)e.getKey(), ((List)e.getValue()).size())).collect(Collectors.toList());
    }
}

