/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.ocr.model.table.criteria;

import ai.djl.Device;
import ai.djl.modality.cv.Image;
import ai.djl.repository.zoo.Criteria;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Translator;
import ai.djl.util.Progress;
import cn.smartjavaai.common.enums.DeviceEnum;
import cn.smartjavaai.ocr.config.TableStructureConfig;
import cn.smartjavaai.ocr.entity.TableStructureResult;
import cn.smartjavaai.ocr.enums.TableStructureModelEnum;
import cn.smartjavaai.ocr.model.table.translator.TableStructTranslator;
import java.nio.file.Paths;
import java.util.Objects;

public class StructureCriteriaFactory {
    public static Criteria<Image, TableStructureResult> createCriteria(TableStructureConfig config) {
        Device device = null;
        if (!Objects.isNull(config.getDevice())) {
            device = config.getDevice() == DeviceEnum.CPU ? Device.cpu() : Device.gpu((int)config.getGpuId());
        }
        Criteria criteria = null;
        if (config.getModelEnum() == TableStructureModelEnum.SLANET) {
            criteria = Criteria.builder().optEngine("OnnxRuntime").setTypes(Image.class, TableStructureResult.class).optModelPath(Paths.get(config.getModelPath(), new String[0])).optOption("removePass", "repeated_fc_relu_fuse_pass").optDevice(device).optTranslator((Translator)new TableStructTranslator()).optProgress((Progress)new ProgressBar()).build();
        } else if (config.getModelEnum() == TableStructureModelEnum.SLANET_PLUS) {
            criteria = Criteria.builder().optEngine("OnnxRuntime").setTypes(Image.class, TableStructureResult.class).optModelPath(Paths.get(config.getModelPath(), new String[0])).optOption("removePass", "repeated_fc_relu_fuse_pass").optDevice(device).optTranslator((Translator)new TableStructTranslator()).optProgress((Progress)new ProgressBar()).build();
        }
        return criteria;
    }
}

