diff --git a/InterfaceStubGenerator.Shared/Emitter.cs b/InterfaceStubGenerator.Shared/Emitter.cs index f8364a2fc..e6daadbfe 100644 --- a/InterfaceStubGenerator.Shared/Emitter.cs +++ b/InterfaceStubGenerator.Shared/Emitter.cs @@ -188,6 +188,7 @@ UniqueNameBuilder uniqueNames ReturnTypeInfo.AsyncVoid => (true, "await (", ").ConfigureAwait(false)"), ReturnTypeInfo.AsyncResult => (true, "return await (", ").ConfigureAwait(false)"), ReturnTypeInfo.Return => (false, "return ", ""), + ReturnTypeInfo.SyncVoid => (false, "", ""), _ => throw new ArgumentOutOfRangeException( nameof(methodModel.ReturnTypeMetadata), methodModel.ReturnTypeMetadata, @@ -228,12 +229,16 @@ UniqueNameBuilder uniqueNames lookupName = lookupName.Substring(lastDotIndex + 1); } + var callExpression = methodModel.ReturnTypeMetadata == ReturnTypeInfo.SyncVoid + ? $"______func(this.Client, ______arguments);" + : $"{@return}({returnType})______func(this.Client, ______arguments){configureAwait};"; + source.WriteLine( $""" var ______arguments = {argumentsArrayString}; var ______func = requestBuilder.BuildRestResultFuncForMethod("{lookupName}", {parameterTypesExpression}{genericString} ); - {@return}({returnType})______func(this.Client, ______arguments){configureAwait}; + {callExpression} """ ); diff --git a/InterfaceStubGenerator.Shared/Models/MethodModel.cs b/InterfaceStubGenerator.Shared/Models/MethodModel.cs index 6f6170c10..3a8163f00 100644 --- a/InterfaceStubGenerator.Shared/Models/MethodModel.cs +++ b/InterfaceStubGenerator.Shared/Models/MethodModel.cs @@ -17,5 +17,6 @@ internal enum ReturnTypeInfo : byte { Return, AsyncVoid, - AsyncResult + AsyncResult, + SyncVoid } diff --git a/InterfaceStubGenerator.Shared/Parser.cs b/InterfaceStubGenerator.Shared/Parser.cs index ebd7e6dee..3b7cda4c2 100644 --- a/InterfaceStubGenerator.Shared/Parser.cs +++ b/InterfaceStubGenerator.Shared/Parser.cs @@ -462,6 +462,7 @@ bool isDerived { "Task" => ReturnTypeInfo.AsyncVoid, "Task`1" or "ValueTask`1" => ReturnTypeInfo.AsyncResult, + "Void" => ReturnTypeInfo.SyncVoid, _ => ReturnTypeInfo.Return, }; @@ -623,6 +624,7 @@ private static MethodModel ParseMethod(IMethodSymbol methodSymbol, bool isImplic { "Task" => ReturnTypeInfo.AsyncVoid, "Task`1" or "ValueTask`1" => ReturnTypeInfo.AsyncResult, + "Void" => ReturnTypeInfo.SyncVoid, _ => ReturnTypeInfo.Return, }; diff --git a/Refit.GeneratorTests/Refit.GeneratorTests.csproj b/Refit.GeneratorTests/Refit.GeneratorTests.csproj index 9464d6af6..bfadad0d0 100644 --- a/Refit.GeneratorTests/Refit.GeneratorTests.csproj +++ b/Refit.GeneratorTests/Refit.GeneratorTests.csproj @@ -11,6 +11,7 @@ true $(NoWarn);CS1591;CA1819;CA2000;CA2007;CA1056;CA1707;CA1861;xUnit1031 + @@ -18,14 +19,12 @@ + - - - @@ -35,7 +34,7 @@ - + diff --git a/Refit.GeneratorTests/ReturnTypeTests.cs b/Refit.GeneratorTests/ReturnTypeTests.cs index 640c1c7e2..0a07ca61d 100644 --- a/Refit.GeneratorTests/ReturnTypeTests.cs +++ b/Refit.GeneratorTests/ReturnTypeTests.cs @@ -42,6 +42,26 @@ public Task VoidTaskShouldWork() """); } + [Fact] + public Task GenericValueTaskShouldWork() + { + return Fixture.VerifyForBody( + """ + [Get("/users")] + ValueTask Get(); + """); + } + + [Fact] + public Task ValueTaskApiResponseShouldWork() + { + return Fixture.VerifyForBody( + """ + [Get("/users")] + ValueTask> Get(); + """); + } + [Fact] public Task GenericConstraintReturnTask() { diff --git a/Refit.GeneratorTests/_snapshots/ReturnTypeTests.GenericValueTaskShouldWork#IGeneratedClient.g.verified.cs b/Refit.GeneratorTests/_snapshots/ReturnTypeTests.GenericValueTaskShouldWork#IGeneratedClient.g.verified.cs new file mode 100644 index 000000000..a54c0e44c --- /dev/null +++ b/Refit.GeneratorTests/_snapshots/ReturnTypeTests.GenericValueTaskShouldWork#IGeneratedClient.g.verified.cs @@ -0,0 +1,43 @@ +//HintName: IGeneratedClient.g.cs +#nullable disable +#pragma warning disable +namespace Refit.Implementation +{ + + partial class Generated + { + + /// + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.Diagnostics.DebuggerNonUserCode] + [global::RefitInternalGenerated.PreserveAttribute] + [global::System.Reflection.Obfuscation(Exclude=true)] + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + partial class RefitGeneratorTestIGeneratedClient + : global::RefitGeneratorTest.IGeneratedClient + { + /// + public global::System.Net.Http.HttpClient Client { get; } + readonly global::Refit.IRequestBuilder requestBuilder; + + /// + public RefitGeneratorTestIGeneratedClient(global::System.Net.Http.HttpClient client, global::Refit.IRequestBuilder requestBuilder) + { + Client = client; + this.requestBuilder = requestBuilder; + } + + + /// + public async global::System.Threading.Tasks.ValueTask Get() + { + var ______arguments = global::System.Array.Empty(); + var ______func = requestBuilder.BuildRestResultFuncForMethod("Get", global::System.Array.Empty() ); + + return await ((global::System.Threading.Tasks.ValueTask)______func(this.Client, ______arguments)).ConfigureAwait(false); + } + } + } +} + +#pragma warning restore diff --git a/Refit.GeneratorTests/_snapshots/ReturnTypeTests.ValueTaskApiResponseShouldWork#IGeneratedClient.g.verified.cs b/Refit.GeneratorTests/_snapshots/ReturnTypeTests.ValueTaskApiResponseShouldWork#IGeneratedClient.g.verified.cs new file mode 100644 index 000000000..97ff01a0f --- /dev/null +++ b/Refit.GeneratorTests/_snapshots/ReturnTypeTests.ValueTaskApiResponseShouldWork#IGeneratedClient.g.verified.cs @@ -0,0 +1,43 @@ +//HintName: IGeneratedClient.g.cs +#nullable disable +#pragma warning disable +namespace Refit.Implementation +{ + + partial class Generated + { + + /// + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.Diagnostics.DebuggerNonUserCode] + [global::RefitInternalGenerated.PreserveAttribute] + [global::System.Reflection.Obfuscation(Exclude=true)] + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + partial class RefitGeneratorTestIGeneratedClient + : global::RefitGeneratorTest.IGeneratedClient + { + /// + public global::System.Net.Http.HttpClient Client { get; } + readonly global::Refit.IRequestBuilder requestBuilder; + + /// + public RefitGeneratorTestIGeneratedClient(global::System.Net.Http.HttpClient client, global::Refit.IRequestBuilder requestBuilder) + { + Client = client; + this.requestBuilder = requestBuilder; + } + + + /// + public async global::System.Threading.Tasks.ValueTask> Get() + { + var ______arguments = global::System.Array.Empty(); + var ______func = requestBuilder.BuildRestResultFuncForMethod("Get", global::System.Array.Empty() ); + + return await ((global::System.Threading.Tasks.ValueTask>)______func(this.Client, ______arguments)).ConfigureAwait(false); + } + } + } +} + +#pragma warning restore diff --git a/Refit.Tests/AuthenticatedClientHandlerTests.cs b/Refit.Tests/AuthenticatedClientHandlerTests.cs index 3f4476273..263285590 100644 --- a/Refit.Tests/AuthenticatedClientHandlerTests.cs +++ b/Refit.Tests/AuthenticatedClientHandlerTests.cs @@ -415,6 +415,34 @@ public async Task AuthorizationHeaderValueGetterIsUsedWhenSupplyingHttpClient() Assert.Equal("Ok", result); } + [Fact] + public async Task AuthorizationHeaderValueGetterCanAwaitWhenSupplyingHttpClient() + { + var handler = new MockHttpMessageHandler(); + var httpClient = new HttpClient(handler) { BaseAddress = new Uri("http://api") }; + + var settings = new RefitSettings + { + AuthorizationHeaderValueGetter = async (_, __) => + { + await Task.Yield(); + return "tokenValue"; + } + }; + + handler + .Expect(HttpMethod.Get, "http://api/auth") + .WithHeaders("Authorization", "Bearer tokenValue") + .Respond("text/plain", "Ok"); + + var fixture = RestService.For(httpClient, settings); + + var result = await fixture.GetAuthenticated(); + + handler.VerifyNoOutstandingExpectation(); + Assert.Equal("Ok", result); + } + [Fact] public async Task AuthorizationHeaderValueGetterDoesNotOverrideExplicitTokenWhenSupplyingHttpClient() { diff --git a/Refit.Tests/CachedRequestBuilder.cs b/Refit.Tests/CachedRequestBuilder.cs index e9b88c8dd..fabe1e4c4 100644 --- a/Refit.Tests/CachedRequestBuilder.cs +++ b/Refit.Tests/CachedRequestBuilder.cs @@ -34,6 +34,26 @@ public interface IDuplicateNames public class CachedRequestBuilderTests { + [Fact] + public void CachedBuilder_ThrowsForNullInnerBuilder() + { + Assert.Throws(() => new CachedRequestBuilderImplementation(null!)); + } + + [Fact] + public void MethodTableKey_ObjectEquals_And_GenericArgumentDifference_AreCovered() + { + var key = new MethodTableKey("Foo", [typeof(string)], [typeof(int)]); + object same = new MethodTableKey("Foo", [typeof(string)], [typeof(int)]); + object different = new MethodTableKey("Foo", [typeof(string)], [typeof(long)]); + var differentParameter = new MethodTableKey("Foo", [typeof(int)], [typeof(int)]); + + Assert.True(key.Equals(same)); + Assert.False(key.Equals(different)); + Assert.False(key.Equals(differentParameter)); + Assert.False(key.Equals(new object())); + } + [Fact] public async Task CacheHasCorrectNumberOfElementsTest() { diff --git a/Refit.Tests/ExplicitInterfaceRefitTests.cs b/Refit.Tests/ExplicitInterfaceRefitTests.cs index 1e1e30e32..903de5e75 100644 --- a/Refit.Tests/ExplicitInterfaceRefitTests.cs +++ b/Refit.Tests/ExplicitInterfaceRefitTests.cs @@ -1,4 +1,5 @@ -using System.Net.Http; +using System.Net; +using System.Net.Http; using System.Threading.Tasks; using Refit; using RichardSzalay.MockHttp; @@ -8,6 +9,12 @@ namespace Refit.Tests; public class ExplicitInterfaceRefitTests { + sealed class SyncCapableMockHttpMessageHandler : MockHttpMessageHandler + { + protected override HttpResponseMessage Send(HttpRequestMessage request, CancellationToken cancellationToken) => + SendAsync(request, cancellationToken).GetAwaiter().GetResult(); + } + public interface IFoo { int Bar(); @@ -29,10 +36,35 @@ public interface IRemoteFoo2 : IFoo abstract int IFoo.Bar(); } + // Interfaces used to test the full sync pipeline + public interface ISyncPipelineApi + { + [Get("/resource")] + internal string GetString(); + + [Get("/resource")] + internal HttpResponseMessage GetHttpResponseMessage(); + + [Get("/resource")] + internal HttpContent GetHttpContent(); + + [Get("/resource")] + internal Stream GetStream(); + + [Get("/resource")] + internal IApiResponse GetApiResponse(); + + [Get("/resource")] + internal IApiResponse GetRawApiResponse(); + + [Get("/resource")] + internal void DoVoid(); + } + [Fact] public void DefaultInterfaceImplementation_calls_internal_refit_method() { - var mockHttp = new MockHttpMessageHandler(); + var mockHttp = new SyncCapableMockHttpMessageHandler(); var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp }; mockHttp @@ -50,7 +82,7 @@ public void DefaultInterfaceImplementation_calls_internal_refit_method() [Fact] public void Explicit_interface_member_with_refit_attribute_is_invoked() { - var mockHttp = new MockHttpMessageHandler(); + var mockHttp = new SyncCapableMockHttpMessageHandler(); var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp }; mockHttp @@ -64,4 +96,181 @@ public void Explicit_interface_member_with_refit_attribute_is_invoked() mockHttp.VerifyNoOutstandingExpectation(); } + + [Fact] + public void Sync_method_throws_ApiException_on_error_response() + { + var mockHttp = new SyncCapableMockHttpMessageHandler(); + var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp }; + + mockHttp + .Expect(HttpMethod.Get, "http://foo/resource") + .Respond(HttpStatusCode.NotFound); + + var fixture = RestService.For("http://foo", settings); + + var ex = Assert.Throws(() => fixture.GetString()); + Assert.Equal(HttpStatusCode.NotFound, ex.StatusCode); + + mockHttp.VerifyNoOutstandingExpectation(); + } + + [Fact] + public void Sync_method_returns_HttpResponseMessage_without_running_ExceptionFactory() + { + var mockHttp = new SyncCapableMockHttpMessageHandler(); + var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp }; + + mockHttp + .Expect(HttpMethod.Get, "http://foo/resource") + .Respond(HttpStatusCode.NotFound); + + var fixture = RestService.For("http://foo", settings); + + // Should not throw even for a 404 – caller owns the response + using var resp = fixture.GetHttpResponseMessage(); + Assert.Equal(HttpStatusCode.NotFound, resp.StatusCode); + + mockHttp.VerifyNoOutstandingExpectation(); + } + + [Fact] + public void Sync_method_returns_HttpContent_without_disposing_response() + { + var mockHttp = new SyncCapableMockHttpMessageHandler(); + var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp }; + + mockHttp + .Expect(HttpMethod.Get, "http://foo/resource") + .Respond("text/plain", "hello"); + + var fixture = RestService.For("http://foo", settings); + + var content = fixture.GetHttpContent(); + Assert.NotNull(content); + var text = content.ReadAsStringAsync().GetAwaiter().GetResult(); + Assert.Equal("hello", text); + + mockHttp.VerifyNoOutstandingExpectation(); + } + + [Fact] + public void Sync_method_returns_Stream_without_disposing_response() + { + var mockHttp = new SyncCapableMockHttpMessageHandler(); + var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp }; + + mockHttp + .Expect(HttpMethod.Get, "http://foo/resource") + .Respond("text/plain", "hello"); + + var fixture = RestService.For("http://foo", settings); + + using var stream = fixture.GetStream(); + Assert.NotNull(stream); + using var reader = new StreamReader(stream); + Assert.Equal("hello", reader.ReadToEnd()); + + mockHttp.VerifyNoOutstandingExpectation(); + } + + [Fact] + public void Sync_method_returns_IApiResponse_with_error_on_bad_status() + { + var mockHttp = new SyncCapableMockHttpMessageHandler(); + var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp }; + + mockHttp + .Expect(HttpMethod.Get, "http://foo/resource") + .Respond(HttpStatusCode.InternalServerError); + + var fixture = RestService.For("http://foo", settings); + + using var apiResp = fixture.GetApiResponse(); + Assert.False(apiResp.IsSuccessStatusCode); + Assert.NotNull(apiResp.Error); + Assert.True(apiResp.HasResponseError(out var error)); + Assert.Equal(HttpStatusCode.InternalServerError, error.StatusCode); + + mockHttp.VerifyNoOutstandingExpectation(); + } + + [Fact] + public void Sync_method_returns_IApiResponse_with_content_on_success() + { + var mockHttp = new SyncCapableMockHttpMessageHandler(); + var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp }; + + mockHttp + .Expect(HttpMethod.Get, "http://foo/resource") + .Respond("application/json", "\"hello\""); + + var fixture = RestService.For("http://foo", settings); + + using var apiResp = fixture.GetApiResponse(); + Assert.True(apiResp.IsSuccessStatusCode); + Assert.Null(apiResp.Error); + Assert.Equal(HttpMethod.Get, apiResp.RequestMessage.Method); + Assert.Equal("http://foo/resource", apiResp.RequestMessage.RequestUri?.ToString()); + // The string branch reads the raw stream (no JSON unwrapping), same as the async path + Assert.Equal("\"hello\"", apiResp.Content); + + mockHttp.VerifyNoOutstandingExpectation(); + } + + [Fact] + public void Sync_method_returns_raw_IApiResponse_on_success() + { + var mockHttp = new SyncCapableMockHttpMessageHandler(); + var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp }; + + mockHttp + .Expect(HttpMethod.Get, "http://foo/resource") + .Respond("text/plain", "hello"); + + var fixture = RestService.For("http://foo", settings); + + using var apiResp = fixture.GetRawApiResponse(); + Assert.True(apiResp.IsSuccessStatusCode); + Assert.Null(apiResp.Error); + Assert.Equal(HttpMethod.Get, apiResp.RequestMessage.Method); + Assert.Equal("http://foo/resource", apiResp.RequestMessage.RequestUri?.ToString()); + + mockHttp.VerifyNoOutstandingExpectation(); + } + + [Fact] + public void Sync_void_method_throws_ApiException_on_error_response() + { + var mockHttp = new SyncCapableMockHttpMessageHandler(); + var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp }; + + mockHttp + .Expect(HttpMethod.Get, "http://foo/resource") + .Respond(HttpStatusCode.BadRequest); + + var fixture = RestService.For("http://foo", settings); + + var ex = Assert.Throws(() => fixture.DoVoid()); + Assert.Equal(HttpStatusCode.BadRequest, ex.StatusCode); + + mockHttp.VerifyNoOutstandingExpectation(); + } + + [Fact] + public void Sync_void_method_succeeds_on_ok_response() + { + var mockHttp = new SyncCapableMockHttpMessageHandler(); + var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp }; + + mockHttp + .Expect(HttpMethod.Get, "http://foo/resource") + .Respond(HttpStatusCode.OK); + + var fixture = RestService.For("http://foo", settings); + + fixture.DoVoid(); // should not throw + + mockHttp.VerifyNoOutstandingExpectation(); + } } diff --git a/Refit.Tests/RequestBuilder.cs b/Refit.Tests/RequestBuilder.cs index 901043144..aeb1aad35 100644 --- a/Refit.Tests/RequestBuilder.cs +++ b/Refit.Tests/RequestBuilder.cs @@ -7,6 +7,7 @@ using System.Net.Http; using System.Reflection; using System.Runtime.Serialization; +using System.Reactive.Linq; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.WebUtilities; @@ -1753,6 +1754,38 @@ public void GenericReturnTypeIsNotTaskOrObservableShouldThrow() ) ); } + + [Fact] + public void InternalSyncGenericReturnTypeSetsDeserializedTypeToReturnType() + { + var input = typeof(IInternalSyncGenericReturnTypeApi); + var fixture = new RestMethodInfoInternal( + input, + input + .GetMethods(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public) + .First(x => x.Name == nameof(IInternalSyncGenericReturnTypeApi.GetValues)) + ); + + Assert.Equal(typeof(List), fixture.ReturnType); + Assert.Equal(typeof(List), fixture.ReturnResultType); + Assert.Equal(typeof(List), fixture.DeserializedResultType); + } + + [Fact] + public void InternalSyncIApiResponseGenericReturnTypeSetsDeserializedTypeToGenericArgument() + { + var input = typeof(IInternalSyncGenericApiResponseReturnTypeApi); + var fixture = new RestMethodInfoInternal( + input, + input + .GetMethods(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public) + .First(x => x.Name == nameof(IInternalSyncGenericApiResponseReturnTypeApi.GetResponse)) + ); + + Assert.Equal(typeof(IApiResponse), fixture.ReturnType); + Assert.Equal(typeof(IApiResponse), fixture.ReturnResultType); + Assert.Equal(typeof(string), fixture.DeserializedResultType); + } } [Headers("User-Agent: RefitTestClient", "Api-Version: 1")] @@ -2154,7 +2187,25 @@ interface IAuthenticatedCancellableMethods { [Headers("Authorization: Bearer")] [Get("/foo")] - Task GetWithAuthorizationAndCancellation(CancellationToken token = default); + Task GetWithAuthorizationAndCancellation(CancellationToken token=default); + } + + interface IObservableCancellableMethods + { + [Get("/value/{value}")] + IObservable GetWithCancellation(string value, CancellationToken token = default); + } + + interface IInternalSyncGenericReturnTypeApi + { + [Get("/values")] + internal List GetValues(); + } + + interface IInternalSyncGenericApiResponseReturnTypeApi + { + [Get("/response")] + internal IApiResponse GetResponse(); } public enum FooWithEnumMember @@ -2443,6 +2494,37 @@ public void StreamResponseAsApiResponseTest() Assert.Equal(reponseContent, reader.ReadToEnd()); } + [Fact] + public void GeneratedSyncApiResponseShouldPreserveRequestMessage() + { + var fixture = new RequestBuilderImplementation(); + var restMethod = new RestMethodInfoInternal( + typeof(IDummyHttpApi), + typeof(IDummyHttpApi) + .GetMethods() + .First(x => x.Name == nameof(IDummyHttpApi.FetchSomeStringWithMetadata)) + ); + var buildGeneratedSyncFuncForMethod = typeof(RequestBuilderImplementation).GetMethod( + "BuildGeneratedSyncFuncForMethod", + BindingFlags.Instance | BindingFlags.NonPublic + ); + var factory = (Func) + buildGeneratedSyncFuncForMethod!.Invoke(fixture, [restMethod])!; + var testHttpMessageHandler = new TestHttpMessageHandler(); + + var response = (ApiResponse) + factory( + new HttpClient(testHttpMessageHandler) + { + BaseAddress = new Uri("http://api/") + }, + [42] + )!; + + Assert.Same(testHttpMessageHandler.RequestMessage, response.RequestMessage); + Assert.Equal(testHttpMessageHandler.RequestMessage.RequestUri, response.RequestMessage.RequestUri); + } + [Fact] public void StreamResponseTest() { @@ -2473,6 +2555,131 @@ public void StreamResponseTest() Assert.Equal(reponseContent, reader.ReadToEnd()); } + [Fact] + public async Task ValueTaskMethodsShouldWork() + { + var fixture = new RequestBuilderImplementation(); + var factory = fixture.BuildRestResultFuncForMethod("GetValue"); + var testHttpMessageHandler = new TestHttpMessageHandler(); + + var valueTask = (ValueTask) + factory( + new HttpClient(testHttpMessageHandler) + { + BaseAddress = new Uri("http://api/") + }, + new object[] { "value" } + )!; + + var result = await valueTask; + + Assert.Equal("test", result); + Assert.Equal( + "http://api/value", + testHttpMessageHandler.RequestMessage.RequestUri.ToString() + ); + } + + [Fact] + public async Task ValueTaskApiResponseMethodsShouldWork() + { + var fixture = new RequestBuilderImplementation(); + var factory = fixture.BuildRestResultFuncForMethod("GetValue"); + var testHttpMessageHandler = new TestHttpMessageHandler(); + + var valueTask = (ValueTask>) + factory( + new HttpClient(testHttpMessageHandler) + { + BaseAddress = new Uri("http://api/") + }, + new object[] { "value" } + )!; + + using var response = await valueTask; + + Assert.True(response.IsSuccessStatusCode); + Assert.Equal("test", response.Content); + Assert.Same(testHttpMessageHandler.RequestMessage, response.RequestMessage); + Assert.Equal( + "http://api/value", + response.RequestMessage.RequestUri.ToString() + ); + } + + [Fact] + public async Task ObservableMethodsWithCancellationTokenShouldWork() + { + var fixture = new RequestBuilderImplementation(); + var factory = fixture.BuildRestResultFuncForMethod("GetWithCancellation"); + var testHttpMessageHandler = new TestHttpMessageHandler(); + using var cts = new CancellationTokenSource(); + cts.Cancel(); + + var observable = (IObservable) + factory( + new HttpClient(testHttpMessageHandler) + { + BaseAddress = new Uri("http://api/") + }, + new object[] { "value", cts.Token } + )!; + + var result = observable.Wait(); + + Assert.Equal("test", result); + Assert.Equal( + "http://api/value/value", + testHttpMessageHandler.RequestMessage.RequestUri.ToString() + ); + Assert.True(testHttpMessageHandler.CancellationToken.IsCancellationRequested); + } + + [Fact] + public void BuildRestResultFuncForMethodThrowsForInvalidPublicSyncMethodFromInjectedMetadata() + { + var fixture = new RequestBuilderImplementation(); + var interfaceHttpMethodsField = typeof(RequestBuilderImplementation).GetField( + "interfaceHttpMethods", + BindingFlags.Instance | BindingFlags.NonPublic + ); + var interfaceHttpMethods = + (Dictionary>)interfaceHttpMethodsField!.GetValue(fixture)!; + + var restMethod = (RestMethodInfoInternal)FormatterServices.GetUninitializedObject( + typeof(RestMethodInfoInternal) + ); + typeof(RestMethodInfoInternal) + .GetField("k__BackingField", BindingFlags.Instance | BindingFlags.NonPublic)! + .SetValue( + restMethod, + typeof(IInvalidReturnTypeIApiResponse).GetMethod(nameof(IInvalidReturnTypeIApiResponse.GetValue)) + ); + typeof(RestMethodInfoInternal) + .GetField("k__BackingField", BindingFlags.Instance | BindingFlags.NonPublic)! + .SetValue(restMethod, typeof(IApiResponse)); + typeof(RestMethodInfoInternal) + .GetField("k__BackingField", BindingFlags.Instance | BindingFlags.NonPublic)! + .SetValue(restMethod, typeof(IApiResponse)); + typeof(RestMethodInfoInternal) + .GetField( + "k__BackingField", + BindingFlags.Instance | BindingFlags.NonPublic + )! + .SetValue(restMethod, typeof(HttpContent)); + + interfaceHttpMethods["GetValue"] = [restMethod]; + + var exception = Assert.Throws( + () => fixture.BuildRestResultFuncForMethod("GetValue") + ); + + Assert.Contains( + "All REST Methods must return either Task or ValueTask or IObservable", + exception.Message + ); + } + [Fact] public void MethodsThatDontHaveAnHttpMethodShouldFail() { diff --git a/Refit.Tests/ResponseTests.cs b/Refit.Tests/ResponseTests.cs index 790babf13..42ab1f888 100644 --- a/Refit.Tests/ResponseTests.cs +++ b/Refit.Tests/ResponseTests.cs @@ -52,6 +52,9 @@ public interface IMyAliasService [Get("/GetIApiResponse")] Task GetIApiResponse(); + + [Get("/GetValueTaskIApiResponse")] + ValueTask GetValueTaskIApiResponse(); } [Fact] @@ -429,6 +432,28 @@ public async Task BadRequestWithStringContent_ShouldReturnIApiResponse() Assert.Equal("Hello world", error.Content); } + [Fact] + public async Task BadRequestWithStringContent_ShouldReturnValueTaskIApiResponse() + { + var expectedResponse = new HttpResponseMessage(HttpStatusCode.BadRequest) + { + Content = new StringContent("Hello world") + }; + expectedResponse.Content.Headers.Clear(); + + mockHandler + .Expect(HttpMethod.Get, $"http://api/{nameof(fixture.GetValueTaskIApiResponse)}") + .Respond(req => expectedResponse); + + var apiResponse = await fixture.GetValueTaskIApiResponse(); + + Assert.NotNull(apiResponse); + Assert.NotNull(apiResponse.Error); + Assert.True(apiResponse.HasResponseError(out var error)); + Assert.NotNull(error.Content); + Assert.Equal("Hello world", error.Content); + } + [Fact] public async Task ValidationApiException_HydratesBaseContent() { diff --git a/Refit.Tests/RestService.cs b/Refit.Tests/RestService.cs index 6af3c7bf3..21770ba9e 100644 --- a/Refit.Tests/RestService.cs +++ b/Refit.Tests/RestService.cs @@ -238,6 +238,18 @@ public interface IStreamApi Task> GetRemoteFileWithMetadata(string filename); } +public interface IValueTaskApi +{ + [Get("/{value}")] + ValueTask GetValue(string value); +} + +public interface IValueTaskApiResponseApi +{ + [Get("/{value}")] + ValueTask> GetValue(string value); +} + public interface IApiWithDecimal { [Get("/withDecimal")] @@ -356,6 +368,41 @@ public void CanCreateInstanceUsingStaticMethod() } #endif + [Fact] + public async Task ValueTaskMethodsShouldWork() + { + var mockHttp = new MockHttpMessageHandler(); + var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp }; + + mockHttp.Expect(HttpMethod.Get, "http://foo/value").Respond("text/plain", "test"); + + var fixture = RestService.For("http://foo", settings); + + var result = await fixture.GetValue("value"); + + Assert.Equal("test", result); + mockHttp.VerifyNoOutstandingExpectation(); + } + + [Fact] + public async Task ValueTaskApiResponseMethodsShouldWork() + { + var mockHttp = new MockHttpMessageHandler(); + var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp }; + + mockHttp.Expect(HttpMethod.Get, "http://foo/value").Respond("text/plain", "test"); + + var fixture = RestService.For("http://foo", settings); + + using var response = await fixture.GetValue("value"); + + Assert.True(response.IsSuccessStatusCode); + Assert.Equal("test", response.Content); + Assert.Equal(HttpMethod.Get, response.RequestMessage.Method); + Assert.Equal("http://foo/value", response.RequestMessage.RequestUri?.ToString()); + mockHttp.VerifyNoOutstandingExpectation(); + } + [Fact] public async Task CanAddContentHeadersToPostWithoutBody() { diff --git a/Refit.Tests/RestServiceExceptions.cs b/Refit.Tests/RestServiceExceptions.cs index 3554c418b..8c38f639e 100644 --- a/Refit.Tests/RestServiceExceptions.cs +++ b/Refit.Tests/RestServiceExceptions.cs @@ -105,6 +105,12 @@ public interface IInvalidReturnType string GetValue(); } +public interface IInvalidReturnTypeIApiResponse +{ + [Get("/")] + IApiResponse GetValue(); +} + public class RestServiceExceptionTests { [Fact] @@ -221,6 +227,13 @@ public void InvalidReturnTypeShouldThrow() AssertExceptionContains("is invalid. All REST Methods must return either Task or ValueTask or IObservable", exception); } + [Fact] + public void InvalidRawApiResponseReturnTypeShouldThrow() + { + var exception = Assert.Throws(() => RestService.For("https://api.github.com")); + AssertExceptionContains("is invalid. All REST Methods must return either Task or ValueTask or IObservable", exception); + } + private static void AssertExceptionContains(string expectedSubstring, Exception exception) { Assert.Contains(expectedSubstring, exception.Message!, StringComparison.Ordinal); diff --git a/Refit.Tests/SerializedContentTests.cs b/Refit.Tests/SerializedContentTests.cs index b09f50887..afaba9639 100644 --- a/Refit.Tests/SerializedContentTests.cs +++ b/Refit.Tests/SerializedContentTests.cs @@ -520,6 +520,54 @@ public void SystemTextJsonContentSerializer_DefaultOptions_SerializeLowercaseEnu Assert.Equal("\"alreadyLowercase\"", json); } + [Theory] + [InlineData("vAlUeOnE")] + [InlineData("ValueOne")] + [InlineData("VALUEONE")] + [InlineData("valueone")] + public void SystemTextJsonContentSerializer_DefaultOptions_DeserializesEnumValuesWithVariousCasings( + string jsonValue + ) + { + var result = SystemTextJsonSerializer.Deserialize( + $"\"{jsonValue}\"", + SystemTextJsonContentSerializer.GetDefaultJsonSerializerOptions() + ); + + Assert.Equal(CamelCaseEnum.ValueOne, result); + } + + [Fact] + public void SystemTextJsonContentSerializer_DefaultOptions_ExactCaseMatchTakesPriorityOverCaseInsensitiveWhenMembersDifferByCase() + { + // When enum has members whose names differ only by case, the exact serialized form + // (camelCase) should be used first (case-sensitive), falling back to case-insensitive only + // for inputs that do not exactly match any known serialized form. + + // CaseDifferentMembers.Alpha serializes to "alpha" (camelCase), + // CaseDifferentMembers.ALPHA serializes to "aLPHA" (camelCase). + // Exact-match lookups must correctly disambiguate these. + var options = SystemTextJsonContentSerializer.GetDefaultJsonSerializerOptions(); + + Assert.Equal( + CaseDifferentMembers.Alpha, + SystemTextJsonSerializer.Deserialize("\"alpha\"", options) + ); + Assert.Equal( + CaseDifferentMembers.ALPHA, + SystemTextJsonSerializer.Deserialize("\"aLPHA\"", options) + ); + // Field names are also accepted via exact match + Assert.Equal( + CaseDifferentMembers.Alpha, + SystemTextJsonSerializer.Deserialize("\"Alpha\"", options) + ); + Assert.Equal( + CaseDifferentMembers.ALPHA, + SystemTextJsonSerializer.Deserialize("\"ALPHA\"", options) + ); + } + [Fact] public async Task SystemTextJsonContentSerializer_UsesSourceGeneratedMetadataWhenProvided() { @@ -614,6 +662,57 @@ public async Task RestService_SerializesBodyUsingDeclaredPolymorphicBaseType() Assert.Contains("\"name\":\"Photon\"", serializedBody, StringComparison.Ordinal); } +#if NET9_0_OR_GREATER + [Fact] + public async Task SystemTextJsonContentSerializer_SupportsJsonStringEnumMemberName() + { + var serializer = new SystemTextJsonContentSerializer( + SystemTextJsonContentSerializer.GetDefaultJsonSerializerOptions() + ); + + var content = serializer.ToHttpContent( + new EnumMemberNameEnvelope { Status = EnumMemberNameStatus.TotallyReady } + ); + var serialized = await content.ReadAsStringAsync(); + var roundTrip = await serializer.FromHttpContentAsync( + new StringContent("{\"status\":\"totally-ready\"}", Encoding.UTF8, "application/json") + ); + + Assert.Contains("totally-ready", serialized, StringComparison.Ordinal); + Assert.NotNull(roundTrip); + Assert.Equal(EnumMemberNameStatus.TotallyReady, roundTrip.Status); + } + + [Fact] + public async Task RestService_UsesDefaultEnumConverterWithJsonStringEnumMemberName() + { + var settings = new RefitSettings( + new SystemTextJsonContentSerializer( + SystemTextJsonContentSerializer.GetDefaultJsonSerializerOptions() + ) + ) + { + HttpMessageHandlerFactory = () => new StubHttpMessageHandler(_ => + Task.FromResult( + new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent( + "{\"status\":\"totally-ready\"}", + Encoding.UTF8, + "application/json" + ) + } + ) + ) + }; + + var api = RestService.For(BaseAddress, settings); + var result = await api.GetStatusAsync(); + + Assert.Equal(EnumMemberNameStatus.TotallyReady, result.Status); + } +#endif + [JsonPolymorphic(TypeDiscriminatorPropertyName = "$type")] [JsonDerivedType(typeof(LaserWeaponRequest), "laser")] public abstract class CreateWeaponRequest @@ -629,6 +728,27 @@ public interface IPolymorphicRequestApi Task CreateWeapon(CreateWeaponRequest request); } +#if NET9_0_OR_GREATER + public enum EnumMemberNameStatus + { + [JsonStringEnumMemberName("totally-ready")] + TotallyReady, + + NeedsReview + } + + public sealed class EnumMemberNameEnvelope + { + public EnumMemberNameStatus Status { get; set; } + } + + public interface IIssue2067StatusApi + { + [Get("/status")] + Task GetStatusAsync(); + } +#endif + [JsonSerializable(typeof(User))] internal sealed partial class SerializedContentJsonSerializerContext : JsonSerializerContext { } @@ -687,6 +807,14 @@ enum CamelCaseEnum alreadyLowercase = 2 } + // Members Alpha and ALPHA differ only by case; this enum is used to verify that + // the case-sensitive lookup takes priority and the correct member is chosen. + enum CaseDifferentMembers + { + Alpha = 1, + ALPHA = 2, + } + sealed class AsyncOnlyJsonContent(string json) : HttpContent { readonly byte[] _bytes = Encoding.UTF8.GetBytes(json); diff --git a/Refit.sln b/Refit.sln index ba2afea5c..07be27e10 100644 --- a/Refit.sln +++ b/Refit.sln @@ -50,6 +50,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ConsoleSampleUsingLocalApi" EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "RestApiForTest", "examples\SampleUsingLocalApi\RestApiforTest\RestApiForTest.csproj", "{23305490-94C3-7131-087F-54EDF910A7ED}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "BlazorWasmIssue2065", "examples\BlazorWasmIssue2065\BlazorWasmIssue2065.csproj", "{89F9BB4A-1D8C-4A9C-986F-40E5757C3D51}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -302,6 +304,22 @@ Global {23305490-94C3-7131-087F-54EDF910A7ED}.Release|x64.Build.0 = Release|Any CPU {23305490-94C3-7131-087F-54EDF910A7ED}.Release|x86.ActiveCfg = Release|Any CPU {23305490-94C3-7131-087F-54EDF910A7ED}.Release|x86.Build.0 = Release|Any CPU + {89F9BB4A-1D8C-4A9C-986F-40E5757C3D51}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {89F9BB4A-1D8C-4A9C-986F-40E5757C3D51}.Debug|Any CPU.Build.0 = Debug|Any CPU + {89F9BB4A-1D8C-4A9C-986F-40E5757C3D51}.Debug|ARM.ActiveCfg = Debug|Any CPU + {89F9BB4A-1D8C-4A9C-986F-40E5757C3D51}.Debug|ARM.Build.0 = Debug|Any CPU + {89F9BB4A-1D8C-4A9C-986F-40E5757C3D51}.Debug|x64.ActiveCfg = Debug|Any CPU + {89F9BB4A-1D8C-4A9C-986F-40E5757C3D51}.Debug|x64.Build.0 = Debug|Any CPU + {89F9BB4A-1D8C-4A9C-986F-40E5757C3D51}.Debug|x86.ActiveCfg = Debug|Any CPU + {89F9BB4A-1D8C-4A9C-986F-40E5757C3D51}.Debug|x86.Build.0 = Debug|Any CPU + {89F9BB4A-1D8C-4A9C-986F-40E5757C3D51}.Release|Any CPU.ActiveCfg = Release|Any CPU + {89F9BB4A-1D8C-4A9C-986F-40E5757C3D51}.Release|Any CPU.Build.0 = Release|Any CPU + {89F9BB4A-1D8C-4A9C-986F-40E5757C3D51}.Release|ARM.ActiveCfg = Release|Any CPU + {89F9BB4A-1D8C-4A9C-986F-40E5757C3D51}.Release|ARM.Build.0 = Release|Any CPU + {89F9BB4A-1D8C-4A9C-986F-40E5757C3D51}.Release|x64.ActiveCfg = Release|Any CPU + {89F9BB4A-1D8C-4A9C-986F-40E5757C3D51}.Release|x64.Build.0 = Release|Any CPU + {89F9BB4A-1D8C-4A9C-986F-40E5757C3D51}.Release|x86.ActiveCfg = Release|Any CPU + {89F9BB4A-1D8C-4A9C-986F-40E5757C3D51}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -315,6 +333,7 @@ Global {55ED7170-CB15-0A50-9C4F-C3A0188E150B} = {FA3CFAFC-3218-487F-837B-E8755672EA27} {BE9DED31-FAE3-F798-BD79-7BA96883D07A} = {FA3CFAFC-3218-487F-837B-E8755672EA27} {23305490-94C3-7131-087F-54EDF910A7ED} = {FA3CFAFC-3218-487F-837B-E8755672EA27} + {89F9BB4A-1D8C-4A9C-986F-40E5757C3D51} = {FA3CFAFC-3218-487F-837B-E8755672EA27} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {6E9C2873-AFF9-4D32-A784-1BA094814054} diff --git a/Refit/Refit.csproj b/Refit/Refit.csproj index 5c40375c7..9c1c24c21 100644 --- a/Refit/Refit.csproj +++ b/Refit/Refit.csproj @@ -24,8 +24,8 @@ - - + + diff --git a/Refit/RequestBuilderImplementation.cs b/Refit/RequestBuilderImplementation.cs index 042ddfae6..85ac1650b 100644 --- a/Refit/RequestBuilderImplementation.cs +++ b/Refit/RequestBuilderImplementation.cs @@ -247,6 +247,24 @@ RestMethodInfoInternal CloseGenericMethodIfNeeded( return (client, args) => taskFunc!.DynamicInvoke(client, args); } + // ValueTask + if (restMethod.ReturnType.GetTypeInfo().IsGenericType && restMethod.ReturnType.GetGenericTypeDefinition() == typeof(ValueTask<>)) + { + var valueTaskFuncMi = typeof(RequestBuilderImplementation).GetMethod( + nameof(BuildValueTaskFuncForMethod), + BindingFlags.NonPublic | BindingFlags.Instance + ); + var valueTaskFunc = (MulticastDelegate?) + ( + valueTaskFuncMi!.MakeGenericMethod( + restMethod.ReturnResultType, + restMethod.DeserializedResultType + ) + ).Invoke(this, [restMethod]); + + return (client, args) => valueTaskFunc!.DynamicInvoke(client, args); + } + // IObservable if (restMethod.ReturnType.GetTypeInfo().IsGenericType && restMethod.ReturnType.GetGenericTypeDefinition() == typeof(IObservable<>)) { @@ -265,30 +283,258 @@ RestMethodInfoInternal CloseGenericMethodIfNeeded( return (client, args) => rxFunc!.DynamicInvoke(client, args); } - // Synchronous return types: build a sync wrapper that awaits internally and returns the value + var isExplicitInterfaceMember = restMethod.MethodInfo.Name.IndexOf('.') >= 0; + var isNonPublic = !restMethod.MethodInfo.IsPublic; + if (isExplicitInterfaceMember || isNonPublic) + { + return BuildGeneratedSyncFuncForMethod(restMethod); + } + + throw new ArgumentException( + $"Method \"{restMethod.MethodInfo.Name}\" is invalid. All REST Methods must return either Task or ValueTask or IObservable" + ); + } + + Func BuildGeneratedSyncFuncForMethod( + RestMethodInfoInternal restMethod + ) + { + if (restMethod.ReturnResultType == typeof(void)) + { + return (client, paramList) => + { + RunSynchronous(() => + ExecuteVoidRequestAsync( + client, + restMethod, + paramList, + CancellationToken.None, + paramsContainsCancellationToken: false + ) + ); + return null; + }; + } + var syncFuncMi = typeof(RequestBuilderImplementation).GetMethod( - nameof(BuildSyncFuncForMethod), + nameof(BuildGeneratedSyncFuncForMethodGeneric), BindingFlags.NonPublic | BindingFlags.Instance ); - var syncFunc = (MulticastDelegate?) - ( - syncFuncMi!.MakeGenericMethod( + return (Func) + syncFuncMi + .MakeGenericMethod( restMethod.ReturnResultType, restMethod.DeserializedResultType ) - ).Invoke(this, [restMethod]); - - return (client, args) => syncFunc!.DynamicInvoke(client, args); + .Invoke(this, [restMethod]); } - private Func BuildSyncFuncForMethod(RestMethodInfoInternal restMethod) + Func BuildGeneratedSyncFuncForMethodGeneric( + RestMethodInfoInternal restMethod + ) { - var taskFunc = BuildTaskFuncForMethod(restMethod); return (client, paramList) => + RunSynchronous(() => + ExecuteRequestAsync( + client, + restMethod, + paramList, + CancellationToken.None, + paramsContainsCancellationToken: false + ) + ); + } + + static void RunSynchronous(Func taskFactory) => + Task.Run(taskFactory).GetAwaiter().GetResult(); + + static T? RunSynchronous(Func> taskFactory) => + Task.Run(taskFactory).GetAwaiter().GetResult(); + + async Task ExecuteVoidRequestAsync( + HttpClient client, + RestMethodInfoInternal restMethod, + object[] paramList, + CancellationToken cancellationToken, + bool paramsContainsCancellationToken + ) + { + if (client.BaseAddress == null) + throw new InvalidOperationException( + "BaseAddress must be set on the HttpClient instance" + ); + + using var rq = await BuildRequestMessageForMethodAsync( + restMethod, + client.BaseAddress.AbsolutePath, + paramsContainsCancellationToken, + paramList + ) + .ConfigureAwait(false); + + if (IsBodyBuffered(restMethod, rq)) { - var task = taskFunc(client, paramList); - return (object?)task.GetAwaiter().GetResult(); - }; + await rq.Content!.LoadIntoBufferAsync().ConfigureAwait(false); + } + + using var resp = await client + .SendAsync(rq, cancellationToken) + .ConfigureAwait(false); + + var exception = await settings.ExceptionFactory(resp).ConfigureAwait(false); + if (exception != null) + { + throw exception; + } + } + + async Task ExecuteRequestAsync( + HttpClient client, + RestMethodInfoInternal restMethod, + object[] paramList, + CancellationToken cancellationToken, + bool paramsContainsCancellationToken + ) + { + if (client.BaseAddress == null) + throw new InvalidOperationException( + "BaseAddress must be set on the HttpClient instance" + ); + + using var rq = await BuildRequestMessageForMethodAsync( + restMethod, + client.BaseAddress.AbsolutePath, + paramsContainsCancellationToken, + paramList + ) + .ConfigureAwait(false); + + HttpResponseMessage? resp = null; + HttpContent? content = null; + var disposeResponse = true; + try + { + if (IsBodyBuffered(restMethod, rq)) + { + await rq.Content!.LoadIntoBufferAsync().ConfigureAwait(false); + } + + try + { + resp = await client + .SendAsync(rq, HttpCompletionOption.ResponseHeadersRead, cancellationToken) + .ConfigureAwait(false); + } + catch (Exception ex) + { + if (!restMethod.IsApiResponse) + throw new ApiRequestException(rq, rq.Method, settings, ex); + + return ApiResponse.Create( + rq, + resp, + default, + settings, + new ApiRequestException(rq, rq.Method, settings, ex) + ); + } + + content = resp.Content ?? new StringContent(string.Empty); + Exception? e = null; + disposeResponse = restMethod.ShouldDisposeResponse; + + if (typeof(T) != typeof(HttpResponseMessage)) + { + e = await settings.ExceptionFactory(resp).ConfigureAwait(false); + } + + if (restMethod.IsApiResponse) + { + var body = default(TBody); + + try + { + body = + e == null + ? await DeserializeContentAsync( + resp, + content, + cancellationToken + ).ConfigureAwait(false) + : default; + } + catch (Exception ex) + { + if (settings.DeserializationExceptionFactory != null) + e = await settings + .DeserializationExceptionFactory(resp, ex) + .ConfigureAwait(false); + else + { + e = await ApiException.Create( + "An error occured deserializing the response.", + resp.RequestMessage!, + resp.RequestMessage!.Method, + resp, + settings, + ex + ).ConfigureAwait(false); + } + } + + return ApiResponse.Create( + rq, + resp, + body, + settings, + e as ApiException + ); + } + else if (e != null) + { + disposeResponse = false; // caller has to dispose + throw e; + } + else + { + try + { + return await DeserializeContentAsync(resp, content, cancellationToken) + .ConfigureAwait(false); + } + catch (Exception ex) + { + if (settings.DeserializationExceptionFactory != null) + { + var customEx = await settings + .DeserializationExceptionFactory(resp, ex) + .ConfigureAwait(false); + if (customEx != null) + throw customEx; + return default; + } + else + { + throw await ApiException.Create( + "An error occured deserializing the response.", + resp.RequestMessage!, + resp.RequestMessage!.Method, + resp, + settings, + ex + ).ConfigureAwait(false); + } + } + } + } + finally + { + if (disposeResponse) + { + resp?.Dispose(); + content?.Dispose(); + } + } } void AddMultipartItem( @@ -724,28 +970,43 @@ static bool ShouldIgnorePropertyInQueryMap(PropertyInfo propertyInfo) return false; } - Func BuildRequestFactoryForMethod( + Func BuildRequestFactoryForMethod( RestMethodInfoInternal restMethod, string basePath, bool paramsContainsCancellationToken ) { return paramList => - { - var cancellationToken = CancellationToken.None; + RunSynchronous(() => + BuildRequestMessageForMethodAsync( + restMethod, + basePath, + paramsContainsCancellationToken, + paramList + ) + ); + } - // make sure we strip out any cancellation tokens - if (paramsContainsCancellationToken) - { - cancellationToken = paramList.OfType().FirstOrDefault(); - paramList = paramList - .Where(o => o == null || o.GetType() != typeof(CancellationToken)) - .ToArray(); - } + async Task BuildRequestMessageForMethodAsync( + RestMethodInfoInternal restMethod, + string basePath, + bool paramsContainsCancellationToken, + object[] paramList + ) + { + var cancellationToken = CancellationToken.None; - var ret = new HttpRequestMessage { Method = restMethod.HttpMethod }; + if (paramsContainsCancellationToken) + { + cancellationToken = paramList.OfType().FirstOrDefault(); + paramList = paramList + .Where(o => o == null || o.GetType() != typeof(CancellationToken)) + .ToArray(); + } - // set up multipart content + var ret = new HttpRequestMessage { Method = restMethod.HttpMethod }; + try + { MultipartFormDataContent? multiPartContent = null; if (restMethod.IsMultipart) { @@ -754,8 +1015,8 @@ bool paramsContainsCancellationToken } List>? queryParamsToAdd = null; - var headersToAdd = restMethod.Headers.Count > 0 ? - new Dictionary(restMethod.Headers) + var headersToAdd = restMethod.Headers.Count > 0 + ? new Dictionary(restMethod.Headers) : null; RestMethodParameterInfo? parameterInfo = null; @@ -764,19 +1025,15 @@ bool paramsContainsCancellationToken { var isParameterMappedToRequest = false; var param = paramList[i]; - // if part of REST resource URL, substitute it in if (restMethod.ParameterMap.TryGetValue(i, out var parameterMapValue)) { parameterInfo = parameterMapValue; if (!parameterInfo.IsObjectPropertyParameter) { - // mark parameter mapped if not an object - // we want objects to fall through so any parameters on this object not bound here get passed as query parameters isParameterMappedToRequest = true; } } - // if marked as body, add to content if ( restMethod.BodyParameterInfo != null && restMethod.BodyParameterInfo.Item3 == i @@ -786,7 +1043,6 @@ bool paramsContainsCancellationToken isParameterMappedToRequest = true; } - // if header, add to request headers if (restMethod.HeaderParameterMap.TryGetValue(i, out var headerParameterValue)) { headersToAdd ??= []; @@ -794,7 +1050,6 @@ bool paramsContainsCancellationToken isParameterMappedToRequest = true; } - //if header collection, add to request headers if (restMethod.HeaderCollectionAt(i)) { if (param is IDictionary headerCollection) @@ -809,7 +1064,6 @@ bool paramsContainsCancellationToken isParameterMappedToRequest = true; } - //if authorize, add to request headers with scheme if ( restMethod.AuthorizeParameterInfo != null && restMethod.AuthorizeParameterInfo.Item2 == i @@ -821,19 +1075,14 @@ bool paramsContainsCancellationToken isParameterMappedToRequest = true; } - //if property, add to populate into HttpRequestMessage.Properties if (restMethod.PropertyParameterMap.ContainsKey(i)) { isParameterMappedToRequest = true; } - // ignore nulls and already processed parameters if (isParameterMappedToRequest || param == null) continue; - // for anything that fell through to here, if this is not a multipart method add the parameter to the query string - // or if is an object bound to the path add any non-path bound properties to query string - // or if it's an object with a query attribute var queryAttribute = restMethod .ParameterInfoArray[i] .GetCustomAttribute(); @@ -845,7 +1094,14 @@ bool paramsContainsCancellationToken ) { queryParamsToAdd ??= []; - AddQueryParameters(restMethod, queryAttribute, param, queryParamsToAdd, i, parameterInfo); + AddQueryParameters( + restMethod, + queryAttribute, + param, + queryParamsToAdd, + i, + parameterInfo + ); continue; } @@ -853,17 +1109,13 @@ bool paramsContainsCancellationToken } AddHeadersToRequest(headersToAdd, ret); - AddAuthorizationHeadersFromGetterAsync(ret, cancellationToken) - .GetAwaiter() - .GetResult(); + await AddAuthorizationHeadersFromGetterAsync(ret, cancellationToken) + .ConfigureAwait(false); AddPropertiesToRequest(restMethod, ret, paramList); #if NET6_0_OR_GREATER AddVersionToRequest(ret); #endif - // NB: The URI methods in .NET are dumb. Also, we do this - // UriBuilder business so that we preserve any hardcoded query - // parameters as well as add the parameterized ones. var urlTarget = BuildRelativePath(basePath, restMethod, paramList); var uri = new UriBuilder(new Uri(BaseUri, urlTarget)); @@ -883,7 +1135,12 @@ bool paramsContainsCancellationToken UriKind.Relative ); return ret; - }; + } + catch + { + ret.Dispose(); + throw; + } } string BuildRelativePath(string basePath, RestMethodInfoInternal restMethod, object[] paramList) @@ -1430,24 +1687,21 @@ RestMethodInfoInternal restMethod }; } - Func BuildVoidTaskFuncForMethod( + Func> BuildValueTaskFuncForMethod( RestMethodInfoInternal restMethod ) { - return async (client, paramList) => - { - if (client.BaseAddress == null) - throw new InvalidOperationException( - "BaseAddress must be set on the HttpClient instance" - ); + var ret = BuildTaskFuncForMethod(restMethod); - var factory = BuildRequestFactoryForMethod( - restMethod, - client.BaseAddress.AbsolutePath, - restMethod.CancellationToken != null - ); - var rq = factory(paramList); + return (client, paramList) => new ValueTask(ret(client, paramList)); + } + Func BuildVoidTaskFuncForMethod( + RestMethodInfoInternal restMethod + ) + { + return (client, paramList) => + { var ct = CancellationToken.None; if (restMethod.CancellationToken != null) @@ -1455,18 +1709,13 @@ RestMethodInfoInternal restMethod ct = paramList.OfType().FirstOrDefault(); } - // Load the data into buffer when body should be buffered. - if (IsBodyBuffered(restMethod, rq)) - { - await rq.Content!.LoadIntoBufferAsync().ConfigureAwait(false); - } - using var resp = await client.SendAsync(rq, ct).ConfigureAwait(false); - - var exception = await settings.ExceptionFactory(resp).ConfigureAwait(false); - if (exception != null) - { - throw exception; - } + return ExecuteVoidRequestAsync( + client, + restMethod, + paramList, + ct, + restMethod.CancellationToken != null + ); }; } diff --git a/Refit/RestMethodInfo.cs b/Refit/RestMethodInfo.cs index 3c502e808..c2e2a1d88 100644 --- a/Refit/RestMethodInfo.cs +++ b/Refit/RestMethodInfo.cs @@ -647,23 +647,7 @@ void DetermineReturnTypeInfo(MethodInfo methodInfo) { ReturnType = returnType; ReturnResultType = returnType.GetGenericArguments()[0]; - - if ( - ReturnResultType.IsGenericType - && ( - ReturnResultType.GetGenericTypeDefinition() == typeof(ApiResponse<>) - || ReturnResultType.GetGenericTypeDefinition() == typeof(IApiResponse<>) - ) - ) - { - DeserializedResultType = ReturnResultType.GetGenericArguments()[0]; - } - else if (ReturnResultType == typeof(IApiResponse)) - { - DeserializedResultType = typeof(HttpContent); - } - else - DeserializedResultType = ReturnResultType; + DeserializedResultType = DetermineDeserializedResultType(ReturnResultType); } else if (returnType == typeof(Task)) { @@ -673,11 +657,13 @@ void DetermineReturnTypeInfo(MethodInfo methodInfo) } else { - // Allow synchronous return types only for non-public or explicit interface members. - // This supports internal Refit methods and explicit interface members annotated with Refit attributes. + // Allow synchronous return types only for methods that are implemented by generated stubs + // (for example explicit/default interface implementations). Public top-level Refit methods must + // still use async-compatible return shapes. var isExplicitInterfaceMember = methodInfo.Name.IndexOf('.') >= 0; - var isNonPublic = !(methodInfo.IsPublic); - if (!(isExplicitInterfaceMember || isNonPublic)) + var isNonPublic = !methodInfo.IsPublic; + + if (!isExplicitInterfaceMember && !isNonPublic) { throw new ArgumentException( $"Method \"{methodInfo.Name}\" is invalid. All REST Methods must return either Task or ValueTask or IObservable" @@ -686,12 +672,28 @@ void DetermineReturnTypeInfo(MethodInfo methodInfo) ReturnType = methodInfo.ReturnType; ReturnResultType = methodInfo.ReturnType; - DeserializedResultType = methodInfo.ReturnType == typeof(IApiResponse) - ? typeof(HttpContent) - : methodInfo.ReturnType; + DeserializedResultType = DetermineDeserializedResultType(ReturnResultType); } } + static Type DetermineDeserializedResultType(Type returnResultType) + { + if ( + returnResultType.IsGenericType + && ( + returnResultType.GetGenericTypeDefinition() == typeof(ApiResponse<>) + || returnResultType.GetGenericTypeDefinition() == typeof(IApiResponse<>) + ) + ) + { + return returnResultType.GetGenericArguments()[0]; + } + + return returnResultType == typeof(IApiResponse) + ? typeof(HttpContent) + : returnResultType; + } + void DetermineIfResponseMustBeDisposed() { // Rest method caller will have to dispose if it's one of those 3 diff --git a/Refit/SystemTextJsonContentSerializer.cs b/Refit/SystemTextJsonContentSerializer.cs index e0dd9e4fb..c62516c22 100644 --- a/Refit/SystemTextJsonContentSerializer.cs +++ b/Refit/SystemTextJsonContentSerializer.cs @@ -176,6 +176,16 @@ public override JsonConverter CreateConverter(Type typeToConvert, JsonSerializer sealed class NonGenericEnumConverter(Type targetType, Type enumType, bool isNullable) : JsonConverter { + readonly Dictionary namesToValues = GetNamesToValues( + enumType, + StringComparer.Ordinal + ); + readonly Dictionary namesToValuesIgnoreCase = GetNamesToValues( + enumType, + StringComparer.OrdinalIgnoreCase + ); + readonly Dictionary valuesToNames = GetValuesToNames(enumType); + public override bool CanConvert(Type typeToConvert) => typeToConvert == targetType; public override object? Read( @@ -203,20 +213,13 @@ JsonSerializerOptions options throw new JsonException($"Cannot convert an empty value to {targetType}."); } - foreach (var name in Enum.GetNames(enumType)) - { - if (string.Equals(ToCamelCase(name), value, StringComparison.Ordinal)) - return Enum.Parse(enumType, name, ignoreCase: false); - } + if (namesToValues.TryGetValue(value, out var namedValue)) + return namedValue; - try - { - return Enum.Parse(enumType, value, ignoreCase: true); - } - catch (ArgumentException) - { - throw new JsonException($"Unable to convert '{value}' to {targetType}."); - } + if (namesToValuesIgnoreCase.TryGetValue(value, out var namedValueIgnoreCase)) + return namedValueIgnoreCase; + + throw new JsonException($"Unable to convert '{value}' to {targetType}."); } if (reader.TokenType == JsonTokenType.Number) @@ -240,14 +243,64 @@ JsonSerializerOptions options return; } - var name = Enum.GetName(enumType, value); - if (name is null) + if (!valuesToNames.TryGetValue(value, out var name)) { writer.WriteNumberValue(Convert.ToInt64(value)); return; } - writer.WriteStringValue(ToCamelCase(name)); + writer.WriteStringValue(name); + } + + static Dictionary GetNamesToValues( + Type enumType, + StringComparer comparer + ) + { + var map = new Dictionary(comparer); + + foreach (var field in enumType.GetFields(BindingFlags.Public | BindingFlags.Static)) + { + var value = Enum.Parse(enumType, field.Name, ignoreCase: false); + foreach (var name in GetSerializedNames(field)) + { + map[name] = value; + } + } + + return map; + } + + static Dictionary GetValuesToNames(Type enumType) + { + var map = new Dictionary(); + + foreach (var field in enumType.GetFields(BindingFlags.Public | BindingFlags.Static)) + { + var value = Enum.Parse(enumType, field.Name, ignoreCase: false); + map[value] = GetPreferredSerializedName(field); + } + + return map; + } + + static IEnumerable GetSerializedNames(FieldInfo field) + { + var preferredName = GetPreferredSerializedName(field); + yield return preferredName; + + if (!string.Equals(field.Name, preferredName, StringComparison.Ordinal)) + yield return field.Name; + } + + static string GetPreferredSerializedName(FieldInfo field) + { +#if NET9_0_OR_GREATER + var enumMemberNameAttribute = field.GetCustomAttribute(); + if (enumMemberNameAttribute is not null) + return enumMemberNameAttribute.Name; +#endif + return ToCamelCase(field.Name); } static string ToCamelCase(string value) => diff --git a/examples/BlazorWasmIssue2065/App.razor b/examples/BlazorWasmIssue2065/App.razor new file mode 100644 index 000000000..14c35973b --- /dev/null +++ b/examples/BlazorWasmIssue2065/App.razor @@ -0,0 +1,13 @@ +@using BlazorWasmIssue2065.Shared + + + + + + + + +

Not found

+
+
+
diff --git a/examples/BlazorWasmIssue2065/BlazorWasmIssue2065.csproj b/examples/BlazorWasmIssue2065/BlazorWasmIssue2065.csproj new file mode 100644 index 000000000..035301dd6 --- /dev/null +++ b/examples/BlazorWasmIssue2065/BlazorWasmIssue2065.csproj @@ -0,0 +1,25 @@ + + + + net10.0 + enable + enable + true + Default + false + + + + + + + + + + + + + diff --git a/examples/BlazorWasmIssue2065/IIssue2065Api.cs b/examples/BlazorWasmIssue2065/IIssue2065Api.cs new file mode 100644 index 000000000..3b21d057e --- /dev/null +++ b/examples/BlazorWasmIssue2065/IIssue2065Api.cs @@ -0,0 +1,9 @@ +using Refit; + +namespace BlazorWasmIssue2065; + +public interface IIssue2065Api +{ + [Get("/sample-data/weather.json")] + Task GetPayload(); +} diff --git a/examples/BlazorWasmIssue2065/IIssue2067Api.cs b/examples/BlazorWasmIssue2065/IIssue2067Api.cs new file mode 100644 index 000000000..2efd90305 --- /dev/null +++ b/examples/BlazorWasmIssue2065/IIssue2067Api.cs @@ -0,0 +1,22 @@ +using Refit; + +namespace BlazorWasmIssue2065; + +internal interface IIssue2067Api +{ + [Get("/sample-data/status.json")] + Task GetStatusAsync(); +} + +internal sealed class Issue2067Response +{ + public Issue2067Status Status { get; set; } +} + +internal enum Issue2067Status +{ + [System.Text.Json.Serialization.JsonStringEnumMemberName("totally-ready")] + TotallyReady, + + NeedsReview +} diff --git a/examples/BlazorWasmIssue2065/Pages/Index.razor b/examples/BlazorWasmIssue2065/Pages/Index.razor new file mode 100644 index 000000000..98b18ff7a --- /dev/null +++ b/examples/BlazorWasmIssue2065/Pages/Index.razor @@ -0,0 +1,31 @@ +@page "/" +@inject IIssue2065Api Api + +Issue 2065 + +

Issue 2065

+ +

This sample proves Refit works inside Blazor WebAssembly without blocking sync waits.

+ + + +

@status

+ +@code { + private string status = "Ready"; + + private async Task CallApi() + { + status = "Calling..."; + + try + { + var payload = await Api.GetPayload(); + status = $"Success: {payload}"; + } + catch (Exception ex) + { + status = $"Failure: {ex.GetType().Name} - {ex.Message}"; + } + } +} diff --git a/examples/BlazorWasmIssue2065/Pages/Issue2067.razor b/examples/BlazorWasmIssue2065/Pages/Issue2067.razor new file mode 100644 index 000000000..589b74f8c --- /dev/null +++ b/examples/BlazorWasmIssue2065/Pages/Issue2067.razor @@ -0,0 +1,31 @@ +@page "/issue2067" +@inject IIssue2067Api Api + +Issue 2067 + +

Issue 2067

+ +

This sample proves Refit's default System.Text.Json enum converter supports JsonStringEnumMemberName.

+ + + +

@status

+ +@code { + private string status = "Ready"; + + private async Task CallApi() + { + status = "Calling..."; + + try + { + var payload = await Api.GetStatusAsync(); + status = $"Success: {payload.Status}"; + } + catch (Exception ex) + { + status = $"Failure: {ex.GetType().Name} - {ex.Message}"; + } + } +} diff --git a/examples/BlazorWasmIssue2065/Program.cs b/examples/BlazorWasmIssue2065/Program.cs new file mode 100644 index 000000000..89c3260dc --- /dev/null +++ b/examples/BlazorWasmIssue2065/Program.cs @@ -0,0 +1,29 @@ +using BlazorWasmIssue2065; +using Microsoft.AspNetCore.Components.Web; +using Microsoft.AspNetCore.Components.WebAssembly.Hosting; +using Refit; + +var builder = WebAssemblyHostBuilder.CreateDefault(args); +builder.RootComponents.Add("#app"); +builder.RootComponents.Add("head::after"); + +builder.Services.AddScoped(_ => new HttpClient +{ + BaseAddress = new Uri(builder.HostEnvironment.BaseAddress) +}); + +builder.Services.AddScoped(sp => + RestService.For(sp.GetRequiredService()) +); +builder.Services.AddScoped(sp => + RestService.For( + sp.GetRequiredService(), + new RefitSettings( + new SystemTextJsonContentSerializer( + SystemTextJsonContentSerializer.GetDefaultJsonSerializerOptions() + ) + ) + ) +); + +await builder.Build().RunAsync(); diff --git a/examples/BlazorWasmIssue2065/Shared/MainLayout.razor b/examples/BlazorWasmIssue2065/Shared/MainLayout.razor new file mode 100644 index 000000000..dcac09d8e --- /dev/null +++ b/examples/BlazorWasmIssue2065/Shared/MainLayout.razor @@ -0,0 +1,7 @@ +@inherits LayoutComponentBase + +
+
+ @Body +
+
diff --git a/examples/BlazorWasmIssue2065/Shared/_Imports.razor b/examples/BlazorWasmIssue2065/Shared/_Imports.razor new file mode 100644 index 000000000..d2616004e --- /dev/null +++ b/examples/BlazorWasmIssue2065/Shared/_Imports.razor @@ -0,0 +1 @@ +@namespace BlazorWasmIssue2065.Shared diff --git a/examples/BlazorWasmIssue2065/_Imports.razor b/examples/BlazorWasmIssue2065/_Imports.razor new file mode 100644 index 000000000..e2dff0fd7 --- /dev/null +++ b/examples/BlazorWasmIssue2065/_Imports.razor @@ -0,0 +1,9 @@ +@using System.Net.Http +@using System.Net.Http.Json +@using Microsoft.AspNetCore.Components.Forms +@using Microsoft.AspNetCore.Components.Routing +@using Microsoft.AspNetCore.Components.Web +@using Microsoft.AspNetCore.Components.Web.Virtualization +@using Microsoft.AspNetCore.Components.WebAssembly.Http +@using Microsoft.JSInterop +@using BlazorWasmIssue2065 diff --git a/examples/BlazorWasmIssue2065/wwwroot/index.html b/examples/BlazorWasmIssue2065/wwwroot/index.html new file mode 100644 index 000000000..d13941ef6 --- /dev/null +++ b/examples/BlazorWasmIssue2065/wwwroot/index.html @@ -0,0 +1,13 @@ + + + + + + BlazorWasmIssue2065 + + + +
Loading...
+ + + diff --git a/examples/BlazorWasmIssue2065/wwwroot/sample-data/status.json b/examples/BlazorWasmIssue2065/wwwroot/sample-data/status.json new file mode 100644 index 000000000..ce6aa6f4a --- /dev/null +++ b/examples/BlazorWasmIssue2065/wwwroot/sample-data/status.json @@ -0,0 +1,3 @@ +{ + "status": "totally-ready" +} diff --git a/examples/BlazorWasmIssue2065/wwwroot/sample-data/weather.json b/examples/BlazorWasmIssue2065/wwwroot/sample-data/weather.json new file mode 100644 index 000000000..90ef3cff0 --- /dev/null +++ b/examples/BlazorWasmIssue2065/wwwroot/sample-data/weather.json @@ -0,0 +1 @@ +"Blazor WASM Refit call completed" diff --git a/examples/Meow.Common/Meow.Common.csproj b/examples/Meow.Common/Meow.Common.csproj index 2728a84b7..5b52a9bf6 100644 --- a/examples/Meow.Common/Meow.Common.csproj +++ b/examples/Meow.Common/Meow.Common.csproj @@ -1,10 +1,11 @@ - + net8.0 enable enable false + $(NoWarn);CS1591;IDE1006;CA1819;CS8618;CA1707;CA1056 diff --git a/examples/Meow/Meow.csproj b/examples/Meow/Meow.csproj index b428fe9b2..4ee710237 100644 --- a/examples/Meow/Meow.csproj +++ b/examples/Meow/Meow.csproj @@ -6,6 +6,7 @@ enable enable false + $(NoWarn);CS1591