/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.examples.training;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.basicdataset.Mnist;
import ai.djl.basicmodelzoo.basic.Mlp;
import ai.djl.examples.training.util.Arguments;
import ai.djl.examples.training.util.ExampleTrainingResult;
import ai.djl.examples.training.util.TrainingUtils;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.hyperparameter.optimizer.HpORandom;
import ai.djl.training.hyperparameter.param.HpInt;
import ai.djl.training.hyperparameter.param.HpSet;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.util.Pair;
import ai.djl.util.Progress;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.Arrays;
import org.apache.commons.cli.ParseException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class TrainWithHpo {
    private static final Logger logger = LoggerFactory.getLogger(TrainWithHpo.class);

    private TrainWithHpo() {
    }

    public static void main(String[] args) throws IOException, ParseException {
        TrainWithHpo.runExample(args);
    }

    public static ExampleTrainingResult runExample(String[] args) throws IOException, ParseException {
        Arguments arguments = Arguments.parseArgs(args);
        RandomAccessDataset trainingSet = TrainWithHpo.getDataset(Dataset.Usage.TRAIN, arguments);
        RandomAccessDataset validateSet = TrainWithHpo.getDataset(Dataset.Usage.TEST, arguments);
        HpSet hyperParams = new HpSet("hp", Arrays.asList(new HpInt("hiddenLayersSize", 10, 100), new HpInt("hiddenLayersCount", 2, 10)));
        HpORandom hpOptimizer = new HpORandom(hyperParams);
        int hyperparameterTests = 50;
        for (int i = 0; i < 50; ++i) {
            HpSet hpVals = hpOptimizer.nextConfig();
            Pair<Model, ExampleTrainingResult> trained = TrainWithHpo.train(arguments, hpVals, trainingSet, validateSet);
            ((Model)trained.getKey()).close();
            ExampleTrainingResult result = (ExampleTrainingResult)trained.getValue();
            hpOptimizer.update(hpVals, result.getLoss());
            logger.info("--------- hp test {}/{} - Loss {} - {}", new Object[]{i, 50, Float.valueOf(result.getLoss()), hpVals});
        }
        HpSet bestHpVals = (HpSet)hpOptimizer.getBest().getKey();
        Pair<Model, ExampleTrainingResult> trained = TrainWithHpo.train(arguments, bestHpVals, trainingSet, validateSet);
        ExampleTrainingResult result = (ExampleTrainingResult)trained.getValue();
        try (Model model = (Model)trained.getKey();){
            logger.info("--------- FINAL_HP - Loss {} - {}", (Object)Float.valueOf(result.getLoss()), (Object)bestHpVals);
            model.save(Paths.get(arguments.getOutputDir(), new String[0]), "mlp");
        }
        return result;
    }

    private static Pair<Model, ExampleTrainingResult> train(Arguments arguments, HpSet hpVals, RandomAccessDataset trainingSet, RandomAccessDataset validateSet) throws IOException {
        ExampleTrainingResult result;
        int[] hidden = new int[((Integer)hpVals.getHParam("hiddenLayersCount").random()).intValue()];
        Arrays.fill(hidden, (Integer)hpVals.getHParam("hiddenLayersSize").random());
        Mlp block = new Mlp(784, 10, hidden);
        Model model = Model.newInstance();
        model.setBlock((Block)block);
        DefaultTrainingConfig config = TrainWithHpo.setupTrainingConfig(arguments);
        config.addTrainingListeners(TrainingListener.Defaults.logging((String)TrainWithHpo.class.getSimpleName(), (int)arguments.getBatchSize(), (int)((int)trainingSet.getNumIterations()), (int)((int)validateSet.getNumIterations()), (String)arguments.getOutputDir()));
        try (Trainer trainer = model.newTrainer((TrainingConfig)config);){
            trainer.setMetrics(new Metrics());
            Shape inputShape = new Shape(new long[]{1L, 784L});
            trainer.initialize(new Shape[]{inputShape});
            TrainingUtils.fit(trainer, arguments.getEpoch(), (Dataset)trainingSet, (Dataset)validateSet, arguments.getOutputDir(), "mlp");
            result = new ExampleTrainingResult(trainer);
        }
        return new Pair((Object)model, (Object)result);
    }

    private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
        return new DefaultTrainingConfig((Loss)Loss.softmaxCrossEntropyLoss()).addEvaluator((Evaluator)new Accuracy()).setBatchSize(arguments.getBatchSize()).optDevices(Device.getDevices((int)arguments.getMaxGpus()));
    }

    private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arguments) throws IOException {
        Mnist mnist = ((Mnist.Builder)((Mnist.Builder)Mnist.builder().optUsage(usage).setSampling(arguments.getBatchSize(), true)).optMaxIteration(arguments.getMaxIterations())).build();
        mnist.prepare((Progress)new ProgressBar());
        return mnist;
    }
}

