/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.examples.training;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.basicdataset.CaptchaDataset;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.examples.training.util.Arguments;
import ai.djl.examples.training.util.ExampleTrainingResult;
import ai.djl.examples.training.util.TrainingUtils;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.loss.SimpleCompositeLoss;
import ai.djl.training.loss.SoftmaxCrossEntropyLoss;
import ai.djl.training.util.ProgressBar;
import ai.djl.util.Progress;
import java.io.IOException;
import java.nio.file.Paths;
import org.apache.commons.cli.ParseException;

public final class TrainCaptcha {
    private TrainCaptcha() {
    }

    public static void main(String[] args) throws IOException, ParseException {
        TrainCaptcha.runExample(args);
    }

    public static ExampleTrainingResult runExample(String[] args) throws ParseException, IOException {
        Arguments arguments = Arguments.parseArgs(args);
        try (Model model = Model.newInstance();){
            ExampleTrainingResult result;
            model.setBlock(TrainCaptcha.getBlock());
            RandomAccessDataset trainingSet = TrainCaptcha.getDataset(Dataset.Usage.TRAIN, arguments);
            RandomAccessDataset validateSet = TrainCaptcha.getDataset(Dataset.Usage.VALIDATION, arguments);
            DefaultTrainingConfig config = TrainCaptcha.setupTrainingConfig(arguments);
            config.addTrainingListeners(TrainingListener.Defaults.logging((String)TrainCaptcha.class.getSimpleName(), (int)arguments.getBatchSize(), (int)((int)trainingSet.getNumIterations()), (int)((int)validateSet.getNumIterations()), (String)arguments.getOutputDir()));
            try (Trainer trainer = model.newTrainer((TrainingConfig)config);){
                trainer.setMetrics(new Metrics());
                Shape inputShape = new Shape(new long[]{1L, 1L, 60L, 160L});
                trainer.initialize(new Shape[]{inputShape});
                TrainingUtils.fit(trainer, arguments.getEpoch(), (Dataset)trainingSet, (Dataset)validateSet, arguments.getOutputDir(), "captcha");
                result = new ExampleTrainingResult(trainer);
            }
            model.save(Paths.get(arguments.getOutputDir(), new String[0]), "captcha");
            ExampleTrainingResult exampleTrainingResult = result;
            return exampleTrainingResult;
        }
    }

    private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
        SimpleCompositeLoss loss = new SimpleCompositeLoss();
        for (int i = 0; i < 6; ++i) {
            loss.addLoss((Loss)new SoftmaxCrossEntropyLoss("loss_digit_" + i), i);
        }
        DefaultTrainingConfig config = new DefaultTrainingConfig((Loss)loss).setBatchSize(arguments.getBatchSize()).optDevices(Device.getDevices((int)arguments.getMaxGpus()));
        for (int i = 0; i < 6; ++i) {
            config.addEvaluator((Evaluator)new Accuracy("acc_digit_" + i, i));
        }
        return config;
    }

    private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arguments) throws IOException {
        CaptchaDataset dataset = ((CaptchaDataset.Builder)((CaptchaDataset.Builder)CaptchaDataset.builder().optUsage(usage).setSampling(arguments.getBatchSize(), true)).optMaxIteration(arguments.getMaxIterations())).build();
        dataset.prepare((Progress)new ProgressBar());
        return dataset;
    }

    private static Block getBlock() {
        Block resnet = ResNetV1.builder().setNumLayers(50).setImageShape(new Shape(new long[]{1L, 60L, 160L})).setOutSize(66L).build();
        return new SequentialBlock().add(resnet).add(resnetOutputList -> {
            NDArray resnetOutput = resnetOutputList.singletonOrThrow();
            NDList splitOutput = resnetOutput.reshape(new long[]{-1L, 6L, 11L}).split(6L, 1);
            NDList output = new NDList(6);
            for (NDArray outputDigit : splitOutput) {
                output.add((Object)outputDigit.squeeze(1));
            }
            return output;
        });
    }
}

