/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.examples.inference.benchmark.util;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.examples.inference.benchmark.MultithreadedBenchmark;
import ai.djl.examples.inference.benchmark.util.Arguments;
import ai.djl.metric.Metric;
import ai.djl.metric.Metrics;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.listener.MemoryTrainingListener;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import ai.djl.util.Progress;
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.concurrent.CompletableFuture;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AbstractBenchmark<I, O> {
    private static final Logger logger = LoggerFactory.getLogger(AbstractBenchmark.class);
    private Class<I> input;
    private Class<O> output;
    private O lastResult;
    protected ProgressBar progressBar;
    protected int maxIterations;
    protected int iterationCount;

    public AbstractBenchmark(Class<I> input, Class<O> output) {
        this.input = input;
        this.output = output;
    }

    public O getPredictResult() {
        return this.lastResult;
    }

    public final boolean runBenchmark(String[] args) {
        Options options = this.getOptions();
        try {
            long totalTime;
            DefaultParser parser = new DefaultParser();
            CommandLine cmd = parser.parse(options, args, null, false);
            Arguments arguments = this.parseArguments(cmd);
            long init = System.nanoTime();
            String version = Engine.getInstance().getVersion();
            long loaded = System.nanoTime();
            logger.info(String.format("Load library %s in %.3f ms.", version, Float.valueOf((float)(loaded - init) / 1000000.0f)));
            this.maxIterations = arguments.getIteration();
            if (this instanceof MultithreadedBenchmark) {
                this.maxIterations = Math.max(this.maxIterations, arguments.getThreads() * 2);
            }
            Duration duration = Duration.ofMinutes(arguments.getDuration());
            if (this.runByIterations()) {
                logger.info("Running {} on: {}, iterations: {}.", new Object[]{this.getClass().getSimpleName(), Device.defaultDevice(), this.maxIterations});
                this.progressBar = new ProgressBar("Iteration", (long)this.maxIterations);
            } else {
                logger.info("Running {} on: {}, duration: {} minutes.", new Object[]{this.getClass().getSimpleName(), Device.defaultDevice(), duration.toMinutes()});
                this.progressBar = new ProgressBar("Iteration", duration.getSeconds() * 1000L);
            }
            Metrics metrics = new Metrics();
            long begin = System.currentTimeMillis();
            ArrayList<CompletableFuture<O>> predictResults = new ArrayList<CompletableFuture<O>>();
            try (ZooModel<I, O> model = this.loadModel(arguments, metrics, this.input, this.output);){
                this.initialize(model, arguments, metrics);
                while (this.keepPredicting(duration, begin)) {
                    ++this.iterationCount;
                    predictResults.add(this.predict(model, arguments, metrics));
                    this.updateProgress(this.progressBar, begin);
                }
                for (CompletableFuture completableFuture : predictResults) {
                    this.lastResult = completableFuture.get();
                }
                totalTime = System.currentTimeMillis() - begin;
            }
            catch (Exception e) {
                logger.error("Failed to run benchmark", (Throwable)e);
                throw e;
            }
            finally {
                this.clean();
            }
            this.recordResults(arguments, metrics, totalTime);
            return true;
        }
        catch (ParseException e) {
            HelpFormatter formatter = new HelpFormatter();
            formatter.setLeftPadding(1);
            formatter.setWidth(120);
            formatter.printHelp(e.getMessage(), options);
        }
        catch (Throwable t) {
            logger.error("Unexpected error", t);
        }
        return false;
    }

    protected abstract void initialize(ZooModel<I, O> var1, Arguments var2, Metrics var3) throws IOException;

    protected abstract CompletableFuture<O> predict(ZooModel<I, O> var1, Arguments var2, Metrics var3) throws TranslateException;

    protected abstract void clean();

    protected Options getOptions() {
        return Arguments.getOptions();
    }

    protected Arguments parseArguments(CommandLine cmd) {
        return new Arguments(cmd);
    }

    protected ZooModel<I, O> loadModel(Arguments arguments, Metrics metrics, Class<I> input, Class<O> output) throws ModelException, IOException {
        long begin = System.nanoTime();
        Criteria.Builder builder = Criteria.builder().optApplication(Application.CV.IMAGE_CLASSIFICATION).setTypes(input, output).optFilters(arguments.getCriteria()).optProgress((Progress)new ProgressBar());
        String modelName = arguments.getModelName();
        if (modelName == null) {
            modelName = "resnet";
        }
        builder.optModelLoaderName(modelName);
        ZooModel model = ModelZoo.loadModel((Criteria)builder.build());
        long delta = System.nanoTime() - begin;
        logger.info("Model {} loaded in: {} ms.", (Object)model.getName(), (Object)String.format("%.3f", Float.valueOf((float)delta / 1000000.0f)));
        metrics.addMetric("LoadModel", (Number)delta);
        return model;
    }

    private boolean runByIterations() {
        return this.maxIterations != -1;
    }

    private boolean keepPredicting(Duration duration, long startTime) {
        if (this.runByIterations()) {
            return this.iterationCount < this.maxIterations;
        }
        return System.currentTimeMillis() - startTime < duration.getSeconds() * 1000L;
    }

    private void updateProgress(ProgressBar progressBar, long startTime) {
        if (this.runByIterations()) {
            progressBar.update((long)this.iterationCount);
        } else {
            progressBar.update(System.currentTimeMillis() - startTime);
        }
    }

    private void recordResults(Arguments arguments, Metrics metrics, long totalTime) {
        logger.info("Last inference result: {}", this.lastResult);
        logger.info(String.format("total time: %d ms, total runs: %d iterations", totalTime, this.iterationCount));
        if (metrics.hasMetric("LoadModel")) {
            long loadModelTime = ((Metric)metrics.getMetric("LoadModel").get(0)).getValue().longValue();
            logger.info("Model loading time: {} ms.", (Object)String.format("%.3f", Float.valueOf((float)loadModelTime / 1000000.0f)));
        }
        if (metrics.hasMetric("Inference") && this.maxIterations > 1) {
            float p50 = (float)metrics.percentile("Inference", 50).getValue().longValue() / 1000000.0f;
            float p90 = (float)metrics.percentile("Inference", 90).getValue().longValue() / 1000000.0f;
            float p99 = (float)metrics.percentile("Inference", 99).getValue().longValue() / 1000000.0f;
            float preP50 = (float)metrics.percentile("Preprocess", 50).getValue().longValue() / 1000000.0f;
            float preP90 = (float)metrics.percentile("Preprocess", 90).getValue().longValue() / 1000000.0f;
            float preP99 = (float)metrics.percentile("Preprocess", 99).getValue().longValue() / 1000000.0f;
            float postP50 = (float)metrics.percentile("Postprocess", 50).getValue().longValue() / 1000000.0f;
            float postP90 = (float)metrics.percentile("Postprocess", 90).getValue().longValue() / 1000000.0f;
            float postP99 = (float)metrics.percentile("Postprocess", 99).getValue().longValue() / 1000000.0f;
            logger.info(String.format("inference P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", Float.valueOf(p50), Float.valueOf(p90), Float.valueOf(p99)));
            logger.info(String.format("preprocess P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", Float.valueOf(preP50), Float.valueOf(preP90), Float.valueOf(preP99)));
            logger.info(String.format("postprocess P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", Float.valueOf(postP50), Float.valueOf(postP90), Float.valueOf(postP99)));
            if (Boolean.getBoolean("collect-memory")) {
                float heap = metrics.percentile("Heap", 90).getValue().longValue();
                float nonHeap = metrics.percentile("NonHeap", 90).getValue().longValue();
                float cpu = metrics.percentile("cpu", 90).getValue().longValue();
                float rss = metrics.percentile("rss", 90).getValue().longValue();
                logger.info(String.format("heap P90: %.3f", Float.valueOf(heap)));
                logger.info(String.format("nonHeap P90: %.3f", Float.valueOf(nonHeap)));
                logger.info(String.format("cpu P90: %.3f", Float.valueOf(cpu)));
                logger.info(String.format("rss P90: %.3f", Float.valueOf(rss)));
            }
            MemoryTrainingListener.dumpMemoryInfo((Metrics)metrics, (String)arguments.getOutputDir());
        }
    }
}

