/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.examples.training;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.basicdataset.Mnist;
import ai.djl.basicmodelzoo.basic.Mlp;
import ai.djl.examples.training.util.Arguments;
import ai.djl.examples.training.util.ExampleTrainingResult;
import ai.djl.examples.training.util.TrainingUtils;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.util.Progress;
import java.io.IOException;
import java.nio.file.Paths;
import org.apache.commons.cli.ParseException;

public final class TrainMnist {
    private TrainMnist() {
    }

    public static void main(String[] args) throws IOException, ParseException {
        TrainMnist.runExample(args);
    }

    public static ExampleTrainingResult runExample(String[] args) throws IOException, ParseException {
        Arguments arguments = Arguments.parseArgs(args);
        Mlp block = new Mlp(784, 10, new int[]{128, 64});
        try (Model model = Model.newInstance();){
            ExampleTrainingResult result;
            model.setBlock((Block)block);
            RandomAccessDataset trainingSet = TrainMnist.getDataset(Dataset.Usage.TRAIN, arguments);
            RandomAccessDataset validateSet = TrainMnist.getDataset(Dataset.Usage.TEST, arguments);
            DefaultTrainingConfig config = TrainMnist.setupTrainingConfig(arguments);
            config.addTrainingListeners(TrainingListener.Defaults.logging((String)TrainMnist.class.getSimpleName(), (int)arguments.getBatchSize(), (int)((int)trainingSet.getNumIterations()), (int)((int)validateSet.getNumIterations()), (String)arguments.getOutputDir()));
            try (Trainer trainer = model.newTrainer((TrainingConfig)config);){
                trainer.setMetrics(new Metrics());
                Shape inputShape = new Shape(new long[]{1L, 784L});
                trainer.initialize(new Shape[]{inputShape});
                TrainingUtils.fit(trainer, arguments.getEpoch(), (Dataset)trainingSet, (Dataset)validateSet, arguments.getOutputDir(), "mlp");
                result = new ExampleTrainingResult(trainer);
            }
            model.save(Paths.get(arguments.getOutputDir(), new String[0]), "mlp");
            ExampleTrainingResult exampleTrainingResult = result;
            return exampleTrainingResult;
        }
    }

    private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
        return new DefaultTrainingConfig((Loss)Loss.softmaxCrossEntropyLoss()).addEvaluator((Evaluator)new Accuracy()).setBatchSize(arguments.getBatchSize()).optDevices(Device.getDevices((int)arguments.getMaxGpus()));
    }

    private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arguments) throws IOException {
        Mnist mnist = ((Mnist.Builder)((Mnist.Builder)Mnist.builder().optUsage(usage).setSampling(arguments.getBatchSize(), true)).optMaxIteration(arguments.getMaxIterations())).build();
        mnist.prepare((Progress)new ProgressBar());
        return mnist;
    }
}

