/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.examples.inference.benchmark;

import ai.djl.examples.inference.benchmark.util.AbstractBenchmark;
import ai.djl.examples.inference.benchmark.util.Arguments;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.util.BufferedImageUtils;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.listener.MemoryTrainingListener;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.nio.file.Path;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.Options;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MultithreadedBenchmark
extends AbstractBenchmark<BufferedImage, Classifications> {
    private static final Logger logger = LoggerFactory.getLogger(MultithreadedBenchmark.class);
    BufferedImage img;
    int numOfThreads;
    AtomicInteger callableNumber;
    AtomicInteger successThreads;
    ExecutorService executorService;

    public MultithreadedBenchmark() {
        super(BufferedImage.class, Classifications.class);
    }

    public static void main(String[] args) {
        if (new MultithreadedBenchmark().runBenchmark(args)) {
            System.exit(0);
        }
        System.exit(-1);
    }

    @Override
    protected void initialize(ZooModel<BufferedImage, Classifications> model, Arguments arguments, Metrics metrics) throws IOException {
        Path imageFile = arguments.getImageFile();
        this.img = BufferedImageUtils.fromFile((Path)imageFile);
        this.numOfThreads = arguments.getThreads();
        this.callableNumber = new AtomicInteger();
        this.successThreads = new AtomicInteger();
        logger.info("Multithreaded inference with {} threads.", (Object)this.numOfThreads);
        metrics.addMetric("thread", (Number)this.numOfThreads);
        this.executorService = Executors.newFixedThreadPool(this.numOfThreads);
    }

    @Override
    protected CompletableFuture<Classifications> predict(ZooModel<BufferedImage, Classifications> model, Arguments arguments, Metrics metrics) {
        PredictorSupplier supplier = new PredictorSupplier(model, metrics);
        return CompletableFuture.supplyAsync(supplier, this.executorService);
    }

    @Override
    protected void clean() {
        this.executorService.shutdown();
        if (this.successThreads.get() != this.callableNumber.get()) {
            logger.error("Only {}/{} threads finished.", (Object)this.successThreads.get(), (Object)this.callableNumber.get());
        }
    }

    @Override
    protected Options getOptions() {
        Options options = super.getOptions();
        options.addOption(Option.builder((String)"t").longOpt("threads").hasArg().argName("NUMBER_THREADS").desc("Number of inference threads.").build());
        return options;
    }

    private class PredictorSupplier
    implements Supplier<Classifications> {
        private Predictor<BufferedImage, Classifications> predictor;
        private Metrics metrics;
        private String workerId;
        private boolean collectMemory;

        public PredictorSupplier(ZooModel<BufferedImage, Classifications> model, Metrics metrics) {
            this.predictor = model.newPredictor();
            this.metrics = metrics;
            int iteration = MultithreadedBenchmark.this.callableNumber.getAndIncrement();
            this.workerId = String.format("%02d", iteration);
            this.collectMemory = iteration == 0;
            this.predictor.setMetrics(metrics);
        }

        @Override
        public Classifications get() {
            try {
                Classifications result = (Classifications)this.predictor.predict((Object)MultithreadedBenchmark.this.img);
                if (this.collectMemory) {
                    MemoryTrainingListener.collectMemoryInfo((Metrics)this.metrics);
                }
                logger.debug("Worker-{}: finished.", (Object)this.workerId);
                this.predictor.close();
                MultithreadedBenchmark.this.successThreads.incrementAndGet();
                return result;
            }
            catch (Exception e) {
                logger.error("Failed to classify with worker " + this.workerId, (Throwable)e);
                return null;
            }
        }
    }
}

