Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ public static IHostApplicationBuilder AddTeamsMcp(this IHostApplicationBuilder b
return builder.AddTeamsPlugin<McpPlugin>();
}

public static IHostApplicationBuilder AddTeamsMcp(this IHostApplicationBuilder builder, Action<McpPluginOptions> configure)
{
var pluginOptions = new McpPluginOptions();
configure(pluginOptions);
builder.Services.AddTeamsPlugin(new McpPlugin(pluginOptions));
return builder;
}

public static IMcpServerBuilder AddTeamsMcp(this IHostApplicationBuilder builder, McpServerOptions options)
{
builder.AddTeamsPlugin<McpPlugin>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Diagnostics.CodeAnalysis;

using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.Teams.Apps;
using Microsoft.Teams.Apps.Events;
using Microsoft.Teams.Apps.Plugins;
Expand All @@ -19,16 +20,60 @@ namespace Microsoft.Teams.Plugins.External.Mcp;
)]
public class McpPlugin : IAspNetCorePlugin
{
internal const string McpPath = "/mcp";

[AllowNull]
[Dependency]
public ILogger Logger { get; set; }

public event EventFunction Events;

private readonly McpPluginOptions _options;

public McpPlugin() : this(new McpPluginOptions()) { }

public McpPlugin(McpPluginOptions options)
{
_options = options;
}

public IApplicationBuilder Configure(IApplicationBuilder builder)
{
builder.UseRouting();
return builder.UseEndpoints(endpoints => endpoints.MapMcp("mcp"));

if (_options.RequireAuth is not null)
{
Func<HttpContext, Task<bool>> requireAuth = _options.RequireAuth;
builder.Use(async (ctx, next) =>
{
if (!ctx.Request.Path.StartsWithSegments(McpPath))
{
await next();
return;
}

bool ok = false;
try
{
ok = await requireAuth(ctx);
}
catch (Exception ex)
{
Logger.Debug($"RequireAuth threw: {ex}");
}

if (!ok)
{
ctx.Response.StatusCode = 401;
await ctx.Response.WriteAsync("unauthorized");
return;
}

await next();
});
}

return builder.UseEndpoints(endpoints => endpoints.MapMcp(McpPath.TrimStart('/')));
}

public Task OnInit(App app, CancellationToken cancellationToken = default)
Expand All @@ -38,6 +83,13 @@ public Task OnInit(App app, CancellationToken cancellationToken = default)

public Task OnStart(App app, CancellationToken cancellationToken = default)
{
if (_options.RequireAuth is null)
{
Logger.Warn(
$"McpPlugin started without RequireAuth. All MCP requests at {McpPath} will be accepted. " +
"Pass RequireAuth via AddTeamsMcp(options => options.RequireAuth = ...) to enforce authentication."
);
}
Logger.Debug("OnStart");
return Task.CompletedTask;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using Microsoft.AspNetCore.Http;

namespace Microsoft.Teams.Plugins.External.Mcp;

public class McpPluginOptions
{
/// <summary>
/// Optional callback that gates inbound MCP requests. Return <c>true</c> to
/// allow the request; return <c>false</c> or throw to reject with HTTP 401.
/// When unset, all MCP requests are accepted and a warning is emitted at
/// plugin startup.
/// </summary>
public Func<HttpContext, Task<bool>>? RequireAuth { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ internal async Task FetchToolsIfNeeded()

internal async Task<List<McpToolDetails>> FetchToolsFromServer(Uri url, McpClientPluginParams pluginParams)
{
await UrlValidation.ValidateMcpServerUrlAsync(url, pluginParams.AllowPrivateNetwork, pluginParams.ValidateUrl);
IClientTransport transport = CreateTransport(url, pluginParams.Transport, pluginParams.HeadersFactory());
var client = await McpClientFactory.CreateAsync(transport);
var tools = await client.ListToolsAsync();
Expand Down Expand Up @@ -241,6 +242,7 @@ internal AI.Function CreateFunctionFromTool(Uri url, McpToolDetails tool, McpCli

internal async Task<string> CallMcpTool(Uri url, McpToolDetails tool, IReadOnlyDictionary<string, object?> args, McpClientPluginParams pluginParams)
{
await UrlValidation.ValidateMcpServerUrlAsync(url, pluginParams.AllowPrivateNetwork, pluginParams.ValidateUrl);
IClientTransport transport = CreateTransport(url, pluginParams.Transport, pluginParams.HeadersFactory());
var client = await McpClientFactory.CreateAsync(transport);
var response = await client.CallToolAsync(tool.Name, args);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,20 @@ public class McpClientPluginParams
/// Override default cache timeout of 1 day
/// </summary>
public int? RefetchTimeoutMs { get; set; }

/// <summary>
/// When true, skip the default private-network filter and allow MCP server
/// URLs that resolve to loopback, RFC1918, or link-local addresses. Use for
/// local development or intentional on-prem MCP servers.
/// </summary>
public bool AllowPrivateNetwork { get; set; } = false;

/// <summary>
/// Fully replace the default URL validation. When set, the callback decides
/// whether the URL is allowed; the default scheme and private-network checks
/// are skipped.
/// </summary>
public Func<Uri, Task<bool>>? ValidateUrl { get; set; }
}

public enum McpClientTransport
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.Net;
using System.Net.Sockets;

namespace Microsoft.Teams.Plugins.External.McpClient;

public class UrlValidationException : Exception
{
public UrlValidationException(string message) : base(message) { }
public UrlValidationException(string message, Exception inner) : base(message, inner) { }
}

public static class UrlValidation
{
/// <summary>
/// Test seam: override to mock DNS lookups. Defaults to <see cref="Dns.GetHostAddressesAsync(string, CancellationToken)"/>.
/// </summary>
internal static Func<string, CancellationToken, Task<IPAddress[]>> HostResolver { get; set; } =
(host, ct) => Dns.GetHostAddressesAsync(host, ct);

/// <summary>
/// Validates a URL destined for an MCP server connection. When
/// <paramref name="validateUrl"/> is provided, it fully replaces the default
/// checks. Otherwise the default policy rejects non-http(s) schemes, and
/// (unless <paramref name="allowPrivateNetwork"/> is <c>true</c>) rejects
/// URLs whose hostname resolves to a private / loopback / link-local address.
/// </summary>
/// <exception cref="UrlValidationException">Thrown on rejection.</exception>
public static async Task<Uri> ValidateMcpServerUrlAsync(
Uri url,
bool allowPrivateNetwork = false,
Func<Uri, Task<bool>>? validateUrl = null,
CancellationToken cancellationToken = default)
{
if (validateUrl is not null)
{
bool allowed = await validateUrl(url);
if (!allowed)
{
throw new UrlValidationException($"URL rejected by ValidateUrl: {url}");
}
return url;
}

if (url.Scheme != Uri.UriSchemeHttp && url.Scheme != Uri.UriSchemeHttps)
{
throw new UrlValidationException(
$"URL scheme {url.Scheme} is not allowed; must be http or https"
);
}

if (allowPrivateNetwork)
{
return url;
}

IPAddress[] addresses;
if (IPAddress.TryParse(url.Host, out var literal))
{
addresses = new[] { literal };
}
else
{
try
{
addresses = await HostResolver(url.Host, cancellationToken);
}
catch (SocketException ex)
{
throw new UrlValidationException(
$"Could not resolve host {url.Host}: {ex.Message}", ex
);
}
}

if (addresses.Length == 0)
{
throw new UrlValidationException($"URL {url} did not resolve to any address");
}

foreach (var address in addresses)
{
if (IsPrivateAddress(address))
{
throw new UrlValidationException(
$"URL {url} resolves to private or loopback address {address}; " +
"set AllowPrivateNetwork to true to bypass"
);
}
}

return url;
}

/// <summary>
/// True if the address is loopback, RFC1918 private, link-local, or an
/// IPv6 unique-local / link-local address.
/// </summary>
public static bool IsPrivateAddress(IPAddress address)
{
if (IPAddress.IsLoopback(address)) return true;
if (IsUnspecified(address)) return true;

if (address.AddressFamily == AddressFamily.InterNetworkV6)
{
if (address.IsIPv6LinkLocal) return true;
if (address.IsIPv6SiteLocal) return true;
if (IsIPv6UniqueLocal(address)) return true;
if (address.IsIPv4MappedToIPv6)
{
return IsPrivateIpv4(address.MapToIPv4());
}
return false;
}

if (address.AddressFamily == AddressFamily.InterNetwork)
{
return IsPrivateIpv4(address);
}

// Unknown address family: fail closed.
return true;
}

private static bool IsPrivateIpv4(IPAddress address)
{
var bytes = address.GetAddressBytes();
if (bytes.Length != 4) return false;

// 10.0.0.0/8
if (bytes[0] == 10) return true;
// 172.16.0.0/12
if (bytes[0] == 172 && bytes[1] >= 16 && bytes[1] <= 31) return true;
// 192.168.0.0/16
if (bytes[0] == 192 && bytes[1] == 168) return true;
// 169.254.0.0/16 link-local
if (bytes[0] == 169 && bytes[1] == 254) return true;
return false;
}

private static bool IsIPv6UniqueLocal(IPAddress address)
{
// fc00::/7 -> first byte is 0xfc or 0xfd
var bytes = address.GetAddressBytes();
return bytes.Length == 16 && (bytes[0] == 0xfc || bytes[0] == 0xfd);
}

private static bool IsUnspecified(IPAddress address)
{
// 0.0.0.0 (IPv4) or :: (IPv6) — no realistic MCP server binds here.
foreach (var b in address.GetAddressBytes())
{
if (b != 0) return false;
}
return true;
}
}
Loading
Loading