/*
 * Decompiled with CFR 0.152.
 */
package io.milvus.common.utils;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.ShortBuffer;
import java.util.ArrayList;
import java.util.List;

public class Float16Utils {
    public static short floatToBf16(float input) {
        int bits = Float.floatToIntBits(input);
        int lsb = bits >> 16 & 1;
        int roundingBias = Short.MAX_VALUE + lsb;
        return (short)((bits += roundingBias) >> 16);
    }

    public static float bf16ToFloat(short input) {
        int bits = input << 16;
        return Float.intBitsToFloat(bits);
    }

    public static short floatToFp16(float input) {
        short output;
        int sign;
        int bits = Float.floatToIntBits(input);
        int F32_INFINITY = Float.floatToIntBits(Float.POSITIVE_INFINITY);
        int F16_MAX = 1199570944;
        int DENORM_MAGIC = 0x3F000000;
        int SIGN_MASK = Integer.MIN_VALUE;
        int ROUNDING_CONST = -939520001;
        if ((bits ^= (sign = bits & Integer.MIN_VALUE)) >= 1199570944) {
            output = bits > F32_INFINITY ? (short)32256 : 31744;
        } else if (bits < 0x38800000) {
            float tmp = Float.intBitsToFloat(bits) + Float.intBitsToFloat(0x3F000000);
            output = (short)(Float.floatToIntBits(tmp) - 0x3F000000);
        } else {
            int mant_odd = bits >> 13 & 1;
            bits -= 939520001;
            output = (short)((bits += mant_odd) >> 13);
        }
        output = (short)(output | (short)(sign >> 16));
        return output;
    }

    public static float fp16ToFloat(short input) {
        int MAGIC = 0x38800000;
        int SHIFTED_EXP = 0xF800000;
        int bits = (input & Short.MAX_VALUE) << 13;
        int exp = 0xF800000 & bits;
        bits += 0x38000000;
        if (exp == 0xF800000) {
            bits += 0x38000000;
        } else if (exp == 0) {
            float tmp = Float.intBitsToFloat(bits += 0x800000) - Float.intBitsToFloat(0x38800000);
            bits = Float.floatToIntBits(tmp);
        }
        return Float.intBitsToFloat(bits |= (input & 0x8000) << 16);
    }

    public static ByteBuffer f32VectorToBf16Buffer(List<Float> vector) {
        if (vector.isEmpty()) {
            return null;
        }
        ByteBuffer buf = ByteBuffer.allocate(2 * vector.size());
        buf.order(ByteOrder.LITTLE_ENDIAN);
        for (Float val : vector) {
            short bf16 = Float16Utils.floatToBf16(val.floatValue());
            buf.putShort(bf16);
        }
        return buf;
    }

    public static List<Float> fp16BufferToVector(ByteBuffer buf) {
        buf.rewind();
        ArrayList<Float> vector = new ArrayList<Float>();
        ShortBuffer sbuf = buf.asShortBuffer();
        for (int i = 0; i < sbuf.limit(); ++i) {
            float val = Float16Utils.fp16ToFloat(sbuf.get(i));
            vector.add(Float.valueOf(val));
        }
        return vector;
    }

    public static ByteBuffer f32VectorToFp16Buffer(List<Float> vector) {
        if (vector.isEmpty()) {
            return null;
        }
        ByteBuffer buf = ByteBuffer.allocate(2 * vector.size());
        buf.order(ByteOrder.LITTLE_ENDIAN);
        for (Float val : vector) {
            short bf16 = Float16Utils.floatToFp16(val.floatValue());
            buf.putShort(bf16);
        }
        return buf;
    }

    public static List<Float> bf16BufferToVector(ByteBuffer buf) {
        buf.rewind();
        ArrayList<Float> vector = new ArrayList<Float>();
        ShortBuffer sbuf = buf.asShortBuffer();
        for (int i = 0; i < sbuf.limit(); ++i) {
            float val = Float16Utils.bf16ToFloat(sbuf.get(i));
            vector.add(Float.valueOf(val));
        }
        return vector;
    }

    public static ByteBuffer f16VectorToBuffer(List<Short> vector) {
        if (vector.isEmpty()) {
            return null;
        }
        ByteBuffer buf = ByteBuffer.allocate(2 * vector.size());
        buf.order(ByteOrder.LITTLE_ENDIAN);
        for (Short val : vector) {
            buf.putShort(val);
        }
        return buf;
    }

    public static List<Short> bufferToF16Vector(ByteBuffer buf) {
        buf.rewind();
        ArrayList<Short> vector = new ArrayList<Short>();
        ShortBuffer sbuf = buf.asShortBuffer();
        for (int i = 0; i < sbuf.limit(); ++i) {
            vector.add(sbuf.get(i));
        }
        return vector;
    }
}

