/*
 * Decompiled with CFR 0.152.
 */
package org.apache.dolphinscheduler.plugin.task.mlflow;

import java.util.ArrayList;
import java.util.Map;
import java.util.regex.Pattern;
import org.apache.commons.lang3.StringUtils;
import org.apache.dolphinscheduler.common.thread.ThreadUtils;
import org.apache.dolphinscheduler.common.utils.JSONUtils;
import org.apache.dolphinscheduler.common.utils.PropertyUtils;
import org.apache.dolphinscheduler.plugin.task.api.AbstractTask;
import org.apache.dolphinscheduler.plugin.task.api.ShellCommandExecutor;
import org.apache.dolphinscheduler.plugin.task.api.TaskCallBack;
import org.apache.dolphinscheduler.plugin.task.api.TaskException;
import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext;
import org.apache.dolphinscheduler.plugin.task.api.model.Property;
import org.apache.dolphinscheduler.plugin.task.api.model.TaskResponse;
import org.apache.dolphinscheduler.plugin.task.api.parser.ParamUtils;
import org.apache.dolphinscheduler.plugin.task.api.parser.ParameterUtils;
import org.apache.dolphinscheduler.plugin.task.api.utils.OSUtils;
import org.apache.dolphinscheduler.plugin.task.mlflow.MlflowParameters;

public class MlflowTask
extends AbstractTask {
    private static final Pattern GIT_CHECK_PATTERN = Pattern.compile("^(git@|https?://)");
    private final ShellCommandExecutor shellCommandExecutor;
    private final TaskExecutionContext taskExecutionContext;
    private MlflowParameters mlflowParameters;

    public MlflowTask(TaskExecutionContext taskExecutionContext) {
        super(taskExecutionContext);
        this.taskExecutionContext = taskExecutionContext;
        this.shellCommandExecutor = new ShellCommandExecutor(arg_0 -> ((MlflowTask)this).logHandle(arg_0), taskExecutionContext, this.logger);
    }

    public static String getPresetRepository() {
        String presetRepository = PropertyUtils.getString((String)"ml.mlflow.preset_repository");
        if (StringUtils.isEmpty((CharSequence)presetRepository)) {
            presetRepository = "https://github.com/apache/dolphinscheduler-mlflow";
        }
        return presetRepository;
    }

    public static String getPresetRepositoryVersion() {
        String version = PropertyUtils.getString((String)"ml.mlflow.preset_repository_version");
        if (StringUtils.isEmpty((CharSequence)version)) {
            version = "main";
        }
        return version;
    }

    public static String getVersionString(String version, String repository) {
        String versionString = StringUtils.isEmpty((CharSequence)version) ? "" : (GIT_CHECK_PATTERN.matcher(repository).find() ? String.format("--version=%s", version) : "");
        return versionString;
    }

    public void init() {
        this.logger.info("shell task params {}", (Object)this.taskExecutionContext.getTaskParams());
        this.mlflowParameters = (MlflowParameters)((Object)JSONUtils.parseObject((String)this.taskExecutionContext.getTaskParams(), MlflowParameters.class));
        if (!this.mlflowParameters.checkParameters()) {
            throw new RuntimeException("shell task params is not valid");
        }
    }

    public void handle(TaskCallBack taskCallBack) throws TaskException {
        try {
            String command = this.buildCommand();
            TaskResponse commandExecuteResult = this.shellCommandExecutor.run(command);
            int exitCode = this.mlflowParameters.getIsDeployDocker() ? this.checkDockerHealth() : commandExecuteResult.getExitStatusCode();
            this.setExitStatusCode(exitCode);
            this.setProcessId(commandExecuteResult.getProcessId());
            this.mlflowParameters.dealOutParam(this.shellCommandExecutor.getVarPool());
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            this.logger.error("The current Mlflow task has been interrupted", (Throwable)e);
            this.setExitStatusCode(-1);
            throw new TaskException("The current Mlflow task has been interrupted", (Throwable)e);
        }
        catch (Exception e) {
            this.logger.error("Mlflow task error", (Throwable)e);
            this.setExitStatusCode(-1);
            throw new TaskException("Execute Mlflow task failed", (Throwable)e);
        }
    }

    public void cancel() throws TaskException {
        try {
            this.shellCommandExecutor.cancelApplication();
        }
        catch (Exception e) {
            throw new TaskException("cancel application error", (Throwable)e);
        }
    }

    public String buildCommand() {
        String command = "";
        if (this.mlflowParameters.getMlflowTaskType().equals("MLflow Projects")) {
            command = this.buildCommandForMlflowProjects();
        } else if (this.mlflowParameters.getMlflowTaskType().equals("MLflow Models")) {
            command = this.buildCommandForMlflowModels();
        }
        this.logger.info("mlflow task command: \n{}", (Object)command);
        return command;
    }

    private String buildCommandForMlflowProjects() {
        String runCommand;
        Map<String, Property> paramsMap = this.getParamsMap();
        ArrayList<String> args = new ArrayList<String>();
        args.add(String.format("export MLFLOW_TRACKING_URI=%s", this.mlflowParameters.getMlflowTrackingUri()));
        String versionString = this.mlflowParameters.isCustomProject() != false ? MlflowTask.getVersionString(this.mlflowParameters.getMlflowProjectVersion(), this.mlflowParameters.getMlflowProjectRepository()) : MlflowTask.getVersionString(MlflowTask.getPresetRepositoryVersion(), MlflowTask.getPresetRepository());
        switch (this.mlflowParameters.getMlflowJobType()) {
            case "BasicAlgorithm": {
                args.add(String.format("data_path=%s", this.mlflowParameters.getDataPath()));
                String repoBasicAlgorithm = MlflowTask.getPresetRepository() + "#Project-BasicAlgorithm";
                args.add(String.format("repo=%s", repoBasicAlgorithm));
                runCommand = "mlflow run $repo -P algorithm=%s -P data_path=$data_path -P params=\"%s\" -P search_params=\"%s\" -P model_name=\"%s\" --experiment-name=\"%s\"";
                runCommand = String.format(runCommand, this.mlflowParameters.getAlgorithm(), this.mlflowParameters.getParams(), this.mlflowParameters.getSearchParams(), this.mlflowParameters.getModelName(), this.mlflowParameters.getExperimentName());
                break;
            }
            case "AutoML": {
                args.add(String.format("data_path=%s", this.mlflowParameters.getDataPath()));
                String repoAutoML = MlflowTask.getPresetRepository() + "#Project-AutoML";
                args.add(String.format("repo=%s", repoAutoML));
                runCommand = "mlflow run $repo -P tool=%s -P data_path=$data_path -P params=\"%s\" -P model_name=\"%s\" --experiment-name=\"%s\"";
                runCommand = String.format(runCommand, this.mlflowParameters.getAutomlTool(), this.mlflowParameters.getParams(), this.mlflowParameters.getModelName(), this.mlflowParameters.getExperimentName());
                break;
            }
            case "CustomProject": {
                args.add(String.format("repo=%s", this.mlflowParameters.getMlflowProjectRepository()));
                runCommand = "mlflow run $repo %s --experiment-name=\"%s\"";
                runCommand = String.format(runCommand, this.mlflowParameters.getParams(), this.mlflowParameters.getExperimentName());
                break;
            }
            default: {
                throw new TaskException("Unsupported mlflow job type: " + this.mlflowParameters.getMlflowJobType());
            }
        }
        if (StringUtils.isNotEmpty((CharSequence)versionString)) {
            runCommand = runCommand + " " + versionString;
        }
        args.add(runCommand);
        return ParameterUtils.convertParameterPlaceholders((String)String.join((CharSequence)"\n", args), (Map)ParamUtils.convert(paramsMap));
    }

    protected String buildCommandForMlflowModels() {
        Map<String, Property> paramsMap = this.getParamsMap();
        ArrayList<String> args = new ArrayList<String>();
        args.add(String.format("export MLFLOW_TRACKING_URI=%s", this.mlflowParameters.getMlflowTrackingUri()));
        String deployModelKey = this.mlflowParameters.getDeployModelKey();
        if (this.mlflowParameters.getDeployType().equals("MLFLOW")) {
            args.add(String.format("mlflow models serve -m %s --port %s -h 0.0.0.0", deployModelKey, this.mlflowParameters.getDeployPort()));
        } else if (this.mlflowParameters.getDeployType().equals("DOCKER")) {
            String imageName = "mlflow/" + this.mlflowParameters.getModelKeyName(":");
            String containerName = this.mlflowParameters.getContainerName();
            args.add(String.format("mlflow models build-docker -m %s -n %s --enable-mlserver", deployModelKey, imageName));
            args.add(String.format("docker rm -f %s", containerName));
            args.add(String.format("docker run -d --name=%s -p=%s:8080 --health-cmd \"curl --fail http://127.0.0.1:8080/ping || exit 1\" --health-interval 5s --health-retries 20 %s", containerName, this.mlflowParameters.getDeployPort(), imageName));
        }
        return ParameterUtils.convertParameterPlaceholders((String)String.join((CharSequence)"\n", args), (Map)ParamUtils.convert(paramsMap));
    }

    private Map<String, Property> getParamsMap() {
        return this.taskExecutionContext.getPrepareParamsMap();
    }

    public int checkDockerHealth() {
        this.logger.info("checking container healthy ... ");
        int exitCode = -1;
        String[] command = new String[]{"sh", "-c", String.format("docker inspect --format \"{{json .State.Health.Status }}\" %s", this.mlflowParameters.getContainerName())};
        for (int x = 0; x < 20; ++x) {
            String status;
            try {
                status = OSUtils.exeShell((String[])command).replace("\n", "").replace("\"", "");
            }
            catch (Exception e) {
                status = String.format("error --- %s", e.getMessage());
            }
            this.logger.info("container healthy status: {}", (Object)status);
            if (status.equals("healthy")) {
                exitCode = 0;
                this.logger.info("container is healthy");
                return exitCode;
            }
            this.logger.info("The health check has been running for {} seconds", (Object)(x * 5000 / 1000));
            ThreadUtils.sleep((long)5000L);
        }
        this.logger.info("health check fail");
        return exitCode;
    }

    public MlflowParameters getParameters() {
        return this.mlflowParameters;
    }
}

