/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.executiongraph.failover;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.Executor;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.checkpoint.CheckpointException;
import org.apache.flink.runtime.checkpoint.CheckpointFailureReason;
import org.apache.flink.runtime.concurrent.FutureUtils;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.GlobalModVersionMismatch;
import org.apache.flink.runtime.executiongraph.SchedulingUtils;
import org.apache.flink.runtime.executiongraph.failover.FailoverStrategy;
import org.apache.flink.runtime.executiongraph.failover.adapter.DefaultFailoverTopology;
import org.apache.flink.runtime.executiongraph.failover.flip1.RestartPipelinedRegionStrategy;
import org.apache.flink.runtime.executiongraph.restart.RestartCallback;
import org.apache.flink.runtime.executiongraph.restart.RestartStrategy;
import org.apache.flink.runtime.jobgraph.JobStatus;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobmanager.scheduler.CoLocationGroup;
import org.apache.flink.runtime.scheduler.ExecutionVertexVersion;
import org.apache.flink.runtime.scheduler.ExecutionVertexVersioner;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AdaptedRestartPipelinedRegionStrategyNG
extends FailoverStrategy {
    private static final Logger LOG = LoggerFactory.getLogger(AdaptedRestartPipelinedRegionStrategyNG.class);
    private final ExecutionGraph executionGraph;
    private final ExecutionVertexVersioner executionVertexVersioner;
    private RestartPipelinedRegionStrategy restartPipelinedRegionStrategy;

    public AdaptedRestartPipelinedRegionStrategyNG(ExecutionGraph executionGraph) {
        this.executionGraph = (ExecutionGraph)Preconditions.checkNotNull((Object)executionGraph);
        this.executionVertexVersioner = new ExecutionVertexVersioner();
    }

    @Override
    public void onTaskFailure(Execution taskExecution, Throwable cause) {
        if (!this.executionGraph.getRestartStrategy().canRestart()) {
            LOG.info("Fail to pass the restart strategy validation in region failover. Fallback to fail global.");
            this.failGlobal(cause);
            return;
        }
        if (!this.isLocalFailoverValid(this.executionGraph.getGlobalModVersion())) {
            LOG.info("Skip current region failover as a global failover is ongoing.");
            return;
        }
        ExecutionVertexID vertexID = this.getExecutionVertexID(taskExecution.getVertex());
        Set<ExecutionVertexID> tasksToRestart = this.restartPipelinedRegionStrategy.getTasksNeedingRestart(vertexID, cause);
        this.restartTasks(tasksToRestart);
    }

    @VisibleForTesting
    protected void restartTasks(Set<ExecutionVertexID> verticesToRestart) {
        long globalModVersion = this.executionGraph.getGlobalModVersion();
        HashSet<ExecutionVertexVersion> vertexVersions = new HashSet<ExecutionVertexVersion>(this.executionVertexVersioner.recordVertexModifications(verticesToRestart).values());
        FutureUtils.assertNoException(((CompletableFuture)this.cancelTasks(verticesToRestart).thenComposeAsync(this.resetAndRescheduleTasks(globalModVersion, vertexVersions), (Executor)this.executionGraph.getJobMasterMainThreadExecutor())).handle(this.failGlobalOnError()));
    }

    private Function<Object, CompletableFuture<Void>> resetAndRescheduleTasks(long globalModVersion, Set<ExecutionVertexVersion> vertexVersions) {
        return ignored -> {
            RestartStrategy restartStrategy = this.executionGraph.getRestartStrategy();
            return restartStrategy.restart(this.createResetAndRescheduleTasksCallback(globalModVersion, vertexVersions), this.executionGraph.getJobMasterMainThreadExecutor());
        };
    }

    private RestartCallback createResetAndRescheduleTasksCallback(long globalModVersion, Set<ExecutionVertexVersion> vertexVersions) {
        return () -> {
            if (!this.isLocalFailoverValid(globalModVersion)) {
                LOG.info("Skip current region failover as a global failover is ongoing.");
                return;
            }
            Set<ExecutionVertex> unmodifiedVertices = this.executionVertexVersioner.getUnmodifiedExecutionVertices(vertexVersions).stream().map(this::getExecutionVertex).collect(Collectors.toSet());
            try {
                LOG.info("Finally restart {} tasks to recover from task failure.", (Object)unmodifiedVertices.size());
                this.resetTasks(unmodifiedVertices, globalModVersion);
                this.rescheduleTasks(unmodifiedVertices, globalModVersion);
            }
            catch (GlobalModVersionMismatch e) {
                throw new IllegalStateException("Bug: ExecutionGraph was concurrently modified outside of main thread", e);
            }
            catch (Exception e) {
                throw new CompletionException(e);
            }
        };
    }

    private BiFunction<Object, Throwable, Object> failGlobalOnError() {
        return (ignored, t) -> {
            if (t != null) {
                LOG.info("Unexpected error happens in region failover. Fail globally.", t);
                this.failGlobal((Throwable)t);
            }
            return null;
        };
    }

    @VisibleForTesting
    protected CompletableFuture<?> cancelTasks(Set<ExecutionVertexID> vertices) {
        List cancelFutures = vertices.stream().map(this::cancelExecutionVertex).collect(Collectors.toList());
        return FutureUtils.combineAll(cancelFutures);
    }

    private void resetTasks(Set<ExecutionVertex> vertices, long globalModVersion) throws Exception {
        HashSet<CoLocationGroup> colGroups = new HashSet<CoLocationGroup>();
        long restartTimestamp = System.currentTimeMillis();
        for (ExecutionVertex ev : vertices) {
            CoLocationGroup cgroup = ev.getJobVertex().getCoLocationGroup();
            if (cgroup != null && !colGroups.contains(cgroup)) {
                cgroup.resetConstraints();
                colGroups.add(cgroup);
            }
            ev.resetForNewExecution(restartTimestamp, globalModVersion);
        }
        if (this.executionGraph.getCheckpointCoordinator() != null) {
            this.executionGraph.getCheckpointCoordinator().abortPendingCheckpoints(new CheckpointException(CheckpointFailureReason.JOB_FAILOVER_REGION));
            Map<JobVertexID, ExecutionJobVertex> involvedExecutionJobVertices = this.getInvolvedExecutionJobVertices(vertices);
            this.executionGraph.getCheckpointCoordinator().restoreLatestCheckpointedState(involvedExecutionJobVertices, false, true);
        }
    }

    private void rescheduleTasks(Set<ExecutionVertex> vertices, long globalModVersion) throws Exception {
        List<ExecutionVertex> sortedVertices = this.sortVerticesTopologically(vertices);
        CompletableFuture<Void> newSchedulingFuture = SchedulingUtils.schedule(this.executionGraph.getScheduleMode(), sortedVertices, this.executionGraph);
        if (this.isLocalFailoverValid(globalModVersion)) {
            newSchedulingFuture.whenComplete((ignored, throwable) -> {
                Throwable strippedThrowable;
                if (throwable != null && !((strippedThrowable = ExceptionUtils.stripCompletionException((Throwable)throwable)) instanceof CancellationException)) {
                    this.failGlobal(strippedThrowable);
                }
            });
        }
    }

    private boolean isLocalFailoverValid(long globalModVersion) {
        return this.executionGraph.getState() == JobStatus.RUNNING && this.executionGraph.getGlobalModVersion() == globalModVersion;
    }

    private CompletableFuture<?> cancelExecutionVertex(ExecutionVertexID executionVertexId) {
        return this.getExecutionVertex(executionVertexId).cancel();
    }

    private Map<JobVertexID, ExecutionJobVertex> getInvolvedExecutionJobVertices(Set<ExecutionVertex> executionVertices) {
        HashMap<JobVertexID, ExecutionJobVertex> tasks = new HashMap<JobVertexID, ExecutionJobVertex>();
        for (ExecutionVertex executionVertex : executionVertices) {
            JobVertexID jobvertexId = executionVertex.getJobvertexId();
            ExecutionJobVertex jobVertex = executionVertex.getJobVertex();
            tasks.putIfAbsent(jobvertexId, jobVertex);
        }
        return tasks;
    }

    private void failGlobal(Throwable cause) {
        this.executionGraph.failGlobal(cause);
    }

    private ExecutionVertex getExecutionVertex(ExecutionVertexID vertexID) {
        return this.executionGraph.getAllVertices().get((Object)vertexID.getJobVertexId()).getTaskVertices()[vertexID.getSubtaskIndex()];
    }

    private ExecutionVertexID getExecutionVertexID(ExecutionVertex vertex) {
        return new ExecutionVertexID(vertex.getJobvertexId(), vertex.getParallelSubtaskIndex());
    }

    private List<ExecutionVertex> sortVerticesTopologically(Set<ExecutionVertex> vertices) {
        HashMap verticesMap = new HashMap();
        for (ExecutionVertex vertex : vertices) {
            verticesMap.computeIfAbsent(vertex.getJobvertexId(), id -> new ArrayList()).add(vertex);
        }
        ArrayList<ExecutionVertex> sortedVertices = new ArrayList<ExecutionVertex>(vertices.size());
        for (ExecutionJobVertex jobVertex : this.executionGraph.getVerticesTopologically()) {
            sortedVertices.addAll(verticesMap.getOrDefault((Object)jobVertex.getJobVertexId(), Collections.emptyList()));
        }
        return sortedVertices;
    }

    @Override
    public void notifyNewVertices(List<ExecutionJobVertex> newJobVerticesTopological) {
        Preconditions.checkState((this.restartPipelinedRegionStrategy == null ? 1 : 0) != 0, (Object)"notifyNewVertices() must be called only once");
        this.restartPipelinedRegionStrategy = new RestartPipelinedRegionStrategy(new DefaultFailoverTopology(this.executionGraph), this.executionGraph.getResultPartitionAvailabilityChecker());
    }

    @Override
    public String getStrategyName() {
        return "New Pipelined Region Failover";
    }

    public static class Factory
    implements FailoverStrategy.Factory {
        @Override
        public FailoverStrategy create(ExecutionGraph executionGraph) {
            return new AdaptedRestartPipelinedRegionStrategyNG(executionGraph);
        }
    }
}

