/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.face.utils;

import cn.smartjavaai.face.enums.SimilarityType;

public class SimilarityUtil {
    public static float calculate(float[] features1, float[] features2, SimilarityType similarityType, boolean normalizeScore) {
        SimilarityUtil.validateInput(features1, features2);
        switch (similarityType) {
            case IP: {
                return SimilarityUtil.innerProductSimilarity(features1, features2, normalizeScore);
            }
            case L2: {
                return SimilarityUtil.euclideanSimilarity(features1, features2, normalizeScore);
            }
            case COSINE: {
                return SimilarityUtil.cosineSimilarity(features1, features2, normalizeScore);
            }
        }
        throw new IllegalArgumentException("\u4e0d\u652f\u6301\u7684\u76f8\u4f3c\u5ea6\u8ba1\u7b97\u7c7b\u578b: " + (Object)((Object)similarityType));
    }

    private static float innerProductSimilarity(float[] v1, float[] v2, boolean normalize) {
        float dot = SimilarityUtil.dotProduct(v1, v2);
        return normalize ? (dot + 1.0f) / 2.0f : dot;
    }

    private static float euclideanSimilarity(float[] v1, float[] v2, boolean normalize) {
        float dist = SimilarityUtil.euclideanDistance(v1, v2);
        return normalize ? 1.0f / (1.0f + dist) : dist;
    }

    private static float cosineSimilarity(float[] v1, float[] v2, boolean normalize) {
        float dot = SimilarityUtil.dotProduct(v1, v2);
        float norm1 = SimilarityUtil.vectorNorm(v1);
        float norm2 = SimilarityUtil.vectorNorm(v2);
        if (norm1 <= 0.0f || norm2 <= 0.0f) {
            return 0.0f;
        }
        float cosine = dot / (norm1 * norm2);
        return normalize ? (cosine + 1.0f) / 2.0f : cosine;
    }

    public static float dotProduct(float[] v1, float[] v2) {
        float sum = 0.0f;
        for (int i = 0; i < v1.length; ++i) {
            sum += v1[i] * v2[i];
        }
        return sum;
    }

    public static float euclideanDistance(float[] v1, float[] v2) {
        float sumSquaredDiff = 0.0f;
        for (int i = 0; i < v1.length; ++i) {
            float diff = v1[i] - v2[i];
            sumSquaredDiff += diff * diff;
        }
        return (float)Math.sqrt(sumSquaredDiff);
    }

    public static float vectorNorm(float[] vector) {
        float sum = 0.0f;
        for (float v : vector) {
            sum += v * v;
        }
        return (float)Math.sqrt(sum);
    }

    private static void validateInput(float[] v1, float[] v2) {
        if (v1 == null || v2 == null) {
            throw new IllegalArgumentException("\u7279\u5f81\u5411\u91cf\u4e0d\u80fd\u4e3anull");
        }
        if (v1.length == 0 || v2.length == 0) {
            throw new IllegalArgumentException("\u7279\u5f81\u5411\u91cf\u4e0d\u80fd\u4e3a\u7a7a");
        }
        if (v1.length != v2.length) {
            throw new IllegalArgumentException("\u7279\u5f81\u5411\u91cf\u957f\u5ea6\u4e0d\u4e00\u81f4: " + v1.length + " vs " + v2.length);
        }
    }
}

