diff --git a/providers/aws/recommendations/client.go b/providers/aws/recommendations/client.go index c78c449a..78b6c42c 100644 --- a/providers/aws/recommendations/client.go +++ b/providers/aws/recommendations/client.go @@ -33,7 +33,9 @@ type CostExplorerAPI interface { type Client struct { costExplorerClient CostExplorerAPI region string - rateLimiter *RateLimiter + // newRateLimiter is called once per API call (not shared across goroutines). + // Tests can replace it with a factory returning a faster limiter. + newRateLimiter func() *RateLimiter } // NewClient creates a new recommendations client @@ -46,7 +48,7 @@ func NewClient(cfg aws.Config) *Client { return &Client{ costExplorerClient: costexplorer.NewFromConfig(ceConfig), region: cfg.Region, - rateLimiter: NewRateLimiter(), + newRateLimiter: NewRateLimiter, } } @@ -55,7 +57,7 @@ func NewClientWithAPI(api CostExplorerAPI, region string) *Client { return &Client{ costExplorerClient: api, region: region, - rateLimiter: NewRateLimiter(), + newRateLimiter: NewRateLimiter, } } @@ -133,12 +135,12 @@ func (c *Client) fetchRIPageWithRetry( ctx context.Context, input *costexplorer.GetReservationPurchaseRecommendationInput, ) (*costexplorer.GetReservationPurchaseRecommendationOutput, error) { - c.rateLimiter.Reset() + rl := c.newRateLimiter() var result *costexplorer.GetReservationPurchaseRecommendationOutput var err error for { - if waitErr := c.rateLimiter.Wait(ctx); waitErr != nil { + if waitErr := rl.Wait(ctx); waitErr != nil { return nil, fmt.Errorf("rate limiter wait failed: %w", waitErr) } @@ -147,13 +149,13 @@ func (c *Client) fetchRIPageWithRetry( } result, err = c.costExplorerClient.GetReservationPurchaseRecommendation(ctx, input) concurrency.Release(ctx) - if !c.rateLimiter.ShouldRetry(err) { + if !rl.ShouldRetry(err) { break } } if err != nil { - return nil, fmt.Errorf("failed to get RI recommendations after %d retries: %w", c.rateLimiter.GetRetryCount(), err) + return nil, fmt.Errorf("failed to get RI recommendations after %d retries: %w", rl.GetRetryCount(), err) } return result, nil diff --git a/providers/aws/recommendations/client_test.go b/providers/aws/recommendations/client_test.go index cc700c7a..0e0c976c 100644 --- a/providers/aws/recommendations/client_test.go +++ b/providers/aws/recommendations/client_test.go @@ -3,6 +3,7 @@ package recommendations import ( "context" "fmt" + "sync" "testing" "time" @@ -17,6 +18,7 @@ import ( // Mock CostExplorerAPI for testing type mockCostExplorerAPI struct { + mu sync.Mutex riRecommendations *costexplorer.GetReservationPurchaseRecommendationOutput spRecommendations *costexplorer.GetSavingsPlansPurchaseRecommendationOutput riError error @@ -28,20 +30,28 @@ type mockCostExplorerAPI struct { } func (m *mockCostExplorerAPI) GetReservationPurchaseRecommendation(ctx context.Context, params *costexplorer.GetReservationPurchaseRecommendationInput, optFns ...func(*costexplorer.Options)) (*costexplorer.GetReservationPurchaseRecommendationOutput, error) { + m.mu.Lock() m.callCount++ m.riCalls = append(m.riCalls, params) - if m.riError != nil { - return nil, m.riError + riErr := m.riError + riRecs := m.riRecommendations + m.mu.Unlock() + if riErr != nil { + return nil, riErr } - return m.riRecommendations, nil + return riRecs, nil } func (m *mockCostExplorerAPI) GetSavingsPlansPurchaseRecommendation(ctx context.Context, params *costexplorer.GetSavingsPlansPurchaseRecommendationInput, optFns ...func(*costexplorer.Options)) (*costexplorer.GetSavingsPlansPurchaseRecommendationOutput, error) { + m.mu.Lock() m.callCount++ - if m.spError != nil { - return nil, m.spError + spErr := m.spError + spRecs := m.spRecommendations + m.mu.Unlock() + if spErr != nil { + return nil, spErr } - return m.spRecommendations, nil + return spRecs, nil } func (m *mockCostExplorerAPI) GetReservationUtilization(ctx context.Context, params *costexplorer.GetReservationUtilizationInput, optFns ...func(*costexplorer.Options)) (*costexplorer.GetReservationUtilizationOutput, error) { @@ -61,7 +71,7 @@ func TestNewClient(t *testing.T) { assert.NotNil(t, client) assert.NotNil(t, client.costExplorerClient) - assert.NotNil(t, client.rateLimiter) + assert.NotNil(t, client.newRateLimiter) assert.Equal(t, "us-west-2", client.region) } @@ -74,7 +84,7 @@ func TestNewClientWithAPI(t *testing.T) { assert.NotNil(t, client) assert.Equal(t, mockAPI, client.costExplorerClient) assert.Equal(t, region, client.region) - assert.NotNil(t, client.rateLimiter) + assert.NotNil(t, client.newRateLimiter) } func TestGetRecommendations_EC2_Success(t *testing.T) { @@ -262,9 +272,11 @@ func TestGetRecommendations_Error(t *testing.T) { riError: newThrottleError(), } - // Use custom rate limiter to speed up test + // Use custom rate limiter factory to speed up test client := NewClientWithAPI(mockAPI, "us-east-1") - client.rateLimiter = NewRateLimiterWithOptions(1*time.Millisecond, 10*time.Millisecond, 2) + client.newRateLimiter = func() *RateLimiter { + return NewRateLimiterWithOptions(1*time.Millisecond, 10*time.Millisecond, 2) + } params := common.RecommendationParams{ Service: common.ServiceEC2, @@ -506,7 +518,9 @@ func TestGetRecommendations_ContextCancellation(t *testing.T) { } client := NewClientWithAPI(mockAPI, "us-east-1") - client.rateLimiter = NewRateLimiterWithOptions(100*time.Millisecond, 1*time.Second, 5) + client.newRateLimiter = func() *RateLimiter { + return NewRateLimiterWithOptions(100*time.Millisecond, 1*time.Second, 5) + } ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately diff --git a/providers/aws/recommendations/coverage.go b/providers/aws/recommendations/coverage.go index 415a7101..2796d072 100644 --- a/providers/aws/recommendations/coverage.go +++ b/providers/aws/recommendations/coverage.go @@ -318,13 +318,13 @@ func serviceRegionFilter(service, region string) *types.Expression { // Mirrors fetchUtilizationPage so the two paths fail and back off the // same way. func (c *Client) fetchCoveragePage(ctx context.Context, input *costexplorer.GetReservationCoverageInput) (*costexplorer.GetReservationCoverageOutput, error) { - c.rateLimiter.Reset() + rl := c.newRateLimiter() for { - if waitErr := c.rateLimiter.Wait(ctx); waitErr != nil { + if waitErr := rl.Wait(ctx); waitErr != nil { return nil, fmt.Errorf("rate limiter wait failed: %w", waitErr) } result, err := c.costExplorerClient.GetReservationCoverage(ctx, input) - if !c.rateLimiter.ShouldRetry(err) { + if !rl.ShouldRetry(err) { if err != nil { return nil, fmt.Errorf("failed to get reservation coverage: %w", err) } diff --git a/providers/aws/recommendations/parser_sp.go b/providers/aws/recommendations/parser_sp.go index 6eebc471..b662a925 100644 --- a/providers/aws/recommendations/parser_sp.go +++ b/providers/aws/recommendations/parser_sp.go @@ -120,12 +120,12 @@ func (c *Client) fetchSPPageWithRetry( ctx context.Context, input *costexplorer.GetSavingsPlansPurchaseRecommendationInput, ) (*costexplorer.GetSavingsPlansPurchaseRecommendationOutput, error) { - c.rateLimiter.Reset() + rl := c.newRateLimiter() var result *costexplorer.GetSavingsPlansPurchaseRecommendationOutput var err error for { - if waitErr := c.rateLimiter.Wait(ctx); waitErr != nil { + if waitErr := rl.Wait(ctx); waitErr != nil { return nil, fmt.Errorf("rate limiter wait failed: %w", waitErr) } @@ -134,7 +134,7 @@ func (c *Client) fetchSPPageWithRetry( } result, err = c.costExplorerClient.GetSavingsPlansPurchaseRecommendation(ctx, input) concurrency.Release(ctx) - if !c.rateLimiter.ShouldRetry(err) { + if !rl.ShouldRetry(err) { break } } diff --git a/providers/aws/recommendations/utilization.go b/providers/aws/recommendations/utilization.go index b5be494e..f0998765 100644 --- a/providers/aws/recommendations/utilization.go +++ b/providers/aws/recommendations/utilization.go @@ -92,14 +92,14 @@ func (c *Client) GetRIUtilization(ctx context.Context, lookbackDays int) ([]RIUt // fetchUtilizationPage calls the Cost Explorer API with rate-limit retry. func (c *Client) fetchUtilizationPage(ctx context.Context, input *costexplorer.GetReservationUtilizationInput) (*costexplorer.GetReservationUtilizationOutput, error) { - c.rateLimiter.Reset() + rl := c.newRateLimiter() for { - if waitErr := c.rateLimiter.Wait(ctx); waitErr != nil { + if waitErr := rl.Wait(ctx); waitErr != nil { return nil, fmt.Errorf("rate limiter wait failed: %w", waitErr) } result, err := c.costExplorerClient.GetReservationUtilization(ctx, input) - if !c.rateLimiter.ShouldRetry(err) { + if !rl.ShouldRetry(err) { if err != nil { return nil, fmt.Errorf("failed to get reservation utilization: %w", err) }