diff --git a/lib/stripe/api_requestor.rb b/lib/stripe/api_requestor.rb index c4d8a93f6..78c927b1a 100644 --- a/lib/stripe/api_requestor.rb +++ b/lib/stripe/api_requestor.rb @@ -195,23 +195,29 @@ def request def execute_request(method, path, base_address, params: {}, opts: {}, usage: []) - params = params.to_h if params.is_a?(RequestParams) - http_resp, req_opts = execute_request_internal( - method, path, base_address, params, opts, usage - ) - req_opts = RequestOptions.extract_opts_from_hash(req_opts) + old_requestor = self.class.current_thread_context.active_requestor + self.class.current_thread_context.active_requestor = self + begin + params = params.to_h if params.is_a?(RequestParams) + http_resp, req_opts = execute_request_internal( + method, path, base_address, params, opts, usage + ) + req_opts = RequestOptions.extract_opts_from_hash(req_opts) - resp = interpret_response(http_resp) + resp = interpret_response(http_resp) - # If being called from `APIRequestor#request`, put the last response in - # thread-local memory so that it can be returned to the user. Don't store - # anything otherwise so that we don't leak memory. - store_last_response(object_id, resp) + # If being called from `APIRequestor#request`, put the last response in + # thread-local memory so that it can be returned to the user. Don't store + # anything otherwise so that we don't leak memory. + store_last_response(object_id, resp) - api_mode = Util.get_api_mode(path) - Util.convert_to_stripe_object_with_params(resp.data, params, RequestOptions.persistable(req_opts), resp, - api_mode: api_mode, requestor: self, - v2_deleted_object: method == :delete && api_mode == :v2) + api_mode = Util.get_api_mode(path) + Util.convert_to_stripe_object_with_params(resp.data, params, RequestOptions.persistable(req_opts), resp, + api_mode: api_mode, requestor: self, + v2_deleted_object: method == :delete && api_mode == :v2) + ensure + self.class.current_thread_context.active_requestor = old_requestor + end end # Execute request without instantiating a new object if the relevant object's name matches the class @@ -274,17 +280,23 @@ def execute_request_stream(method, path, "execute_request_stream requires a read_body_chunk_block" end - params = params.to_h if params.is_a?(RequestParams) - http_resp, api_key = execute_request_internal( - method, path, base_address, params, opts, usage, &read_body_chunk_block - ) + old_requestor = self.class.current_thread_context.active_requestor + self.class.current_thread_context.active_requestor = self + begin + params = params.to_h if params.is_a?(RequestParams) + http_resp, api_key = execute_request_internal( + method, path, base_address, params, opts, usage, &read_body_chunk_block + ) - # When the read_body_chunk_block is given, we no longer have access to the - # response body at this point and so return a response object containing - # only the headers. This is because the body was consumed by the block. - resp = StripeHeadersOnlyResponse.from_net_http(http_resp) + # When the read_body_chunk_block is given, we no longer have access to the + # response body at this point and so return a response object containing + # only the headers. This is because the body was consumed by the block. + resp = StripeHeadersOnlyResponse.from_net_http(http_resp) - [resp, api_key] + [resp, api_key] + ensure + self.class.current_thread_context.active_requestor = old_requestor + end end def store_last_response(object_id, resp) @@ -297,6 +309,19 @@ def last_response_has_key?(object_id) self.class.current_thread_context.last_responses&.key?(object_id) end + # Subclass override point for custom HTTP transports. + # Must return an object compatible with Net::HTTPResponse + # (.code, .body, .to_hash, .[], .read_body). + protected def send_request(method, url, headers, body, query, &response_block) + APIRequestor + .default_connection_manager(config) + .execute_request(method, url, + body: body, + headers: headers, + query: query, + &response_block) + end + # # private # @@ -524,13 +549,7 @@ def self.maybe_gc_connection_managers http_resp = execute_request_with_rescues(base_url, headers, api_mode, usage, context) do - self.class - .default_connection_manager(config) - .execute_request(method, url, - body: body, - headers: headers, - query: query, - &response_block) + send_request(method, url, headers, body, query, &response_block) end [http_resp, opts] diff --git a/lib/stripe/stripe_client.rb b/lib/stripe/stripe_client.rb index fb6a89be5..0d0ecba03 100644 --- a/lib/stripe/stripe_client.rb +++ b/lib/stripe/stripe_client.rb @@ -24,7 +24,8 @@ def initialize(api_key, uploads_base: nil, connect_base: nil, meter_events_base: nil, - client_id: nil) + client_id: nil, + requestor: nil) unless api_key raise AuthenticationError, "No API key provided. " \ 'Set your API key using "client = Stripe::StripeClient.new()". ' \ @@ -46,7 +47,17 @@ def initialize(api_key, }.compact config = StripeConfiguration.client_init(config_opts) - @requestor = APIRequestor.new(config) + @requestor = if requestor + instance = requestor.call(config) + unless instance.is_a?(APIRequestor) + raise ArgumentError, + "requestor callable must return an APIRequestor instance, " \ + "got #{instance.class}" + end + instance + else + APIRequestor.new(config) + end # top-level services: The beginning of the section generated from our OpenAPI spec @v1 = Stripe::V1Services.new(@requestor) diff --git a/rbi/stripe/stripe_client.rbi b/rbi/stripe/stripe_client.rbi index 2dc68fa85..cf6ae88a5 100644 --- a/rbi/stripe/stripe_client.rbi +++ b/rbi/stripe/stripe_client.rbi @@ -3,6 +3,32 @@ module Stripe class StripeClient + sig do + params( + api_key: String, + stripe_account: T.nilable(String), + stripe_context: T.nilable(String), + stripe_version: T.nilable(String), + api_base: T.nilable(String), + uploads_base: T.nilable(String), + connect_base: T.nilable(String), + meter_events_base: T.nilable(String), + client_id: T.nilable(String), + requestor: T.nilable(T.proc.params(config: Stripe::StripeConfiguration).returns(Stripe::APIRequestor)) + ).void + end + def initialize( # rubocop:disable Metrics/ParameterLists + api_key, + stripe_account: nil, + stripe_context: nil, + stripe_version: nil, + api_base: nil, + uploads_base: nil, + connect_base: nil, + meter_events_base: nil, + client_id: nil, + requestor: nil + ); end sig do params( payload: String, diff --git a/test/stripe/api_requestor_test.rb b/test/stripe/api_requestor_test.rb index 3d0620081..048430d05 100644 --- a/test/stripe/api_requestor_test.rb +++ b/test/stripe/api_requestor_test.rb @@ -35,6 +35,91 @@ class RequestorTest < Test::Unit::TestCase assert_equal client, APIRequestor.active_requestor end end + + should "be set to self during execute_request" do + stub_request(:post, "#{Stripe::DEFAULT_API_BASE}/v1/path") + .to_return(body: JSON.generate(object: "account")) + + observed_requestor = nil + custom_class = Class.new(APIRequestor) do + define_method(:send_request) do |method, url, headers, body, query, &block| + observed_requestor = self.class.active_requestor + super(method, url, headers, body, query, &block) + end + end + + requestor = custom_class.new("sk_test_123") + requestor.execute_request(:post, "/v1/path", :api) + + assert_equal requestor, observed_requestor + end + + should "be restored after execute_request completes" do + stub_request(:post, "#{Stripe::DEFAULT_API_BASE}/v1/path") + .to_return(body: JSON.generate(object: "account")) + + before = APIRequestor.active_requestor + requestor = APIRequestor.new("sk_test_123") + requestor.execute_request(:post, "/v1/path", :api) + assert_equal before, APIRequestor.active_requestor + end + + should "be restored after execute_request raises" do + before = APIRequestor.active_requestor + requestor = APIRequestor.new("sk_test_123") + + assert_raises(Stripe::APIConnectionError) do + stub_request(:post, "#{Stripe::DEFAULT_API_BASE}/v1/path") + .to_raise(Errno::ECONNREFUSED) + requestor.execute_request(:post, "/v1/path", :api) + end + + assert_equal before, APIRequestor.active_requestor + end + end + + context "#send_request" do + should "be called during normal request flow" do + stub_request(:post, "#{Stripe::DEFAULT_API_BASE}/v1/path") + .to_return(body: JSON.generate(object: "account")) + + called = false + custom_class = Class.new(APIRequestor) do + define_method(:send_request) do |method, url, headers, body, query, &block| + called = true + super(method, url, headers, body, query, &block) + end + end + + requestor = custom_class.new("sk_test_123") + requestor.execute_request(:post, "/v1/path", :api) + + assert called + end + + should "allow subclass to override transport" do + mock_response = Struct.new(:code, :body, keyword_init: true) do + def to_hash + { "content-type" => ["application/json"] } + end + + def [](name) + to_hash[name.downcase]&.first + end + end + + response = mock_response.new(code: "200", body: JSON.generate(object: "account", id: "acct_mock")) + custom_class = Class.new(APIRequestor) do + define_method(:send_request) do |_method, _url, _headers, _body, _query, &_block| + response + end + end + + requestor = custom_class.new("sk_test_123") + result = requestor.execute_request(:get, "/v1/accounts/acct_mock", :api) + + assert_equal "acct_mock", result.id + end end context ".maybe_gc_connection_managers" do diff --git a/test/stripe/stripe_client_test.rb b/test/stripe/stripe_client_test.rb index 121f7fe68..b930b91a8 100644 --- a/test/stripe/stripe_client_test.rb +++ b/test/stripe/stripe_client_test.rb @@ -13,6 +13,73 @@ class StripeClientTest < Test::Unit::TestCase assert client.instance_variable_get(:@requestor).is_a?(APIRequestor) assert client.instance_variable_get(:@requestor).config.api_key == "sk_test_123" end + + should "use default APIRequestor when no requestor given" do + client = StripeClient.new("sk_test_123") + assert_equal APIRequestor, client.instance_variable_get(:@requestor).class + end + + should "accept a requestor factory callable" do + custom_class = Class.new(APIRequestor) + factory = ->(config) { custom_class.new(config) } + client = StripeClient.new("sk_test_123", requestor: factory) + assert_instance_of custom_class, client.instance_variable_get(:@requestor) + end + + should "pass config to requestor factory callable" do + received_config = nil + factory = lambda do |config| + received_config = config + APIRequestor.new(config) + end + StripeClient.new("sk_test_123", stripe_account: "acct_abc", requestor: factory) + assert_equal "sk_test_123", received_config.api_key + assert_equal "acct_abc", received_config.stripe_account + end + + should "reject requestor that does not return APIRequestor" do + factory = ->(_config) { Object.new } + assert_raises(ArgumentError) do + StripeClient.new("sk_test_123", requestor: factory) + end + end + end + + context "custom requestor" do + setup do + @send_request_calls = [] + calls = @send_request_calls + @custom_requestor_class = Class.new(APIRequestor) do + define_method(:send_request) do |method, url, headers, body, query, &response_block| + calls << { method: method, url: url } + super(method, url, headers, body, query, &response_block) + end + end + end + + should "route v1 service calls through custom transport" do + stub_request(:get, "#{Stripe::DEFAULT_API_BASE}/v1/customers") + .to_return(body: JSON.generate(object: "list", data: [])) + + factory = ->(config) { @custom_requestor_class.new(config) } + client = StripeClient.new("sk_test_123", requestor: factory) + client.v1.customers.list + + assert_equal 1, @send_request_calls.length + assert_equal :get, @send_request_calls[0][:method] + assert_includes @send_request_calls[0][:url], "/v1/customers" + end + + should "route raw_request through custom transport" do + stub_request(:get, "#{Stripe::DEFAULT_API_BASE}/v1/customers") + .to_return(body: JSON.generate(object: "list", data: [])) + + factory = ->(config) { @custom_requestor_class.new(config) } + client = StripeClient.new("sk_test_123", requestor: factory) + client.raw_request(:get, "/v1/customers") + + assert_equal 1, @send_request_calls.length + end end context "StripeClient config" do