package com.ohaotian.plugin.base.filter;

import com.ohaotian.plugin.base.exception.ZTBusinessException;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.util.regex.Pattern;

/**
 * @Description 防止sql注入,xss攻击
 * liuzh
 * @Date 2020/5/20 9:48
 */
//@WebFilter("/*")
//@Component
@Slf4j
public class XssAndSqlFilter implements Filter {
    @Value("${security.xss.key:and|exec|insert|select|delete|update|count|%|chr|mid|master|truncate|char|declare|or|like|where|union|order|by|table|from|grant|use|group_concat|column_name|information_schema.columns|table_schema|}")
    private String securityXssKey;

    private boolean enable;

    private String regx;

    @Override
    public void init(FilterConfig config) throws ServletException {
        //去除末尾分隔符
        if (this.securityXssKey.lastIndexOf("|") == securityXssKey.length() - 1) {
            this.securityXssKey = securityXssKey.substring(0, securityXssKey.length() - 1);
        }
        this.regx = "\\b(" + securityXssKey + ")\\b";
    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        if (enable) {
            request.setCharacterEncoding("utf-8");
            response.setContentType("text/html;charset=utf-8");

            HttpServletRequest httpRequest = (HttpServletRequest) request;
            XssAndSqlHttpServletRequestWrapper xssAndSqlHttpServletRequestWrapper = new XssAndSqlHttpServletRequestWrapper(httpRequest);
            String body = new String(xssAndSqlHttpServletRequestWrapper.body, request.getCharacterEncoding());
            log.info("CrosXssFilter..........doFilter url:{},body:{}", xssAndSqlHttpServletRequestWrapper.getRequestURI(), body);
            if (null != body && Pattern.compile(regx).matcher(body).find()) {
                log.error("[" + httpRequest.getRequestURI() + "]，请求参数中包含不允许sql的关键词");
                throw new ZTBusinessException("[" + httpRequest.getRequestURI() + "]，请求参数中包含不允许sql的关键词");
            }
            chain.doFilter(xssAndSqlHttpServletRequestWrapper, response);
        } else {
            chain.doFilter(request, response);
        }
    }

    @Override
    public void destroy() {

    }

    public boolean isEnable() {
        return enable;
    }

    public void setEnable(boolean enable) {
        this.enable = enable;
    }
}