/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Random;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.common.util.MathUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.util.FeatureUtil;
import org.nd4j.shade.guava.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DataSet
implements org.nd4j.linalg.dataset.api.DataSet {
    private static final Logger log = LoggerFactory.getLogger(DataSet.class);
    private static final long serialVersionUID = 1935520764586513365L;
    private static final byte BITMASK_FEATURES_PRESENT = 1;
    private static final byte BITMASK_LABELS_PRESENT = 2;
    private static final byte BITMASK_LABELS_SAME_AS_FEATURES = 4;
    private static final byte BITMASK_FEATURE_MASK_PRESENT = 8;
    private static final byte BITMASK_LABELS_MASK_PRESENT = 16;
    private static final byte BITMASK_METADATA_PRESET = 32;
    private List<String> columnNames = new ArrayList<String>();
    private List<String> labelNames = new ArrayList<String>();
    private INDArray features;
    private INDArray labels;
    private INDArray featuresMask;
    private INDArray labelsMask;
    private List<Serializable> exampleMetaData;
    private transient boolean preProcessed = false;

    public DataSet() {
        this(null, null);
    }

    @Override
    public List<Serializable> getExampleMetaData() {
        return this.exampleMetaData;
    }

    @Override
    public <T extends Serializable> List<T> getExampleMetaData(Class<T> metaDataType) {
        return this.exampleMetaData;
    }

    @Override
    public void setExampleMetaData(List<? extends Serializable> exampleMetaData) {
        this.exampleMetaData = exampleMetaData;
    }

    public DataSet(INDArray first, INDArray second) {
        this(first, second, null, null);
    }

    public DataSet(INDArray features, INDArray labels, INDArray featuresMask, INDArray labelsMask) {
        this.features = features;
        this.labels = labels;
        this.featuresMask = featuresMask;
        this.labelsMask = labelsMask;
        Nd4j.getExecutioner().commit();
    }

    public boolean isPreProcessed() {
        return this.preProcessed;
    }

    public void markAsPreProcessed() {
        this.preProcessed = true;
    }

    public static DataSet empty() {
        return new DataSet(null, null);
    }

    public static DataSet merge(List<? extends org.nd4j.linalg.dataset.api.DataSet> data) {
        if (data.isEmpty()) {
            throw new IllegalArgumentException("Unable to merge empty dataset");
        }
        int nonEmpty = 0;
        boolean anyFeaturesPreset = false;
        boolean anyLabelsPreset = false;
        boolean first = true;
        for (org.nd4j.linalg.dataset.api.DataSet dataSet : data) {
            if (dataSet.isEmpty()) continue;
            ++nonEmpty;
            if (anyFeaturesPreset && dataSet.getFeatures() == null || !first && !anyFeaturesPreset && dataSet.getFeatures() != null) {
                throw new IllegalStateException("Cannot merge features: encountered null features in one or more DataSets");
            }
            if (anyLabelsPreset && dataSet.getLabels() == null || !first && !anyLabelsPreset && dataSet.getLabels() != null) {
                throw new IllegalStateException("Cannot merge labels: enountered null labels in one or more DataSets");
            }
            anyFeaturesPreset |= dataSet.getFeatures() != null;
            anyLabelsPreset |= dataSet.getLabels() != null;
            first = false;
        }
        INDArray[] featuresToMerge = new INDArray[nonEmpty];
        INDArray[] iNDArrayArray = new INDArray[nonEmpty];
        INDArray[] featuresMasksToMerge = null;
        INDArray[] labelsMasksToMerge = null;
        int count = 0;
        for (org.nd4j.linalg.dataset.api.DataSet dataSet : data) {
            if (dataSet.isEmpty()) continue;
            featuresToMerge[count] = dataSet.getFeatures();
            iNDArrayArray[count] = dataSet.getLabels();
            if (dataSet.getFeaturesMaskArray() != null) {
                if (featuresMasksToMerge == null) {
                    featuresMasksToMerge = new INDArray[nonEmpty];
                }
                featuresMasksToMerge[count] = dataSet.getFeaturesMaskArray();
            }
            if (dataSet.getLabelsMaskArray() != null) {
                if (labelsMasksToMerge == null) {
                    labelsMasksToMerge = new INDArray[nonEmpty];
                }
                labelsMasksToMerge[count] = dataSet.getLabelsMaskArray();
            }
            ++count;
        }
        Pair<INDArray, INDArray> fp = DataSetUtil.mergeFeatures(featuresToMerge, featuresMasksToMerge);
        INDArray featuresOut = (INDArray)fp.getFirst();
        INDArray featuresMaskOut = (INDArray)fp.getSecond();
        Pair<INDArray, INDArray> lp = DataSetUtil.mergeLabels(iNDArrayArray, labelsMasksToMerge);
        INDArray iNDArray = (INDArray)lp.getFirst();
        INDArray labelsMaskOut = (INDArray)lp.getSecond();
        DataSet dataset = new DataSet(featuresOut, iNDArray, featuresMaskOut, labelsMaskOut);
        ArrayList<Serializable> meta = null;
        for (org.nd4j.linalg.dataset.api.DataSet dataSet : data) {
            if (dataSet.getExampleMetaData() == null || dataSet.getExampleMetaData().size() != dataSet.numExamples()) {
                meta = null;
                break;
            }
            if (meta == null) {
                meta = new ArrayList<Serializable>();
            }
            meta.addAll(dataSet.getExampleMetaData());
        }
        if (meta != null) {
            dataset.setExampleMetaData(meta);
        }
        return dataset;
    }

    @Override
    public org.nd4j.linalg.dataset.api.DataSet getRange(int from, int to) {
        if (this.hasMaskArrays()) {
            INDArray featureMaskHere = this.featuresMask != null ? this.featuresMask.get(NDArrayIndex.interval(from, to)) : null;
            INDArray labelMaskHere = this.labelsMask != null ? this.labelsMask.get(NDArrayIndex.interval(from, to)) : null;
            return new DataSet(this.features.get(NDArrayIndex.interval(from, to)), this.labels.get(NDArrayIndex.interval(from, to)), featureMaskHere, labelMaskHere);
        }
        return new DataSet(this.features.get(NDArrayIndex.interval(from, to)), this.labels.get(NDArrayIndex.interval(from, to)));
    }

    @Override
    public void load(InputStream from) {
        try {
            DataInputStream dis = from instanceof BufferedInputStream ? new DataInputStream(from) : new DataInputStream(new BufferedInputStream(from));
            byte included = dis.readByte();
            boolean hasFeatures = (included & 1) != 0;
            boolean hasLabels = (included & 2) != 0;
            boolean hasLabelsSameAsFeatures = (included & 4) != 0;
            boolean hasFeaturesMask = (included & 8) != 0;
            boolean hasLabelsMask = (included & 0x10) != 0;
            boolean hasMetaData = (included & 0x20) != 0;
            INDArray iNDArray = this.features = hasFeatures ? Nd4j.read(dis) : null;
            this.labels = hasLabels ? Nd4j.read(dis) : (hasLabelsSameAsFeatures ? this.features : null);
            this.featuresMask = hasFeaturesMask ? Nd4j.read(dis) : null;
            INDArray iNDArray2 = this.labelsMask = hasLabelsMask ? Nd4j.read(dis) : null;
            if (hasMetaData) {
                ObjectInputStream ois = new ObjectInputStream(dis);
                this.exampleMetaData = (List)ois.readObject();
            }
            dis.close();
        }
        catch (Exception e) {
            throw new RuntimeException("Error loading DataSet", e);
        }
    }

    @Override
    public void load(File from) {
        try (FileInputStream fis = new FileInputStream(from);
             BufferedInputStream bis = new BufferedInputStream(fis, 0x100000);){
            this.load(bis);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void save(OutputStream to) {
        byte included = 0;
        if (this.features != null) {
            included = (byte)(included | 1);
        }
        if (this.labels != null) {
            included = this.labels == this.features ? (byte)(included | 4) : (byte)(included | 2);
        }
        if (this.featuresMask != null) {
            included = (byte)(included | 8);
        }
        if (this.labelsMask != null) {
            included = (byte)(included | 0x10);
        }
        if (this.exampleMetaData != null && this.exampleMetaData.size() > 0) {
            included = (byte)(included | 0x20);
        }
        try {
            BufferedOutputStream bos = new BufferedOutputStream(to);
            DataOutputStream dos = new DataOutputStream(bos);
            dos.writeByte(included);
            if (this.features != null) {
                Nd4j.write(this.features, dos);
            }
            if (this.labels != null && this.labels != this.features) {
                Nd4j.write(this.labels, dos);
            }
            if (this.featuresMask != null) {
                Nd4j.write(this.featuresMask, dos);
            }
            if (this.labelsMask != null) {
                Nd4j.write(this.labelsMask, dos);
            }
            if (this.exampleMetaData != null && this.exampleMetaData.size() > 0) {
                ObjectOutputStream oos = new ObjectOutputStream(bos);
                oos.writeObject(this.exampleMetaData);
                oos.flush();
            }
            dos.flush();
            dos.close();
        }
        catch (Exception e) {
            log.error("", (Throwable)e);
        }
    }

    @Override
    public void save(File to) {
        try (FileOutputStream fos = new FileOutputStream(to, false);
             BufferedOutputStream bos = new BufferedOutputStream(fos);){
            this.save(bos);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public DataSetIterator iterateWithMiniBatches() {
        return null;
    }

    @Override
    public String id() {
        return "";
    }

    @Override
    public INDArray getFeatures() {
        return this.features;
    }

    @Override
    public void setFeatures(INDArray features) {
        this.features = features;
    }

    @Override
    public Map<Integer, Double> labelCounts() {
        HashMap<Integer, Double> ret = new HashMap<Integer, Double>();
        if (this.labels == null) {
            return ret;
        }
        long nTensors = this.labels.tensorsAlongDimension(1);
        int i = 0;
        while ((long)i < nTensors) {
            INDArray row = this.labels.tensorAlongDimension(i, 1);
            INDArray javaRow = this.labels.tensorAlongDimension(i, 1);
            int maxIdx = Nd4j.getBlasWrapper().iamax(row);
            int maxIdxJava = Nd4j.getBlasWrapper().iamax(javaRow);
            if (maxIdx < 0) {
                throw new IllegalStateException("Please check the iamax implementation for " + Nd4j.getBlasWrapper().getClass().getName());
            }
            if (ret.get(maxIdx) == null) {
                ret.put(maxIdx, 1.0);
            } else {
                ret.put(maxIdx, (Double)ret.get(maxIdx) + 1.0);
            }
            ++i;
        }
        return ret;
    }

    @Override
    public DataSet copy() {
        DataSet ret = new DataSet(this.getFeatures().dup(), this.getLabels().dup());
        if (this.getLabelsMaskArray() != null) {
            ret.setLabelsMaskArray(this.getLabelsMaskArray().dup());
        }
        if (this.getFeaturesMaskArray() != null) {
            ret.setFeaturesMaskArray(this.getFeaturesMaskArray().dup());
        }
        ret.setColumnNames(this.getColumnNames());
        ret.setLabelNames(this.getLabelNames());
        return ret;
    }

    @Override
    public DataSet reshape(int rows, int cols) {
        DataSet ret = new DataSet(this.getFeatures().reshape(new long[]{rows, cols}), this.getLabels());
        return ret;
    }

    @Override
    public void multiplyBy(double num) {
        this.getFeatures().muli(Nd4j.scalar(num));
    }

    @Override
    public void divideBy(int num) {
        this.getFeatures().divi(Nd4j.scalar(num));
    }

    @Override
    public void shuffle() {
        long seed = System.currentTimeMillis();
        this.shuffle(seed);
    }

    public void shuffle(long seed) {
        if (this.numExamples() < 2) {
            return;
        }
        ArrayList<INDArray> arrays = new ArrayList<INDArray>();
        ArrayList<int[]> dimensions = new ArrayList<int[]>();
        arrays.add(this.getFeatures());
        dimensions.add(ArrayUtil.range((int)1, (int)this.getFeatures().rank()));
        arrays.add(this.getLabels());
        dimensions.add(ArrayUtil.range((int)1, (int)this.getLabels().rank()));
        if (this.featuresMask != null) {
            arrays.add(this.getFeaturesMaskArray());
            dimensions.add(ArrayUtil.range((int)1, (int)this.getFeaturesMaskArray().rank()));
        }
        if (this.labelsMask != null) {
            arrays.add(this.getLabelsMaskArray());
            dimensions.add(ArrayUtil.range((int)1, (int)this.getLabelsMaskArray().rank()));
        }
        Nd4j.shuffle(arrays, new Random(seed), dimensions);
        if (this.exampleMetaData != null) {
            int[] map = ArrayUtil.buildInterleavedVector((Random)new Random(seed), (int)this.numExamples());
            ArrayUtil.shuffleWithMap(this.exampleMetaData, (int[])map);
        }
    }

    @Override
    public void squishToRange(double min, double max) {
        int i = 0;
        while ((long)i < this.getFeatures().length()) {
            double curr = (Double)this.getFeatures().getScalar((long)i).element();
            if (curr < min) {
                this.getFeatures().put(i, Nd4j.scalar(min));
            } else if (curr > max) {
                this.getFeatures().put(i, Nd4j.scalar(max));
            }
            ++i;
        }
    }

    @Override
    public void scaleMinAndMax(double min, double max) {
        FeatureUtil.scaleMinMax(min, max, this.getFeatures());
    }

    @Override
    public void scale() {
        FeatureUtil.scaleByMax(this.getFeatures());
    }

    @Override
    public void addFeatureVector(INDArray toAdd) {
        this.setFeatures(Nd4j.hstack(this.getFeatures(), toAdd));
    }

    @Override
    public void addFeatureVector(INDArray feature, int example) {
        this.getFeatures().putRow(example, feature);
    }

    @Override
    public void normalize() {
        NormalizerStandardize inClassPreProcessor = new NormalizerStandardize();
        inClassPreProcessor.fit(this);
        inClassPreProcessor.transform(this);
    }

    @Override
    public void binarize() {
        this.binarize(0.0);
    }

    @Override
    public void binarize(double cutoff) {
        INDArray linear = this.getFeatures().reshape(-1L);
        int i = 0;
        while ((long)i < this.getFeatures().length()) {
            double curr = linear.getDouble((long)i);
            if (curr > cutoff) {
                this.getFeatures().putScalar((long)i, 1);
            } else {
                this.getFeatures().putScalar((long)i, 0);
            }
            ++i;
        }
    }

    @Override
    @Deprecated
    public void normalizeZeroMeanZeroUnitVariance() {
        INDArray columnMeans = this.getFeatures().mean(0);
        INDArray columnStds = this.getFeatures().std(0);
        this.setFeatures(this.getFeatures().subiRowVector(columnMeans));
        columnStds.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
        this.setFeatures(this.getFeatures().diviRowVector(columnStds));
    }

    @Override
    public int numInputs() {
        return (int)this.getFeatures().size(1);
    }

    @Override
    public void validate() {
        if (this.getFeatures().size(0) != this.getLabels().size(0)) {
            throw new IllegalStateException("Invalid dataset");
        }
    }

    @Override
    public int outcome() {
        return Nd4j.getBlasWrapper().iamax(this.getLabels());
    }

    @Override
    public void setNewNumberOfLabels(int labels) {
        int examples = this.numExamples();
        INDArray newOutcomes = Nd4j.create(examples, labels);
        this.setLabels(newOutcomes);
    }

    @Override
    public void setOutcome(int example, int label) {
        if (example > this.numExamples()) {
            throw new IllegalArgumentException("No example at " + example);
        }
        if (label > this.numOutcomes() || label < 0) {
            throw new IllegalArgumentException("Illegal label");
        }
        INDArray outcome = FeatureUtil.toOutcomeVector(label, this.numOutcomes());
        this.getLabels().putRow(example, outcome);
    }

    @Override
    public DataSet get(int i) {
        if (i >= this.numExamples() || i < 0) {
            throw new IllegalArgumentException("invalid example number: must be 0 to " + (this.numExamples() - 1) + ", got " + i);
        }
        if (i == 0 && this.numExamples() == 1) {
            return this;
        }
        return new DataSet(this.getHelper(this.features, i), this.getHelper(this.labels, i), this.getHelper(this.featuresMask, i), this.getHelper(this.labelsMask, i));
    }

    @Override
    public DataSet get(int[] i) {
        ArrayList<DataSet> list = new ArrayList<DataSet>();
        for (int ex : i) {
            list.add(this.get(ex));
        }
        return DataSet.merge(list);
    }

    @Override
    public List<DataSet> batchBy(int num) {
        ArrayList batched = Lists.newArrayList();
        for (List splitBatch : Lists.partition(this.asList(), (int)num)) {
            batched.add(DataSet.merge(splitBatch));
        }
        return batched;
    }

    @Override
    public DataSet filterBy(int[] labels) {
        List<DataSet> list = this.asList();
        ArrayList<DataSet> newList = new ArrayList<DataSet>();
        ArrayList<Integer> labelList = new ArrayList<Integer>();
        for (int i : labels) {
            labelList.add(i);
        }
        Object object = list.iterator();
        while (object.hasNext()) {
            DataSet d = (DataSet)object.next();
            int outcome = d.outcome();
            if (!labelList.contains(outcome)) continue;
            newList.add(d);
        }
        return DataSet.merge(newList);
    }

    @Override
    public void filterAndStrip(int[] labels) {
        int i;
        DataSet filtered = this.filterBy(labels);
        ArrayList<Integer> newLabels = new ArrayList<Integer>();
        HashMap<Integer, Integer> labelMap = new HashMap<Integer, Integer>();
        for (i = 0; i < labels.length; ++i) {
            labelMap.put(labels[i], i);
        }
        for (i = 0; i < filtered.numExamples(); ++i) {
            DataSet example = filtered.get(i);
            int o2 = example.outcome();
            Integer outcome = (Integer)labelMap.get(o2);
            newLabels.add(outcome);
        }
        INDArray newLabelMatrix = Nd4j.create(filtered.numExamples(), labels.length);
        if (newLabelMatrix.rows() != newLabels.size()) {
            throw new IllegalStateException("Inconsistent label sizes");
        }
        for (int i2 = 0; i2 < newLabelMatrix.rows(); ++i2) {
            Integer i22 = (Integer)newLabels.get(i2);
            if (i22 == null) {
                throw new IllegalStateException("Label not found on row " + i2);
            }
            INDArray newRow = FeatureUtil.toOutcomeVector(i22.intValue(), labels.length);
            newLabelMatrix.putRow(i2, newRow);
        }
        this.setFeatures(filtered.getFeatures());
        this.setLabels(newLabelMatrix);
    }

    @Override
    public List<DataSet> dataSetBatches(int num) {
        List list = Lists.partition(this.asList(), (int)num);
        ArrayList<DataSet> ret = new ArrayList<DataSet>();
        for (List l : list) {
            ret.add(DataSet.merge(l));
        }
        return ret;
    }

    @Override
    public List<DataSet> sortAndBatchByNumLabels() {
        this.sortByLabel();
        return this.batchByNumLabels();
    }

    @Override
    public List<DataSet> batchByNumLabels() {
        return this.batchBy(this.numOutcomes());
    }

    @Override
    public List<DataSet> asList() {
        ArrayList<DataSet> list = new ArrayList<DataSet>(this.numExamples());
        int rank = this.getFeatures().rank();
        int labelsRank = this.getLabels().rank();
        for (int i = 0; i < this.numExamples(); ++i) {
            INDArray featuresHere = this.getHelper(this.getFeatures(), i);
            INDArray featureMaskHere = this.getHelper(this.featuresMask, i);
            INDArray labelsHere = this.getHelper(this.labels, i);
            INDArray labelMaskHere = this.getHelper(this.labelsMask, i);
            DataSet ds = new DataSet(featuresHere, labelsHere, featureMaskHere, labelMaskHere);
            if (this.exampleMetaData != null && this.exampleMetaData.size() > i) {
                ds.setExampleMetaData(Collections.singletonList(this.exampleMetaData.get(i)));
            }
            list.add(ds);
        }
        return list;
    }

    private INDArray getHelper(INDArray from, int i) {
        if (from == null) {
            return null;
        }
        switch (from.rank()) {
            case 2: {
                return from.get(NDArrayIndex.interval((long)i, (long)i, true), NDArrayIndex.all());
            }
            case 3: {
                return from.get(NDArrayIndex.interval((long)i, (long)i, true), NDArrayIndex.all(), NDArrayIndex.all());
            }
            case 4: {
                return from.get(NDArrayIndex.interval((long)i, (long)i, true), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
            }
            case 5: {
                return from.get(NDArrayIndex.interval((long)i, (long)i, true), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
            }
        }
        throw new IllegalStateException("Cannot convert to list: feature set rank must be in range 2 to 5 inclusive. Got shape: " + Arrays.toString(from.shape()));
    }

    @Override
    public SplitTestAndTrain splitTestAndTrain(int numHoldout, Random rng) {
        long seed = rng.nextLong();
        this.shuffle(seed);
        return this.splitTestAndTrain(numHoldout);
    }

    @Override
    public SplitTestAndTrain splitTestAndTrain(int numHoldout) {
        int numExamples = this.numExamples();
        if (numExamples <= 1) {
            throw new IllegalStateException("Cannot split DataSet with <= 1 rows (data set has " + numExamples + " example)");
        }
        if (numHoldout >= numExamples) {
            throw new IllegalArgumentException("Unable to split on size equal or larger than the number of rows (# numExamples=" + numExamples + ", numHoldout=" + numHoldout + ")");
        }
        DataSet first = new DataSet();
        DataSet second = new DataSet();
        switch (this.features.rank()) {
            case 2: {
                first.setFeatures(this.features.get(NDArrayIndex.interval(0, numHoldout), NDArrayIndex.all()));
                second.setFeatures(this.features.get(NDArrayIndex.interval(numHoldout, numExamples), NDArrayIndex.all()));
                break;
            }
            case 3: {
                first.setFeatures(this.features.get(NDArrayIndex.interval(0, numHoldout), NDArrayIndex.all(), NDArrayIndex.all()));
                second.setFeatures(this.features.get(NDArrayIndex.interval(numHoldout, numExamples), NDArrayIndex.all(), NDArrayIndex.all()));
                break;
            }
            case 4: {
                first.setFeatures(this.features.get(NDArrayIndex.interval(0, numHoldout), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
                second.setFeatures(this.features.get(NDArrayIndex.interval(numHoldout, numExamples), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
                break;
            }
            default: {
                throw new UnsupportedOperationException("Features rank: " + this.features.rank());
            }
        }
        switch (this.labels.rank()) {
            case 2: {
                first.setLabels(this.labels.get(NDArrayIndex.interval(0, numHoldout), NDArrayIndex.all()));
                second.setLabels(this.labels.get(NDArrayIndex.interval(numHoldout, numExamples), NDArrayIndex.all()));
                break;
            }
            case 3: {
                first.setLabels(this.labels.get(NDArrayIndex.interval(0, numHoldout), NDArrayIndex.all(), NDArrayIndex.all()));
                second.setLabels(this.labels.get(NDArrayIndex.interval(numHoldout, numExamples), NDArrayIndex.all(), NDArrayIndex.all()));
                break;
            }
            case 4: {
                first.setLabels(this.labels.get(NDArrayIndex.interval(0, numHoldout), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
                second.setLabels(this.labels.get(NDArrayIndex.interval(numHoldout, numExamples), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
                break;
            }
            default: {
                throw new UnsupportedOperationException("Labels rank: " + this.features.rank());
            }
        }
        if (this.featuresMask != null) {
            first.setFeaturesMaskArray(this.featuresMask.get(NDArrayIndex.interval(0, numHoldout), NDArrayIndex.all()));
            second.setFeaturesMaskArray(this.featuresMask.get(NDArrayIndex.interval(numHoldout, numExamples), NDArrayIndex.all()));
        }
        if (this.labelsMask != null) {
            first.setLabelsMaskArray(this.labelsMask.get(NDArrayIndex.interval(0, numHoldout), NDArrayIndex.all()));
            second.setLabelsMaskArray(this.labelsMask.get(NDArrayIndex.interval(numHoldout, numExamples), NDArrayIndex.all()));
        }
        if (this.exampleMetaData != null) {
            int i;
            ArrayList<Serializable> meta1 = new ArrayList<Serializable>();
            ArrayList<Serializable> meta2 = new ArrayList<Serializable>();
            for (i = 0; i < numHoldout && i < this.exampleMetaData.size(); ++i) {
                meta1.add(this.exampleMetaData.get(i));
            }
            for (i = numHoldout; i < numExamples && i < this.exampleMetaData.size(); ++i) {
                meta2.add(this.exampleMetaData.get(i));
            }
            first.setExampleMetaData(meta1);
            second.setExampleMetaData(meta2);
        }
        return new SplitTestAndTrain(first, second);
    }

    @Override
    public INDArray getLabels() {
        return this.labels;
    }

    @Override
    public String getLabelName(int idx) {
        if (!this.labelNames.isEmpty()) {
            if (idx < this.labelNames.size()) {
                return this.labelNames.get(idx);
            }
            throw new IllegalStateException("Index requested is longer than the number of labels used for classification.");
        }
        throw new IllegalStateException("Label names are not defined on this dataset. Add label names in order to use getLabelName with an id.");
    }

    @Override
    public List<String> getLabelNames(INDArray idxs) {
        ArrayList<String> ret = new ArrayList<String>();
        int i = 0;
        while ((long)i < idxs.length()) {
            ret.add(i, this.getLabelName(i));
            ++i;
        }
        return ret;
    }

    @Override
    public void setLabels(INDArray labels) {
        this.labels = labels;
    }

    /*
     * WARNING - void declaration
     */
    @Override
    public void sortByLabel() {
        void var6_11;
        Queue q;
        HashMap<Integer, ArrayDeque<DataSet>> map = new HashMap<Integer, ArrayDeque<DataSet>>();
        List<DataSet> data = this.asList();
        int numLabels = this.numOutcomes();
        int examples = this.numExamples();
        for (DataSet dataSet : data) {
            int label = dataSet.outcome();
            q = (ArrayDeque<DataSet>)map.get(label);
            if (q == null) {
                q = new ArrayDeque<DataSet>();
                map.put(label, (ArrayDeque<DataSet>)q);
            }
            q.add(dataSet);
        }
        for (Map.Entry entry : map.entrySet()) {
            log.info("Label " + entry + " has " + ((Queue)entry.getValue()).size() + " elements");
        }
        boolean optimal = true;
        boolean bl = false;
        while (var6_11 < examples) {
            if (optimal) {
                for (int j = 0; j < numLabels; ++j) {
                    q = (Queue)map.get(j);
                    if (q == null) {
                        optimal = false;
                    } else {
                        DataSet next = (DataSet)q.poll();
                        if (next != null) {
                            this.addRow(next, (int)var6_11);
                            ++var6_11;
                            continue;
                        }
                        optimal = false;
                    }
                    break;
                }
            } else {
                DataSet add = null;
                for (Queue q2 : map.values()) {
                    if (q2.isEmpty()) continue;
                    add = (DataSet)q2.poll();
                    break;
                }
                this.addRow(add, (int)var6_11);
            }
            ++var6_11;
        }
    }

    @Override
    public void addRow(DataSet d, int i) {
        if (i > this.numExamples() || d == null) {
            throw new IllegalArgumentException("Invalid index for adding a row");
        }
        this.getFeatures().putRow(i, d.getFeatures());
        this.getLabels().putRow(i, d.getLabels());
    }

    private int getLabel(DataSet data) {
        Float f = Float.valueOf(data.getLabels().maxNumber().floatValue());
        return f.intValue();
    }

    @Override
    public INDArray exampleSums() {
        return this.getFeatures().sum(1);
    }

    @Override
    public INDArray exampleMaxs() {
        return this.getFeatures().max(1);
    }

    @Override
    public INDArray exampleMeans() {
        return this.getFeatures().mean(1);
    }

    @Override
    public DataSet sample(int numSamples) {
        return this.sample(numSamples, Nd4j.getRandom());
    }

    @Override
    public DataSet sample(int numSamples, org.nd4j.linalg.api.rng.Random rng) {
        return this.sample(numSamples, rng, false);
    }

    @Override
    public DataSet sample(int numSamples, boolean withReplacement) {
        return this.sample(numSamples, Nd4j.getRandom(), withReplacement);
    }

    @Override
    public DataSet sample(int numSamples, org.nd4j.linalg.api.rng.Random rng, boolean withReplacement) {
        HashSet<Integer> added = new HashSet<Integer>();
        ArrayList<DataSet> toMerge = new ArrayList<DataSet>();
        boolean terminate = false;
        for (int i = 0; i < numSamples && !terminate; ++i) {
            int picked = rng.nextInt(this.numExamples());
            if (!withReplacement) {
                while (added.contains(picked)) {
                    picked = rng.nextInt(this.numExamples());
                    if (added.size() != this.numExamples()) continue;
                    terminate = true;
                    break;
                }
            }
            added.add(picked);
            toMerge.add(this.get(picked));
        }
        return DataSet.merge(toMerge);
    }

    @Override
    public void roundToTheNearest(int roundTo) {
        int i = 0;
        while ((long)i < this.getFeatures().length()) {
            double curr = (Double)this.getFeatures().getScalar((long)i).element();
            this.getFeatures().put(i, Nd4j.scalar(MathUtils.roundDouble((double)curr, (int)roundTo)));
            ++i;
        }
    }

    @Override
    public int numOutcomes() {
        return (int)this.getLabels().size(1);
    }

    @Override
    public int numExamples() {
        if (this.getFeatures() != null) {
            return (int)this.getFeatures().size(0);
        }
        if (this.getLabels() != null) {
            return (int)this.getLabels().size(0);
        }
        return 0;
    }

    public String toString() {
        StringBuilder builder = new StringBuilder();
        if (this.features != null && this.labels != null) {
            builder.append("===========INPUT===================\n").append(this.getFeatures().toString().replaceAll(";", "\n")).append("\n=================OUTPUT==================\n").append(this.getLabels().toString().replaceAll(";", "\n"));
            if (this.featuresMask != null) {
                builder.append("\n===========INPUT MASK===================\n").append(this.getFeaturesMaskArray().toString().replaceAll(";", "\n"));
            }
            if (this.labelsMask != null) {
                builder.append("\n===========OUTPUT MASK===================\n").append(this.getLabelsMaskArray().toString().replaceAll(";", "\n"));
            }
            return builder.toString();
        }
        log.info("Features or labels are null values");
        return "";
    }

    @Override
    @Deprecated
    public List<String> getLabelNames() {
        return this.labelNames;
    }

    @Override
    public List<String> getLabelNamesList() {
        return this.labelNames;
    }

    @Override
    public void setLabelNames(List<String> labelNames) {
        this.labelNames = labelNames;
    }

    @Override
    @Deprecated
    public List<String> getColumnNames() {
        return this.columnNames;
    }

    @Override
    @Deprecated
    public void setColumnNames(List<String> columnNames) {
        this.columnNames = columnNames;
    }

    @Override
    public SplitTestAndTrain splitTestAndTrain(double fractionTrain) {
        Preconditions.checkArgument((fractionTrain > 0.0 && fractionTrain < 1.0 ? 1 : 0) != 0, (String)"Train fraction must be > 0.0 and < 1.0 - got %s", (double)fractionTrain);
        int numTrain = (int)(fractionTrain * (double)this.numExamples());
        if (numTrain <= 0) {
            numTrain = 1;
        }
        return this.splitTestAndTrain(numTrain);
    }

    @Override
    public Iterator<DataSet> iterator() {
        return this.asList().iterator();
    }

    @Override
    public INDArray getFeaturesMaskArray() {
        return this.featuresMask;
    }

    @Override
    public void setFeaturesMaskArray(INDArray featuresMask) {
        this.featuresMask = featuresMask;
    }

    @Override
    public INDArray getLabelsMaskArray() {
        return this.labelsMask;
    }

    @Override
    public void setLabelsMaskArray(INDArray labelsMask) {
        this.labelsMask = labelsMask;
    }

    @Override
    public boolean hasMaskArrays() {
        return this.labelsMask != null || this.featuresMask != null;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (!(o instanceof DataSet)) {
            return false;
        }
        DataSet d = (DataSet)o;
        if (!DataSet.equalOrBothNull(this.features, d.features)) {
            return false;
        }
        if (!DataSet.equalOrBothNull(this.labels, d.labels)) {
            return false;
        }
        if (!DataSet.equalOrBothNull(this.featuresMask, d.featuresMask)) {
            return false;
        }
        return DataSet.equalOrBothNull(this.labelsMask, d.labelsMask);
    }

    private static boolean equalOrBothNull(INDArray first, INDArray second) {
        if (first == null && second == null) {
            return true;
        }
        if (first == null || second == null) {
            return false;
        }
        return first.equals(second);
    }

    public int hashCode() {
        int result = this.getFeatures() != null ? this.getFeatures().hashCode() : 0;
        result = 31 * result + (this.getLabels() != null ? this.getLabels().hashCode() : 0);
        result = 31 * result + (this.getFeaturesMaskArray() != null ? this.getFeaturesMaskArray().hashCode() : 0);
        result = 31 * result + (this.getLabelsMaskArray() != null ? this.getLabelsMaskArray().hashCode() : 0);
        return result;
    }

    @Override
    public long getMemoryFootprint() {
        long reqMem = this.features.length() * (long)Nd4j.sizeOfDataType(this.features.dataType());
        reqMem += this.labels == null ? 0L : this.labels.length() * (long)Nd4j.sizeOfDataType(this.labels.dataType());
        reqMem += this.featuresMask == null ? 0L : this.featuresMask.length() * (long)Nd4j.sizeOfDataType(this.featuresMask.dataType());
        return reqMem += this.labelsMask == null ? 0L : this.labelsMask.length() * (long)Nd4j.sizeOfDataType(this.labelsMask.dataType());
    }

    @Override
    public void migrate() {
        if (Nd4j.getMemoryManager().getCurrentWorkspace() != null) {
            if (this.features != null) {
                this.features = this.features.migrate();
            }
            if (this.labels != null) {
                this.labels = this.labels.migrate();
            }
            if (this.featuresMask != null) {
                this.featuresMask = this.featuresMask.migrate();
            }
            if (this.labelsMask != null) {
                this.labelsMask = this.labelsMask.migrate();
            }
        }
    }

    @Override
    public void detach() {
        if (this.features != null) {
            this.features = this.features.detach();
        }
        if (this.labels != null) {
            this.labels = this.labels.detach();
        }
        if (this.featuresMask != null) {
            this.featuresMask = this.featuresMask.detach();
        }
        if (this.labelsMask != null) {
            this.labelsMask = this.labelsMask.detach();
        }
    }

    @Override
    public boolean isEmpty() {
        return this.features == null && this.labels == null && this.featuresMask == null && this.labelsMask == null;
    }

    @Override
    public org.nd4j.linalg.dataset.api.MultiDataSet toMultiDataSet() {
        INDArray[] iNDArrayArray;
        INDArray[] fMaskNew;
        INDArray[] iNDArrayArray2;
        INDArray[] iNDArrayArray3;
        INDArray[] iNDArrayArray4;
        INDArray f = this.getFeatures();
        INDArray l = this.getLabels();
        INDArray fMask = this.getFeaturesMaskArray();
        INDArray lMask = this.getLabelsMaskArray();
        if (f == null) {
            iNDArrayArray4 = null;
        } else {
            INDArray[] iNDArrayArray5 = new INDArray[1];
            iNDArrayArray4 = iNDArrayArray5;
            iNDArrayArray5[0] = f;
        }
        INDArray[] fNew = iNDArrayArray4;
        if (l == null) {
            iNDArrayArray3 = null;
        } else {
            INDArray[] iNDArrayArray6 = new INDArray[1];
            iNDArrayArray3 = iNDArrayArray6;
            iNDArrayArray6[0] = l;
        }
        INDArray[] lNew = iNDArrayArray3;
        if (fMask != null) {
            INDArray[] iNDArrayArray7 = new INDArray[1];
            iNDArrayArray2 = iNDArrayArray7;
            iNDArrayArray7[0] = fMask;
        } else {
            iNDArrayArray2 = fMaskNew = null;
        }
        if (lMask != null) {
            INDArray[] iNDArrayArray8 = new INDArray[1];
            iNDArrayArray = iNDArrayArray8;
            iNDArrayArray8[0] = lMask;
        } else {
            iNDArrayArray = null;
        }
        INDArray[] lMaskNew = iNDArrayArray;
        return new MultiDataSet(fNew, lNew, fMaskNew, lMaskNew, this.exampleMetaData);
    }
}

