RestTemplate 拦截器(ClientHttpRequestInterceptor)详解与实践
RestTemplate 的拦截器机制允许我们在 HTTP 请求发送前和响应返回后进行自定义处理,非常适合实现日志记录、请求头添加、认证信息附加等横切关注点功能。
ClientHttpRequestInterceptor 核心原理
ClientHttpRequestInterceptor 是 RestTemplate 的拦截器接口,其核心方法 intercept 会在请求执行前后被调用:
1 2 3 4 5 6 7
| public interface ClientHttpRequestInterceptor { ClientHttpResponse intercept( HttpRequest request, byte[] body, ClientHttpRequestExecution execution ) throws IOException; }
|
request:即将发送的 HTTP 请求对象,可修改请求头、请求方法等
body:请求体内容
execution:执行器,调用其 execute 方法继续请求链的执行
拦截器的执行流程如下:
- 拦截器对请求进行预处理(如添加请求头)
- 调用
execution.execute(request, body) 执行实际请求
- 获取响应后可以进行后处理
- 返回响应(可包装或修改)
拦截器实践:分布式追踪实现
实现分布式追踪功能,通过添加 global-trace-id 和 parent-trace-id 头信息,实现跨服务调用的链路追踪。
完善的拦截器实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
| import org.springframework.http.HttpHeaders; import org.springframework.http.HttpRequest; import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpRequestExecution; import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.ClientHttpResponse; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest; import java.io.IOException; import java.io.UnsupportedEncodingException;
public class TraceIdInterceptor implements ClientHttpRequestInterceptor {
@Override public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) throws IOException { addTraceHeaders(request); ClientHttpResponse response = execution.execute(request, body); traceResponse(response); return response; }
private void addTraceHeaders(HttpRequest request) { ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); if (requestAttributes == null) { addDefaultTraceHeaders(request); return; } HttpServletRequest servletRequest = requestAttributes.getRequest(); HttpHeaders headers = request.getHeaders(); String globalTraceId = servletRequest.getHeader("global-trace-id"); if (globalTraceId == null) { globalTraceId = TraceIdGenerator.generateGlobalTraceId(); } headers.add("global-trace-id", globalTraceId); String parentTraceId = servletRequest.getHeader("local-trace-id"); if (parentTraceId == null) { parentTraceId = TraceIdGenerator.generateLocalTraceId(); } headers.add("parent-trace-id", parentTraceId); String currentLocalTraceId = TraceIdGenerator.generateLocalTraceId(); headers.add("local-trace-id", currentLocalTraceId); TraceContextHolder.setGlobalTraceId(globalTraceId); TraceContextHolder.setLocalTraceId(currentLocalTraceId); }
private void addDefaultTraceHeaders(HttpRequest request) { HttpHeaders headers = request.getHeaders(); String globalTraceId = TraceContextHolder.getGlobalTraceId(); if (globalTraceId == null) { globalTraceId = TraceIdGenerator.generateGlobalTraceId(); TraceContextHolder.setGlobalTraceId(globalTraceId); } headers.add("global-trace-id", globalTraceId); String localTraceId = TraceContextHolder.getLocalTraceId(); if (localTraceId == null) { localTraceId = TraceIdGenerator.generateLocalTraceId(); TraceContextHolder.setLocalTraceId(localTraceId); } headers.add("local-trace-id", localTraceId); headers.add("parent-trace-id", "N/A"); }
private void traceResponse(ClientHttpResponse response) throws IOException { int statusCode = response.getRawStatusCode(); String statusText = response.getStatusText(); } }
|
辅助类实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
|
public class TraceContextHolder { private static final ThreadLocal<String> GLOBAL_TRACE_ID = new ThreadLocal<>(); private static final ThreadLocal<String> LOCAL_TRACE_ID = new ThreadLocal<>(); public static String getGlobalTraceId() { return GLOBAL_TRACE_ID.get(); } public static void setGlobalTraceId(String globalTraceId) { GLOBAL_TRACE_ID.set(globalTraceId); } public static String getLocalTraceId() { return LOCAL_TRACE_ID.get(); } public static void setLocalTraceId(String localTraceId) { LOCAL_TRACE_ID.set(localTraceId); }
public static void clear() { GLOBAL_TRACE_ID.remove(); LOCAL_TRACE_ID.remove(); } }
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
| import java.util.UUID;
public class TraceIdGenerator {
public static String generateGlobalTraceId() { return UUID.randomUUID().toString().replaceAll("-", ""); }
public static String generateLocalTraceId() { return System.currentTimeMillis() + "-" + UUID.randomUUID().toString().substring(0, 8); } }
|
拦截器配置与注册
将拦截器添加到 RestTemplate 有多种方式,推荐使用配置类进行集中配置:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
| import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; import org.springframework.web.client.RestTemplate;
import java.util.ArrayList; import java.util.List;
@Configuration public class RestTemplateConfig {
@Bean public RestTemplate restTemplate() { RestTemplate restTemplate = new RestTemplate(new HttpComponentsClientHttpRequestFactory()); List<ClientHttpRequestInterceptor> interceptors = new ArrayList<>(restTemplate.getInterceptors()); interceptors.add(new TraceIdInterceptor()); restTemplate.setInterceptors(interceptors); return restTemplate; }
@Bean public TraceIdInterceptor traceIdInterceptor(RestTemplate restTemplate) { TraceIdInterceptor interceptor = new TraceIdInterceptor(); List<ClientHttpRequestInterceptor> interceptors = new ArrayList<>(restTemplate.getInterceptors()); interceptors.add(interceptor); restTemplate.setInterceptors(interceptors); return interceptor; } }
|
拦截器链的执行顺序
当配置多个拦截器时,它们的执行顺序与添加到列表中的顺序一致:
- 第一个拦截器的
intercept 方法被调用
- 执行
execution.execute() 时,会调用第二个拦截器
- 以此类推,直到最后一个拦截器
- 实际请求发送
- 响应按相反顺序返回给各个拦截器进行后处理
因此,拦截器的添加顺序非常重要,例如:
- 日志拦截器通常放在最前面,记录原始请求
- 认证拦截器应在请求发送前添加认证信息
- 追踪拦截器应在早期添加,确保所有后续操作都能获取到追踪 ID
常见使用场景
除了分布式追踪,拦截器还适用于以下场景:
统一认证:添加 Token、API Key 等认证信息
1
| request.getHeaders().add("Authorization", "Bearer " + getAccessToken());
|
请求 / 响应日志:记录请求参数、响应结果和耗时
1 2 3 4
| long start = System.currentTimeMillis(); ClientHttpResponse response = execution.execute(request, body); long end = System.currentTimeMillis(); log.info("请求耗时: {}ms", end - start);
|
超时控制:为特定请求设置超时时间
1 2 3
| if (request.getURI().getPath().contains("/slow-api")) { ((ClientHttpRequest) request).getHeaders().set("X-Timeout", "10000"); }
|
异常处理:对特定响应码进行统一处理
1 2 3 4 5
| ClientHttpResponse response = execution.execute(request, body); if (response.getRawStatusCode() == 401) { return retryWithNewToken(request, body, execution); }
|
通过合理使用 RestTemplate 拦截器,可以大幅提升代码复用性,将通用功能集中实现,使业务代码更加简洁