package org.encog.neural.networks.training.propagation.back;

import org.encog.ml.data.MLDataSet;
import org.encog.neural.networks.ContainsFlat;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.Momentum;
import org.encog.neural.networks.training.TrainingError;
import org.encog.neural.networks.training.propagation.Propagation;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.neural.networks.training.strategy.SmartLearningRate;
import org.encog.neural.networks.training.strategy.SmartMomentum;
import org.encog.util.validate.ValidateNetwork;

/* loaded from: input_file:org/encog/neural/networks/training/propagation/back/Backpropagation.class */
public class Backpropagation extends Propagation implements Momentum, LearningRate {
    public static final String LAST_DELTA = "LAST_DELTA";
    private double learningRate;
    private double momentum;
    private double[] lastDelta;
    private boolean nesterovUpdate;

    public Backpropagation(ContainsFlat containsFlat, MLDataSet mLDataSet) {
        this(containsFlat, mLDataSet, 0.0d, 0.0d);
        addStrategy(new SmartLearningRate());
        addStrategy(new SmartMomentum());
    }

    public Backpropagation(ContainsFlat containsFlat, MLDataSet mLDataSet, double d, double d2) {
        super(containsFlat, mLDataSet);
        ValidateNetwork.validateMethodToData(containsFlat, mLDataSet);
        this.momentum = d2;
        this.learningRate = d;
        this.lastDelta = new double[containsFlat.getFlat().getWeights().length];
    }

    @Override // org.encog.ml.train.MLTrain
    public boolean canContinue() {
        return false;
    }

    public double[] getLastDelta() {
        return this.lastDelta;
    }

    @Override // org.encog.neural.networks.training.LearningRate
    public double getLearningRate() {
        return this.learningRate;
    }

    @Override // org.encog.neural.networks.training.Momentum
    public double getMomentum() {
        return this.momentum;
    }

    public boolean isValidResume(TrainingContinuation trainingContinuation) {
        return trainingContinuation.getContents().containsKey(LAST_DELTA) && trainingContinuation.getTrainingType().equals(getClass().getSimpleName()) && ((double[]) trainingContinuation.get(LAST_DELTA)).length == ((ContainsFlat) getMethod()).getFlat().getWeights().length;
    }

    @Override // org.encog.ml.train.MLTrain
    public TrainingContinuation pause() {
        TrainingContinuation trainingContinuation = new TrainingContinuation();
        trainingContinuation.setTrainingType(getClass().getSimpleName());
        trainingContinuation.set(LAST_DELTA, this.lastDelta);
        return trainingContinuation;
    }

    @Override // org.encog.ml.train.MLTrain
    public void resume(TrainingContinuation trainingContinuation) {
        if (!isValidResume(trainingContinuation)) {
            throw new TrainingError("Invalid training resume data length");
        }
        this.lastDelta = (double[]) trainingContinuation.get(LAST_DELTA);
    }

    @Override // org.encog.neural.networks.training.LearningRate
    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    @Override // org.encog.neural.networks.training.Momentum
    public void setMomentum(double d) {
        this.momentum = d;
    }

    @Override // org.encog.neural.networks.training.propagation.Propagation
    public double updateWeight(double[] dArr, double[] dArr2, int i) {
        double d = (dArr[i] * this.learningRate) + (this.lastDelta[i] * this.momentum);
        this.lastDelta[i] = d;
        return d;
    }

    @Override // org.encog.neural.networks.training.propagation.Propagation
    public double updateWeight(double[] dArr, double[] dArr2, int i, double d) {
        return this.nesterovUpdate ? updateWeightNesterov(dArr, dArr2, i, d) : updateWeightNormal(dArr, dArr2, i, d);
    }

    private double updateWeightNormal(double[] dArr, double[] dArr2, int i, double d) {
        if (d > 0.0d && this.dropoutRandomSource.nextDouble() < d) {
            return 0.0d;
        }
        double d2 = (dArr[i] * this.learningRate) + (this.lastDelta[i] * this.momentum);
        this.lastDelta[i] = d2;
        return d2;
    }

    private double updateWeightNesterov(double[] dArr, double[] dArr2, int i, double d) {
        if (d > 0.0d && this.dropoutRandomSource.nextDouble() < d) {
            return 0.0d;
        }
        double d2 = this.lastDelta[i];
        this.lastDelta[i] = (this.momentum * d2) + (this.gradients[i] * this.learningRate);
        double d3 = (this.momentum * d2) - ((1.0d + this.momentum) * this.lastDelta[i]);
        this.lastDelta[i] = d3;
        return d3;
    }

    @Override // org.encog.neural.networks.training.propagation.Propagation
    public void initOthers() {
    }

    public boolean isNesterovUpdate() {
        return this.nesterovUpdate;
    }

    public void setNesterovUpdate(boolean z) {
        this.nesterovUpdate = z;
    }
}
