/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.engine;

import ai.djl.BaseModel;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.mxnet.engine.MxNDManager;
import ai.djl.mxnet.engine.MxSymbolBlock;
import ai.djl.mxnet.engine.Symbol;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.nn.Parameter;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import ai.djl.util.Utils;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MxModel
extends BaseModel {
    private static final Logger logger = LoggerFactory.getLogger(MxModel.class);

    MxModel(String name, Device device) {
        super(name);
        this.dataType = DataType.FLOAT32;
        this.properties = new ConcurrentHashMap();
        this.manager = MxNDManager.getSystemManager().newSubManager(device);
        this.manager.setName("mxModel");
    }

    public void load(Path modelPath, String prefix, Map<String, ?> options) throws IOException, MalformedModelException {
        boolean trainParam;
        this.setModelDir(modelPath);
        this.wasLoaded = true;
        if (prefix == null) {
            prefix = this.modelName;
        }
        boolean hasParameter = true;
        String optimization = null;
        if (options != null) {
            String paramOption = (String)options.get("hasParameter");
            if (paramOption != null) {
                hasParameter = Boolean.parseBoolean(paramOption);
            }
            optimization = (String)options.get("MxOptimizeFor");
        }
        Path paramFile = this.paramPathResolver(prefix, options);
        if (hasParameter && paramFile == null && (paramFile = this.paramPathResolver(prefix = this.modelDir.toFile().getName(), options)) == null && this.block == null) {
            throw new FileNotFoundException("Parameter file with prefix: " + prefix + " not found in: " + this.modelDir + " or not readable by the engine.");
        }
        if (this.block == null) {
            Path symbolFile = this.modelDir.resolve(prefix + "-symbol.json");
            if (Files.notExists(symbolFile, new LinkOption[0])) {
                throw new FileNotFoundException("Symbol file not found: " + symbolFile + ", please set block manually for imperative model.");
            }
            Symbol symbol = Symbol.load((MxNDManager)this.manager, symbolFile.toAbsolutePath().toString());
            this.block = new MxSymbolBlock(this.manager, symbol);
        }
        if (hasParameter) {
            this.loadParameters(paramFile, options);
        }
        if (optimization != null) {
            ((MxSymbolBlock)this.block).optimizeFor(optimization);
        }
        boolean bl = trainParam = options != null && Boolean.parseBoolean((String)options.get("trainParam"));
        if (!trainParam) {
            // empty if block
        }
    }

    public Trainer newTrainer(TrainingConfig trainingConfig) {
        PairList initializer = trainingConfig.getInitializers();
        if (this.block == null) {
            throw new IllegalStateException("You must set a block for the model before creating a new trainer");
        }
        if (this.wasLoaded) {
            // empty if block
        }
        for (Pair pair : initializer) {
            if (pair.getKey() == null || pair.getValue() == null) continue;
            this.block.setInitializer((Initializer)pair.getKey(), (Predicate)pair.getValue());
        }
        return new Trainer((Model)this, trainingConfig);
    }

    public String[] getArtifactNames() {
        String[] stringArray;
        block9: {
            Stream<Path> stream = Files.walk(this.modelDir, new FileVisitOption[0]);
            try {
                List files = stream.filter(x$0 -> Files.isRegularFile(x$0, new LinkOption[0])).collect(Collectors.toList());
                ArrayList<String> ret = new ArrayList<String>(files.size());
                for (Path path : files) {
                    String fileName = path.toFile().getName();
                    if (fileName.endsWith(".params") || fileName.endsWith("-symbol.json")) continue;
                    Path relative = this.modelDir.relativize(path);
                    ret.add(relative.toString());
                }
                stringArray = ret.toArray(Utils.EMPTY_ARRAY);
                if (stream == null) break block9;
            }
            catch (Throwable throwable) {
                try {
                    if (stream != null) {
                        try {
                            stream.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (IOException e) {
                    throw new AssertionError("Failed list files", e);
                }
            }
            stream.close();
        }
        return stringArray;
    }

    public void close() {
        JnaUtils.waitAll();
        super.close();
    }

    private void loadParameters(Path paramFile, Map<String, ?> options) throws IOException, MalformedModelException {
        if (this.readParameters(paramFile, options)) {
            return;
        }
        logger.debug("DJL formatted model not found, try to find MXNet model");
        NDList paramNDlist = this.manager.load(paramFile);
        MxSymbolBlock symbolBlock = (MxSymbolBlock)this.block;
        List<Parameter> parameters = symbolBlock.getAllParameters();
        LinkedHashMap map = new LinkedHashMap();
        parameters.forEach(p -> map.put(p.getName(), p));
        for (NDArray nd : paramNDlist) {
            String key = nd.getName();
            if (key == null) {
                throw new IllegalArgumentException("Array names must be present in parameter file");
            }
            String paramName = key.split(":", 2)[1];
            Parameter parameter = (Parameter)map.remove(paramName);
            parameter.setArray(nd);
        }
        symbolBlock.setInputNames(new ArrayList<String>(map.keySet()));
        this.dataType = paramNDlist.head().getDataType();
        logger.debug("MXNet Model {} ({}) loaded successfully.", (Object)paramFile, (Object)this.dataType);
    }
}

