/*
 * Decompiled with CFR 0.152.
 */
package dr.evolution.tree.treemetrics;

import dr.evolution.io.Importer;
import dr.evolution.io.NewickImporter;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evolution.tree.treemetrics.TreeMetric;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;

public class KendallColijnPathDifferenceMetric
implements TreeMetric {
    public static TreeMetric.Type TYPE = TreeMetric.Type.KENDALL_COLIJN;
    private Tree focalTree;
    private int dim;
    private double[] focalSmallM;
    private double[] focalLargeM;
    private final boolean fixedFocalTree;
    private final double lambda;

    public KendallColijnPathDifferenceMetric(double d) {
        this.lambda = d;
        this.fixedFocalTree = false;
    }

    public KendallColijnPathDifferenceMetric(double d, Tree tree) {
        this.lambda = d;
        this.focalTree = tree;
        this.fixedFocalTree = true;
        this.dim = tree.getExternalNodeCount() * tree.getExternalNodeCount();
        this.focalSmallM = new double[this.dim];
        this.focalLargeM = new double[this.dim];
        this.traverse(tree, tree.getRoot(), 0.0, 0, this.focalLargeM, this.focalSmallM);
    }

    @Override
    public double getMetric(Tree tree, Tree tree2) {
        TreeMetric.Utils.checkTreeTaxa(tree, tree2);
        if (tree != this.focalTree) {
            if (this.fixedFocalTree) {
                throw new RuntimeException("Focal tree is different from that set in the constructor.");
            }
            this.focalTree = tree;
            if (this.focalSmallM == null) {
                this.dim = this.focalTree.getExternalNodeCount() * this.focalTree.getExternalNodeCount();
                this.focalSmallM = new double[this.dim];
                this.focalLargeM = new double[this.dim];
            }
            this.traverse(this.focalTree, this.focalTree.getRoot(), 0.0, 0, this.focalLargeM, this.focalSmallM);
        }
        double[] dArray = new double[this.dim];
        double[] dArray2 = new double[this.dim];
        this.traverse(tree2, tree2.getRoot(), 0.0, 0, dArray2, dArray);
        ArrayList arrayList = new ArrayList();
        int n = tree.getExternalNodeCount();
        return this.calculateMetric(this.focalSmallM, this.focalLargeM, dArray, dArray2, n, this.lambda);
    }

    private double calculateMetric(double[] dArray, double[] dArray2, double[] dArray3, double[] dArray4, int n, double d) {
        double d2 = 0.0;
        boolean bl = false;
        for (int i = 0; i < n; ++i) {
            for (int j = i; j < n; ++j) {
                int n2 = i * n + j;
                double d3 = (1.0 - d) * dArray[n2] + d * dArray2[n2];
                double d4 = (1.0 - d) * dArray3[n2] + d * dArray4[n2];
                d2 += Math.pow(d3 - d4, 2.0);
            }
        }
        return Math.sqrt(d2);
    }

    private Set<NodeRef> traverse(Tree tree, NodeRef nodeRef, double d, int n, double[] dArray, double[] dArray2) {
        int n2;
        NodeRef nodeRef2 = tree.getChild(nodeRef, 0);
        NodeRef nodeRef3 = tree.getChild(nodeRef, 1);
        Set<NodeRef> set = null;
        Set<NodeRef> set2 = null;
        if (!tree.isExternal(nodeRef2)) {
            set = this.traverse(tree, nodeRef2, d + tree.getBranchLength(nodeRef2), n + 1, dArray, dArray2);
        } else {
            set = Collections.singleton(nodeRef2);
            n2 = nodeRef2.getNumber() * tree.getExternalNodeCount() + nodeRef2.getNumber();
            dArray[n2] = tree.getBranchLength(nodeRef2);
            dArray2[n2] = 1.0;
        }
        if (!tree.isExternal(nodeRef3)) {
            set2 = this.traverse(tree, nodeRef3, d + tree.getBranchLength(nodeRef3), n + 1, dArray, dArray2);
        } else {
            set2 = Collections.singleton(nodeRef3);
            n2 = nodeRef3.getNumber() * tree.getExternalNodeCount() + nodeRef3.getNumber();
            dArray[n2] = tree.getBranchLength(nodeRef3);
            dArray2[n2] = 1.0;
        }
        for (NodeRef nodeRef4 : set) {
            for (NodeRef nodeRef5 : set2) {
                int n3 = nodeRef4.getNumber() < nodeRef5.getNumber() ? nodeRef4.getNumber() * tree.getExternalNodeCount() + nodeRef5.getNumber() : nodeRef5.getNumber() * tree.getExternalNodeCount() + nodeRef4.getNumber();
                dArray[n3] = d;
                dArray2[n3] = n;
            }
        }
        HashSet hashSet = new HashSet();
        hashSet.addAll(set);
        hashSet.addAll(set2);
        return hashSet;
    }

    @Deprecated
    public ArrayList<Double> getMetric_old(Tree tree, ArrayList<Double> arrayList) {
        Serializable serializable;
        Object object;
        int n;
        int n2;
        if (this.focalTree.getExternalNodeCount() != tree.getExternalNodeCount()) {
            throw new RuntimeException("Different number of taxa in both trees.");
        }
        for (int i = 0; i < this.focalTree.getExternalNodeCount(); ++i) {
            if (this.focalTree.getNodeTaxon(this.focalTree.getExternalNode(i)).getId().equals(tree.getNodeTaxon(tree.getExternalNode(i)).getId())) continue;
            throw new RuntimeException("Mismatch between taxa in both trees: " + this.focalTree.getNodeTaxon(this.focalTree.getExternalNode(i)).getId() + " vs. " + tree.getNodeTaxon(tree.getExternalNode(i)).getId());
        }
        double[] dArray = new double[this.dim];
        double[] dArray2 = new double[this.dim];
        int n3 = 0;
        for (n2 = 0; n2 < tree.getExternalNodeCount(); ++n2) {
            for (n = n2 + 1; n < tree.getExternalNodeCount(); ++n) {
                NodeRef nodeRef = tree.getExternalNode(n2);
                object = tree.getExternalNode(n);
                serializable = TreeUtils.getCommonAncestor(tree, nodeRef, (NodeRef)object);
                int n4 = 0;
                double d = 0.0;
                while (serializable != tree.getRoot()) {
                    ++n4;
                    d += tree.getNodeHeight(tree.getParent((NodeRef)serializable)) - tree.getNodeHeight((NodeRef)serializable);
                    serializable = tree.getParent((NodeRef)serializable);
                }
                dArray[n3] = n4;
                dArray2[n3] = d;
                ++n3;
            }
        }
        n2 = tree.getExternalNodeCount();
        n = (n2 - 2) * (n2 - 1) + n2;
        n3 = 0;
        for (int i = (n2 - 1) * (n2 - 2); i < n; ++i) {
            dArray[i] = 1.0;
            dArray2[i] = tree.getNodeHeight(tree.getParent(tree.getExternalNode(n3))) - tree.getNodeHeight(tree.getExternalNode(n3));
            ++n3;
        }
        double[] dArray3 = new double[this.dim];
        object = new double[this.dim];
        serializable = new ArrayList();
        for (Double d : arrayList) {
            double d2 = 0.0;
            for (int i = 0; i < this.dim; ++i) {
                dArray3[i] = (1.0 - d) * this.focalSmallM[i] + d * this.focalLargeM[i];
                object[i] = (1.0 - d) * dArray[i] + d * dArray2[i];
                d2 += Math.pow(dArray3[i] - object[i], 2.0);
            }
            d2 = Math.sqrt(d2);
            ((ArrayList)serializable).add(d2);
        }
        return serializable;
    }

    @Deprecated
    public ArrayList<Double> getMetric_old(Tree tree, Tree tree2, ArrayList<Double> arrayList) {
        int n;
        Serializable serializable;
        int n2;
        int n3;
        int n4;
        int n5 = (tree.getExternalNodeCount() - 2) * (tree.getExternalNodeCount() - 1) + tree.getExternalNodeCount();
        double[] dArray = new double[n5];
        double[] dArray2 = new double[n5];
        double[] dArray3 = new double[n5];
        double[] dArray4 = new double[n5];
        if (tree.getExternalNodeCount() != tree2.getExternalNodeCount()) {
            throw new RuntimeException("Different number of taxa in both trees.");
        }
        for (n4 = 0; n4 < tree.getExternalNodeCount(); ++n4) {
            if (tree.getNodeTaxon(tree.getExternalNode(n4)).getId().equals(tree2.getNodeTaxon(tree2.getExternalNode(n4)).getId())) continue;
            throw new RuntimeException("Mismatch between taxa in both trees: " + tree.getNodeTaxon(tree.getExternalNode(n4)).getId() + " vs. " + tree2.getNodeTaxon(tree2.getExternalNode(n4)).getId());
        }
        n4 = 0;
        for (n3 = 0; n3 < tree.getExternalNodeCount(); ++n3) {
            for (n2 = n3 + 1; n2 < tree.getExternalNodeCount(); ++n2) {
                NodeRef nodeRef = tree.getExternalNode(n3);
                NodeRef nodeRef2 = tree.getExternalNode(n2);
                serializable = TreeUtils.getCommonAncestor(tree, nodeRef, nodeRef2);
                int n6 = 0;
                double d = 0.0;
                while (serializable != tree.getRoot()) {
                    ++n6;
                    d += tree.getNodeHeight(tree.getParent((NodeRef)serializable)) - tree.getNodeHeight((NodeRef)serializable);
                    serializable = tree.getParent((NodeRef)serializable);
                }
                dArray[n4] = n6;
                dArray2[n4] = d;
                ++n4;
            }
        }
        n3 = tree2.getExternalNodeCount();
        n2 = (n3 - 2) * (n3 - 1) + n3;
        n4 = 0;
        for (n = (n3 - 1) * (n3 - 2); n < n2; ++n) {
            dArray[n] = 1.0;
            dArray2[n] = tree.getNodeHeight(tree.getParent(tree.getExternalNode(n4))) - tree.getNodeHeight(tree.getExternalNode(n4));
            ++n4;
        }
        n4 = 0;
        for (n = 0; n < tree2.getExternalNodeCount(); ++n) {
            for (int i = n + 1; i < tree2.getExternalNodeCount(); ++i) {
                serializable = tree2.getExternalNode(n);
                NodeRef nodeRef = tree2.getExternalNode(i);
                NodeRef nodeRef3 = TreeUtils.getCommonAncestor(tree2, (NodeRef)serializable, nodeRef);
                int n7 = 0;
                double d = 0.0;
                while (nodeRef3 != tree2.getRoot()) {
                    ++n7;
                    d += tree2.getNodeHeight(tree2.getParent(nodeRef3)) - tree2.getNodeHeight(nodeRef3);
                    nodeRef3 = tree2.getParent(nodeRef3);
                }
                dArray3[n4] = n7;
                dArray4[n4] = d;
                ++n4;
            }
        }
        n4 = 0;
        for (n = (n3 - 1) * (n3 - 2); n < n2; ++n) {
            dArray3[n] = 1.0;
            dArray4[n] = tree2.getNodeHeight(tree2.getParent(tree2.getExternalNode(n4))) - tree2.getNodeHeight(tree2.getExternalNode(n4));
            ++n4;
        }
        double[] dArray5 = new double[n5];
        double[] dArray6 = new double[n5];
        serializable = new ArrayList();
        for (Double d : arrayList) {
            double d2 = 0.0;
            for (int i = 0; i < n5; ++i) {
                dArray5[i] = (1.0 - d) * dArray[i] + d * dArray2[i];
                dArray6[i] = (1.0 - d) * dArray3[i] + d * dArray4[i];
                d2 += Math.pow(dArray5[i] - dArray6[i], 2.0);
            }
            d2 = Math.sqrt(d2);
            ((ArrayList)serializable).add(d2);
        }
        return serializable;
    }

    public static void main(String[] stringArray) {
        try {
            NewickImporter newickImporter = new NewickImporter("(('A':1.2,'B':0.8):0.5,('C':0.8,'D':1.0):1.1)");
            Tree tree = newickImporter.importNextTree();
            System.out.println("4-taxa tree 1: " + tree);
            newickImporter = new NewickImporter("((('A':0.8,'B':1.4):0.3,'C':0.7):0.9,'D':1.0)");
            Tree tree2 = newickImporter.importNextTree();
            System.out.println("4-taxa tree 2: " + tree2);
            System.out.println();
            double[] dArray = new double[]{new KendallColijnPathDifferenceMetric(0.0).getMetric(tree, tree2), new KendallColijnPathDifferenceMetric(0.5).getMetric(tree, tree2), new KendallColijnPathDifferenceMetric(1.0).getMetric(tree, tree2)};
            System.out.println("Paired trees:");
            System.out.println("lambda (0.0) = " + dArray[0]);
            System.out.println("lambda (0.5) = " + dArray[1]);
            System.out.println("lambda (1.0) = " + dArray[2]);
            System.out.println();
            dArray = new double[]{new KendallColijnPathDifferenceMetric(0.0, tree).getMetric(tree, tree2), new KendallColijnPathDifferenceMetric(0.5, tree).getMetric(tree, tree2), new KendallColijnPathDifferenceMetric(1.0, tree).getMetric(tree, tree2)};
            System.out.println("Focal trees:");
            System.out.println("lambda (0.0) = " + dArray[0]);
            System.out.println("lambda (0.5) = " + dArray[1]);
            System.out.println("lambda (1.0) = " + dArray[2]);
            System.out.println();
            System.out.println();
            newickImporter = new NewickImporter("(((('A':0.6,'B':0.6):0.1,'C':0.5):0.4,'D':0.7):0.1,'E':1.3)");
            tree = newickImporter.importNextTree();
            System.out.println("5-taxa tree 1: " + tree);
            newickImporter = new NewickImporter("((('A':0.8,'B':1.4):0.1,'C':0.7):0.2,('D':1.0,'E':0.9):1.3)");
            tree2 = newickImporter.importNextTree();
            System.out.println("5-taxa tree 2: " + tree2);
            System.out.println();
            dArray = new double[]{new KendallColijnPathDifferenceMetric(0.0, tree).getMetric(tree, tree2), new KendallColijnPathDifferenceMetric(0.5, tree).getMetric(tree, tree2), new KendallColijnPathDifferenceMetric(1.0, tree).getMetric(tree, tree2)};
            System.out.println("Paired trees:");
            System.out.println("lambda (0.0) = " + dArray[0]);
            System.out.println("lambda (0.5) = " + dArray[1]);
            System.out.println("lambda (1.0) = " + dArray[2]);
            System.out.println();
            dArray = new double[]{new KendallColijnPathDifferenceMetric(0.0, tree).getMetric(tree, tree2), new KendallColijnPathDifferenceMetric(0.5, tree).getMetric(tree, tree2), new KendallColijnPathDifferenceMetric(1.0, tree).getMetric(tree, tree2)};
            System.out.println("Focal trees:");
            System.out.println("lambda (0.0) = " + dArray[0]);
            System.out.println("lambda (0.5) = " + dArray[1]);
            System.out.println("lambda (1.0) = " + dArray[2]);
            System.out.println();
            long l = System.currentTimeMillis();
            for (int i = 0; i < 1000000; ++i) {
                new KendallColijnPathDifferenceMetric(0.5).getMetric(tree, tree2);
            }
            System.out.println("New algorithm, 1M reps: " + (System.currentTimeMillis() - l) + " ms");
        }
        catch (Importer.ImportException importException) {
            System.err.println(importException);
        }
        catch (IOException iOException) {
            System.err.println(iOException);
        }
    }

    @Override
    public TreeMetric.Type getType() {
        return TYPE;
    }

    public String toString() {
        return this.getType().getShortName() + "(" + this.lambda + ")";
    }
}

