package org.encog.ensemble.aggregator;

import java.util.ArrayList;
import java.util.Iterator;
import org.encog.ensemble.EnsembleAggregator;
import org.encog.ensemble.EnsembleML;
import org.encog.ensemble.EnsembleMLMethodFactory;
import org.encog.ensemble.EnsembleTrainFactory;
import org.encog.ensemble.GenericEnsembleML;
import org.encog.ensemble.data.EnsembleDataSet;
import org.encog.ml.data.MLData;
import org.encog.ml.data.basic.BasicMLData;

/* loaded from: input_file:org/encog/ensemble/aggregator/MetaClassifier.class */
public class MetaClassifier implements EnsembleAggregator {
    EnsembleML classifier;
    EnsembleMLMethodFactory mlFact;
    EnsembleTrainFactory etFact;
    double trainError;
    int members;
    boolean adaptiveError;

    public MetaClassifier(double d, EnsembleMLMethodFactory ensembleMLMethodFactory, EnsembleTrainFactory ensembleTrainFactory, boolean z) {
        this.adaptiveError = false;
        this.trainError = d;
        this.mlFact = ensembleMLMethodFactory;
        this.etFact = ensembleTrainFactory;
        this.adaptiveError = z;
        this.members = 1;
    }

    public MetaClassifier(double d, EnsembleMLMethodFactory ensembleMLMethodFactory, EnsembleTrainFactory ensembleTrainFactory) {
        this(d, ensembleMLMethodFactory, ensembleTrainFactory, false);
    }

    public double getTrainingError() {
        return this.trainError;
    }

    public void setTrainingError(double d) {
        this.trainError = d;
    }

    @Override // org.encog.ensemble.EnsembleAggregator
    public void setNumberOfMembers(int i) {
        this.members = i;
    }

    @Override // org.encog.ensemble.EnsembleAggregator
    public MLData evaluate(ArrayList<MLData> arrayList) {
        BasicMLData basicMLData = new BasicMLData(this.classifier.getInputCount());
        Iterator<MLData> it = arrayList.iterator();
        while (it.hasNext()) {
            int i = 0;
            for (double d : it.next().getData()) {
                int i2 = i;
                i++;
                basicMLData.add(i2, d);
            }
        }
        return this.classifier.compute(basicMLData);
    }

    @Override // org.encog.ensemble.EnsembleAggregator
    public String getLabel() {
        String str = "metaclassifier-" + this.mlFact.getLabel() + "-" + this.trainError + "-" + this.etFact.getLabel();
        if (this.adaptiveError) {
            str = str + "-adaptive";
        }
        return str;
    }

    @Override // org.encog.ensemble.EnsembleAggregator
    public void train() {
        if (this.classifier != null) {
            this.classifier.train(this.adaptiveError ? this.trainError / this.members : this.trainError);
        } else {
            System.err.println("Trying to train a null classifier in MetaClassifier");
        }
    }

    @Override // org.encog.ensemble.EnsembleAggregator
    public void setTrainingSet(EnsembleDataSet ensembleDataSet) {
        this.mlFact.setSizeMultiplier(this.members);
        this.classifier = new GenericEnsembleML(this.mlFact.createML(ensembleDataSet.getInputSize(), ensembleDataSet.getIdealSize()), this.mlFact.getLabel());
        this.classifier.setTraining(this.etFact.getTraining(this.classifier.getMl(), ensembleDataSet));
        this.classifier.setTrainingSet(ensembleDataSet);
    }

    @Override // org.encog.ensemble.EnsembleAggregator
    public boolean needsTraining() {
        return true;
    }
}
