/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.examples.training.transferlearning;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.Cifar10;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.examples.training.util.Arguments;
import ai.djl.examples.training.util.ExampleTrainingResult;
import ai.djl.examples.training.util.TrainingUtils;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.SymbolBlock;
import ai.djl.nn.core.Linear;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Transform;
import ai.djl.util.Progress;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.Map;
import org.apache.commons.cli.ParseException;

public final class TrainResnetWithCifar10 {
    private TrainResnetWithCifar10() {
    }

    public static void main(String[] args) throws ParseException, ModelNotFoundException, IOException, MalformedModelException {
        TrainResnetWithCifar10.runExample(args);
    }

    public static ExampleTrainingResult runExample(String[] args) throws IOException, ParseException, ModelNotFoundException, MalformedModelException {
        Arguments arguments = Arguments.parseArgs(args);
        try (Model model = TrainResnetWithCifar10.getModel(arguments);){
            ExampleTrainingResult result;
            RandomAccessDataset trainDataset = TrainResnetWithCifar10.getDataset(Dataset.Usage.TRAIN, arguments);
            RandomAccessDataset validationDataset = TrainResnetWithCifar10.getDataset(Dataset.Usage.TEST, arguments);
            DefaultTrainingConfig config = TrainResnetWithCifar10.setupTrainingConfig(arguments);
            config.addTrainingListeners(TrainingListener.Defaults.logging((String)TrainResnetWithCifar10.class.getSimpleName(), (int)arguments.getBatchSize(), (int)((int)trainDataset.getNumIterations()), (int)((int)validationDataset.getNumIterations()), (String)arguments.getOutputDir()));
            try (Trainer trainer = model.newTrainer((TrainingConfig)config);){
                trainer.setMetrics(new Metrics());
                Shape inputShape = new Shape(new long[]{1L, 3L, 32L, 32L});
                trainer.initialize(new Shape[]{inputShape});
                TrainingUtils.fit(trainer, arguments.getEpoch(), (Dataset)trainDataset, (Dataset)validationDataset, arguments.getOutputDir(), "resnetv1");
                result = new ExampleTrainingResult(trainer);
            }
            model.save(Paths.get("build/model", new String[0]), "resnetv1");
            ExampleTrainingResult exampleTrainingResult = result;
            return exampleTrainingResult;
        }
    }

    private static Model getModel(Arguments arguments) throws IOException, ModelNotFoundException, MalformedModelException {
        boolean isSymbolic = arguments.isSymbolic();
        boolean preTrained = arguments.isPreTrained();
        Map<String, String> options = arguments.getCriteria();
        Criteria.Builder builder = Criteria.builder().optApplication(Application.CV.IMAGE_CLASSIFICATION).setTypes(BufferedImage.class, Classifications.class).optProgress((Progress)new ProgressBar()).optModelLoaderName("resnet");
        if (isSymbolic) {
            builder.optEngine("MXNet").optModelZooName("MXNet");
            if (options == null) {
                builder.optFilter("layers", "50");
                builder.optFilter("flavor", "v1");
            } else {
                builder.optFilters(options);
            }
            ZooModel model = ModelZoo.loadModel((Criteria)builder.build());
            SequentialBlock newBlock = new SequentialBlock();
            SymbolBlock block = (SymbolBlock)model.getBlock();
            block.removeLastBlock();
            newBlock.add((Block)block);
            newBlock.add(x -> new NDList(new NDArray[]{x.singletonOrThrow().squeeze()}));
            newBlock.add((Block)Linear.builder().setOutChannels(10L).build());
            newBlock.add(Blocks.batchFlattenBlock());
            model.setBlock((Block)newBlock);
            if (!preTrained) {
                model.getBlock().clear();
            }
            return model;
        }
        if (preTrained) {
            builder.optModelZooName("Basic");
            if (options == null) {
                builder.optFilter("layers", "50");
                builder.optFilter("flavor", "v1");
                builder.optFilter("dataset", "cifar10");
            } else {
                builder.optFilters(options);
            }
            return ModelZoo.loadModel((Criteria)builder.build());
        }
        Model model = Model.newInstance();
        Block resNet50 = ResNetV1.builder().setImageShape(new Shape(new long[]{3L, 32L, 32L})).setNumLayers(50).setOutSize(10L).build();
        model.setBlock(resNet50);
        return model;
    }

    private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
        return new DefaultTrainingConfig((Loss)Loss.softmaxCrossEntropyLoss()).addEvaluator((Evaluator)new Accuracy()).setBatchSize(arguments.getBatchSize()).optDevices(Device.getDevices((int)arguments.getMaxGpus()));
    }

    private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arguments) throws IOException {
        Pipeline pipeline = new Pipeline(new Transform[]{new ToTensor(), new Normalize(Cifar10.NORMALIZE_MEAN, Cifar10.NORMALIZE_STD)});
        Cifar10 cifar10 = ((Cifar10.Builder)((Cifar10.Builder)((Cifar10.Builder)Cifar10.builder().optUsage(usage).setSampling(arguments.getBatchSize(), true)).optMaxIteration(arguments.getMaxIterations())).optPipeline(pipeline)).build();
        cifar10.prepare((Progress)new ProgressBar());
        return cifar10;
    }
}

