/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.continuous.hmc;

import dr.evomodel.continuous.hmc.IntegratedLoadingsGradient;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitPartialsProvider;
import dr.evomodel.treedatalikelihood.continuous.IntegratedFactorAnalysisLikelihood;
import dr.evomodel.treedatalikelihood.preorder.WrappedNormalSufficientStatistics;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Parameter;
import dr.math.matrixAlgebra.ReadableMatrix;
import dr.math.matrixAlgebra.ReadableVector;
import dr.util.TaskPool;

public class IntegratedLoadingsAndPrecisionGradient
extends IntegratedLoadingsGradient {
    CompoundParameter jointParameter;

    public IntegratedLoadingsAndPrecisionGradient(CompoundParameter compoundParameter, TreeDataLikelihood treeDataLikelihood, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, IntegratedFactorAnalysisLikelihood integratedFactorAnalysisLikelihood, ContinuousTraitPartialsProvider continuousTraitPartialsProvider, TaskPool taskPool, IntegratedLoadingsGradient.ThreadUseProvider threadUseProvider, IntegratedLoadingsGradient.RemainderCompProvider remainderCompProvider) {
        super(treeDataLikelihood, continuousDataLikelihoodDelegate, integratedFactorAnalysisLikelihood, continuousTraitPartialsProvider, taskPool, threadUseProvider, remainderCompProvider);
        this.jointParameter = compoundParameter;
    }

    @Override
    public Parameter getParameter() {
        return this.jointParameter;
    }

    @Override
    protected int getGradientDimension() {
        return this.dimFactors * this.dimTrait + this.dimTrait;
    }

    @Override
    public int getDimension() {
        return this.getGradientDimension();
    }

    private void computePrecisionGradientForOneTaxon(int n, int n2, IntegratedLoadingsGradient.GradientComponents gradientComponents, double[] dArray, double[] dArray2, double[][] dArray3, int n3) {
        double[] dArray4 = gradientComponents.fty;
        double[] dArray5 = gradientComponents.ftfl;
        for (int i = 0; i < this.dimTrait; ++i) {
            int n4 = n2 * this.dimTrait + i;
            if (this.factorAnalysisLikelihood.getDataMissingIndicators()[n4]) continue;
            double d = this.data[n4];
            double[] dArray6 = dArray3[n];
            int n5 = n3 + i;
            dArray6[n5] = dArray6[n5] + 0.5 * (1.0 / dArray2[i] - d * d);
            for (int j = 0; j < this.dimFactors; ++j) {
                int n6 = i * this.dimFactors + j;
                int n7 = j * this.dimTrait + i;
                double[] dArray7 = dArray3[n];
                int n8 = n3 + i;
                dArray7[n8] = dArray7[n8] + (dArray4[n7] - 0.5 * dArray5[n7]) * dArray[n6];
            }
        }
    }

    @Override
    protected void computeGradientForOneTaxon(int n, int n2, ReadableMatrix readableMatrix, double[] dArray, ReadableVector readableVector, double[] dArray2, WrappedNormalSufficientStatistics wrappedNormalSufficientStatistics, double[][] dArray3) {
        IntegratedLoadingsGradient.MeanAndMoment meanAndMoment = this.getMeanAndMoment(n2, wrappedNormalSufficientStatistics);
        IntegratedLoadingsGradient.GradientComponents gradientComponents = this.computeGradientComponents(n2, dArray, meanAndMoment);
        this.computeLoadingsGradientForOneTaxon(n, gradientComponents, dArray2, dArray3);
        this.computePrecisionGradientForOneTaxon(n, n2, gradientComponents, dArray, dArray2, dArray3, this.dimFactors * this.dimTrait);
    }
}

