diff --git a/das/src/main/java/com/das/common/interceptor/DecryptingOncePerRequestFilter.java b/das/src/main/java/com/das/common/interceptor/DecryptingOncePerRequestFilter.java index d714c9ca..ba0eb805 100644 --- a/das/src/main/java/com/das/common/interceptor/DecryptingOncePerRequestFilter.java +++ b/das/src/main/java/com/das/common/interceptor/DecryptingOncePerRequestFilter.java @@ -1,9 +1,10 @@ package com.das.common.interceptor; +import cn.hutool.core.io.IoUtil; +import cn.hutool.core.util.StrUtil; +import com.das.common.constant.HeaderConstant; import com.das.common.utils.AESUtil; -import com.das.common.utils.AdminRedisTemplate; -import io.micrometer.common.util.StringUtils; import jakarta.servlet.FilterChain; import jakarta.servlet.ReadListener; import jakarta.servlet.ServletException; @@ -11,17 +12,25 @@ import jakarta.servlet.ServletInputStream; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequestWrapper; import jakarta.servlet.http.HttpServletResponse; -import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.MediaType; import org.springframework.web.filter.OncePerRequestFilter; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.Enumeration; +import java.util.HashSet; +import java.util.Set; + +import static com.das.common.constant.HeaderConstant.IV_ATTR_NAME; +import static com.das.common.constant.HeaderConstant.TOKEN_ATTR_NAME; +import static org.springframework.http.HttpMethod.POST; public class DecryptingOncePerRequestFilter extends OncePerRequestFilter { + private String aesKey; public DecryptingOncePerRequestFilter(String aesKey) { this.aesKey = aesKey; @@ -31,23 +40,24 @@ public class DecryptingOncePerRequestFilter extends OncePerRequestFilter { @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - String iv = request.getHeader("v"); - String contentType = request.getHeader("Content-Type"); + String method = request.getMethod(); - if ("POST".equals(method) && StringUtils.isNotBlank(contentType) && contentType.contains("application/json")) { - // 读取加密的请求体数据 - String encryptedData = readRequestBody(request); - // token解密 - String token = request.getHeader("token"); - if (StringUtils.isNotBlank(token)) { + String contentType = request.getContentType(); + + //当前只对Post的Application/json请求进行拦截处理 + if (POST.matches(method) && StrUtil.isNotBlank(contentType) && contentType.contains(MediaType.APPLICATION_JSON_VALUE)) { + String token = request.getHeader(TOKEN_ATTR_NAME); + String iv = request.getHeader(IV_ATTR_NAME); + //如果获取到token,则进行解密 + if (StrUtil.isNotBlank(token)){ token = AESUtil.decrypt(aesKey, token, iv); } - if (StringUtils.isNotBlank(encryptedData)) { - encryptedData = AESUtil.decrypt(aesKey, encryptedData, iv); + //如果读取到requestBody,则进行解密 + String bodyData = readRequestBody(request); + if (StrUtil.isNotBlank(bodyData)) { + bodyData = AESUtil.decrypt(aesKey, bodyData, iv); // 使用自定义的请求包装器替换原始请求 - filterChain.doFilter(new DecryptingHttpServletRequestWrapper(request, encryptedData, token), response); - } else { - filterChain.doFilter(new TokenUpdatingHttpServletRequestWrapper(request, token), response); + filterChain.doFilter(new DecryptingHttpServletRequestWrapper(request, bodyData, token), response); } } else { filterChain.doFilter(request, response); @@ -56,32 +66,25 @@ public class DecryptingOncePerRequestFilter extends OncePerRequestFilter { } private String readRequestBody(HttpServletRequest request) throws IOException { - StringBuilder stringBuilder = new StringBuilder(); - try (BufferedReader reader = request.getReader()) { - String line; - while ((line = reader.readLine()) != null) { - stringBuilder.append(line); - } - } - return stringBuilder.toString(); + return IoUtil.read(request.getInputStream(), StandardCharsets.UTF_8); } // 自定义的请求包装器 static class DecryptingHttpServletRequestWrapper extends HttpServletRequestWrapper { - private final String decryptedData; - private final String newTokenValue; + private final String bodyData; + private final String token; - public DecryptingHttpServletRequestWrapper(HttpServletRequest request, String decryptedData, String newTokenValue) { + public DecryptingHttpServletRequestWrapper(HttpServletRequest request, String bodayData, String token) { super(request); - this.decryptedData = decryptedData; - this.newTokenValue = newTokenValue; + this.bodyData = bodayData; + this.token = token; } @Override public ServletInputStream getInputStream() throws IOException { - final ByteArrayInputStream bais = new ByteArrayInputStream(decryptedData.getBytes("UTF-8")); + final ByteArrayInputStream bais = new ByteArrayInputStream(bodyData.getBytes("UTF-8")); return new ServletInputStream() { @Override public boolean isFinished() { @@ -112,19 +115,41 @@ public class DecryptingOncePerRequestFilter extends OncePerRequestFilter { @Override public String getHeader(String name) { - if ("token".equalsIgnoreCase(name)) { - return newTokenValue; // 返回新的token值 + if (TOKEN_ATTR_NAME.equals(name)){ + return token; } - return super.getHeader(name); // 对于其他header,委托给父类处理 + return super.getHeader(name); } @Override public Enumeration getHeaders(String name) { - if ("token".equalsIgnoreCase(name)) { - return Collections.enumeration(Collections.singletonList(newTokenValue)); // 返回包含新token值的枚举 + Set set = new HashSet<>(8); + if (TOKEN_ATTR_NAME.equals(name) && StrUtil.isNotBlank(token)){ + set.add(token); } - return super.getHeaders(name); // 对于其他header,委托给父类处理 + Enumeration e = super.getHeaders(name); + while (e.hasMoreElements()) { + String n = e.nextElement(); + set.add(n); + } + return Collections.enumeration(set); } + + @Override + public Enumeration getHeaderNames() { + Set set = new HashSet<>(8); + Enumeration e = super.getHeaderNames(); + while (e.hasMoreElements()) { + String n = e.nextElement(); + set.add(n); + } + if (StrUtil.isNotBlank(token)){ + set.add(token); + } + return Collections.enumeration(set); + } + + } }