0%

restTemplate拦截器ClientHttpRequestInterceptor

RestTemplate 拦截器(ClientHttpRequestInterceptor)详解与实践

RestTemplate 的拦截器机制允许我们在 HTTP 请求发送前和响应返回后进行自定义处理,非常适合实现日志记录、请求头添加、认证信息附加等横切关注点功能。

ClientHttpRequestInterceptor 核心原理

ClientHttpRequestInterceptor 是 RestTemplate 的拦截器接口,其核心方法 intercept 会在请求执行前后被调用:

1
2
3
4
5
6
7
public interface ClientHttpRequestInterceptor {
ClientHttpResponse intercept(
HttpRequest request,
byte[] body,
ClientHttpRequestExecution execution
) throws IOException;
}
  • request:即将发送的 HTTP 请求对象,可修改请求头、请求方法等
  • body:请求体内容
  • execution:执行器,调用其 execute 方法继续请求链的执行

拦截器的执行流程如下:

  1. 拦截器对请求进行预处理(如添加请求头)
  2. 调用 execution.execute(request, body) 执行实际请求
  3. 获取响应后可以进行后处理
  4. 返回响应(可包装或修改)

拦截器实践:分布式追踪实现

实现分布式追踪功能,通过添加 global-trace-idparent-trace-id 头信息,实现跨服务调用的链路追踪。

完善的拦截器实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpRequest;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestExecution;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.io.UnsupportedEncodingException;

/**
* 分布式追踪的RestTemplate拦截器,添加追踪ID到请求头
*/
public class TraceIdInterceptor implements ClientHttpRequestInterceptor {

@Override
public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) throws IOException {
// 预处理:添加追踪ID到请求头
addTraceHeaders(request);

// 执行请求
ClientHttpResponse response = execution.execute(request, body);

// 后处理:可以记录响应信息,如响应状态码等
traceResponse(response);

return response;
}

/**
* 向请求添加追踪相关的头信息
*/
private void addTraceHeaders(HttpRequest request) {
// 从当前线程获取上下文请求(仅在Web环境有效)
ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
if (requestAttributes == null) {
// 非Web环境,直接使用当前线程的追踪ID
addDefaultTraceHeaders(request);
return;
}

HttpServletRequest servletRequest = requestAttributes.getRequest();
HttpHeaders headers = request.getHeaders();

// 全局追踪ID:跨服务传递,保持不变
String globalTraceId = servletRequest.getHeader("global-trace-id");
if (globalTraceId == null) {
// 若不存在,则生成新的全局追踪ID
globalTraceId = TraceIdGenerator.generateGlobalTraceId();
}
headers.add("global-trace-id", globalTraceId);

// 父追踪ID:当前服务的本地ID,作为下游服务的父ID
String parentTraceId = servletRequest.getHeader("local-trace-id");
if (parentTraceId == null) {
parentTraceId = TraceIdGenerator.generateLocalTraceId();
}
headers.add("parent-trace-id", parentTraceId);

// 添加当前服务的本地追踪ID
String currentLocalTraceId = TraceIdGenerator.generateLocalTraceId();
headers.add("local-trace-id", currentLocalTraceId);

// 将当前追踪ID存入ThreadLocal,供业务逻辑使用
TraceContextHolder.setGlobalTraceId(globalTraceId);
TraceContextHolder.setLocalTraceId(currentLocalTraceId);
}

/**
* 非Web环境下添加默认追踪头
*/
private void addDefaultTraceHeaders(HttpRequest request) {
HttpHeaders headers = request.getHeaders();
String globalTraceId = TraceContextHolder.getGlobalTraceId();
if (globalTraceId == null) {
globalTraceId = TraceIdGenerator.generateGlobalTraceId();
TraceContextHolder.setGlobalTraceId(globalTraceId);
}
headers.add("global-trace-id", globalTraceId);

String localTraceId = TraceContextHolder.getLocalTraceId();
if (localTraceId == null) {
localTraceId = TraceIdGenerator.generateLocalTraceId();
TraceContextHolder.setLocalTraceId(localTraceId);
}
headers.add("local-trace-id", localTraceId);
headers.add("parent-trace-id", "N/A");
}

/**
* 记录响应信息
*/
private void traceResponse(ClientHttpResponse response) throws IOException {
// 可以记录响应状态码、响应时间等信息
int statusCode = response.getRawStatusCode();
String statusText = response.getStatusText();
// 实际应用中可以使用日志框架记录
// log.info("Response: {} {}", statusCode, statusText);
}
}

辅助类实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
/**
* 追踪上下文持有类,使用ThreadLocal存储当前线程的追踪ID
*/
public class TraceContextHolder {
private static final ThreadLocal<String> GLOBAL_TRACE_ID = new ThreadLocal<>();
private static final ThreadLocal<String> LOCAL_TRACE_ID = new ThreadLocal<>();

public static String getGlobalTraceId() {
return GLOBAL_TRACE_ID.get();
}

public static void setGlobalTraceId(String globalTraceId) {
GLOBAL_TRACE_ID.set(globalTraceId);
}

public static String getLocalTraceId() {
return LOCAL_TRACE_ID.get();
}

public static void setLocalTraceId(String localTraceId) {
LOCAL_TRACE_ID.set(localTraceId);
}

/**
* 清除当前线程的追踪信息,防止内存泄漏
*/
public static void clear() {
GLOBAL_TRACE_ID.remove();
LOCAL_TRACE_ID.remove();
}
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import java.util.UUID;

/**
* 追踪ID生成器
*/
public class TraceIdGenerator {

/**
* 生成全局追踪ID,跨服务保持一致
*/
public static String generateGlobalTraceId() {
// 使用UUID作为全局追踪ID
return UUID.randomUUID().toString().replaceAll("-", "");
}

/**
* 生成本地追踪ID,每个服务调用生成一个
*/
public static String generateLocalTraceId() {
// 可以使用更短的ID,如UUID的前8位加上时间戳
return System.currentTimeMillis() + "-" +
UUID.randomUUID().toString().substring(0, 8);
}
}

拦截器配置与注册

将拦截器添加到 RestTemplate 有多种方式,推荐使用配置类进行集中配置:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.web.client.RestTemplate;

import java.util.ArrayList;
import java.util.List;

@Configuration
public class RestTemplateConfig {

/**
* 配置带有追踪拦截器的RestTemplate
*/
@Bean
public RestTemplate restTemplate() {
// 使用HttpClient作为请求工厂,解决401等错误解析问题
RestTemplate restTemplate = new RestTemplate(new HttpComponentsClientHttpRequestFactory());

// 获取已有的拦截器并添加自定义拦截器
List<ClientHttpRequestInterceptor> interceptors = new ArrayList<>(restTemplate.getInterceptors());
interceptors.add(new TraceIdInterceptor());
// 可以添加更多拦截器
// interceptors.add(new LoggingInterceptor());
// interceptors.add(new AuthInterceptor());

restTemplate.setInterceptors(interceptors);

return restTemplate;
}

/**
* 或者,如果希望保留默认的RestTemplate bean,仅添加拦截器
*/
@Bean
public TraceIdInterceptor traceIdInterceptor(RestTemplate restTemplate) {
TraceIdInterceptor interceptor = new TraceIdInterceptor();
List<ClientHttpRequestInterceptor> interceptors = new ArrayList<>(restTemplate.getInterceptors());
interceptors.add(interceptor);
restTemplate.setInterceptors(interceptors);
return interceptor;
}
}

拦截器链的执行顺序

当配置多个拦截器时,它们的执行顺序与添加到列表中的顺序一致:

  1. 第一个拦截器的 intercept 方法被调用
  2. 执行 execution.execute() 时,会调用第二个拦截器
  3. 以此类推,直到最后一个拦截器
  4. 实际请求发送
  5. 响应按相反顺序返回给各个拦截器进行后处理

因此,拦截器的添加顺序非常重要,例如:

  • 日志拦截器通常放在最前面,记录原始请求
  • 认证拦截器应在请求发送前添加认证信息
  • 追踪拦截器应在早期添加,确保所有后续操作都能获取到追踪 ID

常见使用场景

除了分布式追踪,拦截器还适用于以下场景:

  1. 统一认证:添加 Token、API Key 等认证信息

    1
    request.getHeaders().add("Authorization", "Bearer " + getAccessToken());
  2. 请求 / 响应日志:记录请求参数、响应结果和耗时

    1
    2
    3
    4
    long start = System.currentTimeMillis();
    ClientHttpResponse response = execution.execute(request, body);
    long end = System.currentTimeMillis();
    log.info("请求耗时: {}ms", end - start);
  3. 超时控制:为特定请求设置超时时间

    1
    2
    3
    if (request.getURI().getPath().contains("/slow-api")) {
    ((ClientHttpRequest) request).getHeaders().set("X-Timeout", "10000");
    }
  4. 异常处理:对特定响应码进行统一处理

    1
    2
    3
    4
    5
    ClientHttpResponse response = execution.execute(request, body);
    if (response.getRawStatusCode() == 401) {
    // 自动刷新Token并重试
    return retryWithNewToken(request, body, execution);
    }

通过合理使用 RestTemplate 拦截器,可以大幅提升代码复用性,将通用功能集中实现,使业务代码更加简洁

欢迎关注我的其它发布渠道