package com.tydic.dyc.ssc.repository.aop;

import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.springframework.core.env.Environment;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

@Slf4j
@Aspect
@Component
public class BatchInsertAspect {

    private final Environment environment;

    public BatchInsertAspect(Environment environment) {
        this.environment = environment;
    }


    /**
     * 拦截所有带有@BatchInsert注解的方法
     */
    @Around("@annotation(batchInsert)")
    public Object aroundBatchInsert(ProceedingJoinPoint joinPoint, BatchInsert batchInsert) throws Throwable {
        log.error("BatchInsertAspect>进入切面");
        // 解析批次大小
        int batchSize = parseBatchSize(batchInsert.batchSize());
        log.error("BatchInsertAspect>批次大小"+batchSize);
        // 获取方法参数
        Object[] args = joinPoint.getArgs();
        int paramIndex = batchInsert.paramIndex();

        // 验证参数
        validateParameters(paramIndex, args);

        // 获取集合参数
        Collection<?> collection = getCollectionParameter(args, paramIndex);
        if (CollectionUtils.isEmpty(collection)) {
            return joinPoint.proceed();
        }

        // 如果集合大小小于等于批次大小，直接执行
        if (collection.size() <= batchSize) {
            return joinPoint.proceed();
        }

        // 分批处理
        return processInBatches(joinPoint, args, paramIndex, collection, batchSize);
    }

    /**
     * 解析批次大小
     */
    private int parseBatchSize(String batchSizeExpression) {
        try {
            // 处理 ${key:default} 格式的配置
            if (batchSizeExpression.startsWith("${") && batchSizeExpression.endsWith("}")) {
                String configKey = batchSizeExpression.substring(2, batchSizeExpression.length() - 1);

                // 处理带默认值的配置
                if (configKey.contains(":")) {
                    String[] parts = configKey.split(":");
                    String key = parts[0].trim();
                    String defaultValue = parts[1].trim();

                    String value = environment.getProperty(key);
                    return value != null ? Integer.parseInt(value) : Integer.parseInt(defaultValue);
                }

                // 不带默认值的配置
                String value = environment.getProperty(configKey);
                if (value == null) {
                    throw new IllegalArgumentException("配置项不存在: " + configKey);
                }
                return Integer.parseInt(value);
            }

            // 直接是数字
            return Integer.parseInt(batchSizeExpression);

        } catch (Exception e) {
            throw new IllegalArgumentException("无效的批次大小配置: " + batchSizeExpression, e);
        }
    }

    /**
     * 验证参数
     */
    private void validateParameters(int paramIndex, Object[] args) {
        if (paramIndex < 0 || paramIndex >= args.length) {
            throw new IllegalArgumentException("参数索引超出范围");
        }
    }

    /**
     * 获取集合参数
     */
    private Collection<?> getCollectionParameter(Object[] args, int paramIndex) {
        Object param = args[paramIndex];
        if (param instanceof Collection) {
            return (Collection<?>) param;
        }
        throw new IllegalArgumentException("被@BatchInsert注解的参数必须是Collection类型");
    }

    /**
     * 分批处理
     */
    private Object processInBatches(ProceedingJoinPoint joinPoint, Object[] args,
                                    int paramIndex, Collection<?> collection, int batchSize) throws Throwable {
        List<Object> result = new ArrayList<>();
        List<Object> batchList = new ArrayList<>(batchSize);
        log.error("BatchInsertAspect>进入分批处理");
        for (Object item : collection) {
            batchList.add(item);

            if (batchList.size() >= batchSize) {
                log.error("BatchInsertAspect>进入分批处理>for循环");

                Object batchResult = executeBatch(joinPoint, args, paramIndex, batchList);
                collectResult(result, batchResult);
                batchList.clear();
            }
        }

        // 处理最后一批
        if (!batchList.isEmpty()) {
            log.error("BatchInsertAspect>进入分批处理>处理最后一批");
            Object batchResult = executeBatch(joinPoint, args, paramIndex, batchList);
            collectResult(result, batchResult);
        }
        log.error("BatchInsertAspect>处理结束");
        return mergeResults(result);
    }

    /**
     * 执行单批次插入
     */
    private Object executeBatch(ProceedingJoinPoint joinPoint, Object[] originalArgs,
                                int paramIndex, List<Object> batch) throws Throwable {
        Object[] newArgs = originalArgs.clone();
        newArgs[paramIndex] = batch;
        return joinPoint.proceed(newArgs);
    }

    /**
     * 收集结果
     */
    private void collectResult(List<Object> result, Object batchResult) {
        if (batchResult != null) {
            result.add(batchResult);
        }
    }

    /**
     * 合并结果
     */
    private Object mergeResults(List<Object> results) {
        if (results.isEmpty()) {
            return null;
        }

        // 如果返回的是数字类型（如插入记录数），则累加
        if (results.get(0) instanceof Integer) {
            return results.stream().mapToInt(r -> (Integer) r).sum();
        }

        // 如果返回的是Long类型，则累加
        if (results.get(0) instanceof Long) {
            return results.stream().mapToLong(r -> (Long) r).sum();
        }

        // 如果返回的是Boolean类型，全部成功才返回true
        if (results.get(0) instanceof Boolean) {
            return results.stream().allMatch(r -> (Boolean) r);
        }

        // 默认返回最后一个结果
        return results.get(results.size() - 1);
    }
}