/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.examples.training;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.PikachuDetection;
import ai.djl.basicmodelzoo.cv.object_detection.ssd.SingleShotDetection;
import ai.djl.examples.training.util.Arguments;
import ai.djl.examples.training.util.ExampleTrainingResult;
import ai.djl.examples.training.util.TrainingUtils;
import ai.djl.metric.Metrics;
import ai.djl.modality.cv.MultiBoxDetection;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.LambdaBlock;
import ai.djl.nn.SequentialBlock;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.BoundingBoxError;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.evaluator.SingleShotDetectionAccuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.loss.SingleShotDetectionLoss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Transform;
import ai.djl.translate.TranslateException;
import ai.djl.util.Progress;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.cli.ParseException;

public final class TrainPikachu {
    private TrainPikachu() {
    }

    public static void main(String[] args) throws IOException, ParseException {
        TrainPikachu.runExample(args);
    }

    public static ExampleTrainingResult runExample(String[] args) throws IOException, ParseException {
        Arguments arguments = Arguments.parseArgs(args);
        try (Model model = Model.newInstance();){
            ExampleTrainingResult result;
            model.setBlock(TrainPikachu.getSsdTrainBlock());
            RandomAccessDataset pikachuDetectionTrain = TrainPikachu.getDataset(Dataset.Usage.TRAIN, arguments);
            RandomAccessDataset pikachuDetectionTest = TrainPikachu.getDataset(Dataset.Usage.TEST, arguments);
            DefaultTrainingConfig config = TrainPikachu.setupTrainingConfig(arguments);
            config.addTrainingListeners(TrainingListener.Defaults.logging((String)TrainPikachu.class.getSimpleName(), (int)arguments.getBatchSize(), (int)((int)pikachuDetectionTrain.getNumIterations()), (int)((int)pikachuDetectionTest.getNumIterations()), (String)arguments.getOutputDir()));
            try (Trainer trainer = model.newTrainer((TrainingConfig)config);){
                trainer.setMetrics(new Metrics());
                Shape inputShape = new Shape(new long[]{arguments.getBatchSize(), 3L, 256L, 256L});
                trainer.initialize(new Shape[]{inputShape});
                TrainingUtils.fit(trainer, arguments.getEpoch(), (Dataset)pikachuDetectionTrain, (Dataset)pikachuDetectionTest, arguments.getOutputDir(), "ssd");
                result = new ExampleTrainingResult(trainer);
            }
            model.save(Paths.get(arguments.getOutputDir(), new String[0]), "ssd");
            ExampleTrainingResult exampleTrainingResult = result;
            return exampleTrainingResult;
        }
    }

    /*
     * Exception decompiling
     */
    public static int predict(String outputDir, String imageFile) throws IOException, MalformedModelException, TranslateException {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 2 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arguments) throws IOException {
        Pipeline pipeline = new Pipeline(new Transform[]{new ToTensor()});
        PikachuDetection pikachuDetection = ((PikachuDetection.Builder)((PikachuDetection.Builder)((PikachuDetection.Builder)PikachuDetection.builder().optUsage(usage).optMaxIteration(arguments.getMaxIterations())).optPipeline(pipeline)).setSampling(arguments.getBatchSize(), true)).build();
        pikachuDetection.prepare((Progress)new ProgressBar());
        return pikachuDetection;
    }

    private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
        return new DefaultTrainingConfig((Loss)new SingleShotDetectionLoss()).setBatchSize(arguments.getBatchSize()).addEvaluator((Evaluator)new SingleShotDetectionAccuracy("classAccuracy")).addEvaluator((Evaluator)new BoundingBoxError("boundingBoxError")).optDevices(Device.getDevices((int)arguments.getMaxGpus()));
    }

    public static Block getSsdTrainBlock() {
        int[] numFilters = new int[]{16, 32, 64};
        SequentialBlock baseBlock = new SequentialBlock();
        for (int numFilter : numFilters) {
            baseBlock.add((Block)SingleShotDetection.getDownSamplingBlock((int)numFilter));
        }
        ArrayList<List<Float>> sizes = new ArrayList<List<Float>>();
        ArrayList<List<Float>> ratios = new ArrayList<List<Float>>();
        for (int i = 0; i < 5; ++i) {
            ratios.add(Arrays.asList(Float.valueOf(1.0f), Float.valueOf(2.0f), Float.valueOf(0.5f)));
        }
        sizes.add(Arrays.asList(Float.valueOf(0.2f), Float.valueOf(0.272f)));
        sizes.add(Arrays.asList(Float.valueOf(0.37f), Float.valueOf(0.447f)));
        sizes.add(Arrays.asList(Float.valueOf(0.54f), Float.valueOf(0.619f)));
        sizes.add(Arrays.asList(Float.valueOf(0.71f), Float.valueOf(0.79f)));
        sizes.add(Arrays.asList(Float.valueOf(0.88f), Float.valueOf(0.961f)));
        return SingleShotDetection.builder().setNumClasses(1).setNumFeatures(3).optGlobalPool(true).setRatios(ratios).setSizes(sizes).setBaseNetwork((Block)baseBlock).build();
    }

    public static Block getSsdPredictBlock(Block ssdTrain) {
        SequentialBlock ssdPredict = new SequentialBlock();
        ssdPredict.add(ssdTrain);
        ssdPredict.add((Block)new LambdaBlock(output -> {
            NDArray anchors = (NDArray)output.get(0);
            NDArray classPredictions = ((NDArray)output.get(1)).softmax(-1).transpose(new int[]{0, 2, 1});
            NDArray boundingBoxPredictions = (NDArray)output.get(2);
            MultiBoxDetection multiBoxDetection = MultiBoxDetection.builder().build();
            NDList detections = multiBoxDetection.detection(new NDList(new NDArray[]{classPredictions, boundingBoxPredictions, anchors}));
            return detections.singletonOrThrow().split(new long[]{1L, 2L}, 2);
        }));
        return ssdPredict;
    }
}

