/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.BooleanSupplier;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
import org.apache.flink.runtime.checkpoint.OperatorState;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.RescaleMappings;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.StateObject;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class TaskStateAssignment {
    private static final Logger LOG = LoggerFactory.getLogger(TaskStateAssignment.class);
    final ExecutionJobVertex executionJobVertex;
    final Map<OperatorID, OperatorState> oldState;
    final boolean hasState;
    final boolean hasInputState;
    final boolean hasOutputState;
    final int newParallelism;
    final OperatorID inputOperatorID;
    final OperatorID outputOperatorID;
    final Map<OperatorInstanceID, List<OperatorStateHandle>> subManagedOperatorState;
    final Map<OperatorInstanceID, List<OperatorStateHandle>> subRawOperatorState;
    final Map<OperatorInstanceID, List<KeyedStateHandle>> subManagedKeyedState;
    final Map<OperatorInstanceID, List<KeyedStateHandle>> subRawKeyedState;
    final Map<OperatorInstanceID, List<InputChannelStateHandle>> inputChannelStates;
    final Map<OperatorInstanceID, List<ResultSubpartitionStateHandle>> resultSubpartitionStates;
    @Nullable
    private RescaleMappings outputSubtaskMappings;
    @Nullable
    private RescaleMappings inputSubtaskMappings;
    boolean mayHaveAmbiguousSubtasks;
    @Nullable
    private TaskStateAssignment[] downstreamAssignments;
    @Nullable
    private TaskStateAssignment[] upstreamAssignments;
    private final Map<IntermediateDataSetID, TaskStateAssignment> consumerAssignment;
    private final Map<ExecutionJobVertex, TaskStateAssignment> vertexAssignments;

    public TaskStateAssignment(ExecutionJobVertex executionJobVertex, Map<OperatorID, OperatorState> oldState, Map<IntermediateDataSetID, TaskStateAssignment> consumerAssignment, Map<ExecutionJobVertex, TaskStateAssignment> vertexAssignments) {
        this.executionJobVertex = executionJobVertex;
        this.oldState = oldState;
        this.hasState = oldState.values().stream().anyMatch(operatorState -> operatorState.getNumberCollectedStates() > 0);
        this.newParallelism = executionJobVertex.getParallelism();
        this.consumerAssignment = (Map)Preconditions.checkNotNull(consumerAssignment);
        this.vertexAssignments = (Map)Preconditions.checkNotNull(vertexAssignments);
        int expectedNumberOfSubtasks = this.newParallelism * oldState.size();
        this.subManagedOperatorState = new HashMap<OperatorInstanceID, List<OperatorStateHandle>>(expectedNumberOfSubtasks);
        this.subRawOperatorState = new HashMap<OperatorInstanceID, List<OperatorStateHandle>>(expectedNumberOfSubtasks);
        this.inputChannelStates = new HashMap<OperatorInstanceID, List<InputChannelStateHandle>>(expectedNumberOfSubtasks);
        this.resultSubpartitionStates = new HashMap<OperatorInstanceID, List<ResultSubpartitionStateHandle>>(expectedNumberOfSubtasks);
        this.subManagedKeyedState = new HashMap<OperatorInstanceID, List<KeyedStateHandle>>(expectedNumberOfSubtasks);
        this.subRawKeyedState = new HashMap<OperatorInstanceID, List<KeyedStateHandle>>(expectedNumberOfSubtasks);
        List<OperatorIDPair> operatorIDs = executionJobVertex.getOperatorIDs();
        this.outputOperatorID = operatorIDs.get(0).getGeneratedOperatorID();
        this.inputOperatorID = operatorIDs.get(operatorIDs.size() - 1).getGeneratedOperatorID();
        this.hasInputState = oldState.get((Object)this.inputOperatorID).getStates().stream().anyMatch(subState -> !subState.getInputChannelState().isEmpty());
        this.hasOutputState = oldState.get((Object)this.outputOperatorID).getStates().stream().anyMatch(subState -> !subState.getResultSubpartitionState().isEmpty());
    }

    public TaskStateAssignment[] getDownstreamAssignments() {
        if (this.downstreamAssignments == null) {
            this.downstreamAssignments = (TaskStateAssignment[])Arrays.stream(this.executionJobVertex.getProducedDataSets()).map(result -> this.consumerAssignment.get(result.getId())).toArray(TaskStateAssignment[]::new);
        }
        return this.downstreamAssignments;
    }

    public TaskStateAssignment[] getUpstreamAssignments() {
        if (this.upstreamAssignments == null) {
            this.upstreamAssignments = (TaskStateAssignment[])this.executionJobVertex.getInputs().stream().map(result -> this.vertexAssignments.get(result.getProducer())).toArray(TaskStateAssignment[]::new);
        }
        return this.upstreamAssignments;
    }

    public OperatorSubtaskState getSubtaskState(OperatorInstanceID instanceID) {
        Preconditions.checkState((this.subManagedKeyedState.containsKey(instanceID) || !this.subRawKeyedState.containsKey(instanceID) ? 1 : 0) != 0, (Object)"If an operator has no managed key state, it should also not have a raw keyed state.");
        StateObjectCollection<InputChannelStateHandle> inputState = this.getState(instanceID, this.inputChannelStates);
        StateObjectCollection<ResultSubpartitionStateHandle> outputState = this.getState(instanceID, this.resultSubpartitionStates);
        return OperatorSubtaskState.builder().setManagedOperatorState(this.getState(instanceID, this.subManagedOperatorState)).setRawOperatorState(this.getState(instanceID, this.subRawOperatorState)).setManagedKeyedState(this.getState(instanceID, this.subManagedKeyedState)).setRawKeyedState(this.getState(instanceID, this.subRawKeyedState)).setInputChannelState(inputState).setResultSubpartitionState(outputState).setInputRescalingDescriptor(this.createRescalingDescriptor(instanceID, this.inputOperatorID, this.getUpstreamAssignments(), assignment -> assignment.outputSubtaskMappings, assignment -> assignment.getOutputMapping(Arrays.asList(assignment.getDownstreamAssignments()).indexOf(this)), this.inputSubtaskMappings, () -> this.getInputMapping(0), inputState, () -> this.mayHaveAmbiguousSubtasks)).setOutputRescalingDescriptor(this.createRescalingDescriptor(instanceID, this.outputOperatorID, this.getDownstreamAssignments(), assignment -> assignment.inputSubtaskMappings, assignment -> assignment.getInputMapping(Arrays.asList(assignment.getUpstreamAssignments()).indexOf(this)), this.outputSubtaskMappings, () -> this.getOutputMapping(0), outputState, () -> false)).build();
    }

    private InflightDataRescalingDescriptor log(InflightDataRescalingDescriptor descriptor, int subtask) {
        LOG.debug("created {} for task={} subtask={}", new Object[]{descriptor, this.executionJobVertex.getName(), subtask});
        return descriptor;
    }

    private InflightDataRescalingDescriptor createRescalingDescriptor(OperatorInstanceID instanceID, OperatorID expectedOperatorID, TaskStateAssignment[] connectedAssignments, Function<TaskStateAssignment, RescaleMappings> mappingRetriever, Function<TaskStateAssignment, RescaleMappings> mappingCalculator, @Nullable RescaleMappings subtaskMappings, Supplier<RescaleMappings> subtaskMappingCalculator, StateObjectCollection<?> state, BooleanSupplier mayHaveAmbiguousSubtasks) {
        int[] oldSubtaskInstances;
        if (!expectedOperatorID.equals((Object)instanceID.getOperatorId())) {
            return InflightDataRescalingDescriptor.NO_RESCALE;
        }
        RescaleMappings[] rescaledChannelsMappings = (RescaleMappings[])Arrays.stream(connectedAssignments).map(mappingRetriever).toArray(RescaleMappings[]::new);
        if (subtaskMappings == null && Arrays.stream(rescaledChannelsMappings).allMatch(Objects::isNull)) {
            return InflightDataRescalingDescriptor.NO_RESCALE;
        }
        if (subtaskMappings == null) {
            subtaskMappings = subtaskMappingCalculator.get();
        }
        if ((oldSubtaskInstances = subtaskMappings.getMappedIndexes(instanceID.getSubtaskId())).length == 0) {
            Preconditions.checkState((boolean)state.isEmpty(), (Object)"Unmapped new subtask should not have any state assigned");
            return this.log(InflightDataRescalingDescriptor.NO_RESCALE, instanceID.getSubtaskId());
        }
        for (int partition = 0; partition < rescaledChannelsMappings.length; ++partition) {
            if (rescaledChannelsMappings[partition] != null) continue;
            rescaledChannelsMappings[partition] = mappingCalculator.apply(connectedAssignments[partition]);
        }
        if (subtaskMappings.isIdentity() && Arrays.stream(rescaledChannelsMappings).allMatch(RescaleMappings::isIdentity)) {
            return this.log(InflightDataRescalingDescriptor.NO_RESCALE, instanceID.getSubtaskId());
        }
        Set<Integer> ambiguousSubtasks = mayHaveAmbiguousSubtasks.getAsBoolean() ? subtaskMappings.getAmbiguousTargets() : Collections.emptySet();
        return this.log(new InflightDataRescalingDescriptor(oldSubtaskInstances, rescaledChannelsMappings, ambiguousSubtasks), instanceID.getSubtaskId());
    }

    private <T extends StateObject> StateObjectCollection<T> getState(OperatorInstanceID instanceID, Map<OperatorInstanceID, List<T>> subManagedOperatorState) {
        List<T> value = subManagedOperatorState.get(instanceID);
        return value != null ? new StateObjectCollection<T>(value) : StateObjectCollection.empty();
    }

    public RescaleMappings getOutputMapping(int partitionIndex) {
        TaskStateAssignment downstreamAssignment = this.getDownstreamAssignments()[partitionIndex];
        IntermediateResult output = this.executionJobVertex.getProducedDataSets()[partitionIndex];
        int gateIndex = downstreamAssignment.executionJobVertex.getInputs().indexOf(output);
        SubtaskStateMapper mapper = (SubtaskStateMapper)((Object)Preconditions.checkNotNull((Object)((Object)downstreamAssignment.executionJobVertex.getJobVertex().getInputs().get(gateIndex).getUpstreamSubtaskStateMapper()), (String)"No channel rescaler found during rescaling of channel state"));
        RescaleMappings mapping = mapper.getNewToOldSubtasksMapping(this.oldState.get((Object)this.outputOperatorID).getParallelism(), this.newParallelism);
        this.outputSubtaskMappings = TaskStateAssignment.checkSubtaskMapping(this.outputSubtaskMappings, mapping);
        return this.outputSubtaskMappings;
    }

    public RescaleMappings getInputMapping(int gateIndex) {
        SubtaskStateMapper mapper = (SubtaskStateMapper)((Object)Preconditions.checkNotNull((Object)((Object)this.executionJobVertex.getJobVertex().getInputs().get(gateIndex).getDownstreamSubtaskStateMapper()), (String)"No channel rescaler found during rescaling of channel state"));
        RescaleMappings mapping = mapper.getNewToOldSubtasksMapping(this.oldState.get((Object)this.inputOperatorID).getParallelism(), this.newParallelism);
        this.inputSubtaskMappings = TaskStateAssignment.checkSubtaskMapping(this.inputSubtaskMappings, mapping);
        this.mayHaveAmbiguousSubtasks |= mapper.isAmbiguous();
        return this.inputSubtaskMappings;
    }

    public String toString() {
        return "TaskStateAssignment for " + this.executionJobVertex.getName();
    }

    private static RescaleMappings checkSubtaskMapping(@Nullable RescaleMappings oldMapping, RescaleMappings mapping) {
        if (oldMapping == null) {
            return mapping;
        }
        if (!oldMapping.equals(mapping)) {
            throw new IllegalStateException("Incompatible subtask mappings: are multiple operators ingesting/producing intermediate results with varying degrees of parallelism?Found " + oldMapping + " and " + mapping + ".");
        }
        return oldMapping;
    }
}

