/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.internal;

import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.NonNull;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.DependencyList;
import org.nd4j.autodiff.samediff.internal.DependencyTracker;
import org.nd4j.autodiff.samediff.internal.FrameIter;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.function.Predicate;
import org.nd4j.imports.VariableUtils;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AbstractSession<T, O> {
    private static final Logger log = LoggerFactory.getLogger(AbstractSession.class);
    public static final String OUTER_FRAME = "main";
    protected final SameDiff sameDiff;
    protected final Map<VarId, T> nodeOutputs = new HashMap<VarId, T>();
    protected final Map<VarId, List<T>> tensorArrays = new HashMap<VarId, List<T>>();
    protected final DependencyTracker<ExecStep, ExecStep> dt = new DependencyTracker();
    protected final Set<String> subgraph = new HashSet<String>();
    protected final Set<String> subgraphOps = new HashSet<String>();
    protected final Set<String> zeroInputOpsInSubgraph = new HashSet<String>();

    public AbstractSession(@NonNull SameDiff sameDiff) {
        if (sameDiff == null) {
            throw new NullPointerException("sameDiff is marked non-null but is null");
        }
        this.sameDiff = sameDiff;
    }

    public boolean contains(String variable, String frame, int iteration, FrameIter parentFrameIter) {
        VarId varId = new VarId(variable, frame, iteration, parentFrameIter);
        return this.nodeOutputs.containsKey(varId);
    }

    public T get(String variable, String frame, int iteration, FrameIter parentFrameIter) {
        return this.get(variable, frame, iteration, parentFrameIter, true);
    }

    public T get(String variable, String frame, int iteration, FrameIter parentFrameIter, boolean enforceExistence) {
        VarId varId = new VarId(variable, frame, iteration, parentFrameIter);
        T out = this.nodeOutputs.get(varId);
        if (enforceExistence) {
            Preconditions.checkNotNull(out, (String)"No output found for variable %s (frame %s, iteration %s)", (Object)variable, (Object)frame, (Object)iteration);
        }
        return out;
    }

    public Map<String, T> output(@NonNull List<String> variables, Map<String, T> placeholderValues, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> listeners, At at) {
        ExecStep es;
        if (variables == null) {
            throw new NullPointerException("variables is marked non-null but is null");
        }
        Preconditions.checkState((!variables.isEmpty() || !requiredActivations.isEmpty() ? 1 : 0) != 0, (String)"Variables to perform forward pass for must not be empty");
        if (requiredActivations == null) {
            requiredActivations = Collections.emptySet();
        }
        if (at == null) {
            at = At.defaultAt();
        }
        for (String s : variables) {
            Preconditions.checkState((boolean)this.sameDiff.variableMap().containsKey(s), (String)"Requested output variable %s does not exist in SameDiff instance", (Object)s);
        }
        HashSet<String> reqOutputVariablesSet = new HashSet<String>(variables);
        placeholderValues = this.preprocessPlaceholders(placeholderValues, at);
        this.dt.clear();
        this.subgraph.clear();
        this.subgraphOps.clear();
        this.nodeOutputs.clear();
        this.tensorArrays.clear();
        HashSet<String> userRequestedUnique = new HashSet<String>(variables);
        HashSet<String> allRequired = new HashSet<String>(requiredActivations);
        allRequired.addAll(variables);
        this.initSubgraph(allRequired);
        List<String> phNames = this.sameDiff.inputs();
        if (placeholderValues == null || !placeholderValues.keySet().containsAll(phNames)) {
            for (String string : phNames) {
                Variable v;
                boolean required = false;
                if (variables.contains(string)) {
                    required = true;
                }
                if (!required && (v = this.sameDiff.getVariables().get(string)).getInputsForOp() != null) {
                    for (String s2 : v.getInputsForOp()) {
                        if (!this.subgraph.contains(s2)) continue;
                        required = true;
                        break;
                    }
                }
                if (!required || placeholderValues != null && placeholderValues.containsKey(string)) continue;
                throw new IllegalStateException("An input placeholder \"" + string + "\" is required to calculate the requested outputs, but a placeholder value was not provided");
            }
        }
        ExecStep start = new ExecStep(ExecType.EXEC_START, "", null);
        for (SDVariable v : this.sameDiff.variables()) {
            VariableType vt = v.getVariableType();
            if (vt != VariableType.VARIABLE && vt != VariableType.CONSTANT) continue;
            ExecType et = vt == VariableType.VARIABLE ? ExecType.VARIABLE : ExecType.CONSTANT;
            ExecStep es2 = new ExecStep(et, v.name(), new FrameIter(OUTER_FRAME, 0, null));
            this.dt.addDependency(es2, start);
            Variable var = this.sameDiff.getVariables().get(v.name());
            if (var.getControlDeps() == null) continue;
            this.addVarControlDeps(es2, var);
        }
        for (String s : phNames) {
            es = new ExecStep(ExecType.PLACEHOLDER, s, new FrameIter(OUTER_FRAME, 0, null));
            this.dt.addDependency(es, start);
            Variable var = this.sameDiff.getVariables().get(s);
            if (var.getControlDeps() == null) continue;
            this.addVarControlDeps(es, var);
        }
        for (String s : this.zeroInputOpsInSubgraph) {
            es = new ExecStep(ExecType.OP, s, new FrameIter(OUTER_FRAME, 0, null));
            this.dt.addDependency(es, start);
        }
        this.dt.markSatisfied(start, true);
        HashMap<String, T> hashMap = new HashMap<String, T>();
        HashSet<String> allExecuted = new HashSet<String>();
        int step = 0;
        String currentFrame = OUTER_FRAME;
        int currentFrameIter = 0;
        FrameIter currParentFrame = null;
        ExecStepPredicate predicate = new ExecStepPredicate();
        while (allExecuted.size() < allRequired.size()) {
            FrameIter outFrameIter;
            VarId vid;
            if (!this.dt.hasNewAllSatisfied()) {
                this.execFailed(userRequestedUnique, hashMap, allRequired, allExecuted, step);
            }
            predicate.setCurrentFrame(currentFrame);
            predicate.setCurrentFrameIter(currentFrameIter);
            predicate.setCurrParentFrame(currParentFrame);
            ExecStep es3 = this.dt.getFirstNewAllSatisfiedMatching(predicate);
            if (es3 == null) {
                es3 = (ExecStep)this.dt.getNewAllSatisfied();
            }
            currentFrame = es3.getFrameIter().getFrame();
            currentFrameIter = es3.getFrameIter().getIteration();
            currParentFrame = es3.getFrameIter().getParentFrame();
            log.trace("Beginning execution step {}: {}", (Object)step, (Object)es3);
            boolean skipDepUpdate = false;
            boolean skipMarkSatisfied = false;
            if (es3.getType() == ExecType.CONSTANT || es3.getType() == ExecType.VARIABLE) {
                vid = new VarId(es3.getName(), OUTER_FRAME, 0, null);
                T arr = this.getConstantOrVariable(es3.getName());
                Preconditions.checkNotNull(arr, (String)"Encountered null placeholder array for constant: %s", (Object)vid);
                this.nodeOutputs.put(vid, arr);
                outFrameIter = new FrameIter(OUTER_FRAME, 0, null);
                if (userRequestedUnique.contains(es3.getName())) {
                    hashMap.put(es3.getName(), arr);
                }
                if (allRequired.contains(es3.getName())) {
                    allExecuted.add(es3.getName());
                }
            } else if (es3.getType() == ExecType.PLACEHOLDER) {
                vid = new VarId(es3.getName(), OUTER_FRAME, 0, null);
                Object phVal = placeholderValues == null ? null : (Object)placeholderValues.get(es3.getName());
                this.nodeOutputs.put(vid, phVal);
                outFrameIter = new FrameIter(OUTER_FRAME, 0, null);
                if (allRequired.contains(es3.getName())) {
                    Preconditions.checkState((placeholderValues != null && placeholderValues.containsKey(es3.getName()) ? 1 : 0) != 0, (String)"No array was provided for the placeholder variable \"%s\" that is required for execution", (Object)es3.getName());
                    hashMap.put(es3.getName(), placeholderValues.get(es3.getName()));
                }
                if (allRequired.contains(es3.getName())) {
                    allExecuted.add(es3.getName());
                }
            } else if (es3.getType() == ExecType.OP) {
                ExecStep cdEs;
                List<String> opOutVarNames;
                O parameterizedOp;
                T[] opOutputValues;
                String outFrame;
                String opName = es3.getName();
                SameDiffOp op = this.sameDiff.getOps().get(opName);
                DifferentialFunction o = op.getOp();
                if (o instanceof Enter) {
                    outFrame = ((Enter)o).getFrameName();
                    outFrameIter = new FrameIter(outFrame, 0, es3.getFrameIter());
                } else if (o instanceof Exit) {
                    outFrame = es3.getFrameIter().getParentFrame().getFrame();
                    int outIter = es3.getFrameIter().getParentFrame().getIteration();
                    FrameIter outParentFrame = es3.getFrameIter().getParentFrame().getParentFrame();
                    outFrameIter = new FrameIter(outFrame, outIter, outParentFrame);
                } else if (o instanceof NextIteration) {
                    outFrameIter = es3.getFrameIter().clone();
                    outFrameIter.setIteration(outFrameIter.getIteration());
                } else {
                    outFrameIter = es3.getFrameIter();
                }
                HashSet<VarId> inputs = null;
                HashSet<VarId> allIterInputs = null;
                HashSet<String> constAndPhInputs = null;
                DependencyList dl = this.dt.getDependencies(es3);
                List<String> inputNames = op.getInputsToOp();
                if (inputNames != null && !inputNames.isEmpty()) {
                    inputs = new HashSet<VarId>();
                    allIterInputs = new HashSet<VarId>();
                    constAndPhInputs = new HashSet<String>();
                    List deps = dl.getDependencies();
                    if (deps != null && !deps.isEmpty()) {
                        block12: for (ExecStep dep : deps) {
                            switch (dep.getType()) {
                                case OP: 
                                case SWITCH_L: 
                                case SWITCH_R: {
                                    SameDiffOp toExecOp = this.sameDiff.getOps().get(es3.getName());
                                    List<String> inputsToExecOp = toExecOp.getInputsToOp();
                                    SameDiffOp inputOp = this.sameDiff.getOps().get(dep.getName());
                                    List<String> inputOpOutNames = inputOp.getOutputsOfOp();
                                    for (String s : inputsToExecOp) {
                                        if (!inputOpOutNames.contains(s)) continue;
                                        VarId vid2 = new VarId(s, dep.getFrameIter().getFrame(), dep.getFrameIter().getIteration(), dep.getFrameIter().getParentFrame());
                                        inputs.add(vid2);
                                    }
                                    continue block12;
                                }
                                case VARIABLE: {
                                    inputs.add(new VarId(dep.getName(), OUTER_FRAME, 0, null));
                                    break;
                                }
                                case CONSTANT: 
                                case PLACEHOLDER: {
                                    constAndPhInputs.add(dep.getName());
                                    break;
                                }
                                default: {
                                    throw new UnsupportedOperationException("Not yet implemented: " + (Object)((Object)dep.getType()));
                                }
                            }
                        }
                    }
                }
                Preconditions.checkState(((opOutputValues = this.getOutputs(parameterizedOp = this.getAndParameterizeOp(opName, outFrameIter, inputs, allIterInputs, constAndPhInputs, placeholderValues, reqOutputVariablesSet), outFrameIter, inputs, allIterInputs, constAndPhInputs, listeners, at, batch, reqOutputVariablesSet)).length == (opOutVarNames = op.getOutputsOfOp()).size() ? 1 : 0) != 0, (String)"Unexpected number of outputs from executed op %s: got %s outputs when %s outputs were expected (%s)", (Object)parameterizedOp.getClass().getSimpleName(), (Object)opOutputValues.length, (Object)opOutVarNames.size(), opOutVarNames);
                for (int i = 0; i < opOutputValues.length; ++i) {
                    if (opOutputValues[i] == null && op.getOp() instanceof Switch) continue;
                    String n = opOutVarNames.get(i);
                    VarId vid3 = new VarId(n, outFrameIter.getFrame(), outFrameIter.getIteration(), outFrameIter.getParentFrame());
                    this.nodeOutputs.put(vid3, opOutputValues[i]);
                    if (userRequestedUnique.contains(n)) {
                        hashMap.put(n, opOutputValues[i]);
                    }
                    if (!allRequired.contains(n)) continue;
                    allExecuted.add(n);
                }
                if (o instanceof Switch) {
                    skipDepUpdate = true;
                    skipMarkSatisfied = true;
                    int nullCount = (opOutputValues[0] == null ? 1 : 0) + (opOutputValues[1] == null ? 1 : 0);
                    Preconditions.checkState((nullCount == 1 ? 1 : 0) != 0, (String)"Expected exactly one output to be present for switch ops, got %s", (int)nullCount);
                    boolean left = opOutputValues[0] != null;
                    ExecStep branch = left ? new ExecStep(ExecType.SWITCH_L, es3.getName(), es3.getFrameIter()) : new ExecStep(ExecType.SWITCH_R, es3.getName(), es3.getFrameIter());
                    this.updateDescendantDeps(branch, outFrameIter);
                    this.dt.markSatisfied(branch, true);
                } else if (o instanceof Enter) {
                    skipDepUpdate = true;
                    skipMarkSatisfied = true;
                    Enter e = (Enter)o;
                    FrameIter fi = new FrameIter(e.getFrameName(), 0, es3.getFrameIter());
                    ExecStep exec = new ExecStep(ExecType.OP, es3.getName(), fi);
                    this.updateDescendantDeps(exec, fi);
                    this.dt.markSatisfied(exec, true);
                } else if (o instanceof Exit) {
                    skipDepUpdate = true;
                    skipMarkSatisfied = true;
                    FrameIter fi = es3.getFrameIter().getParentFrame();
                    ExecStep exec = new ExecStep(ExecType.OP, es3.getName(), fi);
                    this.updateDescendantDeps(exec, fi);
                    this.dt.markSatisfied(exec, true);
                }
                List<String> cdFor = op.getControlDepFor();
                if (cdFor != null && !this.dt.isSatisfied(cdEs = new ExecStep(ExecType.CONTROL_DEP, opName, null))) {
                    this.dt.markSatisfied(cdEs, true);
                }
            } else {
                throw new RuntimeException("Unknown ExecStep: " + es3);
            }
            if (!skipDepUpdate) {
                this.updateDescendantDeps(es3, outFrameIter);
            }
            if (!skipMarkSatisfied) {
                this.dt.markSatisfied(es3, true);
            }
            ++step;
        }
        Map map = this.postProcessOutput(hashMap);
        return map;
    }

    protected void addVarControlDeps(ExecStep es, Variable v) {
        List<String> cds = v.getControlDeps();
        if (cds != null) {
            for (String s : cds) {
                ExecStep controlES = new ExecStep(ExecType.CONTROL_DEP, s, null);
                this.dt.addDependency(es, controlES);
            }
        }
    }

    protected void execFailed(Set<String> userRequestedUnique, Map<String, T> out, Set<String> allRequired, Set<String> allExecuted, int step) {
        int missingCount = userRequestedUnique.size() - out.size();
        StringBuilder sb = new StringBuilder();
        sb.append("No variable are available for execution at step ").append(step).append(": ").append(missingCount).append(" requested output values remaining, ").append(allExecuted.size() - allRequired.size()).append(" variables required to be executed remaining");
        HashSet<String> missing = new HashSet<String>();
        for (String s : userRequestedUnique) {
            if (out.containsKey(s)) continue;
            missing.add(s);
        }
        if (missingCount <= 10) {
            sb.append(". Missing variables: ");
            sb.append(missing);
        } else {
            sb.append(". First 10 missing variables: ");
            Iterator iter = missing.iterator();
            for (int i = 0; i < 10 && iter.hasNext(); ++i) {
                if (i > 0) {
                    sb.append(",");
                }
                sb.append((String)iter.next());
            }
        }
        String s = sb.toString();
        System.out.println(this.sameDiff.summary());
        throw new IllegalStateException(s);
    }

    protected void updateDescendantDeps(ExecStep justExecuted, FrameIter outFrameIter) {
        ExecType t = justExecuted.getType();
        String n = justExecuted.getName();
        if (justExecuted.getType() == ExecType.OP) {
            SameDiffOp op = this.sameDiff.getOps().get(n);
            List<String> outNames = op.getOutputsOfOp();
            for (String s : outNames) {
                List<String> cdForOps;
                Variable v = this.sameDiff.getVariables().get(s);
                if (v == null) continue;
                List<String> inputsToOps = v.getInputsForOp();
                if (inputsToOps != null) {
                    for (String opName : inputsToOps) {
                        if (!this.subgraphOps.contains(opName)) continue;
                        this.addDependenciesForOp(opName, outFrameIter);
                    }
                }
                if ((cdForOps = v.getControlDepsForOp()) == null) continue;
                for (String opName : cdForOps) {
                    if (!this.subgraphOps.contains(opName)) continue;
                    this.addDependenciesForOp(opName, outFrameIter);
                }
            }
        } else if (t == ExecType.VARIABLE || t == ExecType.CONSTANT || t == ExecType.PLACEHOLDER) {
            List<String> inputsToOps;
            Variable v = this.sameDiff.getVariables().get(n);
            if (v != null && (inputsToOps = v.getInputsForOp()) != null) {
                for (String opName : inputsToOps) {
                    if (!this.subgraphOps.contains(opName)) continue;
                    this.addDependenciesForOp(opName, outFrameIter);
                }
            }
        } else if (justExecuted.getType() == ExecType.SWITCH_L || justExecuted.getType() == ExecType.SWITCH_R) {
            List<String> inputsToOps;
            SameDiffOp op = this.sameDiff.getOps().get(n);
            List<String> outNames = op.getOutputsOfOp();
            String branchVarName = justExecuted.getType() == ExecType.SWITCH_L ? outNames.get(0) : outNames.get(1);
            Variable v = this.sameDiff.getVariables().get(branchVarName);
            if (v != null && (inputsToOps = v.getInputsForOp()) != null) {
                for (String opName : inputsToOps) {
                    if (!this.subgraphOps.contains(opName)) continue;
                    this.addDependenciesForOp(opName, outFrameIter);
                }
            }
        } else {
            throw new UnsupportedOperationException("Unknown or not yet implemented exec type: " + justExecuted);
        }
    }

    protected void addDependenciesForOp(String opName, FrameIter depFrameIter) {
        ExecStep req;
        SameDiffOp op = this.sameDiff.getOps().get(opName);
        List<String> inputs = op.getInputsToOp();
        List<String> cdOps = op.getControlDeps();
        List<String> cdVars = op.getVarControlDeps();
        ExecStep es = new ExecStep(ExecType.OP, opName, depFrameIter);
        if (!(op.getOp() instanceof NextIteration) && this.dt.hasDependency(es)) {
            return;
        }
        if (op.getOp() instanceof Merge) {
            Variable v0 = this.sameDiff.getVariables().get(inputs.get(0));
            Variable v1 = this.sameDiff.getVariables().get(inputs.get(1));
            ExecStep or0 = this.getExecStepForVar(v0.getName(), depFrameIter);
            ExecStep or1 = this.getExecStepForVar(v1.getName(), depFrameIter);
            this.dt.addOrDependency(es, or0, or1);
        } else if (op.getOp() instanceof NextIteration) {
            FrameIter fi = depFrameIter.clone();
            fi.setIteration(fi.getIteration() + 1);
            es = new ExecStep(ExecType.OP, opName, fi);
            for (String s : inputs) {
                ExecStep req2 = this.getExecStepForVar(s, depFrameIter);
                this.dt.addDependency(es, req2);
            }
        } else {
            for (String s : inputs) {
                req = this.getExecStepForVar(s, depFrameIter);
                this.dt.addDependency(es, req);
            }
        }
        if (cdOps != null) {
            for (String s : cdOps) {
                req = this.getExecStepForVar(s, depFrameIter);
                this.dt.addDependency(es, req);
            }
        }
        if (cdVars != null) {
            for (String string : cdVars) {
            }
        }
    }

    protected ExecStep getExecStepForVar(String varName, FrameIter frameIter) {
        Enter e;
        Variable v = this.sameDiff.getVariables().get(varName);
        VariableType vt = v.getVariable().getVariableType();
        if (vt == VariableType.VARIABLE) {
            return new ExecStep(ExecType.VARIABLE, v.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null));
        }
        if (vt == VariableType.PLACEHOLDER) {
            return new ExecStep(ExecType.PLACEHOLDER, v.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null));
        }
        if (vt == VariableType.CONSTANT) {
            return new ExecStep(ExecType.CONSTANT, v.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null));
        }
        if (v.getOutputOfOp() == null) {
            v = this.sameDiff.getVariables().get(VariableUtils.stripVarSuffix(v.getName()));
        }
        String outOfOp = v.getOutputOfOp();
        SameDiffOp sdo = this.sameDiff.getOps().get(outOfOp);
        if (sdo == null) {
            throw new IllegalStateException("Samediff output op named " + v.getName() + " did not have any ops associated with it.");
        }
        if (sdo.getOp() instanceof Switch) {
            List<String> opOutputs = sdo.getOutputsOfOp();
            int idx = opOutputs.indexOf(v.getName());
            if (idx == 0) {
                return new ExecStep(ExecType.SWITCH_L, outOfOp, frameIter);
            }
            if (idx == 1) {
                return new ExecStep(ExecType.SWITCH_R, outOfOp, frameIter);
            }
            throw new IllegalStateException("Expected variable \"" + v.getName() + "\" to be an output of operation \"" + outOfOp + "\", but op output variables are: " + opOutputs);
        }
        if (sdo.getOp() instanceof Enter && (e = (Enter)sdo.getOp()).isConstant()) {
            Variable var;
            FrameIter fi = frameIter.clone();
            fi.setIteration(0);
            String inVarName = sdo.getInputsToOp().get(0);
            for (FrameIter parentFrame = fi.getParentFrame(); parentFrame != null && (var = this.sameDiff.getVariables().get(inVarName)).getOutputOfOp() != null; parentFrame = parentFrame.getParentFrame()) {
                Enter e2;
                String opName = var.getOutputOfOp();
                SameDiffOp sdo2 = this.sameDiff.getOps().get(opName);
                if (!(sdo2.getOp() instanceof Enter) || !(e2 = (Enter)sdo.getOp()).isConstant()) break;
                parentFrame.setIteration(0);
                inVarName = sdo2.getInputsToOp().get(0);
            }
            return new ExecStep(ExecType.OP, outOfOp, fi);
        }
        return new ExecStep(ExecType.OP, outOfOp, frameIter);
    }

    protected void initSubgraph(Set<String> variables) {
        LinkedList<String> processingQueue = new LinkedList<String>(variables);
        while (!processingQueue.isEmpty()) {
            String[] inputs;
            String opName;
            String varName = (String)processingQueue.remove();
            String string = opName = this.sameDiff.getVariableOutputOp(varName) == null ? null : this.sameDiff.getVariableOutputOp(varName).getOwnName();
            if (!this.subgraph.contains(varName)) {
                int numInputs;
                String[] opInputs = opName == null ? null : this.sameDiff.getInputsForOp(this.sameDiff.getOpById(opName));
                Variable currVar = this.sameDiff.getVariables().get(varName);
                log.trace("Adding " + varName + " to subgraph for output.");
                List<String> list = currVar.getInputsForOp();
                List<String> controlDeps = currVar.getControlDeps();
                String output = currVar.getOutputOfOp();
                int n = numInputs = opInputs == null ? 0 : opInputs.length;
                if (controlDeps != null) {
                    numInputs += controlDeps.size();
                }
                if (numInputs == 0 && opName != null) {
                    this.zeroInputOpsInSubgraph.add(opName);
                }
                this.subgraph.add(varName);
                if (opName != null) {
                    this.subgraphOps.add(opName);
                }
                if (controlDeps != null) {
                    for (String s : controlDeps) {
                        if (this.subgraph.contains(s)) continue;
                        processingQueue.add(s);
                    }
                }
            }
            if (opName == null) continue;
            DifferentialFunction opById = this.sameDiff.getOpById(opName);
            for (String s2 : inputs = this.sameDiff.getInputsForOp(opById)) {
                if (this.subgraph.contains(s2)) continue;
                processingQueue.add(s2);
            }
            List<String> list = this.sameDiff.getOps().get(opName).getControlDeps();
            if (list == null) continue;
            for (String s2 : list) {
                if (this.subgraph.contains(s2)) continue;
                processingQueue.add(s2);
            }
        }
    }

    protected Map<String, T> preprocessPlaceholders(Map<String, T> placeholders, At at) {
        return placeholders;
    }

    protected Map<String, T> postProcessOutput(Map<String, T> output) {
        return output;
    }

    public abstract T getConstantOrVariable(String var1);

    public abstract O getAndParameterizeOp(String var1, FrameIter var2, Set<VarId> var3, Set<VarId> var4, Set<String> var5, Map<String, T> var6, Set<String> var7);

    public abstract T[] getOutputs(O var1, FrameIter var2, Set<VarId> var3, Set<VarId> var4, Set<String> var5, List<Listener> var6, At var7, MultiDataSet var8, Set<String> var9);

    protected static VarId lookup(String name, Collection<VarId> varIds, Collection<VarId> varIds2, boolean exceptionOnNotFound) {
        VarId vid;
        VarId varId = vid = varIds == null ? null : AbstractSession.lookup(name, varIds, false);
        if (vid == null && varIds2 != null) {
            vid = AbstractSession.lookup(name, varIds2, false);
        }
        if (vid == null && exceptionOnNotFound) {
            throw new RuntimeException("Could not find VarId for input \"" + name + "\"");
        }
        return vid;
    }

    protected static VarId lookup(String name, Collection<VarId> varIds, boolean exceptionOnNotFound) {
        for (VarId vid : varIds) {
            if (!vid.getVariable().equals(name)) continue;
            return vid;
        }
        if (exceptionOnNotFound) {
            throw new RuntimeException("Could not find VarId to input " + name);
        }
        return null;
    }

    public Map<VarId, T> getNodeOutputs() {
        return this.nodeOutputs;
    }

    public Map<VarId, List<T>> getTensorArrays() {
        return this.tensorArrays;
    }

    protected class ExecStepPredicate
    implements Predicate<ExecStep> {
        protected String currentFrame;
        protected int currentFrameIter;
        protected FrameIter currParentFrame;

        public boolean test(ExecStep execStep) {
            return this.currentFrame.equals(execStep.getFrameIter().getFrame()) && this.currentFrameIter == execStep.getFrameIter().getIteration() && (this.currParentFrame == null && execStep.getFrameIter().getParentFrame() == null || this.currParentFrame.equals(execStep.getFrameIter().getParentFrame()));
        }

        public String getCurrentFrame() {
            return this.currentFrame;
        }

        public int getCurrentFrameIter() {
            return this.currentFrameIter;
        }

        public FrameIter getCurrParentFrame() {
            return this.currParentFrame;
        }

        public void setCurrentFrame(String currentFrame) {
            this.currentFrame = currentFrame;
        }

        public void setCurrentFrameIter(int currentFrameIter) {
            this.currentFrameIter = currentFrameIter;
        }

        public void setCurrParentFrame(FrameIter currParentFrame) {
            this.currParentFrame = currParentFrame;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof ExecStepPredicate)) {
                return false;
            }
            ExecStepPredicate other = (ExecStepPredicate)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (this.getCurrentFrameIter() != other.getCurrentFrameIter()) {
                return false;
            }
            String this$currentFrame = this.getCurrentFrame();
            String other$currentFrame = other.getCurrentFrame();
            if (this$currentFrame == null ? other$currentFrame != null : !this$currentFrame.equals(other$currentFrame)) {
                return false;
            }
            FrameIter this$currParentFrame = this.getCurrParentFrame();
            FrameIter other$currParentFrame = other.getCurrParentFrame();
            return !(this$currParentFrame == null ? other$currParentFrame != null : !((Object)this$currParentFrame).equals(other$currParentFrame));
        }

        protected boolean canEqual(Object other) {
            return other instanceof ExecStepPredicate;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + this.getCurrentFrameIter();
            String $currentFrame = this.getCurrentFrame();
            result = result * 59 + ($currentFrame == null ? 43 : $currentFrame.hashCode());
            FrameIter $currParentFrame = this.getCurrParentFrame();
            result = result * 59 + ($currParentFrame == null ? 43 : ((Object)$currParentFrame).hashCode());
            return result;
        }

        public String toString() {
            return "AbstractSession.ExecStepPredicate(currentFrame=" + this.getCurrentFrame() + ", currentFrameIter=" + this.getCurrentFrameIter() + ", currParentFrame=" + this.getCurrParentFrame() + ")";
        }

        public ExecStepPredicate(String currentFrame, int currentFrameIter, FrameIter currParentFrame) {
            this.currentFrame = currentFrame;
            this.currentFrameIter = currentFrameIter;
            this.currParentFrame = currParentFrame;
        }

        public ExecStepPredicate() {
        }
    }

    protected static class ExecStep {
        protected final ExecType type;
        protected final String name;
        protected final FrameIter frameIter;

        protected ExecStep(@NonNull ExecType execType, @NonNull String name, FrameIter frameIter) {
            if (execType == null) {
                throw new NullPointerException("execType is marked non-null but is null");
            }
            if (name == null) {
                throw new NullPointerException("name is marked non-null but is null");
            }
            this.type = execType;
            this.name = name;
            this.frameIter = frameIter;
        }

        protected VarId toVarId() {
            return new VarId(this.name, this.frameIter.getFrame(), this.frameIter.getIteration(), this.frameIter.getParentFrame());
        }

        public String toString() {
            return "ExecStep(" + (Object)((Object)this.type) + ",name=\"" + this.name + "\"," + this.frameIter + ")";
        }

        public ExecType getType() {
            return this.type;
        }

        public String getName() {
            return this.name;
        }

        public FrameIter getFrameIter() {
            return this.frameIter;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof ExecStep)) {
                return false;
            }
            ExecStep other = (ExecStep)o;
            if (!other.canEqual(this)) {
                return false;
            }
            ExecType this$type = this.getType();
            ExecType other$type = other.getType();
            if (this$type == null ? other$type != null : !((Object)((Object)this$type)).equals((Object)other$type)) {
                return false;
            }
            String this$name = this.getName();
            String other$name = other.getName();
            if (this$name == null ? other$name != null : !this$name.equals(other$name)) {
                return false;
            }
            FrameIter this$frameIter = this.getFrameIter();
            FrameIter other$frameIter = other.getFrameIter();
            return !(this$frameIter == null ? other$frameIter != null : !((Object)this$frameIter).equals(other$frameIter));
        }

        protected boolean canEqual(Object other) {
            return other instanceof ExecStep;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            ExecType $type = this.getType();
            result = result * 59 + ($type == null ? 43 : ((Object)((Object)$type)).hashCode());
            String $name = this.getName();
            result = result * 59 + ($name == null ? 43 : $name.hashCode());
            FrameIter $frameIter = this.getFrameIter();
            result = result * 59 + ($frameIter == null ? 43 : ((Object)$frameIter).hashCode());
            return result;
        }
    }

    protected static enum ExecType {
        OP,
        VARIABLE,
        CONSTANT,
        PLACEHOLDER,
        SWITCH_L,
        SWITCH_R,
        EXEC_START,
        CONTROL_DEP;

    }

    public static class VarId {
        private String variable;
        private String frame;
        private int iteration;
        private FrameIter parentFrame;

        public String toString() {
            return "VarId(\"" + this.variable + "\",\"" + this.frame + "\"," + this.iteration + ",parent=" + this.parentFrame + ")";
        }

        public FrameIter toFrameIter() {
            return new FrameIter(this.frame, this.iteration, this.parentFrame);
        }

        public String getVariable() {
            return this.variable;
        }

        public String getFrame() {
            return this.frame;
        }

        public int getIteration() {
            return this.iteration;
        }

        public FrameIter getParentFrame() {
            return this.parentFrame;
        }

        public void setVariable(String variable) {
            this.variable = variable;
        }

        public void setFrame(String frame) {
            this.frame = frame;
        }

        public void setIteration(int iteration) {
            this.iteration = iteration;
        }

        public void setParentFrame(FrameIter parentFrame) {
            this.parentFrame = parentFrame;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof VarId)) {
                return false;
            }
            VarId other = (VarId)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (this.getIteration() != other.getIteration()) {
                return false;
            }
            String this$variable = this.getVariable();
            String other$variable = other.getVariable();
            if (this$variable == null ? other$variable != null : !this$variable.equals(other$variable)) {
                return false;
            }
            String this$frame = this.getFrame();
            String other$frame = other.getFrame();
            if (this$frame == null ? other$frame != null : !this$frame.equals(other$frame)) {
                return false;
            }
            FrameIter this$parentFrame = this.getParentFrame();
            FrameIter other$parentFrame = other.getParentFrame();
            return !(this$parentFrame == null ? other$parentFrame != null : !((Object)this$parentFrame).equals(other$parentFrame));
        }

        protected boolean canEqual(Object other) {
            return other instanceof VarId;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + this.getIteration();
            String $variable = this.getVariable();
            result = result * 59 + ($variable == null ? 43 : $variable.hashCode());
            String $frame = this.getFrame();
            result = result * 59 + ($frame == null ? 43 : $frame.hashCode());
            FrameIter $parentFrame = this.getParentFrame();
            result = result * 59 + ($parentFrame == null ? 43 : ((Object)$parentFrame).hashCode());
            return result;
        }

        public VarId(String variable, String frame, int iteration, FrameIter parentFrame) {
            this.variable = variable;
            this.frame = frame;
            this.iteration = iteration;
            this.parentFrame = parentFrame;
        }
    }
}

