package org.encog.neural.cpn.training;

import java.util.Iterator;
import org.encog.mathutil.error.ErrorCalculation;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.cpn.CPN;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.util.EngineArray;

/* loaded from: input_file:org/encog/neural/cpn/training/TrainOutstar.class */
public class TrainOutstar extends BasicTraining implements LearningRate {
    private double learningRate;
    private final CPN network;
    private final MLDataSet training;
    private boolean mustInit;

    public TrainOutstar(CPN cpn, MLDataSet mLDataSet, double d) {
        super(TrainingImplementationType.Iterative);
        this.mustInit = true;
        this.network = cpn;
        this.training = mLDataSet;
        this.learningRate = d;
    }

    @Override // org.encog.ml.train.MLTrain
    public boolean canContinue() {
        return false;
    }

    @Override // org.encog.neural.networks.training.LearningRate
    public double getLearningRate() {
        return this.learningRate;
    }

    @Override // org.encog.ml.train.MLTrain
    public MLMethod getMethod() {
        return this.network;
    }

    private void initWeight() {
        for (int i = 0; i < this.network.getOutstarCount(); i++) {
            int i2 = 0;
            Iterator<MLDataPair> it = this.training.iterator();
            while (it.hasNext()) {
                int i3 = i2;
                i2++;
                this.network.getWeightsInstarToOutstar().set(i3, i, it.next().getIdeal().getData(i));
            }
        }
        this.mustInit = false;
    }

    @Override // org.encog.ml.train.MLTrain
    public void iteration() {
        if (this.mustInit) {
            initWeight();
        }
        ErrorCalculation errorCalculation = new ErrorCalculation();
        for (MLDataPair mLDataPair : this.training) {
            MLData computeInstar = this.network.computeInstar(mLDataPair.getInput());
            int indexOfLargest = EngineArray.indexOfLargest(computeInstar.getData());
            for (int i = 0; i < this.network.getOutstarCount(); i++) {
                this.network.getWeightsInstarToOutstar().add(indexOfLargest, i, this.learningRate * (mLDataPair.getIdeal().getData(i) - this.network.getWeightsInstarToOutstar().get(indexOfLargest, i)));
            }
            errorCalculation.updateError(this.network.computeOutstar(computeInstar).getData(), mLDataPair.getIdeal().getData(), mLDataPair.getSignificance());
        }
        setError(errorCalculation.calculate());
    }

    @Override // org.encog.ml.train.MLTrain
    public TrainingContinuation pause() {
        return null;
    }

    @Override // org.encog.ml.train.MLTrain
    public void resume(TrainingContinuation trainingContinuation) {
    }

    @Override // org.encog.neural.networks.training.LearningRate
    public void setLearningRate(double d) {
        this.learningRate = d;
    }
}
