/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.streaming.runtime.io;

import java.io.Closeable;
import java.io.IOException;
import java.util.Arrays;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.concurrent.NotThreadSafe;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.checkpoint.CheckpointException;
import org.apache.flink.runtime.checkpoint.CheckpointFailureReason;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.concurrent.FutureUtils;
import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker;
import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferReceivedListener;
import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.streaming.runtime.io.CheckpointBarrierHandler;
import org.apache.flink.streaming.runtime.tasks.SubtaskCheckpointCoordinator;
import org.apache.flink.util.CloseableIterator;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Internal
@NotThreadSafe
public class CheckpointBarrierUnaligner
extends CheckpointBarrierHandler {
    private static final Logger LOG = LoggerFactory.getLogger(CheckpointBarrierUnaligner.class);
    private final String taskName;
    private final Map<InputChannelInfo, Boolean> hasInflightBuffers;
    private int numBarrierConsumed;
    private long currentConsumedCheckpointId = -1L;
    private final ThreadSafeUnaligner threadSafeUnaligner;

    CheckpointBarrierUnaligner(SubtaskCheckpointCoordinator checkpointCoordinator, String taskName, AbstractInvokable toNotifyOnCheckpoint, InputGate ... inputGates) {
        super(toNotifyOnCheckpoint);
        this.taskName = taskName;
        this.hasInflightBuffers = Arrays.stream(inputGates).flatMap(gate -> gate.getChannelInfos().stream()).collect(Collectors.toMap(Function.identity(), info -> false));
        this.threadSafeUnaligner = new ThreadSafeUnaligner((SubtaskCheckpointCoordinator)Preconditions.checkNotNull((Object)checkpointCoordinator), this, inputGates);
    }

    @Override
    public void processBarrier(CheckpointBarrier receivedBarrier, InputChannelInfo channelInfo) throws Exception {
        long barrierId = receivedBarrier.getId();
        if (this.currentConsumedCheckpointId > barrierId || this.currentConsumedCheckpointId == barrierId && !this.isCheckpointPending()) {
            return;
        }
        if (this.currentConsumedCheckpointId < barrierId) {
            this.currentConsumedCheckpointId = barrierId;
            this.numBarrierConsumed = 0;
            this.hasInflightBuffers.entrySet().forEach(hasInflightBuffer -> hasInflightBuffer.setValue(true));
        }
        if (this.currentConsumedCheckpointId == barrierId) {
            this.hasInflightBuffers.put(channelInfo, false);
            ++this.numBarrierConsumed;
        }
        this.threadSafeUnaligner.notifyBarrierReceived(receivedBarrier, channelInfo);
    }

    @Override
    public void abortPendingCheckpoint(long checkpointId, CheckpointException exception) throws IOException {
        this.threadSafeUnaligner.tryAbortPendingCheckpoint(checkpointId, exception);
        if (checkpointId > this.currentConsumedCheckpointId) {
            this.resetPendingCheckpoint(checkpointId);
        }
    }

    @Override
    public void processCancellationBarrier(CancelCheckpointMarker cancelBarrier) throws Exception {
        long cancelledId = cancelBarrier.getCheckpointId();
        boolean shouldAbort = this.threadSafeUnaligner.setCancelledCheckpointId(cancelledId);
        if (shouldAbort) {
            this.notifyAbort(cancelledId, new CheckpointException(CheckpointFailureReason.CHECKPOINT_DECLINED_ON_CANCELLATION_BARRIER));
        }
        if (cancelledId >= this.currentConsumedCheckpointId) {
            this.resetPendingCheckpoint(cancelledId);
            this.currentConsumedCheckpointId = cancelledId;
        }
    }

    @Override
    public void processEndOfPartition() throws Exception {
        this.threadSafeUnaligner.onChannelClosed();
        this.resetPendingCheckpoint(-1L);
    }

    private void resetPendingCheckpoint(long checkpointId) {
        if (this.isCheckpointPending()) {
            LOG.warn("{}: Received barrier or EndOfPartition(-1) {} before completing current checkpoint {}. Skipping current checkpoint.", new Object[]{this.taskName, checkpointId, this.currentConsumedCheckpointId});
            this.hasInflightBuffers.entrySet().forEach(hasInflightBuffer -> hasInflightBuffer.setValue(false));
            this.numBarrierConsumed = 0;
        }
    }

    @Override
    public long getLatestCheckpointId() {
        return this.currentConsumedCheckpointId;
    }

    public String toString() {
        return String.format("%s: last checkpoint: %d", this.taskName, this.currentConsumedCheckpointId);
    }

    @Override
    public void close() throws IOException {
        super.close();
        this.threadSafeUnaligner.close();
    }

    @Override
    public boolean hasInflightData(long checkpointId, InputChannelInfo channelInfo) {
        if (checkpointId < this.currentConsumedCheckpointId) {
            return false;
        }
        if (checkpointId > this.currentConsumedCheckpointId) {
            return true;
        }
        return this.hasInflightBuffers.get(channelInfo);
    }

    @Override
    public CompletableFuture<Void> getAllBarriersReceivedFuture(long checkpointId) {
        return this.threadSafeUnaligner.getAllBarriersReceivedFuture(checkpointId);
    }

    @Override
    public Optional<BufferReceivedListener> getBufferReceivedListener() {
        return Optional.of(this.threadSafeUnaligner);
    }

    @Override
    protected boolean isCheckpointPending() {
        return this.numBarrierConsumed > 0;
    }

    @VisibleForTesting
    int getNumOpenChannels() {
        return this.threadSafeUnaligner.getNumOpenChannels();
    }

    @VisibleForTesting
    ThreadSafeUnaligner getThreadSafeUnaligner() {
        return this.threadSafeUnaligner;
    }

    private void notifyCheckpoint(CheckpointBarrier barrier) throws IOException {
        if (barrier.getId() >= this.threadSafeUnaligner.getCurrentCheckpointId()) {
            super.notifyCheckpoint(barrier, 0L);
        }
    }

    @ThreadSafe
    static class ThreadSafeUnaligner
    implements BufferReceivedListener,
    Closeable {
        private final Map<InputChannelInfo, Boolean> storeNewBuffers;
        private int numBarriersReceived;
        private CompletableFuture<Void> allBarriersReceivedFuture = FutureUtils.completedVoidFuture();
        private long currentReceivedCheckpointId = -1L;
        private int numOpenChannels;
        private final SubtaskCheckpointCoordinator checkpointCoordinator;
        private final CheckpointBarrierUnaligner handler;

        ThreadSafeUnaligner(SubtaskCheckpointCoordinator checkpointCoordinator, CheckpointBarrierUnaligner handler, InputGate ... inputGates) {
            this.storeNewBuffers = Arrays.stream(inputGates).flatMap(gate -> gate.getChannelInfos().stream()).collect(Collectors.toMap(Function.identity(), info -> false));
            this.numOpenChannels = this.storeNewBuffers.size();
            this.checkpointCoordinator = checkpointCoordinator;
            this.handler = handler;
        }

        public synchronized void notifyBarrierReceived(CheckpointBarrier barrier, InputChannelInfo channelInfo) throws IOException {
            long barrierId = barrier.getId();
            if (this.currentReceivedCheckpointId < barrierId) {
                this.handleNewCheckpoint(barrier);
                this.handler.executeInTaskThread(() -> this.handler.notifyCheckpoint(barrier), "notifyCheckpoint", new Object[0]);
            }
            if (barrierId == this.currentReceivedCheckpointId && this.storeNewBuffers.get(channelInfo).booleanValue()) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug("{}: Received barrier from channel {} @ {}.", new Object[]{this.handler.taskName, channelInfo, barrierId});
                }
                this.storeNewBuffers.put(channelInfo, false);
                if (++this.numBarriersReceived == this.numOpenChannels) {
                    this.allBarriersReceivedFuture.complete(null);
                }
            }
        }

        public synchronized void notifyBufferReceived(Buffer buffer, InputChannelInfo channelInfo) {
            if (this.storeNewBuffers.get(channelInfo).booleanValue()) {
                this.checkpointCoordinator.getChannelStateWriter().addInputData(this.currentReceivedCheckpointId, channelInfo, -2, CloseableIterator.ofElement((Object)buffer, Buffer::recycleBuffer));
            } else {
                buffer.recycleBuffer();
            }
        }

        @Override
        public synchronized void close() throws IOException {
            this.allBarriersReceivedFuture.cancel(false);
        }

        private synchronized void handleNewCheckpoint(CheckpointBarrier barrier) throws IOException {
            long barrierId = barrier.getId();
            if (!this.allBarriersReceivedFuture.isDone()) {
                CheckpointException exception = new CheckpointException("Barrier id: " + barrierId, CheckpointFailureReason.CHECKPOINT_DECLINED_SUBSUMED);
                if (this.isCheckpointPending()) {
                    LOG.warn("{}: Received checkpoint barrier for checkpoint {} before completing current checkpoint {}. Skipping current checkpoint.", new Object[]{this.handler.taskName, barrierId, this.currentReceivedCheckpointId});
                    long currentCheckpointId = this.currentReceivedCheckpointId;
                    this.handler.executeInTaskThread(() -> this.handler.notifyAbort(currentCheckpointId, exception), "notifyAbort", new Object[0]);
                }
                this.allBarriersReceivedFuture.completeExceptionally((Throwable)exception);
            }
            this.currentReceivedCheckpointId = barrierId;
            this.storeNewBuffers.entrySet().forEach(storeNewBuffer -> storeNewBuffer.setValue(true));
            this.numBarriersReceived = 0;
            this.allBarriersReceivedFuture = new CompletableFuture();
            this.checkpointCoordinator.initCheckpoint(barrierId, barrier.getCheckpointOptions());
        }

        synchronized CompletableFuture<Void> getAllBarriersReceivedFuture(long checkpointId) {
            if (checkpointId < this.currentReceivedCheckpointId) {
                return FutureUtils.completedVoidFuture();
            }
            if (checkpointId > this.currentReceivedCheckpointId) {
                throw new IllegalStateException("Checkpoint " + checkpointId + " has not been started at all");
            }
            return this.allBarriersReceivedFuture;
        }

        synchronized void onChannelClosed() throws IOException {
            --this.numOpenChannels;
            if (this.resetPendingCheckpoint()) {
                this.handler.notifyAbort(this.currentReceivedCheckpointId, new CheckpointException(CheckpointFailureReason.CHECKPOINT_DECLINED_INPUT_END_OF_STREAM));
            }
        }

        synchronized boolean setCancelledCheckpointId(long cancelledId) {
            if (this.currentReceivedCheckpointId > cancelledId || this.currentReceivedCheckpointId == cancelledId && this.numBarriersReceived == 0) {
                return false;
            }
            this.resetPendingCheckpoint();
            this.currentReceivedCheckpointId = cancelledId;
            return true;
        }

        synchronized void tryAbortPendingCheckpoint(long checkpointId, CheckpointException exception) throws IOException {
            if (checkpointId > this.currentReceivedCheckpointId && this.resetPendingCheckpoint()) {
                this.handler.notifyAbort(this.currentReceivedCheckpointId, exception);
            }
        }

        private boolean resetPendingCheckpoint() {
            if (this.numBarriersReceived == 0) {
                return false;
            }
            this.storeNewBuffers.entrySet().forEach(storeNewBuffer -> storeNewBuffer.setValue(false));
            this.numBarriersReceived = 0;
            return true;
        }

        @VisibleForTesting
        synchronized int getNumOpenChannels() {
            return this.numOpenChannels;
        }

        @VisibleForTesting
        synchronized long getCurrentCheckpointId() {
            return this.currentReceivedCheckpointId;
        }

        @VisibleForTesting
        boolean isCheckpointPending() {
            return this.numBarriersReceived > 0;
        }
    }
}

