diff --git a/src/Ocelot/Requester/SseDelegatingHandler.cs b/src/Ocelot/Requester/SseDelegatingHandler.cs new file mode 100644 index 000000000..23f2ce06e --- /dev/null +++ b/src/Ocelot/Requester/SseDelegatingHandler.cs @@ -0,0 +1,64 @@ +using Microsoft.AspNetCore.Http; + +namespace Ocelot.Requester +{ + public class SseDelegatingHandler : DelegatingHandler + { + private readonly IHttpContextAccessor _httpContextAccessor; + + public SseDelegatingHandler(IHttpContextAccessor httpContextAccessor) + { + _httpContextAccessor = httpContextAccessor; + } + + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + var httpContext = _httpContextAccessor.HttpContext; + + var isSse = request.Headers.Accept.Any(h => h.MediaType == "text/event-stream"); + if (!isSse) + { + return await base.SendAsync(request, cancellationToken); + } + + // Correct overload: only 2 parameters + var response = await base.SendAsync(request, cancellationToken); + + httpContext.Response.StatusCode = (int)response.StatusCode; + httpContext.Response.ContentType = "text/event-stream"; + + // Forward response headers + foreach (var header in response.Headers) + { + httpContext.Response.Headers[header.Key] = header.Value.ToArray(); + } + foreach (var header in response.Content.Headers) + { + httpContext.Response.Headers[header.Key] = header.Value.ToArray(); + } + + httpContext.Response.Headers.Remove("transfer-encoding"); + + // Stream content + await using var downstreamStream = await response.Content.ReadAsStreamAsync(cancellationToken); + using var reader = new StreamReader(downstreamStream, Encoding.UTF8); + + while (!reader.EndOfStream && !cancellationToken.IsCancellationRequested) + { + var line = await reader.ReadLineAsync(); + if (line != null) + { + var buffer = Encoding.UTF8.GetBytes(line + "\n"); + await httpContext.Response.Body.WriteAsync(buffer, 0, buffer.Length, cancellationToken); + await httpContext.Response.Body.FlushAsync(cancellationToken); + } + } + + // Dummy response to complete pipeline + return new HttpResponseMessage(response.StatusCode) + { + ReasonPhrase = "SSE stream has been forwarded" + }; + } + } +} diff --git a/test/Ocelot.UnitTests/Requester/SseDelegatingHandlerTests.cs b/test/Ocelot.UnitTests/Requester/SseDelegatingHandlerTests.cs new file mode 100644 index 000000000..8ad95dd4c --- /dev/null +++ b/test/Ocelot.UnitTests/Requester/SseDelegatingHandlerTests.cs @@ -0,0 +1,44 @@ +using Microsoft.AspNetCore.Http; +using Moq.Protected; +using Ocelot.Requester; + +namespace Ocelot.UnitTests.Requester +{ + public class SseDelegatingHandlerTests + { + [Fact] + public async Task SendAsync_ForNonSseRequest_CallsBaseHandler() + { + // Arrange + var mockHttpContext = new DefaultHttpContext(); + var mockAccessor = new Mock(); + mockAccessor.Setup(a => a.HttpContext).Returns(mockHttpContext); + + var mockInnerHandler = new Mock(); + mockInnerHandler + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage(HttpStatusCode.OK)); + + var handler = new SseDelegatingHandler(mockAccessor.Object) + { + InnerHandler = mockInnerHandler.Object + }; + + var client = new HttpClient(handler); + var request = new HttpRequestMessage(HttpMethod.Get, "http://example.com"); + request.Headers.Accept.Add(new System.Net.Http.Headers.MediaTypeWithQualityHeaderValue("application/json")); + + // Act + var response = await client.SendAsync(request); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + mockInnerHandler.Protected().Verify("SendAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); + } + } + +}