1. 概述

拦截器,也称为过滤器,是Spring框架中的一项功能,它允许我们拦截客户端请求。这使得我们能够在控制器处理请求或向客户端返回响应之前检查和转换请求。

在本教程中,我们将讨论在使用WebFlux框架时,以各种方式拦截客户端请求并添加自定义头。首先,我们将针对特定端点探讨如何操作。然后,我们将确定拦截所有入站请求的方法。

2. Maven依赖项

我们将使用以下Maven依赖项来支持Spring框架的Reactive Web支持:spring-boot-starter-webflux

<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-webflux</artifactId>
    <version>3.1.5</version>
</dependency>

3. 服务器请求拦截与转换

Spring WebFlux过滤器可以分为WebFilterHandlerFilterFunction。我们将利用这些过滤器来拦截服务器网络请求,并添加新的自定义头或修改现有头。

3.1. 使用WebFilter

WebFilter是一个以链式、拦截式方式处理服务器网络请求的契约。**WebFilter全局作用,一旦启用,将拦截所有请求和响应。**

首先,我们需要定义基于注解的控制器:

@GetMapping(value= "/trace-annotated")
public Mono<String> trace(@RequestHeader(name = "traceId") final String traceId) {
    return Mono.just("TraceId: ".concat(traceId));
}

然后,我们使用TraceWebFilter实现拦截服务器网络请求,并添加一个新的头traceId

@Component
public class TraceWebFilter implements WebFilter {
    @Override
    public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
        exchange.getRequest().mutate()
          .header("traceId", "ANNOTATED-TRACE-ID");
        return chain.filter(exchange);
    }
}

现在,我们可以使用WebTestClient向带有注解的端点发送一个GET请求,验证响应中是否包含我们附加的traceId头值,如"TraceId: ANNOTATED-TRACE-ID":

@Test
void whenCallAnnotatedTraceEndpoint_thenResponseContainsTraceId() {
    EntityExchangeResult<String> result = webTestClient.get()
      .uri("/trace-annotated")
      .exchange()
      .expectStatus()
      .isOk()
      .expectBody(String.class)
      .returnResult();

    String body = "TraceId: ANNOTATED-TRACE-ID";
    assertEquals(result.getResponseBody(), body);
}

重要的是,由于请求头映射是只读的,我们不能直接修改请求头,就像修改响应头一样:

@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
    if (exchange.getRequest().getPath().toString().equals("/trace-exceptional")) {
        exchange.getRequest().getHeaders().add("traceId", "TRACE-ID");
    }
    return chain.filter(exchange);
 }

此实现会抛出UnsupportedOperationException

使用WebTestClient验证,在向trace-exceptional端点发送GET请求后,过滤器确实会抛出异常,导致服务器错误:

@GetMapping(value = "/trace-exceptional")
public Mono<String> traceExceptional() {
    return Mono.just("Traced");
}
@Test
void whenCallTraceExceptionalEndpoint_thenThrowsException() {
    EntityExchangeResult<Map> result = webTestClient.get()
      .uri("/trace-exceptional")
      .exchange()
      .expectStatus()
      .is5xxServerError()
      .expectBody(Map.class)
      .returnResult();

    assertNotNull(result.getResponseBody());
}

3.2. 使用HandlerFilterFunction

在函数式风格下,路由器函数会拦截请求并调用相应的处理函数。

我们可以启用零个或多个HandlerFilterFunction,它们作为过滤HandlerFunction的函数。**HandlerFilterFunction实现仅适用于基于路由器的处理。**

对于函数式端点,我们需要先创建一个处理器:

@Component
public class TraceRouterHandler {
    public Mono<ServerResponse> handle(final ServerRequest serverRequest) {
        String traceId = serverRequest.headers().firstHeader("traceId");
      
        assert traceId != null;
        Mono<String> body = Mono.just("TraceId: ".concat(traceId));
        return ok().body(body, String.class);
    }
}

配置好处理器后,我们使用TraceHandlerFilterFunction实现拦截服务器网络请求,并添加一个新的头traceId

public RouterFunction<ServerResponse> routes(TraceRouterHandler routerHandler) {
    return RouterFunctions
      .route(GET("/trace-functional-filter"), routerHandler::handle)
      .filter(new TraceHandlerFilterFunction());
}
public class TraceHandlerFilterFunction implements HandlerFilterFunction<ServerResponse, ServerResponse> {
    @Override
    public Mono<ServerResponse> filter(ServerRequest request, HandlerFunction<ServerResponse> handlerFunction) {
        ServerRequest serverRequest = ServerRequest.from(request)
          .header("traceId", "FUNCTIONAL-TRACE-ID")
          .build();
        return handlerFunction.handle(serverRequest);
    }
}

现在,当我们触发对trace-functional-filter端点的GET调用时,我们可以验证响应中包含我们附加的traceId头值,如"TraceId: FUNCTIONAL-TRACE-ID":

@Test
void whenCallTraceFunctionalEndpoint_thenResponseContainsTraceId() {
    EntityExchangeResult<String> result = webTestClient.get()
      .uri("/trace-functional-filter")
      .exchange()
      .expectStatus()
      .isOk()
      .expectBody(String.class)
      .returnResult();

    String body = "TraceId: FUNCTIONAL-TRACE-ID";
    assertEquals(result.getResponseBody(), body);
}

3.3. 使用自定义处理器Function

处理器函数类似于路由器函数,它会拦截请求并调用相应的处理函数。

函数式路由API使我们能够添加零个或多个自定义Function实例,这些实例会在调用HandlerFunction之前应用。

这个过滤函数会拦截由构建器创建的服务器网络请求,并添加一个新的头traceId

public RouterFunction<ServerResponse> routes(TraceRouterHandler routerHandler) {
    return route()
      .GET("/trace-functional-before", routerHandler::handle)
      .before(request -> ServerRequest.from(request)
        .header("traceId", "FUNCTIONAL-TRACE-ID")
        .build())
      .build());
}

trace-functional-before端点发送一个GET请求后,我们可以验证响应中包含我们附加的traceId头值,如"TraceId: FUNCTIONAL-TRACE-ID":

@Test
void whenCallTraceFunctionalBeforeEndpoint_thenResponseContainsTraceId() {
    EntityExchangeResult<String> result = webTestClient.get()
      .uri("/trace-functional-before")
      .exchange()
      .expectStatus()
      .isOk()
      .expectBody(String.class)
      .returnResult();

    String body = "TraceId: FUNCTIONAL-TRACE-ID";
    assertEquals(result.getResponseBody(), body);
}

4. 客户端请求拦截与转换

我们将使用ExchangeFilterFunction来拦截Spring WebClient的客户端请求。

4.1. 使用ExchangeFilterFunction

ExchangeFilterFunction是与Spring WebClient相关的术语。我们使用它来拦截客户端请求,并在发送或接收响应前对其进行变换。

首先,定义一个交换过滤器函数,用于拦截客户端请求并添加一个新的头traceId。我们将跟踪所有请求头以验证ExchangeFilterFunction

public ExchangeFilterFunction modifyRequestHeaders(MultiValueMap<String, String> changedMap) {
    return (request, next) -> {
        ClientRequest clientRequest = ClientRequest.from(request)
          .header("traceId", "TRACE-ID")
          .build();
        changedMap.addAll(clientRequest.headers());
        return next.exchange(clientRequest);
    };
}

既然定义了过滤器函数,我们就可以将其附加到WebClient实例上了。这只能在创建WebClient时进行:

public WebClient webclient() {
    return WebClient.builder()
      .filter(modifyRequestHeaders(new LinkedMultiValueMap<>()))
      .build();
}

现在,我们可以使用Wiremock来测试自定义的ExchangeFilterFunction

@RegisterExtension
static WireMockExtension extension = WireMockExtension.newInstance()
  .options(wireMockConfig().dynamicPort().dynamicHttpsPort())
  .build();
@Test
void whenCallEndpoint_thenRequestHeadersModified() {
    extension.stubFor(get("/test").willReturn(aResponse().withStatus(200)
      .withBody("SUCCESS")));

    MultiValueMap<String, String> map = new LinkedMultiValueMap<>();

    WebClient webClient = WebClient.builder()
      .filter(modifyRequestHeaders(map))
      .build();
    String receivedResponse = triggerGetRequest(webClient);

    String body = "SUCCESS";
    Assertions.assertEquals(receivedResponse, body);
    Assertions.assertEquals("TRACE-ID", map.getFirst("traceId"));
}

最后,通过Wiremock,我们已经验证了ExchangeFilterFunction,确认新的traceId头在MultivalueMap实例中可用。

5. 总结

在这篇文章中,我们探讨了在服务器网络请求和客户端请求上添加自定义头的不同方法。

首先,我们学习了如何使用WebFilterHandlerFilterFunction为服务器网络请求添加自定义头。接着,我们讨论了如何使用ExchangeFilterFunction为WebClient请求执行相同的操作。

如往常一样,完整的教程源代码可以在GitHub上找到。