/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.state.heap;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
import org.apache.flink.runtime.state.StateSnapshot;
import org.apache.flink.runtime.state.StateSnapshotTransformer;
import org.apache.flink.runtime.state.StateTransformationFunction;
import org.apache.flink.runtime.state.heap.AbstractStateTableSnapshot;
import org.apache.flink.runtime.state.heap.InternalKeyContext;
import org.apache.flink.runtime.state.heap.StateTable;
import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
import org.apache.flink.util.Preconditions;

@Internal
public class NestedMapsStateTable<K, N, S>
extends StateTable<K, N, S> {
    private final Map<N, Map<K, S>>[] state;
    private final int keyGroupOffset;

    public NestedMapsStateTable(InternalKeyContext<K> keyContext, RegisteredKeyValueStateBackendMetaInfo<N, S> metaInfo) {
        super(keyContext, metaInfo);
        this.keyGroupOffset = keyContext.getKeyGroupRange().getStartKeyGroup();
        Map[] state = new Map[keyContext.getKeyGroupRange().getNumberOfKeyGroups()];
        this.state = state;
    }

    @VisibleForTesting
    public Map<N, Map<K, S>>[] getState() {
        return this.state;
    }

    @VisibleForTesting
    Map<N, Map<K, S>> getMapForKeyGroup(int keyGroupIndex) {
        int pos = this.indexToOffset(keyGroupIndex);
        if (pos >= 0 && pos < this.state.length) {
            return this.state[pos];
        }
        return null;
    }

    private void setMapForKeyGroup(int keyGroupId, Map<N, Map<K, S>> map) {
        try {
            this.state[this.indexToOffset((int)keyGroupId)] = map;
        }
        catch (ArrayIndexOutOfBoundsException e) {
            throw new IllegalArgumentException("Key group index " + keyGroupId + " is out of range of key group range [" + this.keyGroupOffset + ", " + (this.keyGroupOffset + this.state.length) + ").");
        }
    }

    private int indexToOffset(int index) {
        return index - this.keyGroupOffset;
    }

    @Override
    public int size() {
        int count = 0;
        for (Map<N, Map<N, S>> map : this.state) {
            if (null == map) continue;
            for (Map<K, S> keyMap : map.values()) {
                if (null == keyMap) continue;
                count += keyMap.size();
            }
        }
        return count;
    }

    @Override
    public S get(N namespace) {
        return this.get(this.keyContext.getCurrentKey(), this.keyContext.getCurrentKeyGroupIndex(), namespace);
    }

    @Override
    public boolean containsKey(N namespace) {
        return this.containsKey(this.keyContext.getCurrentKey(), this.keyContext.getCurrentKeyGroupIndex(), namespace);
    }

    @Override
    public void put(N namespace, S state) {
        this.put(this.keyContext.getCurrentKey(), this.keyContext.getCurrentKeyGroupIndex(), namespace, state);
    }

    @Override
    public S putAndGetOld(N namespace, S state) {
        return this.putAndGetOld(this.keyContext.getCurrentKey(), this.keyContext.getCurrentKeyGroupIndex(), namespace, state);
    }

    @Override
    public void remove(N namespace) {
        this.remove(this.keyContext.getCurrentKey(), this.keyContext.getCurrentKeyGroupIndex(), namespace);
    }

    @Override
    public S removeAndGetOld(N namespace) {
        return this.removeAndGetOld(this.keyContext.getCurrentKey(), this.keyContext.getCurrentKeyGroupIndex(), namespace);
    }

    @Override
    public S get(K key, N namespace) {
        int keyGroup = KeyGroupRangeAssignment.assignToKeyGroup(key, this.keyContext.getNumberOfKeyGroups());
        return this.get(key, keyGroup, namespace);
    }

    @Override
    public Stream<K> getKeys(N namespace) {
        return Arrays.stream(this.state).filter(Objects::nonNull).map(namespaces -> namespaces.getOrDefault(namespace, Collections.emptyMap())).flatMap(namespaceSate -> namespaceSate.keySet().stream());
    }

    private boolean containsKey(K key, int keyGroupIndex, N namespace) {
        this.checkKeyNamespacePreconditions(key, namespace);
        Map<N, Map<K, S>> namespaceMap = this.getMapForKeyGroup(keyGroupIndex);
        if (namespaceMap == null) {
            return false;
        }
        Map<K, S> keyedMap = namespaceMap.get(namespace);
        return keyedMap != null && keyedMap.containsKey(key);
    }

    S get(K key, int keyGroupIndex, N namespace) {
        this.checkKeyNamespacePreconditions(key, namespace);
        Map<N, Map<K, S>> namespaceMap = this.getMapForKeyGroup(keyGroupIndex);
        if (namespaceMap == null) {
            return null;
        }
        Map<K, S> keyedMap = namespaceMap.get(namespace);
        if (keyedMap == null) {
            return null;
        }
        return keyedMap.get(key);
    }

    @Override
    public void put(K key, int keyGroupIndex, N namespace, S value) {
        this.putAndGetOld(key, keyGroupIndex, namespace, value);
    }

    private S putAndGetOld(K key, int keyGroupIndex, N namespace, S value) {
        this.checkKeyNamespacePreconditions(key, namespace);
        Map<N, Map<Object, S>> namespaceMap = this.getMapForKeyGroup(keyGroupIndex);
        if (namespaceMap == null) {
            namespaceMap = new HashMap<N, Map<K, S>>();
            this.setMapForKeyGroup(keyGroupIndex, namespaceMap);
        }
        Map keyedMap = namespaceMap.computeIfAbsent(namespace, k -> new HashMap());
        return keyedMap.put(key, value);
    }

    private void remove(K key, int keyGroupIndex, N namespace) {
        this.removeAndGetOld(key, keyGroupIndex, namespace);
    }

    private S removeAndGetOld(K key, int keyGroupIndex, N namespace) {
        this.checkKeyNamespacePreconditions(key, namespace);
        Map<N, Map<K, S>> namespaceMap = this.getMapForKeyGroup(keyGroupIndex);
        if (namespaceMap == null) {
            return null;
        }
        Map<K, S> keyedMap = namespaceMap.get(namespace);
        if (keyedMap == null) {
            return null;
        }
        S removed = keyedMap.remove(key);
        if (keyedMap.isEmpty()) {
            namespaceMap.remove(namespace);
        }
        return removed;
    }

    private void checkKeyNamespacePreconditions(K key, N namespace) {
        Preconditions.checkNotNull(key, (String)"No key set. This method should not be called outside of a keyed context.");
        Preconditions.checkNotNull(namespace, (String)"Provided namespace is null.");
    }

    @Override
    public int sizeOfNamespace(Object namespace) {
        int count = 0;
        for (Map<N, Map<N, S>> map : this.state) {
            if (null == map) continue;
            Map<K, S> keyMap = map.get(namespace);
            count += keyMap != null ? keyMap.size() : 0;
        }
        return count;
    }

    @Override
    public <T> void transform(N namespace, T value, StateTransformationFunction<S, T> transformation) throws Exception {
        Object key = this.keyContext.getCurrentKey();
        this.checkKeyNamespacePreconditions(key, namespace);
        int keyGroupIndex = this.keyContext.getCurrentKeyGroupIndex();
        Map<N, Map<Object, S>> namespaceMap = this.getMapForKeyGroup(keyGroupIndex);
        if (namespaceMap == null) {
            namespaceMap = new HashMap<N, Map<K, S>>();
            this.setMapForKeyGroup(keyGroupIndex, namespaceMap);
        }
        Map keyedMap = namespaceMap.computeIfAbsent(namespace, k -> new HashMap());
        keyedMap.put(key, transformation.apply(keyedMap.get(key), value));
    }

    private static <K, N, S> int countMappingsInKeyGroup(Map<N, Map<K, S>> keyGroupMap) {
        int count = 0;
        for (Map<K, S> namespaceMap : keyGroupMap.values()) {
            count += namespaceMap.size();
        }
        return count;
    }

    @Override
    @Nonnull
    public NestedMapsStateTableSnapshot<K, N, S> stateSnapshot() {
        return new NestedMapsStateTableSnapshot(this, this.metaInfo.getSnapshotTransformer());
    }

    static class NestedMapsStateTableSnapshot<K, N, S>
    extends AbstractStateTableSnapshot<K, N, S, NestedMapsStateTable<K, N, S>>
    implements StateSnapshot.StateKeyGroupWriter {
        private final TypeSerializer<K> keySerializer;
        private final TypeSerializer<N> namespaceSerializer;
        private final TypeSerializer<S> stateSerializer;
        private final StateSnapshotTransformer<S> snapshotFilter;

        NestedMapsStateTableSnapshot(NestedMapsStateTable<K, N, S> owningTable, StateSnapshotTransformer<S> snapshotFilter) {
            super(owningTable);
            this.snapshotFilter = snapshotFilter;
            this.keySerializer = ((NestedMapsStateTable)this.owningStateTable).keyContext.getKeySerializer();
            this.namespaceSerializer = ((NestedMapsStateTable)this.owningStateTable).metaInfo.getNamespaceSerializer();
            this.stateSerializer = ((NestedMapsStateTable)this.owningStateTable).metaInfo.getStateSerializer();
        }

        @Override
        @Nonnull
        public StateSnapshot.StateKeyGroupWriter getKeyGroupWriter() {
            return this;
        }

        @Override
        @Nonnull
        public StateMetaInfoSnapshot getMetaInfoSnapshot() {
            return ((NestedMapsStateTable)this.owningStateTable).metaInfo.snapshot();
        }

        @Override
        public void writeStateInKeyGroup(@Nonnull DataOutputView dov, int keyGroupId) throws IOException {
            Map keyGroupMap = ((NestedMapsStateTable)this.owningStateTable).getMapForKeyGroup(keyGroupId);
            if (null != keyGroupMap) {
                Map filteredMappings = this.filterMappingsInKeyGroupIfNeeded(keyGroupMap);
                dov.writeInt(NestedMapsStateTable.countMappingsInKeyGroup(filteredMappings));
                for (Map.Entry namespaceEntry : filteredMappings.entrySet()) {
                    Object namespace = namespaceEntry.getKey();
                    Map namespaceMap = namespaceEntry.getValue();
                    for (Map.Entry keyEntry : namespaceMap.entrySet()) {
                        this.writeElement(namespace, keyEntry, dov);
                    }
                }
            } else {
                dov.writeInt(0);
            }
        }

        private void writeElement(N namespace, Map.Entry<K, S> keyEntry, DataOutputView dov) throws IOException {
            this.namespaceSerializer.serialize(namespace, dov);
            this.keySerializer.serialize(keyEntry.getKey(), dov);
            this.stateSerializer.serialize(keyEntry.getValue(), dov);
        }

        private Map<N, Map<K, S>> filterMappingsInKeyGroupIfNeeded(Map<N, Map<K, S>> keyGroupMap) {
            return this.snapshotFilter == null ? keyGroupMap : this.filterMappingsInKeyGroup(keyGroupMap);
        }

        private Map<N, Map<K, S>> filterMappingsInKeyGroup(Map<N, Map<K, S>> keyGroupMap) {
            HashMap<Object, Map> filtered = new HashMap<Object, Map>();
            for (Map.Entry<N, Map<K, S>> namespaceEntry : keyGroupMap.entrySet()) {
                N namespace = namespaceEntry.getKey();
                Map filteredNamespaceMap = filtered.computeIfAbsent(namespace, n -> new HashMap());
                for (Map.Entry<K, S> keyEntry : namespaceEntry.getValue().entrySet()) {
                    K key = keyEntry.getKey();
                    S transformedvalue = this.snapshotFilter.filterOrTransform(keyEntry.getValue());
                    if (transformedvalue == null) continue;
                    filteredNamespaceMap.put(key, transformedvalue);
                }
            }
            return filtered;
        }
    }
}

