/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.factorAnalysis;

import dr.inference.distribution.DistributionLikelihood;
import dr.inference.distribution.LatentFactorModelInterface;
import dr.inference.distribution.MomentDistributionModel;
import dr.inference.model.Likelihood;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.PathDependent;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.MathUtils;
import dr.math.distributions.NormalDistribution;
import dr.math.matrixAlgebra.SymmetricMatrix;

public class LoadingsGibbsTruncatedOperator
extends SimpleMCMCOperator
implements PathDependent,
GibbsOperator {
    Likelihood prior;
    LatentFactorModelInterface LFM;
    double[][] precisionArray;
    double[] meanMidArray;
    double[] meanArray;
    boolean randomScan;
    double pathParameter = 1.0;
    final Parameter missingIndicator;
    double priorPrecision;
    double priorMeanPrecision;
    MatrixParameterInterface loadings;
    DistributionLikelihood cutoffPrior;

    public LoadingsGibbsTruncatedOperator(LatentFactorModelInterface latentFactorModelInterface, Likelihood likelihood, double d, boolean bl, MatrixParameterInterface matrixParameterInterface, DistributionLikelihood distributionLikelihood) {
        this.setWeight(d);
        this.loadings = matrixParameterInterface;
        this.prior = likelihood;
        this.LFM = latentFactorModelInterface;
        if (likelihood instanceof MomentDistributionModel) {
            this.priorPrecision = ((MomentDistributionModel)this.prior).getScaleMatrix()[0][0];
            this.priorMeanPrecision = ((MomentDistributionModel)this.prior).getMean()[0] * this.priorPrecision;
        } else if (likelihood instanceof DistributionLikelihood) {
            this.priorPrecision = 1.0 / ((DistributionLikelihood)this.prior).getDistribution().variance();
            this.priorMeanPrecision = ((DistributionLikelihood)this.prior).getDistribution().mean() * this.priorPrecision;
        }
        this.cutoffPrior = distributionLikelihood;
        this.missingIndicator = latentFactorModelInterface.getMissingIndicator();
    }

    private void getPrecisionOfTruncated(MatrixParameterInterface matrixParameterInterface, int n, int n2, double[][] dArray) {
        int n3 = matrixParameterInterface.getColumnDimension();
        for (int i = 0; i < n; ++i) {
            for (int j = i; j < n; ++j) {
                double d = 0.0;
                for (int k = 0; k < n3; ++k) {
                    if (this.missingIndicator != null && this.missingIndicator.getParameterValue(k * this.LFM.getScaledData().getRowDimension() + n2) == 1.0) continue;
                    d += matrixParameterInterface.getParameterValue(i, k) * matrixParameterInterface.getParameterValue(j, k);
                }
                dArray[i][j] = d * this.LFM.getColumnPrecision().getParameterValue(n2, n2);
                if (i == j) {
                    dArray[i][j] = dArray[i][j] * this.pathParameter + this.priorPrecision;
                    continue;
                }
                double[] dArray2 = dArray[i];
                int n4 = j;
                dArray2[n4] = dArray2[n4] * this.pathParameter;
                dArray[j][i] = dArray[i][j];
            }
        }
    }

    private void getTruncatedMean(int n, int n2, double[][] dArray, double[] dArray2, double[] dArray3) {
        int n3;
        double d;
        int n4;
        MatrixParameterInterface matrixParameterInterface = this.LFM.getScaledData();
        MatrixParameterInterface matrixParameterInterface2 = this.LFM.getFactors();
        int n5 = matrixParameterInterface.getColumnDimension();
        for (n4 = 0; n4 < n; ++n4) {
            d = 0.0;
            for (n3 = 0; n3 < n5; ++n3) {
                if (this.missingIndicator != null && this.missingIndicator.getParameterValue(n3 * this.LFM.getScaledData().getRowDimension() + n2) == 1.0) continue;
                d += matrixParameterInterface2.getParameterValue(n4, n3) * matrixParameterInterface.getParameterValue(n2, n3);
            }
            d *= this.LFM.getColumnPrecision().getParameterValue(n2, n2);
            dArray2[n4] = d += this.priorMeanPrecision;
        }
        for (n4 = 0; n4 < n; ++n4) {
            d = 0.0;
            for (n3 = 0; n3 < n; ++n3) {
                d += dArray[n4][n3] * dArray2[n3];
            }
            dArray3[n4] = d;
        }
    }

    private void getPrecision(int n, double[][] dArray) {
        int n2 = this.loadings.getColumnDimension();
        this.getPrecisionOfTruncated(this.LFM.getFactors(), n2, n, dArray);
    }

    private void getMean(int n, double[][] dArray, double[] dArray2, double[] dArray3) {
        int n2 = this.loadings.getColumnDimension();
        this.getTruncatedMean(n2, n, dArray, dArray2, dArray3);
        int n3 = 0;
        while (n3 < dArray3.length) {
            int n4 = n3++;
            dArray3[n4] = dArray3[n4] * this.pathParameter;
        }
    }

    private void copy(int n, double[] dArray) {
        MatrixParameterInterface matrixParameterInterface = this.loadings;
        for (int i = 0; i < dArray.length; ++i) {
            matrixParameterInterface.setParameterValueQuietly(n, i, dArray[i]);
        }
    }

    private double getTruncatedDraw(int n, int n2, NormalDistribution normalDistribution, boolean bl) {
        MatrixParameterInterface matrixParameterInterface = (MatrixParameterInterface)((MomentDistributionModel)this.prior).getCutoff();
        double d = -Math.sqrt(matrixParameterInterface.getParameterValue(n, n2));
        double d2 = -d;
        double d3 = normalDistribution.cdf(d);
        double d4 = normalDistribution.cdf(d2);
        double d5 = d3 / (d3 + (1.0 - d4));
        double d6 = 0.0;
        if (bl) {
            int n3;
            for (n3 = 0; (d6 < d2 && d6 > d || Double.isNaN(d6)) && n3 < 10; ++n3) {
                double d7 = MathUtils.nextDouble();
                if (d7 < d5) {
                    d6 = MathUtils.nextDouble() * d3;
                    d6 = normalDistribution.quantile(d6);
                    continue;
                }
                d6 = MathUtils.nextDouble() * (1.0 - d4) + d4;
                d6 = normalDistribution.quantile(d6);
            }
            if (n3 < 10) {
                this.loadings.setParameterValue(n, n2, d6);
            }
        } else {
            d6 = this.loadings.getParameterValue(n, n2);
        }
        double d8 = Double.isNaN(d6) || Double.isNaN(Math.log(1.0 - (d4 - d3))) ? Double.NEGATIVE_INFINITY : normalDistribution.logPdf(d6) - Math.log(1.0 - (d4 - d3));
        return d8;
    }

    public double drawI(int n, int n2, boolean bl) {
        NormalDistribution normalDistribution;
        Object var4_4 = null;
        this.precisionArray = new double[this.loadings.getColumnDimension()][this.loadings.getColumnDimension()];
        this.meanMidArray = new double[this.loadings.getColumnDimension()];
        this.meanArray = new double[this.loadings.getColumnDimension()];
        Object var6_5 = null;
        this.getPrecision(n, this.precisionArray);
        if (this.LFM.getLoadings().getParameterValue(n, n2) != 0.0) {
            double[][] dArray = new SymmetricMatrix(this.precisionArray).inverse().toComponents();
            this.getMean(n, dArray, this.meanMidArray, this.meanArray);
            normalDistribution = this.LFM.getFactorDimension() != 1 ? this.getConditionalDistribution(this.meanArray, dArray, n2, n) : new NormalDistribution(this.meanArray[0], Math.sqrt(dArray[0][0]));
        } else {
            normalDistribution = new NormalDistribution(0.0, Math.sqrt(1.0 / this.priorPrecision));
        }
        double d = 0.0;
        if (this.prior instanceof MomentDistributionModel) {
            d = MathUtils.nextDouble() < 0.5 ? this.getTruncatedDraw(n, n2, normalDistribution, bl) : this.getTruncatedDraw(n, n2, normalDistribution, bl);
        } else {
            this.loadings.setParameterValue(n, n2, normalDistribution.quantile(MathUtils.nextDouble()));
        }
        return d;
    }

    private NormalDistribution getConditionalDistribution(double[] dArray, double[][] dArray2, int n, int n2) {
        int n3;
        int n4;
        double[][] dArray3 = new double[dArray.length - 1][dArray.length - 1];
        for (int i = 0; i < dArray.length; ++i) {
            for (int j = 0; j < dArray.length; ++j) {
                if (i < n && j < n) {
                    dArray3[i][j] = dArray2[i][j];
                    continue;
                }
                if (i < n && j > n) {
                    dArray3[i][j - 1] = dArray2[i][j];
                    continue;
                }
                if (i > n && j < n) {
                    dArray3[i - 1][j] = dArray2[i][j];
                    continue;
                }
                if (i <= n || j <= n) continue;
                dArray3[i - 1][j - 1] = dArray2[i][j];
            }
        }
        double[][] dArray4 = new SymmetricMatrix(dArray3).inverse().toComponents();
        double[] dArray5 = new double[dArray.length - 1];
        double[] dArray6 = new double[dArray.length - 1];
        double[] dArray7 = new double[dArray.length - 1];
        for (n4 = 0; n4 < dArray.length; ++n4) {
            if (n4 < n) {
                dArray5[n4] = this.LFM.getLoadings().getParameterValue(n2, n4) - dArray[n4];
                continue;
            }
            if (n4 <= n) continue;
            dArray5[n4 - 1] = this.LFM.getLoadings().getParameterValue(n2, n4) - dArray[n4];
        }
        for (n4 = 0; n4 < dArray.length - 1; ++n4) {
            for (int i = 0; i < dArray.length - 1; ++i) {
                int n5 = n4;
                dArray6[n5] = dArray6[n5] + dArray4[n4][i] * dArray5[i];
            }
        }
        double d = dArray[n];
        for (n3 = 0; n3 < dArray.length - 1; ++n3) {
            if (n3 < n) {
                d += dArray6[n3] * dArray2[n3][n];
                continue;
            }
            d += dArray6[n3] * dArray2[n3 + 1][n];
        }
        for (n3 = 0; n3 < dArray.length - 1; ++n3) {
            for (int i = 0; i < dArray.length - 1; ++i) {
                if (n3 < n) {
                    int n6 = n3;
                    dArray7[n6] = dArray7[n6] + dArray4[n3][i] * dArray2[i][n];
                    continue;
                }
                int n7 = n3;
                dArray7[n7] = dArray7[n7] + dArray4[n3][i] * dArray2[i + 1][n];
            }
        }
        double d2 = dArray2[n][n];
        for (int i = 0; i < dArray.length - 1; ++i) {
            if (i < n) {
                d2 -= dArray7[i] * dArray2[i][n];
                continue;
            }
            d2 -= dArray7[i] * dArray2[i + 1][n];
        }
        return new NormalDistribution(d, Math.sqrt(d2));
    }

    void getCutoffDraw(int n, int n2, NormalDistribution normalDistribution) {
        double d = Math.abs(this.loadings.getParameterValue(n, n2));
        double d2 = MathUtils.nextDouble() * d;
        double d3 = Math.sqrt(((MatrixParameterInterface)((MomentDistributionModel)this.prior).getCutoff()).getParameterValue(n, n2));
        double d4 = this.cutoffPrior.getDistribution().pdf(Math.pow(d2, 2.0)) / (1.0 - (normalDistribution.cdf(d2) - normalDistribution.cdf(-d2)));
        double d5 = this.cutoffPrior.getDistribution().pdf(Math.pow(d3, 2.0)) / (1.0 - (normalDistribution.cdf(d3) - normalDistribution.cdf(-d3)));
        if (MathUtils.nextDouble() < d4 / d5) {
            ((MatrixParameterInterface)((MomentDistributionModel)this.prior).getCutoff()).setParameterValue(n, n2, Math.pow(d2, 2.0));
        }
    }

    public int getStepCount() {
        return 0;
    }

    @Override
    public String getOperatorName() {
        return "loadingsGibbsTruncatedOperator";
    }

    @Override
    public double doOperation() {
        int n = this.LFM.getLoadings().getRowDimension();
        int n2 = MathUtils.nextInt(this.LFM.getLoadings().getColumnDimension());
        for (int i = 0; i < n; ++i) {
            this.drawI(i, n2, true);
        }
        this.loadings.fireParameterChangedEvent();
        return 0.0;
    }

    @Override
    public void setPathParameter(double d) {
        this.pathParameter = d;
    }
}

