diff --git a/src/Ocelot/Authentication/AuthenticationMiddleware.cs b/src/Ocelot/Authentication/AuthenticationMiddleware.cs index d58e18b76..6de8206f7 100644 --- a/src/Ocelot/Authentication/AuthenticationMiddleware.cs +++ b/src/Ocelot/Authentication/AuthenticationMiddleware.cs @@ -7,7 +7,7 @@ namespace Ocelot.Authentication; -public sealed class AuthenticationMiddleware : OcelotMiddleware +public class AuthenticationMiddleware : OcelotMiddleware { private readonly RequestDelegate _next; @@ -35,6 +35,7 @@ public async Task Invoke(HttpContext context) var result = await AuthenticateAsync(context, route); if (result.Principal?.Identity == null) { + await ChallengeAsync(context, route, result); SetUnauthenticatedError(context, path, null); return; } @@ -47,6 +48,7 @@ public async Task Invoke(HttpContext context) return; } + await ChallengeAsync(context, route, result); SetUnauthenticatedError(context, path, context.User.Identity.Name); } @@ -58,7 +60,7 @@ private void SetUnauthenticatedError(HttpContext httpContext, string path, strin httpContext.Items.SetError(error); } - private async Task AuthenticateAsync(HttpContext context, DownstreamRoute route) + protected virtual async Task AuthenticateAsync(HttpContext context, DownstreamRoute route) { var notEmptySchemes = route.AuthenticationOptions.AuthenticationProviderKeys .Where(s => !string.IsNullOrWhiteSpace(s)); @@ -82,4 +84,16 @@ private async Task AuthenticateAsync(HttpContext context, Do return result ?? AuthenticateResult.NoResult(); } + + protected virtual async Task ChallengeAsync(HttpContext context, DownstreamRoute route, AuthenticateResult status) + { + // Perform a challenge. This populates the WWW-Authenticate header on the response + await context.ChallengeAsync(route.AuthenticationOptions.AuthenticationProviderKeys[0]); // TODO Read failed scheme from auth result + + // Since the response gets re-created down the pipeline, we store the challenge in the Items, so we can re-apply it when sending the response + if (context.Response.Headers.TryGetValue("WWW-Authenticate", out var authenticateHeader)) + { + context.Items.SetAuthChallenge(authenticateHeader); + } + } } diff --git a/src/Ocelot/Middleware/HttpItemsExtensions.cs b/src/Ocelot/Middleware/HttpItemsExtensions.cs index b9fbf4bbe..437192ff2 100644 --- a/src/Ocelot/Middleware/HttpItemsExtensions.cs +++ b/src/Ocelot/Middleware/HttpItemsExtensions.cs @@ -43,6 +43,11 @@ public static void SetError(this IDictionary input, Error error) input.Upsert("Errors", errors); } + public static void SetAuthChallenge(this IDictionary input, string challengeString) => + input.Upsert("AuthChallenge", challengeString); + + public static string AuthChallenge(this IDictionary input) => + input.Get("AuthChallenge"); public static void SetIInternalConfiguration(this IDictionary input, IInternalConfiguration config) { input.Upsert("IInternalConfiguration", config); diff --git a/src/Ocelot/Multiplexer/MultiplexingMiddleware.cs b/src/Ocelot/Multiplexer/MultiplexingMiddleware.cs index c148eb9b9..7af79b5c6 100644 --- a/src/Ocelot/Multiplexer/MultiplexingMiddleware.cs +++ b/src/Ocelot/Multiplexer/MultiplexingMiddleware.cs @@ -1,14 +1,14 @@ -using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Primitives; using Newtonsoft.Json.Linq; -using Ocelot.Configuration; +using Ocelot.Configuration; using Ocelot.Configuration.File; -using Ocelot.DownstreamRouteFinder.UrlMatcher; -using Ocelot.Logging; -using Ocelot.Middleware; +using Ocelot.DownstreamRouteFinder.UrlMatcher; +using Ocelot.Logging; +using Ocelot.Middleware; using System.Collections; using Route = Ocelot.Configuration.Route; - + namespace Ocelot.Multiplexer; public class MultiplexingMiddleware : OcelotMiddleware @@ -16,7 +16,7 @@ public class MultiplexingMiddleware : OcelotMiddleware private readonly RequestDelegate _next; private readonly IResponseAggregatorFactory _factory; private const string RequestIdString = "RequestId"; - + public MultiplexingMiddleware(RequestDelegate next, IOcelotLoggerFactory loggerFactory, IResponseAggregatorFactory factory) @@ -25,7 +25,7 @@ public MultiplexingMiddleware(RequestDelegate next, _factory = factory; _next = next; } - + public async Task Invoke(HttpContext httpContext) { var downstreamRouteHolder = httpContext.Items.DownstreamRouteHolder(); @@ -38,13 +38,13 @@ public async Task Invoke(HttpContext httpContext) await ProcessSingleRouteAsync(httpContext, downstreamRoutes[0]); return; } - + // Case 2: if no downstream routes if (downstreamRoutes.Count == 0) { return; } - + // Case 3: if multiple downstream routes var routeKeysConfigs = route.DownstreamRouteConfig; if (routeKeysConfigs == null || routeKeysConfigs.Count == 0) @@ -52,23 +52,23 @@ public async Task Invoke(HttpContext httpContext) await ProcessRoutesAsync(httpContext, route); return; } - + // Case 4: if multiple downstream routes with route keys var mainResponseContext = await ProcessMainRouteAsync(httpContext, downstreamRoutes[0]); if (mainResponseContext == null) { return; } - + var responsesContexts = await ProcessRoutesWithRouteKeysAsync(httpContext, downstreamRoutes, routeKeysConfigs, mainResponseContext); if (responsesContexts.Length == 0) { return; } - + await MapResponsesAsync(httpContext, route, mainResponseContext, responsesContexts); } - + /// /// Helper method to determine if only the first downstream route should be processed. /// It is the case if the request is a websocket request or if there is only one downstream route. @@ -78,7 +78,7 @@ public async Task Invoke(HttpContext httpContext) /// True if only the first downstream route should be processed. private static bool ShouldProcessSingleRoute(HttpContext context, ICollection routes) => context.WebSockets.IsWebSocketRequest || routes.Count == 1; - + /// /// Processing a single downstream route (no route keys). /// In that case, no need to make copies of the http context. @@ -89,9 +89,10 @@ private static bool ShouldProcessSingleRoute(HttpContext context, ICollection ro protected virtual Task ProcessSingleRouteAsync(HttpContext context, DownstreamRoute route) { context.Items.UpsertDownstreamRoute(route); + context.Items.SetAuthChallenge(/*finished*/context.Items.AuthChallenge()); return _next.Invoke(context); } - + /// /// Processing the downstream routes (no route keys). /// @@ -105,7 +106,7 @@ private async Task ProcessRoutesAsync(HttpContext context, Route route) var contexts = await Task.WhenAll(tasks); await MapAsync(context, route, new(contexts)); } - + /// /// When using route keys, the first route is the main route and the rest are additional routes. /// Since we need to break if the main route response is null, we must process the main route first. @@ -119,7 +120,7 @@ private async Task ProcessMainRouteAsync(HttpContext context, Downs await _next.Invoke(context); return context; } - + /// /// Processing the downstream routes with route keys except the main route that has already been processed. /// @@ -133,7 +134,7 @@ protected virtual async Task ProcessRoutesWithRouteKeysAsync(Http var processing = new List>(); var content = await mainResponse.Items.DownstreamResponse().Content.ReadAsStringAsync(); var jObject = JToken.Parse(content); - + foreach (var downstreamRoute in routes.Skip(1)) { var matchAdvancedAgg = routeKeysConfigs.FirstOrDefault(q => q.RouteKey == downstreamRoute.Key); @@ -142,13 +143,13 @@ protected virtual async Task ProcessRoutesWithRouteKeysAsync(Http processing.AddRange(ProcessRouteWithComplexAggregation(matchAdvancedAgg, jObject, context, downstreamRoute)); continue; } - + processing.Add(ProcessRouteAsync(context, downstreamRoute)); } - + return await Task.WhenAll(processing); } - + /// /// Mapping responses. /// @@ -158,7 +159,7 @@ private Task MapResponsesAsync(HttpContext context, Route route, HttpContext mai contexts.AddRange(responsesContexts); return MapAsync(context, route, contexts); } - + /// /// Processing a route with aggregation. /// @@ -173,7 +174,7 @@ private IEnumerable> ProcessRouteWithComplexAggregation(Aggreg tPnv.Add(new PlaceholderNameAndValue('{' + matchAdvancedAgg.Parameter + '}', value)); processing.Add(ProcessRouteAsync(httpContext, downstreamRoute, tPnv)); } - + return processing; } @@ -186,11 +187,11 @@ private async Task ProcessRouteAsync(HttpContext sourceContext, Dow var newHttpContext = await CreateThreadContextAsync(sourceContext, route); CopyItemsToNewContext(newHttpContext, sourceContext, placeholders); newHttpContext.Items.UpsertDownstreamRoute(route); - + await _next.Invoke(newHttpContext); return newHttpContext; } - + /// /// Copying some needed parameters to the Http context items. /// @@ -247,7 +248,7 @@ protected virtual async Task CreateThreadContextAsync(HttpContext s target.Response.RegisterForDisposeAsync(bodyStream); // manage Stream lifetime by HttpResponse object return target; } - + protected virtual Task MapAsync(HttpContext httpContext, Route route, List contexts) { if (route.DownstreamRoute.Count == 1) @@ -282,4 +283,4 @@ protected virtual async Task CloneRequestBodyAsync(HttpRequest request, return targetBuffer; } -} +} diff --git a/src/Ocelot/Responder/HttpContextResponder.cs b/src/Ocelot/Responder/HttpContextResponder.cs index d636c6b05..8aefbde40 100644 --- a/src/Ocelot/Responder/HttpContextResponder.cs +++ b/src/Ocelot/Responder/HttpContextResponder.cs @@ -81,6 +81,9 @@ protected virtual async Task WriteToUpstreamAsync(HttpContext context, Downstrea await content.CopyToAsync(context.Response.Body, context.RequestAborted); } + public void SetAuthChallengeOnContext(HttpContext context, string challenge) + => AddHeaderIfDoesntExist(context, new Header("WWW-Authenticate", new[] { challenge })); + private static void SetStatusCode(HttpContext context, int statusCode) { if (!context.Response.HasStarted) diff --git a/src/Ocelot/Responder/IHttpResponder.cs b/src/Ocelot/Responder/IHttpResponder.cs index 6dac2368d..1ada31cfb 100644 --- a/src/Ocelot/Responder/IHttpResponder.cs +++ b/src/Ocelot/Responder/IHttpResponder.cs @@ -10,4 +10,6 @@ public interface IHttpResponder void SetErrorResponseOnContext(HttpContext context, int statusCode); Task SetErrorResponseOnContext(HttpContext context, DownstreamResponse response); + + void SetAuthChallengeOnContext(HttpContext context, string challenge); } diff --git a/src/Ocelot/Responder/Middleware/ResponderMiddleware.cs b/src/Ocelot/Responder/Middleware/ResponderMiddleware.cs index d0e2b380c..b46ebb689 100644 --- a/src/Ocelot/Responder/Middleware/ResponderMiddleware.cs +++ b/src/Ocelot/Responder/Middleware/ResponderMiddleware.cs @@ -57,12 +57,19 @@ private async Task SetErrorResponse(HttpContext context, List errors) var statusCode = _codeMapper.Map(errors); _responder.SetErrorResponseOnContext(context, statusCode); - if (errors.All(e => e.Code != OcelotErrorCode.QuotaExceededError)) + if (errors.Any(e => e.Code == OcelotErrorCode.QuotaExceededError)) { - return; + var downstreamResponse = context.Items.DownstreamResponse(); + await _responder.SetErrorResponseOnContext(context, downstreamResponse); } - var downstreamResponse = context.Items.DownstreamResponse(); - await _responder.SetErrorResponseOnContext(context, downstreamResponse); + if (errors.Any(e => e.Code == OcelotErrorCode.UnauthenticatedError)) + { + var challenge = context.Items.AuthChallenge(); + if (!string.IsNullOrEmpty(challenge)) + { + _responder.SetAuthChallengeOnContext(context, challenge); + } + } } } diff --git a/test/Ocelot.AcceptanceTests/Authentication/AuthenticationTests.cs b/test/Ocelot.AcceptanceTests/Authentication/AuthenticationTests.cs index 0c5395ebc..0f5776154 100644 --- a/test/Ocelot.AcceptanceTests/Authentication/AuthenticationTests.cs +++ b/test/Ocelot.AcceptanceTests/Authentication/AuthenticationTests.cs @@ -237,4 +237,26 @@ public async Task ShouldApplyGlobalGroupAuthenticationOptions_ForStaticRoutes_Wh ThenTheStatusCodeShouldBe(HttpStatusCode.Forbidden); await ThenTheResponseBodyShouldBeEmpty(); } + + [Fact] + [Trait("Feat", "1387")] // https://github.com/ThreeMammals/Ocelot/pull/1387 + public void Should_return_www_authenticate_header_on_401() + { + var port = PortFinder.GetRandomPort(); + var route = GivenAuthRoute(port); + var configuration = GivenConfiguration(route); + this.Given(x => GivenThereIsAConfiguration(configuration)) + .And(x => GivenOcelotIsRunning(WithJwtBearerAuthentication)) + .And(x => GivenIHaveNoTokenForMyRequest()) + .When(x => WhenIGetUrlOnTheApiGateway("/")) + .Then(x => ThenTheStatusCodeShouldBe(HttpStatusCode.Unauthorized)) + .And(x => ThenTheResponseShouldContainAuthChallenge()) + .BDDfy(); + } + private void GivenIHaveNoTokenForMyRequest() => ocelotClient.DefaultRequestHeaders.Authorization = null; + private void ThenTheResponseShouldContainAuthChallenge() + { + response.Headers.TryGetValues("WWW-Authenticate", out var headerValue).ShouldBeTrue(); + headerValue.ShouldNotBeEmpty(); + } } diff --git a/test/Ocelot.AcceptanceTests/Steps.cs b/test/Ocelot.AcceptanceTests/Steps.cs index a8af40c4a..2e8cbacfc 100644 --- a/test/Ocelot.AcceptanceTests/Steps.cs +++ b/test/Ocelot.AcceptanceTests/Steps.cs @@ -16,7 +16,7 @@ public class Steps : AcceptanceSteps public Steps() : base() { BddfyConfig.Configure(); - } + } public static bool IsCiCd() => IsRunningInGitHubActions(); public static bool IsRunningInGitHubActions() => Environment.GetEnvironmentVariable("GITHUB_ACTIONS") == "true"; @@ -36,57 +36,57 @@ public void GivenOcelotIsRunning(OcelotPipelineConfiguration pipelineConfig) ocelotServer = new TestServer(builder); ocelotClient = ocelotServer.CreateClient(); } - + protected virtual void GivenThereIsAServiceRunningOn(int port, [CallerMemberName] string responseBody = "") => GivenThereIsAServiceRunningOn(port, HttpStatusCode.OK, responseBody); - + protected virtual HttpStatusCode MapStatus_StatusCode { get; set; } = HttpStatusCode.OK; protected virtual Func MapStatus_ResponseBody { get; set; } protected virtual Task MapStatus(HttpContext context) - { + { context.Response.StatusCode = (int)MapStatus_StatusCode; return context.Response.WriteAsync(MapStatus_ResponseBody?.Invoke() ?? string.Empty); - } + } protected virtual void GivenThereIsAServiceRunningOn(int port, HttpStatusCode statusCode, [CallerMemberName] string responseBody = "") - { + { MapStatus_StatusCode = statusCode; MapStatus_ResponseBody = () => responseBody; handler.GivenThereIsAServiceRunningOn(port, MapStatus); } - + protected Func pMapOK_ResponseBody; protected virtual Task MapOK(HttpContext context) - { + { context.Response.StatusCode = StatusCodes.Status200OK; return context.Response.WriteAsync(pMapOK_ResponseBody?.Invoke() ?? string.Empty); - } + } public virtual void GivenThereIsAServiceRunningOnPath(int port, string basePath, [CallerMemberName] string responseBody = "") - { + { pMapOK_ResponseBody = () => responseBody; handler.GivenThereIsAServiceRunningOn(port, basePath, MapOK); - } + } public virtual void GivenThereIsAServiceRunningOn(int port, string basePath, RequestDelegate requestDelegate) - { + { handler.GivenThereIsAServiceRunningOn(port, basePath, requestDelegate); - } - + } + protected override FileHostAndPort Localhost(int port) => base.Localhost(port) as FileHostAndPort; protected override FileConfiguration GivenConfiguration(params object[] routes) => base.GivenConfiguration(routes) as FileConfiguration; protected override FileRoute GivenDefaultRoute(int port) => base.GivenDefaultRoute(port) as FileRoute; protected override FileRoute GivenCatchAllRoute(int port) => base.GivenCatchAllRoute(port) as FileRoute; protected override FileRoute GivenRoute(int port, string upstream = null, string downstream = null) => base.GivenRoute(port, upstream, downstream) as FileRoute; - + protected static FileRouteBox Box(FileRoute route) => new(route); - + #region TODO: Move to Ocelot.Testing package public virtual string Body([CallerMemberName] string responseBody = null) => responseBody ?? GetType().Name; public virtual string TestName([CallerMemberName] string testName = null) => testName ?? GetType().Name; public static Task GivenIWaitAsync(int wait) => Task.Delay(wait); public Task ThenTheResponseShouldBeAsync(HttpStatusCode expected, [CallerMemberName] string expectedBody = null) - { + { ThenTheStatusCodeShouldBe(expected); return ThenTheResponseBodyShouldBeAsync(expectedBody ?? Body(expectedBody)); - } + } public Task ThenTheResponseBodyShouldBeEmpty() => ThenTheResponseBodyShouldBeAsync(string.Empty); public Task GivenOcelotIsRunningAsync(Action configureServices) => Task.Run(() => GivenOcelotIsRunning(configureServices)); // TODO Need async version in the lib