package com.ohaotian.plugin.base.filter;

import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.StreamUtils;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;

/**
 * @Description 防止sql注入,xss攻击
 * @Date 2020/5/20 9:48
 */
@Slf4j
public class XssAndSqlHttpServletRequestWrapper extends HttpServletRequestWrapper {
    public final byte[] body; //用于保存读取body中数据
    private String currentUrl;

    public XssAndSqlHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
        super(request);
        currentUrl = request.getRequestURI();
        body = StreamUtils.copyToByteArray(request.getInputStream());
    }

    /**
     * @Description 覆盖getParameter方法，将参数和参数值做xss过滤
     * @Date 2020/5/20 9:57
     */
    @Override
    public String getParameter(String name) {
        String value = super.getParameter(name);
        if(StringUtils.isEmpty(value)){
            return null;
        }
        return cleanXss(value);
    }

    @Override
    public Map<String, String[]> getParameterMap() {
        Map<String,String[]> values = super.getParameterMap();
        if(null == values){
            return null;
        }
        Map<String,String[]> result = new HashMap<>();
        for (String key : values.keySet()) {
            String encodedKey = cleanXss(key);
            int count = values.get(key).length;
            String[] encodedValues = new String[count];
            for(int i = 0;i < count;i++){
                encodedValues[i] = cleanXss(values.get(key)[i]);
            }
            result.put(encodedKey,encodedValues);
        }
        return result;
    }

    @Override
    public String getHeader(String name) {
        String value = super.getHeader(name);
        if(StringUtils.isEmpty(value)){
            return null;
        }
        return cleanXss(value);
    }

    @Override
    public String[] getParameterValues(String name) {
        String[] values = super.getParameterValues(name);
        if(values == null){
            return null;
        }
        int count = values.length;
        String[] encodedValues = new String[count];
        for (int i = 0;i < count;i++) {
            encodedValues[i] = cleanXss(values[i]);
        }
        return encodedValues;
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        final ByteArrayInputStream bais = new ByteArrayInputStream(body);
        return new ServletInputStream() {

            @Override
            public int read() throws IOException {
                return bais.read();
            }

            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener arg0) {

            }
        };
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    /**
     * @Description 解析参数
     * @Date 2020/5/20 10:01
     */
    private String cleanXss(String valueP){
        String value = valueP.replaceAll("<","&lt;").replaceAll(">","&gt;");
        value = value.replaceAll("<","& lt;").replaceAll(">","& gt;");
        value = value.replaceAll("\\(","& #40;").replaceAll("\\)","& #41;");
        value = value.replaceAll("'","& #39;");
        value = value.replaceAll("eval\\((.*)\\)","");
        value = value.replaceAll("[\\\"\\\'][\\s]*javascript:(.*)[\\\"\\\']","\"\"");
        value = value.replaceAll("script","");
        return value;
    }
}