/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.ocr.model.common.direction;

import ai.djl.MalformedModelException;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import cn.smartjavaai.common.pool.PredictorFactory;
import cn.smartjavaai.common.utils.FileUtils;
import cn.smartjavaai.common.utils.ImageUtils;
import cn.smartjavaai.common.utils.OpenCVUtils;
import cn.smartjavaai.ocr.config.DirectionModelConfig;
import cn.smartjavaai.ocr.entity.DirectionInfo;
import cn.smartjavaai.ocr.entity.OcrBox;
import cn.smartjavaai.ocr.entity.OcrItem;
import cn.smartjavaai.ocr.enums.AngleEnum;
import cn.smartjavaai.ocr.exception.OcrException;
import cn.smartjavaai.ocr.model.common.detect.OcrCommonDetModel;
import cn.smartjavaai.ocr.model.common.direction.OcrDirectionModel;
import cn.smartjavaai.ocr.model.common.direction.criteria.DirectionCriteriaFactory;
import cn.smartjavaai.ocr.utils.OcrUtils;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import javax.imageio.ImageIO;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.pool2.PooledObjectFactory;
import org.apache.commons.pool2.impl.GenericObjectPool;
import org.opencv.core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PPOCRMobileV2ClsModel
implements OcrDirectionModel {
    private static final Logger log = LoggerFactory.getLogger(PPOCRMobileV2ClsModel.class);
    private GenericObjectPool<Predictor<Image, DirectionInfo>> predictorPool;
    private DirectionModelConfig config;
    private ZooModel<Image, DirectionInfo> model;
    private OcrCommonDetModel textDetModel;

    @Override
    public void loadModel(DirectionModelConfig config) {
        if (StringUtils.isBlank((CharSequence)config.getModelPath())) {
            throw new OcrException("modelPath is null");
        }
        this.config = config;
        this.textDetModel = config.getTextDetModel();
        ConcurrentHashMap<String, String> params = new ConcurrentHashMap<String, String>();
        if (StringUtils.isNotBlank((CharSequence)config.getBatchifier())) {
            params.put("batchifier", config.getBatchifier());
        }
        Criteria<Image, DirectionInfo> criteria = DirectionCriteriaFactory.createCriteria(config);
        try {
            this.model = ModelZoo.loadModel(criteria);
            this.predictorPool = new GenericObjectPool((PooledObjectFactory)new PredictorFactory(this.model));
            int predictorPoolSize = config.getPredictorPoolSize();
            if (config.getPredictorPoolSize() <= 0) {
                predictorPoolSize = Runtime.getRuntime().availableProcessors();
            }
            this.predictorPool.setMaxTotal(predictorPoolSize);
            log.debug("\u5f53\u524d\u8bbe\u5907: " + this.model.getNDManager().getDevice());
            log.debug("\u5f53\u524d\u5f15\u64ce: " + Engine.getInstance().getEngineName());
            log.debug("\u6a21\u578b\u63a8\u7406\u5668\u7ebf\u7a0b\u6c60\u6700\u5927\u6570\u91cf: " + predictorPoolSize);
        }
        catch (MalformedModelException | ModelNotFoundException | IOException e) {
            throw new OcrException("\u6a21\u578b\u52a0\u8f7d\u5931\u8d25", e);
        }
    }

    @Override
    public List<OcrItem> detect(String imagePath) {
        if (!FileUtils.isFileExists((String)imagePath)) {
            throw new OcrException("\u56fe\u50cf\u6587\u4ef6\u4e0d\u5b58\u5728");
        }
        Image img = null;
        try {
            img = ImageFactory.getInstance().fromFile(Paths.get(imagePath, new String[0]));
            List<OcrItem> list = this.detect(img);
            return list;
        }
        catch (IOException e) {
            throw new OcrException("\u65e0\u6548\u7684\u56fe\u7247", e);
        }
        finally {
            if (img != null) {
                ((Mat)img.getWrappedImage()).release();
            }
        }
    }

    @Override
    public List<OcrItem> detect(Image image) {
        if (Objects.isNull(this.textDetModel)) {
            throw new OcrException("textDetModel is null");
        }
        List<OcrBox> boxeList = this.textDetModel.detect(image);
        if (Objects.isNull(boxeList) || boxeList.isEmpty()) {
            throw new OcrException("\u672a\u68c0\u6d4b\u5230\u6587\u672c");
        }
        Mat srcMat = (Mat)image.getWrappedImage();
        return this.detect(boxeList, srcMat);
    }

    @Override
    public List<OcrItem> detect(List<OcrBox> boxList, Mat srcMat) {
        if (Objects.isNull(boxList) || boxList.isEmpty()) {
            throw new OcrException("boxList\u4e3a\u7a7a");
        }
        List<List<OcrItem>> ocrItemList = this.batchDetect(Collections.singletonList(boxList), Collections.singletonList(srcMat));
        if (Objects.isNull(ocrItemList) || ocrItemList.isEmpty()) {
            throw new OcrException("\u65b9\u5411\u68c0\u6d4b\u5931\u8d25");
        }
        return ocrItemList.get(0);
    }

    @Override
    public void detectAndDraw(String imagePath, String outputPath) {
        if (!FileUtils.isFileExists((String)imagePath)) {
            throw new OcrException("\u56fe\u50cf\u6587\u4ef6\u4e0d\u5b58\u5728");
        }
        Image img = null;
        try {
            img = ImageFactory.getInstance().fromFile(Paths.get(imagePath, new String[0]));
            List<OcrItem> itemList = this.detect(img);
            if (Objects.isNull(itemList) || itemList.isEmpty()) {
                throw new OcrException("\u672a\u68c0\u6d4b\u5230\u6587\u5b57");
            }
            OcrUtils.drawRectWithText((Mat)img.getWrappedImage(), itemList);
            Path output = Paths.get(outputPath, new String[0]);
            log.debug("Saving to {}", (Object)output.toAbsolutePath().toString());
            img.save(Files.newOutputStream(output, new OpenOption[0]), "png");
        }
        catch (IOException e) {
            throw new OcrException(e);
        }
        finally {
            if (img != null) {
                ((Mat)img.getWrappedImage()).release();
            }
        }
    }

    @Override
    public List<OcrItem> detect(BufferedImage image) {
        if (!ImageUtils.isImageValid((BufferedImage)image)) {
            throw new OcrException("\u56fe\u50cf\u65e0\u6548");
        }
        Image img = ImageFactory.getInstance().fromImage((Object)OpenCVUtils.image2Mat((BufferedImage)image));
        List<OcrItem> ocrItemList = this.detect(img);
        ((Mat)img.getWrappedImage()).release();
        return ocrItemList;
    }

    @Override
    public List<OcrItem> detect(byte[] imageData) {
        if (Objects.isNull(imageData)) {
            throw new OcrException("\u56fe\u50cf\u65e0\u6548");
        }
        try {
            BufferedImage image = ImageIO.read(new ByteArrayInputStream(imageData));
            return this.detect(image);
        }
        catch (IOException e) {
            throw new OcrException("\u9519\u8bef\u7684\u56fe\u50cf", e);
        }
    }

    @Override
    public BufferedImage detectAndDraw(BufferedImage sourceImage) {
        if (!ImageUtils.isImageValid((BufferedImage)sourceImage)) {
            throw new OcrException("\u56fe\u50cf\u65e0\u6548");
        }
        Image img = ImageFactory.getInstance().fromImage((Object)OpenCVUtils.image2Mat((BufferedImage)sourceImage));
        List<OcrItem> ocrItemList = this.detect(img);
        if (Objects.isNull(ocrItemList) || ocrItemList.isEmpty()) {
            throw new OcrException("\u672a\u68c0\u6d4b\u5230\u6587\u5b57");
        }
        OcrUtils.drawRectWithText((Mat)img.getWrappedImage(), ocrItemList);
        try {
            ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
            img.save((OutputStream)outputStream, "png");
            byte[] imageBytes = outputStream.toByteArray();
            BufferedImage bufferedImage = ImageIO.read(new ByteArrayInputStream(imageBytes));
            return bufferedImage;
        }
        catch (IOException e) {
            throw new OcrException("\u5bfc\u51fa\u56fe\u7247\u5931\u8d25", e);
        }
        finally {
            if (img != null) {
                ((Mat)img.getWrappedImage()).release();
            }
        }
    }

    @Override
    public List<List<OcrItem>> batchDetect(List<List<OcrBox>> boxList, List<Mat> srcMatList) {
        if (CollectionUtils.isEmpty(boxList)) {
            throw new OcrException("boxList \u4e0d\u80fd\u4e3a\u7a7a");
        }
        if (CollectionUtils.isEmpty(srcMatList)) {
            throw new OcrException("srcMatList \u4e0d\u80fd\u4e3a\u7a7a");
        }
        for (int i = 0; i < srcMatList.size(); ++i) {
            List<OcrBox> ocrBoxes = boxList.get(i);
            Mat mat = srcMatList.get(i);
            if (ocrBoxes == null) {
                throw new OcrException("\u7b2c " + i + " \u4e2a boxList \u4e3a null");
            }
            if (ocrBoxes.isEmpty()) {
                throw new OcrException("\u7b2c " + i + " \u4e2a boxList \u6ca1\u6709\u68c0\u6d4b\u7ed3\u679c");
            }
            if (!mat.empty()) continue;
            throw new OcrException("\u7b2c " + i + " \u5f20\u56fe\u7247\u4e3a\u7a7a Mat");
        }
        ArrayList<Image> imageList = new ArrayList<Image>();
        ArrayList<Boolean> isRotatedList = new ArrayList<Boolean>();
        int index = 0;
        try (NDManager manager = this.model.getNDManager().newSubManager();){
            for (int i = 0; i < srcMatList.size(); ++i) {
                for (int j = 0; j < boxList.get(i).size(); ++j) {
                    Image subImg = OcrUtils.transformAndCrop(srcMatList.get(i), boxList.get(i).get(j));
                    if ((double)subImg.getHeight() * 1.0 / (double)subImg.getWidth() > 1.5) {
                        subImg = OcrUtils.rotateImg(manager, subImg);
                        isRotatedList.add(true);
                        imageList.add(subImg);
                    } else {
                        isRotatedList.add(false);
                        imageList.add(subImg);
                    }
                    ++index;
                }
            }
            ArrayList<List<OcrItem>> result = new ArrayList<List<OcrItem>>();
            List<DirectionInfo> directionInfos = this.batchDetect(imageList);
            if (CollectionUtils.isEmpty(directionInfos)) {
                throw new OcrException("\u65b9\u5411\u68c0\u6d4b\u5931\u8d25");
            }
            index = 0;
            for (int i = 0; i < srcMatList.size(); ++i) {
                ArrayList<OcrItem> ocrItemList = new ArrayList<OcrItem>();
                for (int j = 0; j < boxList.get(i).size(); ++j) {
                    DirectionInfo directionInfo = directionInfos.get(index);
                    if (Objects.isNull(directionInfo)) {
                        throw new OcrException("\u65b9\u5411\u68c0\u6d4b\u5931\u8d25: \u7b2c" + i + "\u5f20\u56fe\u7247, \u7b2c" + j + "\u4e2a\u6587\u672c\u5757\uff0c\u672a\u68c0\u6d4b\u5230\u65b9\u5411");
                    }
                    String angle = ((Boolean)isRotatedList.get(index)).booleanValue() ? (directionInfo.getName().equalsIgnoreCase("Rotate") ? "270" : "90") : (directionInfo.getName().equalsIgnoreCase("No Rotate") ? "0" : "180");
                    OcrItem ocrItem = new OcrItem(boxList.get(i).get(j), AngleEnum.fromValue(angle), directionInfo.getProb().floatValue());
                    ocrItemList.add(ocrItem);
                    ++index;
                }
                result.add(ocrItemList);
            }
            ArrayList<List<OcrItem>> arrayList = result;
            return arrayList;
        }
    }

    private List<DirectionInfo> batchDetect(List<Image> imageList) {
        Predictor predictor = null;
        try {
            predictor = (Predictor)this.predictorPool.borrowObject();
            List list = predictor.batchPredict(imageList);
            return list;
        }
        catch (Exception e) {
            throw new OcrException("OCR\u68c0\u6d4b\u9519\u8bef", e);
        }
        finally {
            if (predictor != null) {
                try {
                    this.predictorPool.returnObject((Object)predictor);
                }
                catch (Exception e) {
                    log.warn("\u5f52\u8fd8Predictor\u5931\u8d25", (Throwable)e);
                    try {
                        predictor.close();
                    }
                    catch (Exception ex) {
                        log.error("\u5173\u95edPredictor\u5931\u8d25", (Throwable)ex);
                    }
                }
            }
        }
    }

    @Override
    public void setTextDetModel(OcrCommonDetModel detModel) {
        this.textDetModel = detModel;
    }

    @Override
    public OcrCommonDetModel getTextDetModel() {
        return this.textDetModel;
    }

    @Override
    public GenericObjectPool<Predictor<Image, DirectionInfo>> getPool() {
        return this.predictorPool;
    }

    @Override
    public void close() throws Exception {
        try {
            if (this.predictorPool != null) {
                this.predictorPool.close();
            }
        }
        catch (Exception e) {
            log.warn("\u5173\u95ed predictorPool \u5931\u8d25", (Throwable)e);
        }
        try {
            if (this.model != null) {
                this.model.close();
            }
        }
        catch (Exception e) {
            log.warn("\u5173\u95ed model \u5931\u8d25", (Throwable)e);
        }
    }
}

