/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.translation.model;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.engine.Engine;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.Translator;
import cn.smartjavaai.common.entity.R;
import cn.smartjavaai.common.enums.DeviceEnum;
import cn.smartjavaai.common.pool.CommonPredictorFactory;
import cn.smartjavaai.translation.config.NllbSearchConfig;
import cn.smartjavaai.translation.config.TranslationModelConfig;
import cn.smartjavaai.translation.entity.TranslateParam;
import cn.smartjavaai.translation.exception.TranslationException;
import cn.smartjavaai.translation.model.TranslationModel;
import cn.smartjavaai.translation.model.translator.NllbDecoder2Translator;
import cn.smartjavaai.translation.model.translator.NllbDecoderTranslator;
import cn.smartjavaai.translation.model.translator.NllbEncoderTranslator;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Objects;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.pool2.PooledObjectFactory;
import org.apache.commons.pool2.impl.GenericObjectPool;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NllbModel
implements TranslationModel {
    private static final Logger log = LoggerFactory.getLogger(NllbModel.class);
    private GenericObjectPool<Predictor<?, ?>> encodePredictorPool;
    private GenericObjectPool<Predictor<?, ?>> decodePredictorPool;
    private GenericObjectPool<Predictor<?, ?>> decode2PredictorPool;
    private ZooModel<NDList, NDList> nllbModel;
    private HuggingFaceTokenizer tokenizer;
    private NllbSearchConfig searchConfig;
    private TranslationModelConfig config;

    @Override
    public void loadModel(TranslationModelConfig config) {
        if (StringUtils.isBlank((CharSequence)config.getModelPath())) {
            throw new TranslationException("modelPath is null");
        }
        Device device = null;
        if (!Objects.isNull(config.getDevice())) {
            device = config.getDevice() == DeviceEnum.CPU ? Device.cpu() : Device.gpu((int)config.getGpuId());
        }
        this.config = config;
        Path modelPath = Paths.get(config.getModelPath(), new String[0]);
        Criteria criteria = Criteria.builder().setTypes(NDList.class, NDList.class).optModelPath(modelPath).optEngine("PyTorch").optDevice(device).optTranslator((Translator)new NoopTranslator()).build();
        try {
            this.nllbModel = ModelZoo.loadModel((Criteria)criteria);
            this.encodePredictorPool = new GenericObjectPool((PooledObjectFactory)new CommonPredictorFactory(this.nllbModel, (NoBatchifyTranslator)new NllbEncoderTranslator()));
            this.decodePredictorPool = new GenericObjectPool((PooledObjectFactory)new CommonPredictorFactory(this.nllbModel, (NoBatchifyTranslator)new NllbDecoderTranslator()));
            this.decode2PredictorPool = new GenericObjectPool((PooledObjectFactory)new CommonPredictorFactory(this.nllbModel, (NoBatchifyTranslator)new NllbDecoder2Translator()));
            Path tokenizerPath = modelPath.getParent().resolve("tokenizer.json");
            this.tokenizer = HuggingFaceTokenizer.newInstance((Path)tokenizerPath);
            this.searchConfig = new NllbSearchConfig();
            int predictorPoolSize = config.getPredictorPoolSize();
            if (config.getPredictorPoolSize() <= 0) {
                predictorPoolSize = Runtime.getRuntime().availableProcessors();
            }
            this.encodePredictorPool.setMaxTotal(predictorPoolSize);
            this.decodePredictorPool.setMaxTotal(predictorPoolSize);
            this.decode2PredictorPool.setMaxTotal(predictorPoolSize);
            log.debug("\u5f53\u524d\u8bbe\u5907: " + this.nllbModel.getNDManager().getDevice());
            log.debug("\u5f53\u524d\u5f15\u64ce: " + Engine.getInstance().getEngineName());
            log.debug("\u6a21\u578b\u63a8\u7406\u5668\u7ebf\u7a0b\u6c60\u6700\u5927\u6570\u91cf: " + predictorPoolSize);
        }
        catch (MalformedModelException | ModelNotFoundException | IOException e) {
            throw new TranslationException("\u6a21\u578b\u52a0\u8f7d\u5931\u8d25", e);
        }
    }

    @Override
    public R<String> translate(TranslateParam translateParam) {
        if (translateParam == null) {
            return R.fail((R.Status)R.Status.PARAM_ERROR);
        }
        R<String> validateResult = translateParam.validate();
        if (!validateResult.isSuccess()) {
            return validateResult;
        }
        this.searchConfig.setSrcLangId(translateParam.getSourceLanguage().getId());
        this.searchConfig.setForcedBosTokenId(translateParam.getTargetLanguage().getId());
        return R.ok((Object)this.translateLanguage(translateParam));
    }

    /*
     * Exception decompiling
     */
    private String translateLanguage(TranslateParam translateParam) {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 2 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    public NDArray greedyStepGen(NllbSearchConfig config, NDArray pastOutputIds, NDArray next_token_scores, NDManager manager) {
        next_token_scores = next_token_scores.get(":, -1, :", new Object[0]);
        NDArray new_next_token_scores = manager.create(next_token_scores.getShape(), next_token_scores.getDataType());
        next_token_scores.copyTo(new_next_token_scores);
        long cur_len = pastOutputIds.getShape().getLastDimension();
        if (cur_len == 1L) {
            long num_tokens = new_next_token_scores.getShape().getLastDimension();
            for (long i = 0L; i < num_tokens; ++i) {
                if (i == config.getForcedBosTokenId()) continue;
                new_next_token_scores.set(new NDIndex(":," + i, new Object[0]), (Number)Float.valueOf(Float.NEGATIVE_INFINITY));
            }
            new_next_token_scores.set(new NDIndex(":," + config.getForcedBosTokenId(), new Object[0]), (Number)0);
        }
        NDArray probs = new_next_token_scores.softmax(-1);
        NDArray next_tokens = probs.argMax(-1);
        return next_tokens.expandDims(0);
    }

    public GenericObjectPool<Predictor<?, ?>> getEncodePredictorPool() {
        return this.encodePredictorPool;
    }

    public GenericObjectPool<Predictor<?, ?>> getDecodePredictorPool() {
        return this.decodePredictorPool;
    }

    public GenericObjectPool<Predictor<?, ?>> getDecode2PredictorPool() {
        return this.decode2PredictorPool;
    }

    @Override
    public void close() throws Exception {
        try {
            if (this.nllbModel != null) {
                this.nllbModel.close();
            }
        }
        catch (Exception e) {
            log.warn("\u5173\u95ed model \u5931\u8d25", (Throwable)e);
        }
        try {
            if (this.tokenizer != null) {
                this.tokenizer.close();
            }
        }
        catch (Exception e) {
            log.warn("\u5173\u95ed tokenizer \u5931\u8d25", (Throwable)e);
        }
        try {
            if (this.encodePredictorPool != null) {
                this.encodePredictorPool.close();
            }
        }
        catch (Exception e) {
            log.warn("\u5173\u95ed encodePredictorPool \u5931\u8d25", (Throwable)e);
        }
        try {
            if (this.decodePredictorPool != null) {
                this.decodePredictorPool.close();
            }
        }
        catch (Exception e) {
            log.warn("\u5173\u95ed decodePredictorPool \u5931\u8d25", (Throwable)e);
        }
        try {
            if (this.decode2PredictorPool != null) {
                this.decode2PredictorPool.close();
            }
        }
        catch (Exception e) {
            log.warn("\u5173\u95ed decode2PredictorPool \u5931\u8d25", (Throwable)e);
        }
    }
}

