diff --git a/src/ReverseProxy/LoadBalancing/ILoadBalancingDestinationSelector.cs b/src/ReverseProxy/LoadBalancing/ILoadBalancingDestinationSelector.cs new file mode 100644 index 0000000000..ca70c7cdb0 --- /dev/null +++ b/src/ReverseProxy/LoadBalancing/ILoadBalancingDestinationSelector.cs @@ -0,0 +1,17 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using Microsoft.AspNetCore.Http; +using Yarp.ReverseProxy.Model; + +namespace Yarp.ReverseProxy.LoadBalancing; + +internal interface ILoadBalancingDestinationSelector +{ + DestinationState? PickDestination( + HttpContext? context, + ClusterState cluster, + IReadOnlyList availableDestinations, + string? loadBalancingPolicy = null); +} diff --git a/src/ReverseProxy/LoadBalancing/LoadBalancingDestinationSelector.cs b/src/ReverseProxy/LoadBalancing/LoadBalancingDestinationSelector.cs new file mode 100644 index 0000000000..0a1f43c611 --- /dev/null +++ b/src/ReverseProxy/LoadBalancing/LoadBalancingDestinationSelector.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using Microsoft.AspNetCore.Http; +using Yarp.ReverseProxy.Model; +using Yarp.ReverseProxy.Utilities; + +namespace Yarp.ReverseProxy.LoadBalancing; + +internal sealed class LoadBalancingDestinationSelector : ILoadBalancingDestinationSelector +{ + private readonly FrozenDictionary _loadBalancingPolicies; + + public LoadBalancingDestinationSelector(IEnumerable loadBalancingPolicies) + { + ArgumentNullException.ThrowIfNull(loadBalancingPolicies); + _loadBalancingPolicies = loadBalancingPolicies.ToDictionaryByUniqueId(p => p.Name); + } + + public DestinationState? PickDestination( + HttpContext? context, + ClusterState cluster, + IReadOnlyList availableDestinations, + string? loadBalancingPolicy = null) + { + ArgumentNullException.ThrowIfNull(cluster); + ArgumentNullException.ThrowIfNull(availableDestinations); + + var destinationCount = availableDestinations.Count; + + if (destinationCount == 0) + { + return null; + } + + if (destinationCount == 1) + { + return availableDestinations[0]; + } + + var currentPolicy = _loadBalancingPolicies.GetRequiredServiceById( + loadBalancingPolicy ?? cluster.Model.Config.LoadBalancingPolicy, + LoadBalancingPolicies.PowerOfTwoChoices); + return currentPolicy.PickDestination(context ?? new DefaultHttpContext(), cluster, availableDestinations); + } +} diff --git a/src/ReverseProxy/LoadBalancing/LoadBalancingMiddleware.cs b/src/ReverseProxy/LoadBalancing/LoadBalancingMiddleware.cs index 450af11d19..fd8855288f 100644 --- a/src/ReverseProxy/LoadBalancing/LoadBalancingMiddleware.cs +++ b/src/ReverseProxy/LoadBalancing/LoadBalancingMiddleware.cs @@ -2,13 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections.Frozen; -using System.Collections.Generic; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; using Yarp.ReverseProxy.Model; -using Yarp.ReverseProxy.Utilities; namespace Yarp.ReverseProxy.LoadBalancing; @@ -18,44 +15,31 @@ namespace Yarp.ReverseProxy.LoadBalancing; internal sealed class LoadBalancingMiddleware { private readonly ILogger _logger; - private readonly FrozenDictionary _loadBalancingPolicies; + private readonly ILoadBalancingDestinationSelector _destinationSelector; private readonly RequestDelegate _next; public LoadBalancingMiddleware( RequestDelegate next, ILogger logger, - IEnumerable loadBalancingPolicies) + ILoadBalancingDestinationSelector destinationSelector) { ArgumentNullException.ThrowIfNull(next); ArgumentNullException.ThrowIfNull(logger); - ArgumentNullException.ThrowIfNull(loadBalancingPolicies); + ArgumentNullException.ThrowIfNull(destinationSelector); _next = next; _logger = logger; - _loadBalancingPolicies = loadBalancingPolicies.ToDictionaryByUniqueId(p => p.Name); + _destinationSelector = destinationSelector; } public Task Invoke(HttpContext context) { var proxyFeature = context.GetReverseProxyFeature(); - var destinations = proxyFeature.AvailableDestinations; - var destinationCount = destinations.Count; - - DestinationState? destination; - - if (destinationCount == 0) - { - destination = null; - } - else if (destinationCount == 1) - { - destination = destinations[0]; - } - else - { - var currentPolicy = _loadBalancingPolicies.GetRequiredServiceById(proxyFeature.Cluster.Config.LoadBalancingPolicy, LoadBalancingPolicies.PowerOfTwoChoices); - destination = currentPolicy.PickDestination(context, proxyFeature.Route.Cluster!, destinations); - } + var destination = _destinationSelector.PickDestination( + context, + proxyFeature.Route.Cluster!, + proxyFeature.AvailableDestinations, + proxyFeature.Cluster.Config.LoadBalancingPolicy); if (destination is null) { diff --git a/src/ReverseProxy/Management/ClusterDestinationResolver.cs b/src/ReverseProxy/Management/ClusterDestinationResolver.cs new file mode 100644 index 0000000000..76ef3e50e0 --- /dev/null +++ b/src/ReverseProxy/Management/ClusterDestinationResolver.cs @@ -0,0 +1,46 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Yarp.ReverseProxy.LoadBalancing; +using Yarp.ReverseProxy.Model; + +namespace Yarp.ReverseProxy.Management; + +internal sealed class ClusterDestinationResolver : IClusterDestinationResolver +{ + private readonly IProxyStateLookup _proxyStateLookup; + private readonly ILoadBalancingDestinationSelector _destinationSelector; + + public ClusterDestinationResolver( + IProxyStateLookup proxyStateLookup, + ILoadBalancingDestinationSelector destinationSelector) + { + ArgumentNullException.ThrowIfNull(proxyStateLookup); + ArgumentNullException.ThrowIfNull(destinationSelector); + + _proxyStateLookup = proxyStateLookup; + _destinationSelector = destinationSelector; + } + + public ValueTask GetDestinationAsync( + string clusterId, + HttpContext? context = null, + CancellationToken cancellationToken = default) + { + ArgumentException.ThrowIfNullOrEmpty(clusterId); + cancellationToken.ThrowIfCancellationRequested(); + + if (!_proxyStateLookup.TryGetCluster(clusterId, out var cluster)) + { + throw new KeyNotFoundException($"No cluster was found for the id '{clusterId}'."); + } + + return ValueTask.FromResult( + _destinationSelector.PickDestination(context, cluster, cluster.DestinationsState.AvailableDestinations)); + } +} diff --git a/src/ReverseProxy/Management/IClusterDestinationResolver.cs b/src/ReverseProxy/Management/IClusterDestinationResolver.cs new file mode 100644 index 0000000000..7bd346445d --- /dev/null +++ b/src/ReverseProxy/Management/IClusterDestinationResolver.cs @@ -0,0 +1,27 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Yarp.ReverseProxy.Model; + +namespace Yarp.ReverseProxy; + +/// +/// Resolves a destination for a cluster using the current runtime state and configured load balancing policy. +/// +public interface IClusterDestinationResolver +{ + /// + /// Resolves a destination for the given cluster. + /// + /// The cluster id. + /// Optional request context used by load balancing policies. + /// Cancellation token. + /// The selected destination, or if no destinations are currently available. + ValueTask GetDestinationAsync( + string clusterId, + HttpContext? context = null, + CancellationToken cancellationToken = default); +} diff --git a/src/ReverseProxy/Management/IClusterDestinationResolverExtensions.cs b/src/ReverseProxy/Management/IClusterDestinationResolverExtensions.cs new file mode 100644 index 0000000000..5bb648a072 --- /dev/null +++ b/src/ReverseProxy/Management/IClusterDestinationResolverExtensions.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; + +namespace Yarp.ReverseProxy; + +/// +/// Extension methods for . +/// +public static class IClusterDestinationResolverExtensions +{ + /// + /// Resolves the destination URI for the given cluster. + /// + /// The destination resolver. + /// The cluster id. + /// Optional request context used by load balancing policies. + /// Cancellation token. + /// The selected destination URI, or if no destinations are currently available. + public static async ValueTask GetDestinationUriAsync( + this IClusterDestinationResolver resolver, + string clusterId, + HttpContext? context = null, + CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(resolver); + + var destination = await resolver.GetDestinationAsync(clusterId, context, cancellationToken); + if (destination is null) + { + return null; + } + + return new Uri(destination.Model.Config.Address, UriKind.Absolute); + } +} diff --git a/src/ReverseProxy/Management/IReverseProxyBuilderExtensions.cs b/src/ReverseProxy/Management/IReverseProxyBuilderExtensions.cs index 05be7bd781..371a7588f6 100644 --- a/src/ReverseProxy/Management/IReverseProxyBuilderExtensions.cs +++ b/src/ReverseProxy/Management/IReverseProxyBuilderExtensions.cs @@ -72,6 +72,7 @@ public static IReverseProxyBuilder AddConfigManager(this IReverseProxyBuilder bu { builder.Services.TryAddSingleton(); builder.Services.TryAddSingleton(sp => sp.GetRequiredService()); + builder.Services.TryAddSingleton(); return builder; } @@ -86,6 +87,7 @@ public static IReverseProxyBuilder AddProxy(this IReverseProxyBuilder builder) public static IReverseProxyBuilder AddLoadBalancingPolicies(this IReverseProxyBuilder builder) { builder.Services.TryAddSingleton(); + builder.Services.TryAddSingleton(); builder.Services.TryAddEnumerable(new[] { ServiceDescriptor.Singleton(), diff --git a/test/ReverseProxy.Tests/LoadBalancing/LoadBalancerMiddlewareTests.cs b/test/ReverseProxy.Tests/LoadBalancing/LoadBalancerMiddlewareTests.cs index 79f90f1f75..a869ac96ad 100644 --- a/test/ReverseProxy.Tests/LoadBalancing/LoadBalancerMiddlewareTests.cs +++ b/test/ReverseProxy.Tests/LoadBalancing/LoadBalancerMiddlewareTests.cs @@ -24,10 +24,12 @@ private static LoadBalancingMiddleware CreateMiddleware(RequestDelegate next, pa .Setup(l => l.IsEnabled(It.IsAny())) .Returns(true); + var destinationSelector = new LoadBalancingDestinationSelector(loadBalancingPolicies); + return new LoadBalancingMiddleware( next, logger.Object, - loadBalancingPolicies); + destinationSelector); } [Fact] diff --git a/test/ReverseProxy.Tests/Management/ClusterDestinationResolverTests.cs b/test/ReverseProxy.Tests/Management/ClusterDestinationResolverTests.cs new file mode 100644 index 0000000000..8fd518955a --- /dev/null +++ b/test/ReverseProxy.Tests/Management/ClusterDestinationResolverTests.cs @@ -0,0 +1,189 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#nullable enable + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Xunit; +using Yarp.ReverseProxy.Configuration; +using Yarp.ReverseProxy.LoadBalancing; +using Yarp.ReverseProxy.Management; +using Yarp.ReverseProxy.Model; + +namespace Yarp.ReverseProxy.Tests; + +public class ClusterDestinationResolverTests +{ + [Fact] + public async Task GetDestinationAsync_MissingCluster_Throws() + { + var resolver = CreateResolver(Enumerable.Empty(), new FirstLoadBalancingPolicy()); + + var ex = await Assert.ThrowsAsync( + async () => await resolver.GetDestinationAsync("missing")); + + Assert.Equal("No cluster was found for the id 'missing'.", ex.Message); + } + + [Fact] + public async Task GetDestinationAsync_WithoutAvailableDestinations_ReturnsNull() + { + var cluster = CreateCluster("cluster1", loadBalancingPolicy: null, Array.Empty()); + var resolver = CreateResolver(new[] { cluster }, new FirstLoadBalancingPolicy()); + + var destination = await resolver.GetDestinationAsync("cluster1"); + + Assert.Null(destination); + } + + [Fact] + public async Task GetDestinationAsync_SingleAvailableDestination_ReturnsDestination() + { + var destination1 = CreateDestination("destination1", "https://localhost:10001/"); + var cluster = CreateCluster("cluster1", loadBalancingPolicy: null, new[] { destination1 }); + var resolver = CreateResolver(new[] { cluster }, new FirstLoadBalancingPolicy()); + + var destination = await resolver.GetDestinationAsync("cluster1"); + + Assert.Same(destination1, destination); + } + + [Fact] + public async Task GetDestinationAsync_MultipleAvailableDestinations_UsesConfiguredPolicy() + { + var destination1 = CreateDestination("destination2", "https://localhost:10002/"); + var destination2 = CreateDestination("destination1", "https://localhost:10001/"); + var cluster = CreateCluster("cluster1", LoadBalancingPolicies.FirstAlphabetical, new[] { destination1, destination2 }); + var resolver = CreateResolver(new[] { cluster }, new FirstLoadBalancingPolicy()); + + var destination = await resolver.GetDestinationAsync("cluster1"); + + Assert.Same(destination2, destination); + } + + [Fact] + public async Task GetDestinationAsync_PassesHttpContextToPolicy() + { + var destination1 = CreateDestination("destination1", "https://localhost:10001/"); + var destination2 = CreateDestination("destination2", "https://localhost:10002/"); + var context = new DefaultHttpContext(); + var policy = new ContextAwarePolicy(); + var cluster = CreateCluster("cluster1", policy.Name, new[] { destination1, destination2 }); + var resolver = CreateResolver(new[] { cluster }, policy); + + var destination = await resolver.GetDestinationAsync("cluster1", context); + + Assert.Same(destination2, destination); + Assert.Same(context, policy.LastContext); + } + + [Fact] + public async Task GetDestinationUriAsync_ReturnsDestinationUri() + { + var destination1 = CreateDestination("destination1", "https://localhost:10001/base/"); + var cluster = CreateCluster("cluster1", loadBalancingPolicy: null, new[] { destination1 }); + var resolver = CreateResolver(new[] { cluster }, new FirstLoadBalancingPolicy()); + + var uri = await resolver.GetDestinationUriAsync("cluster1"); + + Assert.Equal(new Uri("https://localhost:10001/base/"), uri); + } + + [Fact] + public async Task GetDestinationAsync_CancellationRequested_Throws() + { + var destination1 = CreateDestination("destination1", "https://localhost:10001/"); + var cluster = CreateCluster("cluster1", loadBalancingPolicy: null, new[] { destination1 }); + var resolver = CreateResolver(new[] { cluster }, new FirstLoadBalancingPolicy()); + using var cts = new CancellationTokenSource(); + cts.Cancel(); + + await Assert.ThrowsAnyAsync( + async () => await resolver.GetDestinationAsync("cluster1", cancellationToken: cts.Token)); + } + + private static IClusterDestinationResolver CreateResolver(IEnumerable clusters, params ILoadBalancingPolicy[] policies) + { + return new ClusterDestinationResolver( + new TestProxyStateLookup(clusters), + new LoadBalancingDestinationSelector(policies)); + } + + private static ClusterState CreateCluster(string clusterId, string? loadBalancingPolicy, IReadOnlyList destinations) + { + var cluster = new ClusterState(clusterId) + { + Model = new ClusterModel( + new ClusterConfig + { + ClusterId = clusterId, + LoadBalancingPolicy = loadBalancingPolicy + }, + new HttpMessageInvoker(new HttpClientHandler())), + DestinationsState = new ClusterDestinationsState(destinations, destinations) + }; + + foreach (var destination in destinations) + { + cluster.Destinations.TryAdd(destination.DestinationId, destination); + } + + return cluster; + } + + private static DestinationState CreateDestination(string destinationId, string address) + { + return new DestinationState(destinationId) + { + Model = new DestinationModel(new DestinationConfig + { + Address = address + }) + }; + } + + private sealed class TestProxyStateLookup : IProxyStateLookup + { + private readonly Dictionary _clusters; + + public TestProxyStateLookup(IEnumerable clusters) + { + _clusters = clusters.ToDictionary(cluster => cluster.ClusterId, StringComparer.OrdinalIgnoreCase); + } + + public IEnumerable GetRoutes() => Array.Empty(); + + public IEnumerable GetClusters() => _clusters.Values; + + public bool TryGetRoute(string id, [NotNullWhen(true)] out RouteModel? route) + { + route = null; + return false; + } + + public bool TryGetCluster(string id, [NotNullWhen(true)] out ClusterState? cluster) + { + return _clusters.TryGetValue(id, out cluster); + } + } + + private sealed class ContextAwarePolicy : ILoadBalancingPolicy + { + public string Name => "ContextAware"; + + public HttpContext? LastContext { get; private set; } + + public DestinationState? PickDestination(HttpContext context, ClusterState cluster, IReadOnlyList availableDestinations) + { + LastContext = context; + return availableDestinations[^1]; + } + } +}