修正变量名

This commit is contained in:
谷成伟 2024-06-21 15:27:39 +08:00
parent d3de4162a7
commit 9e5af73463

View File

@ -1,9 +1,10 @@
package com.das.common.interceptor; package com.das.common.interceptor;
import cn.hutool.core.io.IoUtil;
import cn.hutool.core.util.StrUtil;
import com.das.common.constant.HeaderConstant;
import com.das.common.utils.AESUtil; import com.das.common.utils.AESUtil;
import com.das.common.utils.AdminRedisTemplate;
import io.micrometer.common.util.StringUtils;
import jakarta.servlet.FilterChain; import jakarta.servlet.FilterChain;
import jakarta.servlet.ReadListener; import jakarta.servlet.ReadListener;
import jakarta.servlet.ServletException; import jakarta.servlet.ServletException;
@ -11,17 +12,25 @@ import jakarta.servlet.ServletInputStream;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper; import jakarta.servlet.http.HttpServletRequestWrapper;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.MediaType;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStreamReader; import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.Collections; import java.util.Collections;
import java.util.Enumeration; import java.util.Enumeration;
import java.util.HashSet;
import java.util.Set;
import static com.das.common.constant.HeaderConstant.IV_ATTR_NAME;
import static com.das.common.constant.HeaderConstant.TOKEN_ATTR_NAME;
import static org.springframework.http.HttpMethod.POST;
public class DecryptingOncePerRequestFilter extends OncePerRequestFilter { public class DecryptingOncePerRequestFilter extends OncePerRequestFilter {
private String aesKey; private String aesKey;
public DecryptingOncePerRequestFilter(String aesKey) { public DecryptingOncePerRequestFilter(String aesKey) {
this.aesKey = aesKey; this.aesKey = aesKey;
@ -31,23 +40,24 @@ public class DecryptingOncePerRequestFilter extends OncePerRequestFilter {
@Override @Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException { throws ServletException, IOException {
String iv = request.getHeader("v");
String contentType = request.getHeader("Content-Type");
String method = request.getMethod(); String method = request.getMethod();
if ("POST".equals(method) && StringUtils.isNotBlank(contentType) && contentType.contains("application/json")) { String contentType = request.getContentType();
// 读取加密的请求体数据
String encryptedData = readRequestBody(request); //当前只对Post的Application/json请求进行拦截处理
// token解密 if (POST.matches(method) && StrUtil.isNotBlank(contentType) && contentType.contains(MediaType.APPLICATION_JSON_VALUE)) {
String token = request.getHeader("token"); String token = request.getHeader(TOKEN_ATTR_NAME);
if (StringUtils.isNotBlank(token)) { String iv = request.getHeader(IV_ATTR_NAME);
//如果获取到token则进行解密
if (StrUtil.isNotBlank(token)){
token = AESUtil.decrypt(aesKey, token, iv); token = AESUtil.decrypt(aesKey, token, iv);
} }
if (StringUtils.isNotBlank(encryptedData)) { //如果读取到requestBody则进行解密
encryptedData = AESUtil.decrypt(aesKey, encryptedData, iv); String bodyData = readRequestBody(request);
if (StrUtil.isNotBlank(bodyData)) {
bodyData = AESUtil.decrypt(aesKey, bodyData, iv);
// 使用自定义的请求包装器替换原始请求 // 使用自定义的请求包装器替换原始请求
filterChain.doFilter(new DecryptingHttpServletRequestWrapper(request, encryptedData, token), response); filterChain.doFilter(new DecryptingHttpServletRequestWrapper(request, bodyData, token), response);
} else {
filterChain.doFilter(new TokenUpdatingHttpServletRequestWrapper(request, token), response);
} }
} else { } else {
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
@ -56,32 +66,25 @@ public class DecryptingOncePerRequestFilter extends OncePerRequestFilter {
} }
private String readRequestBody(HttpServletRequest request) throws IOException { private String readRequestBody(HttpServletRequest request) throws IOException {
StringBuilder stringBuilder = new StringBuilder(); return IoUtil.read(request.getInputStream(), StandardCharsets.UTF_8);
try (BufferedReader reader = request.getReader()) {
String line;
while ((line = reader.readLine()) != null) {
stringBuilder.append(line);
}
}
return stringBuilder.toString();
} }
// 自定义的请求包装器 // 自定义的请求包装器
static class DecryptingHttpServletRequestWrapper extends HttpServletRequestWrapper { static class DecryptingHttpServletRequestWrapper extends HttpServletRequestWrapper {
private final String decryptedData; private final String bodyData;
private final String newTokenValue; private final String token;
public DecryptingHttpServletRequestWrapper(HttpServletRequest request, String decryptedData, String newTokenValue) { public DecryptingHttpServletRequestWrapper(HttpServletRequest request, String bodayData, String token) {
super(request); super(request);
this.decryptedData = decryptedData; this.bodyData = bodayData;
this.newTokenValue = newTokenValue; this.token = token;
} }
@Override @Override
public ServletInputStream getInputStream() throws IOException { public ServletInputStream getInputStream() throws IOException {
final ByteArrayInputStream bais = new ByteArrayInputStream(decryptedData.getBytes("UTF-8")); final ByteArrayInputStream bais = new ByteArrayInputStream(bodyData.getBytes("UTF-8"));
return new ServletInputStream() { return new ServletInputStream() {
@Override @Override
public boolean isFinished() { public boolean isFinished() {
@ -112,19 +115,41 @@ public class DecryptingOncePerRequestFilter extends OncePerRequestFilter {
@Override @Override
public String getHeader(String name) { public String getHeader(String name) {
if ("token".equalsIgnoreCase(name)) { if (TOKEN_ATTR_NAME.equals(name)){
return newTokenValue; // 返回新的token值 return token;
} }
return super.getHeader(name); // 对于其他header委托给父类处理 return super.getHeader(name);
} }
@Override @Override
public Enumeration<String> getHeaders(String name) { public Enumeration<String> getHeaders(String name) {
if ("token".equalsIgnoreCase(name)) { Set<String> set = new HashSet<>(8);
return Collections.enumeration(Collections.singletonList(newTokenValue)); // 返回包含新token值的枚举 if (TOKEN_ATTR_NAME.equals(name) && StrUtil.isNotBlank(token)){
set.add(token);
} }
return super.getHeaders(name); // 对于其他header委托给父类处理 Enumeration<String> e = super.getHeaders(name);
while (e.hasMoreElements()) {
String n = e.nextElement();
set.add(n);
} }
return Collections.enumeration(set);
}
@Override
public Enumeration<String> getHeaderNames() {
Set<String> set = new HashSet<>(8);
Enumeration<String> e = super.getHeaderNames();
while (e.hasMoreElements()) {
String n = e.nextElement();
set.add(n);
}
if (StrUtil.isNotBlank(token)){
set.add(token);
}
return Collections.enumeration(set);
}
} }
} }