/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.objectdetection.model;

import ai.djl.MalformedModelException;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import cn.smartjavaai.common.entity.DetectionResponse;
import cn.smartjavaai.common.pool.PredictorFactory;
import cn.smartjavaai.common.utils.FileUtils;
import cn.smartjavaai.common.utils.ImageUtils;
import cn.smartjavaai.common.utils.OpenCVUtils;
import cn.smartjavaai.objectdetection.config.DetectorModelConfig;
import cn.smartjavaai.objectdetection.criteria.CriteriaBuilderFactory;
import cn.smartjavaai.objectdetection.exception.DetectionException;
import cn.smartjavaai.objectdetection.utils.DetectorUtils;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import javax.imageio.ImageIO;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.pool2.PooledObjectFactory;
import org.apache.commons.pool2.impl.GenericObjectPool;
import org.opencv.core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DetectorModel
implements AutoCloseable {
    private static final Logger log = LoggerFactory.getLogger(DetectorModel.class);
    private ZooModel<Image, DetectedObjects> model;
    private GenericObjectPool<Predictor<Image, DetectedObjects>> predictorPool;
    private DetectorModelConfig config;

    public void loadModel(DetectorModelConfig config) {
        if (Objects.isNull((Object)config.getModelEnum())) {
            throw new DetectionException("\u672a\u914d\u7f6e\u6a21\u578b\u679a\u4e3e");
        }
        Criteria<Image, DetectedObjects> criteria = CriteriaBuilderFactory.createCriteria(config);
        this.config = config;
        try {
            this.model = criteria.loadModel();
            this.predictorPool = new GenericObjectPool((PooledObjectFactory)new PredictorFactory(this.model));
            int predictorPoolSize = config.getPredictorPoolSize();
            if (config.getPredictorPoolSize() <= 0) {
                predictorPoolSize = Runtime.getRuntime().availableProcessors();
            }
            this.predictorPool.setMaxTotal(predictorPoolSize);
            log.debug("\u5f53\u524d\u8bbe\u5907: " + this.model.getNDManager().getDevice());
            log.debug("\u5f53\u524d\u5f15\u64ce: " + Engine.getInstance().getEngineName());
            log.debug("\u6a21\u578b\u63a8\u7406\u5668\u7ebf\u7a0b\u6c60\u6700\u5927\u6570\u91cf: " + predictorPoolSize);
        }
        catch (MalformedModelException | ModelNotFoundException | IOException e) {
            throw new DetectionException("\u6a21\u578b\u52a0\u8f7d\u5931\u8d25", e);
        }
    }

    public DetectionResponse detect(String imagePath) {
        if (!FileUtils.isFileExists((String)imagePath)) {
            throw new DetectionException("\u56fe\u50cf\u6587\u4ef6\u4e0d\u5b58\u5728");
        }
        Image image = null;
        try {
            image = ImageFactory.getInstance().fromFile(Paths.get(imagePath, new String[0]));
            DetectedObjects detectedObjects = this.detect(image);
            DetectionResponse detectionResponse = DetectorUtils.convertToDetectionResponse(detectedObjects, image);
            return detectionResponse;
        }
        catch (Exception e) {
            throw new DetectionException(e);
        }
        finally {
            if (image != null) {
                ((Mat)image.getWrappedImage()).release();
            }
        }
    }

    public void detectAndDraw(String imagePath, String outputPath) {
        if (!FileUtils.isFileExists((String)imagePath)) {
            throw new DetectionException("\u56fe\u50cf\u6587\u4ef6\u4e0d\u5b58\u5728");
        }
        Image img = null;
        try {
            img = ImageFactory.getInstance().fromFile(Paths.get(imagePath, new String[0]));
            DetectedObjects detectedObjects = this.detect(img);
            img.drawBoundingBoxes(detectedObjects);
            ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
            img.save((OutputStream)new FileOutputStream(Paths.get(outputPath, new String[0]).toAbsolutePath().toString()), "png");
        }
        catch (IOException e) {
            throw new DetectionException(e);
        }
        finally {
            if (img != null) {
                ((Mat)img.getWrappedImage()).release();
            }
        }
    }

    public DetectionResponse detect(byte[] imageData) {
        if (Objects.isNull(imageData)) {
            throw new DetectionException("\u56fe\u50cf\u65e0\u6548");
        }
        try {
            BufferedImage image = ImageIO.read(new ByteArrayInputStream(imageData));
            return this.detect(image);
        }
        catch (IOException e) {
            throw new DetectionException("\u9519\u8bef\u7684\u56fe\u50cf", e);
        }
    }

    public DetectionResponse detect(BufferedImage image) {
        if (!ImageUtils.isImageValid((BufferedImage)image)) {
            throw new DetectionException("\u56fe\u50cf\u65e0\u6548");
        }
        Image img = null;
        try {
            img = ImageFactory.getInstance().fromImage((Object)OpenCVUtils.image2Mat((BufferedImage)image));
            DetectedObjects detectedObjects = this.detect(img);
            DetectionResponse detectionResponse = DetectorUtils.convertToDetectionResponse(detectedObjects, img);
            return detectionResponse;
        }
        catch (Exception e) {
            throw new DetectionException(e);
        }
        finally {
            if (img != null) {
                ((Mat)img.getWrappedImage()).release();
            }
        }
    }

    public BufferedImage detectAndDraw(BufferedImage sourceImage) {
        if (!ImageUtils.isImageValid((BufferedImage)sourceImage)) {
            throw new DetectionException("\u56fe\u50cf\u65e0\u6548");
        }
        Image img = ImageFactory.getInstance().fromImage((Object)OpenCVUtils.image2Mat((BufferedImage)sourceImage));
        DetectedObjects detectedObjects = this.detect(img);
        img.drawBoundingBoxes(detectedObjects);
        try {
            ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
            img.save((OutputStream)outputStream, "png");
            byte[] imageBytes = outputStream.toByteArray();
            BufferedImage bufferedImage = ImageIO.read(new ByteArrayInputStream(imageBytes));
            return bufferedImage;
        }
        catch (IOException e) {
            throw new DetectionException("\u5bfc\u51fa\u56fe\u7247\u5931\u8d25", e);
        }
        finally {
            if (img != null) {
                ((Mat)img.getWrappedImage()).release();
            }
        }
    }

    public DetectedObjects detect(Image image) {
        Predictor predictor = null;
        try {
            predictor = (Predictor)this.predictorPool.borrowObject();
            DetectedObjects detectedObjects = (DetectedObjects)predictor.predict((Object)image);
            DetectedObjects detectedObjects2 = detectedObjects = this.filterDetections(detectedObjects);
            return detectedObjects2;
        }
        catch (Exception e) {
            throw new DetectionException("\u76ee\u6807\u68c0\u6d4b\u9519\u8bef", e);
        }
        finally {
            if (predictor != null) {
                try {
                    this.predictorPool.returnObject((Object)predictor);
                    log.debug("\u91ca\u653e\u8d44\u6e90");
                }
                catch (Exception e) {
                    log.warn("\u5f52\u8fd8Predictor\u5931\u8d25", (Throwable)e);
                    try {
                        predictor.close();
                    }
                    catch (Exception ex) {
                        log.error("\u5173\u95edPredictor\u5931\u8d25", (Throwable)ex);
                    }
                }
            }
        }
    }

    private DetectedObjects filterDetections(DetectedObjects detectedObjects) {
        if (Objects.isNull(detectedObjects) || detectedObjects.getNumberOfObjects() == 0) {
            return detectedObjects;
        }
        List items = detectedObjects.items();
        List<DetectedObjects.DetectedObject> filtered = new ArrayList();
        if (!CollectionUtils.isEmpty(this.config.getAllowedClasses())) {
            for (DetectedObjects.DetectedObject obj : items) {
                if (!this.config.getAllowedClasses().contains(obj.getClassName())) continue;
                filtered.add(obj);
            }
        } else {
            filtered = items;
        }
        filtered.sort((o1, o2) -> Double.compare(o2.getProbability(), o1.getProbability()));
        if (this.config.getTopK() > 0 && filtered.size() > this.config.getTopK()) {
            filtered = filtered.subList(0, this.config.getTopK());
        }
        ArrayList<String> names = new ArrayList<String>();
        ArrayList<Double> probs = new ArrayList<Double>();
        ArrayList<BoundingBox> boxes = new ArrayList<BoundingBox>();
        for (DetectedObjects.DetectedObject obj : filtered) {
            names.add(obj.getClassName());
            probs.add(obj.getProbability());
            boxes.add(obj.getBoundingBox());
        }
        return new DetectedObjects(names, probs, boxes);
    }

    public GenericObjectPool<Predictor<Image, DetectedObjects>> getPool() {
        return this.predictorPool;
    }

    @Override
    public void close() {
        try {
            if (this.predictorPool != null) {
                this.predictorPool.close();
            }
        }
        catch (Exception e) {
            log.warn("\u5173\u95ed predictorPool \u5931\u8d25", (Throwable)e);
        }
        try {
            if (this.model != null) {
                this.model.close();
            }
        }
        catch (Exception e) {
            log.warn("\u5173\u95ed model \u5931\u8d25", (Throwable)e);
        }
    }
}

