Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 49 additions & 30 deletions lib/stripe/api_requestor.rb
Original file line number Diff line number Diff line change
Expand Up @@ -195,23 +195,29 @@ def request

def execute_request(method, path, base_address,
params: {}, opts: {}, usage: [])
params = params.to_h if params.is_a?(RequestParams)
http_resp, req_opts = execute_request_internal(
method, path, base_address, params, opts, usage
)
req_opts = RequestOptions.extract_opts_from_hash(req_opts)
old_requestor = self.class.current_thread_context.active_requestor
self.class.current_thread_context.active_requestor = self
begin
params = params.to_h if params.is_a?(RequestParams)
http_resp, req_opts = execute_request_internal(
method, path, base_address, params, opts, usage
)
req_opts = RequestOptions.extract_opts_from_hash(req_opts)

resp = interpret_response(http_resp)
resp = interpret_response(http_resp)

# If being called from `APIRequestor#request`, put the last response in
# thread-local memory so that it can be returned to the user. Don't store
# anything otherwise so that we don't leak memory.
store_last_response(object_id, resp)
# If being called from `APIRequestor#request`, put the last response in
# thread-local memory so that it can be returned to the user. Don't store
# anything otherwise so that we don't leak memory.
store_last_response(object_id, resp)

api_mode = Util.get_api_mode(path)
Util.convert_to_stripe_object_with_params(resp.data, params, RequestOptions.persistable(req_opts), resp,
api_mode: api_mode, requestor: self,
v2_deleted_object: method == :delete && api_mode == :v2)
api_mode = Util.get_api_mode(path)
Util.convert_to_stripe_object_with_params(resp.data, params, RequestOptions.persistable(req_opts), resp,
api_mode: api_mode, requestor: self,
v2_deleted_object: method == :delete && api_mode == :v2)
ensure
self.class.current_thread_context.active_requestor = old_requestor
end
end

# Execute request without instantiating a new object if the relevant object's name matches the class
Expand Down Expand Up @@ -274,17 +280,23 @@ def execute_request_stream(method, path,
"execute_request_stream requires a read_body_chunk_block"
end

params = params.to_h if params.is_a?(RequestParams)
http_resp, api_key = execute_request_internal(
method, path, base_address, params, opts, usage, &read_body_chunk_block
)
old_requestor = self.class.current_thread_context.active_requestor
self.class.current_thread_context.active_requestor = self
begin
params = params.to_h if params.is_a?(RequestParams)
http_resp, api_key = execute_request_internal(
method, path, base_address, params, opts, usage, &read_body_chunk_block
)

# When the read_body_chunk_block is given, we no longer have access to the
# response body at this point and so return a response object containing
# only the headers. This is because the body was consumed by the block.
resp = StripeHeadersOnlyResponse.from_net_http(http_resp)
# When the read_body_chunk_block is given, we no longer have access to the
# response body at this point and so return a response object containing
# only the headers. This is because the body was consumed by the block.
resp = StripeHeadersOnlyResponse.from_net_http(http_resp)

[resp, api_key]
[resp, api_key]
ensure
self.class.current_thread_context.active_requestor = old_requestor
end
end

def store_last_response(object_id, resp)
Expand All @@ -297,6 +309,19 @@ def last_response_has_key?(object_id)
self.class.current_thread_context.last_responses&.key?(object_id)
end

# Subclass override point for custom HTTP transports.
# Must return an object compatible with Net::HTTPResponse
# (.code, .body, .to_hash, .[], .read_body).
protected def send_request(method, url, headers, body, query, &response_block)
APIRequestor
.default_connection_manager(config)
.execute_request(method, url,
body: body,
headers: headers,
query: query,
&response_block)
end

#
# private
#
Expand Down Expand Up @@ -524,13 +549,7 @@ def self.maybe_gc_connection_managers

http_resp =
execute_request_with_rescues(base_url, headers, api_mode, usage, context) do
self.class
.default_connection_manager(config)
.execute_request(method, url,
body: body,
headers: headers,
query: query,
&response_block)
send_request(method, url, headers, body, query, &response_block)
end

[http_resp, opts]
Expand Down
15 changes: 13 additions & 2 deletions lib/stripe/stripe_client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def initialize(api_key,
uploads_base: nil,
connect_base: nil,
meter_events_base: nil,
client_id: nil)
client_id: nil,
requestor: nil)
unless api_key
raise AuthenticationError, "No API key provided. " \
'Set your API key using "client = Stripe::StripeClient.new(<API-KEY>)". ' \
Expand All @@ -46,7 +47,17 @@ def initialize(api_key,
}.compact

config = StripeConfiguration.client_init(config_opts)
@requestor = APIRequestor.new(config)
@requestor = if requestor
instance = requestor.call(config)
unless instance.is_a?(APIRequestor)
raise ArgumentError,
"requestor callable must return an APIRequestor instance, " \
"got #{instance.class}"
end
instance
else
APIRequestor.new(config)
end

# top-level services: The beginning of the section generated from our OpenAPI spec
@v1 = Stripe::V1Services.new(@requestor)
Expand Down
26 changes: 26 additions & 0 deletions rbi/stripe/stripe_client.rbi
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,32 @@

module Stripe
class StripeClient
sig do
params(
api_key: String,
stripe_account: T.nilable(String),
stripe_context: T.nilable(String),
stripe_version: T.nilable(String),
api_base: T.nilable(String),
uploads_base: T.nilable(String),
connect_base: T.nilable(String),
meter_events_base: T.nilable(String),
client_id: T.nilable(String),
requestor: T.nilable(T.proc.params(config: Stripe::StripeConfiguration).returns(Stripe::APIRequestor))
).void
end
def initialize( # rubocop:disable Metrics/ParameterLists
api_key,
stripe_account: nil,
stripe_context: nil,
stripe_version: nil,
api_base: nil,
uploads_base: nil,
connect_base: nil,
meter_events_base: nil,
client_id: nil,
requestor: nil
); end
sig do
params(
payload: String,
Expand Down
85 changes: 85 additions & 0 deletions test/stripe/api_requestor_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,91 @@ class RequestorTest < Test::Unit::TestCase
assert_equal client, APIRequestor.active_requestor
end
end

should "be set to self during execute_request" do
stub_request(:post, "#{Stripe::DEFAULT_API_BASE}/v1/path")
.to_return(body: JSON.generate(object: "account"))

observed_requestor = nil
custom_class = Class.new(APIRequestor) do
define_method(:send_request) do |method, url, headers, body, query, &block|
observed_requestor = self.class.active_requestor
super(method, url, headers, body, query, &block)
end
end

requestor = custom_class.new("sk_test_123")
requestor.execute_request(:post, "/v1/path", :api)

assert_equal requestor, observed_requestor
end

should "be restored after execute_request completes" do
stub_request(:post, "#{Stripe::DEFAULT_API_BASE}/v1/path")
.to_return(body: JSON.generate(object: "account"))

before = APIRequestor.active_requestor
requestor = APIRequestor.new("sk_test_123")
requestor.execute_request(:post, "/v1/path", :api)
assert_equal before, APIRequestor.active_requestor
end

should "be restored after execute_request raises" do
before = APIRequestor.active_requestor
requestor = APIRequestor.new("sk_test_123")

assert_raises(Stripe::APIConnectionError) do
stub_request(:post, "#{Stripe::DEFAULT_API_BASE}/v1/path")
.to_raise(Errno::ECONNREFUSED)
requestor.execute_request(:post, "/v1/path", :api)
end

assert_equal before, APIRequestor.active_requestor
end
end

context "#send_request" do
should "be called during normal request flow" do
stub_request(:post, "#{Stripe::DEFAULT_API_BASE}/v1/path")
.to_return(body: JSON.generate(object: "account"))

called = false
custom_class = Class.new(APIRequestor) do
define_method(:send_request) do |method, url, headers, body, query, &block|
called = true
super(method, url, headers, body, query, &block)
end
end

requestor = custom_class.new("sk_test_123")
requestor.execute_request(:post, "/v1/path", :api)

assert called
end

should "allow subclass to override transport" do
mock_response = Struct.new(:code, :body, keyword_init: true) do
def to_hash
{ "content-type" => ["application/json"] }
end

def [](name)
to_hash[name.downcase]&.first
end
end

response = mock_response.new(code: "200", body: JSON.generate(object: "account", id: "acct_mock"))
custom_class = Class.new(APIRequestor) do
define_method(:send_request) do |_method, _url, _headers, _body, _query, &_block|
response
end
end

requestor = custom_class.new("sk_test_123")
result = requestor.execute_request(:get, "/v1/accounts/acct_mock", :api)

assert_equal "acct_mock", result.id
end
end

context ".maybe_gc_connection_managers" do
Expand Down
67 changes: 67 additions & 0 deletions test/stripe/stripe_client_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,73 @@ class StripeClientTest < Test::Unit::TestCase
assert client.instance_variable_get(:@requestor).is_a?(APIRequestor)
assert client.instance_variable_get(:@requestor).config.api_key == "sk_test_123"
end

should "use default APIRequestor when no requestor given" do
client = StripeClient.new("sk_test_123")
assert_equal APIRequestor, client.instance_variable_get(:@requestor).class
end

should "accept a requestor factory callable" do
custom_class = Class.new(APIRequestor)
factory = ->(config) { custom_class.new(config) }
client = StripeClient.new("sk_test_123", requestor: factory)
assert_instance_of custom_class, client.instance_variable_get(:@requestor)
end

should "pass config to requestor factory callable" do
received_config = nil
factory = lambda do |config|
received_config = config
APIRequestor.new(config)
end
StripeClient.new("sk_test_123", stripe_account: "acct_abc", requestor: factory)
assert_equal "sk_test_123", received_config.api_key
assert_equal "acct_abc", received_config.stripe_account
end

should "reject requestor that does not return APIRequestor" do
factory = ->(_config) { Object.new }
assert_raises(ArgumentError) do
StripeClient.new("sk_test_123", requestor: factory)
end
end
end

context "custom requestor" do
setup do
@send_request_calls = []
calls = @send_request_calls
@custom_requestor_class = Class.new(APIRequestor) do
define_method(:send_request) do |method, url, headers, body, query, &response_block|
calls << { method: method, url: url }
super(method, url, headers, body, query, &response_block)
end
end
end

should "route v1 service calls through custom transport" do
stub_request(:get, "#{Stripe::DEFAULT_API_BASE}/v1/customers")
.to_return(body: JSON.generate(object: "list", data: []))

factory = ->(config) { @custom_requestor_class.new(config) }
client = StripeClient.new("sk_test_123", requestor: factory)
client.v1.customers.list

assert_equal 1, @send_request_calls.length
assert_equal :get, @send_request_calls[0][:method]
assert_includes @send_request_calls[0][:url], "/v1/customers"
end

should "route raw_request through custom transport" do
stub_request(:get, "#{Stripe::DEFAULT_API_BASE}/v1/customers")
.to_return(body: JSON.generate(object: "list", data: []))

factory = ->(config) { @custom_requestor_class.new(config) }
client = StripeClient.new("sk_test_123", requestor: factory)
client.raw_request(:get, "/v1/customers")

assert_equal 1, @send_request_calls.length
end
end

context "StripeClient config" do
Expand Down
Loading