/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.pnn;

import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.networks.training.pnn.CalculationCriteria;
import org.encog.neural.networks.training.pnn.DeriveMinimum;
import org.encog.neural.networks.training.pnn.GlobalMinimumSearch;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.neural.pnn.BasicPNN;
import org.encog.neural.pnn.PNNKernelType;
import org.encog.neural.pnn.PNNOutputMode;
import org.encog.util.EngineArray;

public class TrainBasicPNN
extends BasicTraining
implements CalculationCriteria {
    public static final double DEFAULT_MAX_ERROR = 0.05;
    public static final double DEFAULT_MIN_IMPROVEMENT = 1.0E-4;
    public static final double DEFAULT_SIGMA_LOW = 1.0E-4;
    public static final double DEFAULT_SIGMA_HIGH = 10.0;
    public static final int DEFAULT_NUM_SIGMAS = 10;
    private double[] v;
    private double[] w;
    private double[] dsqr;
    private final BasicPNN network;
    private final MLDataSet training;
    private double maxError;
    private double minImprovement;
    private double sigmaLow;
    private double sigmaHigh;
    private int numSigmas;
    private boolean samplesLoaded;

    public TrainBasicPNN(BasicPNN network, MLDataSet training) {
        super(TrainingImplementationType.OnePass);
        this.network = network;
        this.training = training;
        this.maxError = 0.05;
        this.minImprovement = 1.0E-4;
        this.sigmaLow = 1.0E-4;
        this.sigmaHigh = 10.0;
        this.numSigmas = 10;
        this.samplesLoaded = false;
    }

    @Override
    public double calcErrorWithMultipleSigma(double[] x, double[] der1, double[] der2, boolean der) {
        int ivar;
        for (ivar = 0; ivar < this.network.getInputCount(); ++ivar) {
            this.network.getSigma()[ivar] = x[ivar];
        }
        if (!der) {
            return this.calculateError(this.network.getSamples(), false);
        }
        double err = this.calculateError(this.network.getSamples(), true);
        for (ivar = 0; ivar < this.network.getInputCount(); ++ivar) {
            der1[ivar] = this.network.getDeriv()[ivar];
            der2[ivar] = this.network.getDeriv2()[ivar];
        }
        return err;
    }

    @Override
    public double calcErrorWithSingleSigma(double sig) {
        for (int ivar = 0; ivar < this.network.getInputCount(); ++ivar) {
            this.network.getSigma()[ivar] = sig;
        }
        return this.calculateError(this.network.getSamples(), false);
    }

    public double calculateError(MLDataSet training, boolean deriv) {
        int i;
        double totErr = 0.0;
        if (deriv) {
            int num = this.network.isSeparateClass() ? this.network.getInputCount() * this.network.getOutputCount() : this.network.getInputCount();
            for (int i2 = 0; i2 < num; ++i2) {
                this.network.getDeriv()[i2] = 0.0;
                this.network.getDeriv2()[i2] = 0.0;
            }
        }
        this.network.setExclude((int)training.getRecordCount());
        MLDataPair pair = BasicMLDataPair.createPair(training.getInputSize(), training.getIdealSize());
        double[] out = new double[this.network.getOutputCount()];
        int r = 0;
        while ((long)r < training.getRecordCount()) {
            double diff;
            int i3;
            training.getRecord(r, pair);
            this.network.setExclude(this.network.getExclude() - 1);
            double err = 0.0;
            MLData input = pair.getInput();
            MLData target = pair.getIdeal();
            if (this.network.getOutputMode() == PNNOutputMode.Unsupervised) {
                int z;
                MLData output;
                if (deriv) {
                    output = this.computeDeriv(input, target);
                    for (z = 0; z < this.network.getOutputCount(); ++z) {
                        out[z] = output.getData(z);
                    }
                } else {
                    output = this.network.compute(input);
                    for (z = 0; z < this.network.getOutputCount(); ++z) {
                        out[z] = output.getData(z);
                    }
                }
                for (i3 = 0; i3 < this.network.getOutputCount(); ++i3) {
                    diff = input.getData(i3) - out[i3];
                    err += diff * diff;
                }
            } else if (this.network.getOutputMode() == PNNOutputMode.Classification) {
                int tclass = (int)target.getData(0);
                MLData output = deriv ? this.computeDeriv(input, pair.getIdeal()) : this.network.compute(input);
                EngineArray.arrayCopy(output.getData(), out);
                for (int i4 = 0; i4 < out.length; ++i4) {
                    if (i4 == tclass) {
                        diff = 1.0 - out[i4];
                        err += diff * diff;
                        continue;
                    }
                    err += out[i4] * out[i4];
                }
            } else if (this.network.getOutputMode() == PNNOutputMode.Regression) {
                int z;
                MLData output;
                if (deriv) {
                    output = this.network.compute(input);
                    for (z = 0; z < this.network.getOutputCount(); ++z) {
                        out[z] = output.getData(z);
                    }
                } else {
                    output = this.network.compute(input);
                    for (z = 0; z < this.network.getOutputCount(); ++z) {
                        out[z] = output.getData(z);
                    }
                }
                for (i3 = 0; i3 < this.network.getOutputCount(); ++i3) {
                    diff = target.getData(i3) - out[i3];
                    err += diff * diff;
                }
            }
            totErr += err;
            ++r;
        }
        this.network.setExclude(-1);
        this.network.setError(totErr / (double)training.getRecordCount());
        if (deriv) {
            i = 0;
            while (i < this.network.getDeriv().length) {
                double[] dArray = this.network.getDeriv();
                int n = i;
                dArray[n] = dArray[n] / (double)training.getRecordCount();
                double[] dArray2 = this.network.getDeriv2();
                int n2 = i++;
                dArray2[n2] = dArray2[n2] / (double)training.getRecordCount();
            }
        }
        if (this.network.getOutputMode() == PNNOutputMode.Unsupervised || this.network.getOutputMode() == PNNOutputMode.Regression) {
            this.network.setError(this.network.getError() / (double)this.network.getOutputCount());
            if (deriv) {
                i = 0;
                while (i < this.network.getInputCount()) {
                    double[] dArray = this.network.getDeriv();
                    int n = i;
                    dArray[n] = dArray[n] / (double)this.network.getOutputCount();
                    double[] dArray3 = this.network.getDeriv2();
                    int n3 = i++;
                    dArray3[n3] = dArray3[n3] / (double)this.network.getOutputCount();
                }
            }
        }
        return this.network.getError();
    }

    @Override
    public boolean canContinue() {
        return false;
    }

    public MLData computeDeriv(MLData input, MLData target) {
        int outvar;
        double temp;
        int ivar;
        int pop;
        boolean ibest = false;
        int vsptr = 0;
        int wsptr = 0;
        double[] out = new double[this.network.getOutputCount()];
        for (pop = 0; pop < this.network.getOutputCount(); ++pop) {
            out[pop] = 0.0;
            for (ivar = 0; ivar < this.network.getInputCount(); ++ivar) {
                this.v[pop * this.network.getInputCount() + ivar] = 0.0;
                this.w[pop * this.network.getInputCount() + ivar] = 0.0;
            }
        }
        double psum = 0.0;
        if (this.network.getOutputMode() != PNNOutputMode.Classification) {
            vsptr = this.network.getOutputCount() * this.network.getInputCount();
            wsptr = this.network.getOutputCount() * this.network.getInputCount();
            for (ivar = 0; ivar < this.network.getInputCount(); ++ivar) {
                this.v[vsptr + ivar] = 0.0;
                this.w[wsptr + ivar] = 0.0;
            }
        }
        MLDataPair pair = BasicMLDataPair.createPair(this.network.getSamples().getInputSize(), this.network.getSamples().getIdealSize());
        int r = 0;
        while ((long)r < this.network.getSamples().getRecordCount()) {
            this.network.getSamples().getRecord(r, pair);
            if (r != this.network.getExclude()) {
                int wptr;
                int vptr;
                double dist = 0.0;
                for (ivar = 0; ivar < this.network.getInputCount(); ++ivar) {
                    double diff = input.getData(ivar) - pair.getInput().getData(ivar);
                    this.dsqr[ivar] = (diff /= this.network.getSigma()[ivar]) * diff;
                    dist += this.dsqr[ivar];
                }
                if (this.network.getKernel() == PNNKernelType.Gaussian) {
                    dist = Math.exp(-dist);
                } else if (this.network.getKernel() == PNNKernelType.Reciprocal) {
                    dist = 1.0 / (1.0 + dist);
                }
                double truedist = dist;
                if (dist < 1.0E-40) {
                    dist = 1.0E-40;
                }
                if (this.network.getOutputMode() == PNNOutputMode.Classification) {
                    int n = pop = (int)pair.getIdeal().getData(0);
                    out[n] = out[n] + dist;
                    vptr = pop * this.network.getInputCount();
                    wptr = pop * this.network.getInputCount();
                    for (ivar = 0; ivar < this.network.getInputCount(); ++ivar) {
                        temp = truedist * this.dsqr[ivar];
                        int n2 = vptr + ivar;
                        this.v[n2] = this.v[n2] + temp;
                        int n3 = wptr + ivar;
                        this.w[n3] = this.w[n3] + temp * (2.0 * this.dsqr[ivar] - 3.0);
                    }
                } else if (this.network.getOutputMode() == PNNOutputMode.Unsupervised) {
                    for (ivar = 0; ivar < this.network.getInputCount(); ++ivar) {
                        int n = ivar;
                        out[n] = out[n] + dist * pair.getInput().getData(ivar);
                        temp = truedist * this.dsqr[ivar];
                        int n4 = vsptr + ivar;
                        this.v[n4] = this.v[n4] + temp;
                        int n5 = wsptr + ivar;
                        this.w[n5] = this.w[n5] + temp * (2.0 * this.dsqr[ivar] - 3.0);
                    }
                    vptr = 0;
                    wptr = 0;
                    for (outvar = 0; outvar < this.network.getOutputCount(); ++outvar) {
                        for (ivar = 0; ivar < this.network.getInputCount(); ++ivar) {
                            temp = truedist * this.dsqr[ivar] * pair.getInput().getData(ivar);
                            int n = vptr++;
                            this.v[n] = this.v[n] + temp;
                            int n6 = wptr++;
                            this.w[n6] = this.w[n6] + temp * (2.0 * this.dsqr[ivar] - 3.0);
                        }
                    }
                    psum += dist;
                } else if (this.network.getOutputMode() == PNNOutputMode.Regression) {
                    for (ivar = 0; ivar < this.network.getOutputCount(); ++ivar) {
                        int n = ivar;
                        out[n] = out[n] + dist * pair.getIdeal().getData(ivar);
                    }
                    vptr = 0;
                    wptr = 0;
                    for (outvar = 0; outvar < this.network.getOutputCount(); ++outvar) {
                        for (ivar = 0; ivar < this.network.getInputCount(); ++ivar) {
                            temp = truedist * this.dsqr[ivar] * pair.getIdeal().getData(outvar);
                            int n = vptr++;
                            this.v[n] = this.v[n] + temp;
                            int n7 = wptr++;
                            this.w[n7] = this.w[n7] + temp * (2.0 * this.dsqr[ivar] - 3.0);
                        }
                    }
                    for (ivar = 0; ivar < this.network.getInputCount(); ++ivar) {
                        temp = truedist * this.dsqr[ivar];
                        int n = vsptr + ivar;
                        this.v[n] = this.v[n] + temp;
                        int n8 = wsptr + ivar;
                        this.w[n8] = this.w[n8] + temp * (2.0 * this.dsqr[ivar] - 3.0);
                    }
                    psum += dist;
                }
            }
            ++r;
        }
        if (this.network.getOutputMode() == PNNOutputMode.Classification) {
            psum = 0.0;
            for (pop = 0; pop < this.network.getOutputCount(); ++pop) {
                if (this.network.getPriors()[pop] >= 0.0) {
                    int n = pop;
                    out[n] = out[n] * (this.network.getPriors()[pop] / (double)this.network.getCountPer()[pop]);
                }
                psum += out[pop];
            }
            if (psum < 1.0E-40) {
                psum = 1.0E-40;
            }
        }
        pop = 0;
        while (pop < this.network.getOutputCount()) {
            int n = pop++;
            out[n] = out[n] / psum;
        }
        for (ivar = 0; ivar < this.network.getInputCount(); ++ivar) {
            double vtot;
            double wtot;
            if (this.network.getOutputMode() == PNNOutputMode.Classification) {
                wtot = 0.0;
                vtot = 0.0;
            } else {
                vtot = this.v[vsptr + ivar] * 2.0 / (psum * this.network.getSigma()[ivar]);
                wtot = this.w[wsptr + ivar] * 2.0 / (psum * this.network.getSigma()[ivar] * this.network.getSigma()[ivar]);
            }
            for (outvar = 0; outvar < this.network.getOutputCount(); ++outvar) {
                if (this.network.getOutputMode() == PNNOutputMode.Classification && this.network.getPriors()[outvar] >= 0.0) {
                    int n = outvar * this.network.getInputCount() + ivar;
                    this.v[n] = this.v[n] * (this.network.getPriors()[outvar] / (double)this.network.getCountPer()[outvar]);
                    int n9 = outvar * this.network.getInputCount() + ivar;
                    this.w[n9] = this.w[n9] * (this.network.getPriors()[outvar] / (double)this.network.getCountPer()[outvar]);
                }
                int n = outvar * this.network.getInputCount() + ivar;
                this.v[n] = this.v[n] * (2.0 / (psum * this.network.getSigma()[ivar]));
                int n10 = outvar * this.network.getInputCount() + ivar;
                this.w[n10] = this.w[n10] * (2.0 / (psum * this.network.getSigma()[ivar] * this.network.getSigma()[ivar]));
                if (this.network.getOutputMode() != PNNOutputMode.Classification) continue;
                vtot += this.v[outvar * this.network.getInputCount() + ivar];
                wtot += this.w[outvar * this.network.getInputCount() + ivar];
            }
            for (outvar = 0; outvar < this.network.getOutputCount(); ++outvar) {
                double der1 = this.v[outvar * this.network.getInputCount() + ivar] - out[outvar] * vtot;
                double der2 = this.w[outvar * this.network.getInputCount() + ivar] + 2.0 * out[outvar] * vtot * vtot - 2.0 * this.v[outvar * this.network.getInputCount() + ivar] * vtot - out[outvar] * wtot;
                temp = this.network.getOutputMode() == PNNOutputMode.Classification ? (outvar == (int)target.getData(0) ? 2.0 * (out[outvar] - 1.0) : 2.0 * out[outvar]) : 2.0 * (out[outvar] - target.getData(outvar));
                double[] dArray = this.network.getDeriv();
                int n = ivar;
                dArray[n] = dArray[n] + temp * der1;
                double[] dArray2 = this.network.getDeriv2();
                int n11 = ivar;
                dArray2[n11] = dArray2[n11] + (temp * der2 + 2.0 * der1 * der1);
            }
        }
        return new BasicMLData(out);
    }

    public double getMaxError() {
        return this.maxError;
    }

    @Override
    public MLMethod getMethod() {
        return this.network;
    }

    public double getMinImprovement() {
        return this.minImprovement;
    }

    public int getNumSigmas() {
        return this.numSigmas;
    }

    public double getSigmaHigh() {
        return this.sigmaHigh;
    }

    public double getSigmaLow() {
        return this.sigmaLow;
    }

    @Override
    public void iteration() {
        int i;
        this.preIteration();
        if (!this.samplesLoaded) {
            this.network.setSamples(new BasicMLDataSet(this.training));
            this.samplesLoaded = true;
        }
        GlobalMinimumSearch globalMinimum = new GlobalMinimumSearch();
        DeriveMinimum dermin = new DeriveMinimum();
        int k = this.network.getOutputMode() == PNNOutputMode.Classification ? this.network.getOutputCount() : this.network.getOutputCount() + 1;
        this.dsqr = new double[this.network.getInputCount()];
        this.v = new double[this.network.getInputCount() * k];
        this.w = new double[this.network.getInputCount() * k];
        double[] x = new double[this.network.getInputCount()];
        double[] base = new double[this.network.getInputCount()];
        double[] direc = new double[this.network.getInputCount()];
        double[] g = new double[this.network.getInputCount()];
        double[] h = new double[this.network.getInputCount()];
        double[] dwk2 = new double[this.network.getInputCount()];
        if (this.network.isTrained()) {
            k = 0;
            for (i = 0; i < this.network.getInputCount(); ++i) {
                x[i] = this.network.getSigma()[i];
            }
            globalMinimum.setY2(1.0E30);
        } else {
            globalMinimum.findBestRange(this.sigmaLow, this.sigmaHigh, this.numSigmas, true, this.maxError, this);
            for (i = 0; i < this.network.getInputCount(); ++i) {
                x[i] = globalMinimum.getX2();
            }
        }
        double d = dermin.calculate(Short.MAX_VALUE, this.maxError, 1.0E-8, this.minImprovement, this, this.network.getInputCount(), x, globalMinimum.getY2(), base, direc, g, h, dwk2);
        globalMinimum.setY2(d);
        for (int i2 = 0; i2 < this.network.getInputCount(); ++i2) {
            this.network.getSigma()[i2] = x[i2];
        }
        this.network.setError(Math.abs(globalMinimum.getY2()));
        this.network.setTrained(true);
        this.setError(this.network.getError());
        this.postIteration();
    }

    @Override
    public TrainingContinuation pause() {
        return null;
    }

    @Override
    public void resume(TrainingContinuation state) {
    }

    public void setMaxError(double maxError) {
        this.maxError = maxError;
    }

    public void setMinImprovement(double minImprovement) {
        this.minImprovement = minImprovement;
    }

    public void setNumSigmas(int numSigmas) {
        this.numSigmas = numSigmas;
    }

    public void setSigmaHigh(double sigmaHigh) {
        this.sigmaHigh = sigmaHigh;
    }

    public void setSigmaLow(double sigmaLow) {
        this.sigmaLow = sigmaLow;
    }
}

