package com.ohaotian.plugin.uuid.security.filter;

import org.apache.commons.lang3.StringUtils;
import org.apache.http.HttpStatus;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.GenericFilterBean;

import javax.annotation.PostConstruct;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;

@Component
public class CSRFilter extends GenericFilterBean {

    private final static Logger log = LoggerFactory.getLogger(CSRFilter.class);


    @Value("${verify.referer}")
    private String verifyReferer;

    private String[] verifyReferers = null;

    @PostConstruct
    public void init() {
        if (StringUtils.isNoneBlank(verifyReferer)) {
            this.verifyReferers = verifyReferer.split(",");
        }
    }

    @Override
    public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
            throws IOException, ServletException {
        HttpServletRequest request = (HttpServletRequest) req;
        HttpServletResponse response = (HttpServletResponse) res;
        String referer = request.getHeader("Referer");
        boolean b = false;
        if (verifyReferers == null) {
            chain.doFilter(request, response);
            return;
        }

        for (String vReferer : verifyReferers) {
            if (referer == null || referer.trim().startsWith(vReferer)) {
                b = true;
                chain.doFilter(request, response);
                return;
            }
        }
        if (!b) {
            log.error("疑似CSRF攻击，referer:" + referer);
        }

        String method = ((HttpServletRequest) request).getMethod();

        if (!"GET".equalsIgnoreCase(method) && !"POST".equalsIgnoreCase(method) && !"HEAD".equalsIgnoreCase(method)) {
            log.error("The request with Method[" + method + "] was forbidden by server!");
            response.setContentType("text/html;charset=UTF-8");
            response.setCharacterEncoding("UTF-8");
            response.setStatus(HttpStatus.SC_FORBIDDEN);
            response.getWriter().print("<font size=6 color=red>对不起，您的请求非法，系统拒绝响应!</font>");
            return;
        }

    }

}