/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.common.utils;

import ai.djl.ndarray.NDArray;
import java.util.ArrayList;

public class NMSUtils {
    public static int[] nms(NDArray boxes, NDArray scores, float iouThreshold) {
        if (boxes.isEmpty()) {
            return new int[0];
        }
        NDArray x1 = boxes.get(":, 0", new Object[0]);
        NDArray y1 = boxes.get(":, 1", new Object[0]);
        NDArray x2 = boxes.get(":, 2", new Object[0]);
        NDArray y2 = boxes.get(":, 3", new Object[0]);
        NDArray areas = x2.sub(x1).add((Number)1).mul(y2.sub(y1).add((Number)1));
        NDArray order = scores.argSort().flip(new int[]{0});
        ArrayList<Integer> keep = new ArrayList<Integer>();
        while (order.size() > 0L) {
            int idx = (int)order.getLong(new long[]{0L});
            keep.add(idx);
            if (order.size() == 1L) break;
            NDArray currentBox = boxes.get(new long[]{idx});
            NDArray others = boxes.get(order);
            NDArray xx1 = x1.get(order).maximum(x1.get(new long[]{idx}));
            NDArray yy1 = y1.get(order).maximum(y1.get(new long[]{idx}));
            NDArray xx2 = x2.get(order).minimum(x2.get(new long[]{idx}));
            NDArray yy2 = y2.get(order).minimum(y2.get(new long[]{idx}));
            NDArray w = xx2.sub(xx1).add((Number)1).maximum((Number)0);
            NDArray h = yy2.sub(yy1).add((Number)1).maximum((Number)0);
            NDArray inter = w.mul(h);
            NDArray remAreas = areas.get(order);
            NDArray union = remAreas.add(areas.get(new long[]{idx})).sub(inter);
            NDArray iou = inter.div(union);
            NDArray mask = iou.lte((Number)Float.valueOf(iouThreshold));
            order = order.get(mask);
        }
        return keep.stream().mapToInt(i -> i).toArray();
    }
}

