diff --git a/docs/should_block_request.md b/docs/should_block_request.md index 755352bb..11c9b0dd 100644 --- a/docs/should_block_request.md +++ b/docs/should_block_request.md @@ -76,9 +76,14 @@ class AikidoMiddleware implements MiddlewareInterface else if ($decision->trigger == "group") { $message = "Your group exceeded the rate limit for this endpoint!"; } - return new Response([ + $response = new Response([ 'message' => $message, ], 429); + // Set Retry-After header to inform the client how long to wait (in seconds) + if ($decision->retry_after_seconds > 0) { + $response = $response->withHeader('Retry-After', (string) $decision->retry_after_seconds); + } + return $response; } // Aikido decided to block but decision type is not implemented @@ -145,15 +150,21 @@ class AikidoMiddleware } } else if ($decision->type == "ratelimited") { + $message = ''; if ($decision->trigger == "user") { - return response('Your user exceeded the rate limit for this endpoint!', 429); + $message = 'Your user exceeded the rate limit for this endpoint!'; } else if ($decision->trigger == "ip") { - return response("Your IP ({$decision->ip}) exceeded the rate limit for this endpoint!", 429); + $message = "Your IP ({$decision->ip}) exceeded the rate limit for this endpoint!"; } else if ($decision->trigger == "group") { - return response("Your group exceeded the rate limit for this endpoint!", 429); + $message = "Your group exceeded the rate limit for this endpoint!"; } + $resp = response($message, 429); + if ($decision->retry_after_seconds > 0) { + $resp->header('Retry-After', $decision->retry_after_seconds); + } + return $resp; } } @@ -266,10 +277,14 @@ class AikidoEventSubscriber implements EventSubscriberInterface $message = "Your group exceeded the rate limit for this endpoint!"; } - $event->setResponse(new JsonResponse( + $response = new JsonResponse( ['message' => $message], 429 - )); + ); + if ($decision->retry_after_seconds > 0) { + $response->headers->set('Retry-After', (string) $decision->retry_after_seconds); + } + $event->setResponse($response); return; } } diff --git a/lib/agent/aikido_types/sliding_window.go b/lib/agent/aikido_types/sliding_window.go index 2ecc93a6..a2f39d07 100644 --- a/lib/agent/aikido_types/sliding_window.go +++ b/lib/agent/aikido_types/sliding_window.go @@ -1,5 +1,7 @@ package aikido_types +import "time" + type SuspiciousRequest struct { Method string `json:"method"` Url string `json:"url"` @@ -8,15 +10,17 @@ type SuspiciousRequest struct { // SlidingWindow represents a time-based sliding window counter. // It maintains a queue of counts per time bucket and a running total. type SlidingWindow struct { - Total int // Running total of all counts in the window - Queue Queue[int] // Queue of counts per time bucket - Samples []SuspiciousRequest // Sample requests collected for attack wave detection (max MaxSamplesPerIP) + Total int // Running total of all counts in the window + Queue Queue[int] // Queue of counts per time bucket + Samples []SuspiciousRequest // Sample requests collected for attack wave detection (max MaxSamplesPerIP) + CreatedAt time.Time // Timestamp when this window was first created } // NewSlidingWindow creates a new sliding window with the specified size. func NewSlidingWindow() *SlidingWindow { sw := &SlidingWindow{ - Queue: NewQueue[int](0), // no max size, we handle it manually + Queue: NewQueue[int](0), // no max size, we handle it manually + CreatedAt: time.Now(), } // Ensure there is a current bucket sw.Queue.Push(0) @@ -32,6 +36,7 @@ func (sw *SlidingWindow) Advance(windowSize int) { if sw.Total >= dropped { // safety check to avoid negative total sw.Total -= dropped } + sw.CreatedAt = sw.CreatedAt.Add(time.Minute) } // Add a new bucket for the current time period sw.Queue.Push(0) diff --git a/lib/agent/aikido_types/sliding_window_test.go b/lib/agent/aikido_types/sliding_window_test.go index c904ee37..891d0c4e 100644 --- a/lib/agent/aikido_types/sliding_window_test.go +++ b/lib/agent/aikido_types/sliding_window_test.go @@ -2,6 +2,7 @@ package aikido_types import ( "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -107,6 +108,49 @@ func TestSlidingWindowAdvance(t *testing.T) { sw.Advance(2) // evict bucket with value 10, but total is only 5 assert.Equal(t, 5, sw.Total) // should not go negative }) + + t.Run("does not advance CreatedAt when below capacity", func(t *testing.T) { + sw := NewSlidingWindow() + originalCreatedAt := sw.CreatedAt + sw.Increment() + + sw.Advance(5) // queue length 1 < window size 5, no eviction + assert.Equal(t, originalCreatedAt, sw.CreatedAt) + }) + + t.Run("advances CreatedAt by one minute when evicting a bucket", func(t *testing.T) { + sw := NewSlidingWindow() + originalCreatedAt := sw.CreatedAt + sw.Increment() + + sw.Advance(2) // queue grows to 2 (capacity) + sw.Increment() + + sw.Advance(2) // evicts oldest bucket, CreatedAt should shift by 1 minute + assert.Equal(t, originalCreatedAt.Add(time.Minute), sw.CreatedAt) + }) + + t.Run("advances CreatedAt by one minute per eviction over multiple advances", func(t *testing.T) { + sw := NewSlidingWindow() + originalCreatedAt := sw.CreatedAt + sw.Increment() + + // Fill up the window (size 3) + sw.Advance(3) + sw.Increment() + sw.Advance(3) + sw.Increment() + // Queue is now at capacity (3 buckets) + + sw.Advance(3) // eviction 1 + assert.Equal(t, originalCreatedAt.Add(1*time.Minute), sw.CreatedAt) + + sw.Advance(3) // eviction 2 + assert.Equal(t, originalCreatedAt.Add(2*time.Minute), sw.CreatedAt) + + sw.Advance(3) // eviction 3 + assert.Equal(t, originalCreatedAt.Add(3*time.Minute), sw.CreatedAt) + }) } func TestSlidingWindowIsEmpty(t *testing.T) { diff --git a/lib/agent/grpc/request.go b/lib/agent/grpc/request.go index 6ee732df..366bd3b8 100644 --- a/lib/agent/grpc/request.go +++ b/lib/agent/grpc/request.go @@ -10,6 +10,7 @@ import ( "main/utils" "slices" "strings" + "time" ) func storeTotalStats(server *ServerData, rateLimited bool) { @@ -304,20 +305,34 @@ func getRateLimitingDataForEndpoint(server *ServerData, method, route, routePars return wildcardMatches[0] } +func computeRetryAfterSeconds(sw *SlidingWindow, windowSizeInMinutes int) int32 { + windowSizeInSeconds := int32(windowSizeInMinutes * 60) + if sw == nil { + return windowSizeInSeconds + } + elapsed := int32(time.Since(sw.CreatedAt).Seconds()) + retryAfter := windowSizeInSeconds - elapsed + if retryAfter < 1 { + retryAfter = 1 + } + return retryAfter +} + func isRateLimitingThresholdExceededAndIncrement(rateLimitingDataMatch *RateLimitingValue, countsMap map[string]*SlidingWindow, key string, -) bool { +) (bool, int32) { rateLimitingDataMatch.Mutex.Lock() defer rateLimitingDataMatch.Mutex.Unlock() if isRateLimitingThresholdExceeded(&rateLimitingDataMatch.Config, countsMap, key) { - return true + retryAfter := computeRetryAfterSeconds(countsMap[key], rateLimitingDataMatch.Config.WindowSizeInMinutes) + return true, retryAfter } incrementSlidingWindowEntry(countsMap, key) - return false + return false, 0 } func getRateLimitingStatus(server *ServerData, method, route, routeParsed, user, ip, rateLimitGroup string) *protos.RateLimitingStatus { @@ -333,21 +348,21 @@ func getRateLimitingStatus(server *ServerData, method, route, routeParsed, user, if rateLimitGroup != "" { // If the rate limit group exists, we only try to rate limit by rate limit group - if isRateLimitingThresholdExceededAndIncrement(rateLimitingDataMatch, rateLimitingDataMatch.RateLimitGroupCounts, rateLimitGroup) { + if exceeded, retryAfterSeconds := isRateLimitingThresholdExceededAndIncrement(rateLimitingDataMatch, rateLimitingDataMatch.RateLimitGroupCounts, rateLimitGroup); exceeded { log.Infof(server.Logger, "Rate limited request for group %s - %s %s - %v", rateLimitGroup, method, routeParsed, rateLimitingDataMatch.RateLimitGroupCounts[rateLimitGroup]) - return &protos.RateLimitingStatus{Block: true, Trigger: "group"} + return &protos.RateLimitingStatus{Block: true, Trigger: "group", RetryAfterSeconds: retryAfterSeconds} } } else if user != "" { // Otherwise, if the user exists, we try to rate limit by user - if isRateLimitingThresholdExceededAndIncrement(rateLimitingDataMatch, rateLimitingDataMatch.UserCounts, user) { + if exceeded, retryAfterSeconds := isRateLimitingThresholdExceededAndIncrement(rateLimitingDataMatch, rateLimitingDataMatch.UserCounts, user); exceeded { log.Infof(server.Logger, "Rate limited request for user %s - %s %s - %v", user, method, routeParsed, rateLimitingDataMatch.UserCounts[user]) - return &protos.RateLimitingStatus{Block: true, Trigger: "user"} + return &protos.RateLimitingStatus{Block: true, Trigger: "user", RetryAfterSeconds: retryAfterSeconds} } } else { // Otherwise, we try to rate limit by ip - if isRateLimitingThresholdExceededAndIncrement(rateLimitingDataMatch, rateLimitingDataMatch.IpCounts, ip) { + if exceeded, retryAfterSeconds := isRateLimitingThresholdExceededAndIncrement(rateLimitingDataMatch, rateLimitingDataMatch.IpCounts, ip); exceeded { log.Infof(server.Logger, "Rate limited request for ip %s - %s %s - %v", ip, method, routeParsed, rateLimitingDataMatch.IpCounts[ip]) - return &protos.RateLimitingStatus{Block: true, Trigger: "ip"} + return &protos.RateLimitingStatus{Block: true, Trigger: "ip", RetryAfterSeconds: retryAfterSeconds} } } diff --git a/lib/agent/grpc/request_test.go b/lib/agent/grpc/request_test.go index 8f4186c8..f3d42f7a 100644 --- a/lib/agent/grpc/request_test.go +++ b/lib/agent/grpc/request_test.go @@ -3,6 +3,7 @@ package grpc import ( "sync" "testing" + "time" . "main/aikido_types" "main/utils" @@ -72,3 +73,42 @@ func TestAttackWaveThrottling(t *testing.T) { assert.True(t, server.AttackWave.LastSent[ip] > 0, "LastSent should be set after event is sent") }) } + +func TestComputeRetryAfterSeconds(t *testing.T) { + t.Run("returns full window size when sliding window is nil", func(t *testing.T) { + result := computeRetryAfterSeconds(nil, 5) + assert.Equal(t, int32(300), result) + }) + + t.Run("returns full window size for a freshly created window", func(t *testing.T) { + sw := NewSlidingWindow() + result := computeRetryAfterSeconds(sw, 5) + assert.True(t, result >= 299 && result <= 300, "expected ~300, got %d", result) + }) + + t.Run("decreases over time", func(t *testing.T) { + sw := NewSlidingWindow() + sw.CreatedAt = time.Now().Add(-60 * time.Second) + result := computeRetryAfterSeconds(sw, 5) + assert.True(t, result >= 239 && result <= 241, "expected ~240, got %d", result) + }) + + t.Run("clamps to 1 when window has expired", func(t *testing.T) { + sw := NewSlidingWindow() + sw.CreatedAt = time.Now().Add(-600 * time.Second) + result := computeRetryAfterSeconds(sw, 5) + assert.Equal(t, int32(1), result) + }) + + t.Run("stays accurate after CreatedAt is advanced by eviction", func(t *testing.T) { + sw := NewSlidingWindow() + sw.CreatedAt = time.Now().Add(-30 * time.Second) + + // Simulate one eviction advancing CreatedAt by 1 minute + sw.CreatedAt = sw.CreatedAt.Add(time.Minute) + + result := computeRetryAfterSeconds(sw, 2) // 2 min window = 120s + // CreatedAt is now ~30 seconds in the future, so retryAfter ≈ 120 + 30 = 150 + assert.True(t, result >= 149 && result <= 151, "expected ~150, got %d", result) + }) +} diff --git a/lib/ipc.proto b/lib/ipc.proto index 0155712d..f2e4a6e9 100644 --- a/lib/ipc.proto +++ b/lib/ipc.proto @@ -144,6 +144,7 @@ message CloudConfig { message RateLimitingStatus { bool block = 1; string trigger = 2; + int32 retry_after_seconds = 3; } message User { diff --git a/lib/php-extension/Action.cpp b/lib/php-extension/Action.cpp index 7538b75b..2e156ba7 100644 --- a/lib/php-extension/Action.cpp +++ b/lib/php-extension/Action.cpp @@ -38,6 +38,9 @@ ACTION_STATUS Action::executeStore(json &event) { if (trigger == "user-agent") { userAgent = event["user-agent"]; } + if (event.contains("retryAfterSeconds") && event["retryAfterSeconds"].is_number()) { + retryAfterSeconds = event["retryAfterSeconds"].get(); + } return CONTINUE; } @@ -102,6 +105,7 @@ void Action::Reset() { description = ""; ip = ""; userAgent = ""; + retryAfterSeconds = 0; } bool Action::Exit() { @@ -135,3 +139,7 @@ char *Action::Ip() { char *Action::UserAgent() { return (char *)userAgent.c_str(); } + +int Action::RetryAfterSeconds() { + return retryAfterSeconds; +} diff --git a/lib/php-extension/HandleShouldBlockRequest.cpp b/lib/php-extension/HandleShouldBlockRequest.cpp index 8d3cf2d2..8786b164 100644 --- a/lib/php-extension/HandleShouldBlockRequest.cpp +++ b/lib/php-extension/HandleShouldBlockRequest.cpp @@ -82,6 +82,7 @@ ZEND_FUNCTION(should_block_request) { zend_update_property_string(blockingStatusClass, obj, "description", sizeof("description") - 1, action.Description()); zend_update_property_string(blockingStatusClass, obj, "ip", sizeof("ip") - 1, action.Ip()); zend_update_property_string(blockingStatusClass, obj, "user_agent", sizeof("user_agent") - 1, action.UserAgent()); + zend_update_property_long(blockingStatusClass, obj, "retry_after_seconds", sizeof("retry_after_seconds") - 1, action.RetryAfterSeconds()); } ZEND_FUNCTION(auto_block_request) { @@ -145,6 +146,7 @@ void RegisterAikidoBlockRequestStatusClass() { zend_declare_property_string(blockingStatusClass, "description", sizeof("description") - 1, "", ZEND_ACC_PUBLIC); zend_declare_property_string(blockingStatusClass, "ip", sizeof("ip") - 1, "", ZEND_ACC_PUBLIC); zend_declare_property_string(blockingStatusClass, "user_agent", sizeof("user_agent") - 1, "", ZEND_ACC_PUBLIC); + zend_declare_property_long(blockingStatusClass, "retry_after_seconds", sizeof("retry_after_seconds") - 1, 0, ZEND_ACC_PUBLIC); } void RegisterAikidoWhitelistRequestStatusClass() { diff --git a/lib/php-extension/include/Action.h b/lib/php-extension/include/Action.h index 131776db..0e940df2 100644 --- a/lib/php-extension/include/Action.h +++ b/lib/php-extension/include/Action.h @@ -18,6 +18,7 @@ class Action { std::string description; std::string ip; std::string userAgent; + int retryAfterSeconds = 0; private: ACTION_STATUS executeThrow(json &event); @@ -49,4 +50,5 @@ class Action { char* Description(); char* Ip(); char* UserAgent(); + int RetryAfterSeconds(); }; diff --git a/lib/request-processor/handle_blocking_request.go b/lib/request-processor/handle_blocking_request.go index 1df5d700..228bcf45 100644 --- a/lib/request-processor/handle_blocking_request.go +++ b/lib/request-processor/handle_blocking_request.go @@ -29,6 +29,24 @@ func GetAction(actionHandling, actionType, trigger, description, data string, re return string(actionJson) } +func GetRateLimitedAction(trigger, ip string, retryAfterSeconds int) string { + actionMap := map[string]interface{}{ + "action": "store", + "type": "ratelimited", + "trigger": trigger, + "description": html.EscapeString("configured rate limit exceeded by current ip"), + "message": fmt.Sprintf("Your %s (%s) is blocked due to: %s!", trigger, ip, "configured rate limit exceeded by current ip"), + trigger: ip, + "response_code": 429, + "retryAfterSeconds": retryAfterSeconds, + } + actionJson, err := json.Marshal(actionMap) + if err != nil { + return "" + } + return string(actionJson) +} + func OnGetBlockingStatus(instance *instance.RequestProcessorInstance) string { log.Debugf(instance, "OnGetBlockingStatus called!") @@ -68,7 +86,7 @@ func OnGetBlockingStatus(instance *instance.RequestProcessorInstance) string { if rateLimitingStatus != nil && rateLimitingStatus.Block { context.ContextSetIsEndpointRateLimited(instance) log.Infof(instance, "Request made from IP \"%s\" is ratelimited by \"%s\"!", ip, rateLimitingStatus.Trigger) - return GetAction("store", "ratelimited", rateLimitingStatus.Trigger, "configured rate limit exceeded by current ip", ip, 429) + return GetRateLimitedAction(rateLimitingStatus.Trigger, ip, int(rateLimitingStatus.RetryAfterSeconds)) } } diff --git a/tests/server/test_rate_limiting_retry_after_header/env.json b/tests/server/test_rate_limiting_retry_after_header/env.json new file mode 100644 index 00000000..870e7b16 --- /dev/null +++ b/tests/server/test_rate_limiting_retry_after_header/env.json @@ -0,0 +1,5 @@ +{ + "AIKIDO_BLOCK": "1", + "AIKIDO_LOCALHOST_ALLOWED_BY_DEFAULT": "0", + "AIKIDO_FEATURE_COLLECT_API_SCHEMA": "1" +} diff --git a/tests/server/test_rate_limiting_retry_after_header/index.php b/tests/server/test_rate_limiting_retry_after_header/index.php new file mode 100644 index 00000000..80d4b9f5 --- /dev/null +++ b/tests/server/test_rate_limiting_retry_after_header/index.php @@ -0,0 +1,20 @@ +block && $decision->type == "ratelimited") { + http_response_code(429); + header("Retry-After: " . $decision->retry_after_seconds); + echo "Rate limit exceeded"; + exit(); + } +} + +echo "Request successful!"; + +?> diff --git a/tests/server/test_rate_limiting_retry_after_header/start_config.json b/tests/server/test_rate_limiting_retry_after_header/start_config.json new file mode 100644 index 00000000..a25cd209 --- /dev/null +++ b/tests/server/test_rate_limiting_retry_after_header/start_config.json @@ -0,0 +1,22 @@ +{ + "success": true, + "serviceId": 1, + "heartbeatIntervalInMS": 600000, + "endpoints": [ + { + "method": "GET", + "route": "\/", + "forceProtectionOff": false, + "graphql": null, + "allowedIPAddresses": [], + "rateLimiting": { + "enabled": true, + "maxRequests": 3, + "windowSizeInMS": 120000 + } + } + ], + "blockedUserIds": [], + "allowedIPAddresses": [], + "receivedAnyStats": true +} diff --git a/tests/server/test_rate_limiting_retry_after_header/test.py b/tests/server/test_rate_limiting_retry_after_header/test.py new file mode 100644 index 00000000..80e48ec9 --- /dev/null +++ b/tests/server/test_rate_limiting_retry_after_header/test.py @@ -0,0 +1,113 @@ +import requests +import time +import sys +from testlib import * + +''' +Window: 2 minutes (2 buckets of 1 minute each), max 3 requests. + +Phase 1: Basic Retry-After + countdown + - Send 3 requests (OK), trigger rate limiting, check Retry-After <= 120. + - Sleep 5s, trigger again, assert Retry-After decreased. + +Phase 2: Retry-After stays accurate across bucket evictions + - Wait for window to expire (~120s). Rate limit resets. + - Send 1 request at T=0 (bucket 0 = 1, total = 1). + - Sleep 65s so bucket advances. Send 1 request (bucket 1 = 1, total = 2). + - Sleep 65s. Bucket 0 is evicted, total drops to 1, window survives. + - Send 2 more requests to push total back to 3 → rate limited again. + - Assert Retry-After > 10 and <= 120. + +Phase 3: Full reset + - Wait for window to expire. Trigger rate limiting again. + - Assert Retry-After resets back near 120. +''' + +def run_test(): + # --- Phase 1: basic countdown --- + for _ in range(3): + response = php_server_get("/") + assert_response_code_is(response, 200) + assert_response_body_contains(response, "Request successful") + + time.sleep(3) + + for _ in range(3): + response = php_server_get("/") + + response = php_server_get("/") + assert_response_code_is(response, 429) + assert_response_header_contains(response, "Content-Type", "text") + assert_response_body_contains(response, "Rate limit exceeded") + assert "Retry-After" in response.headers, f"Retry-After header missing: {response.headers}" + first_retry_after = int(response.headers["Retry-After"]) + assert first_retry_after > 0, f"Retry-After should be > 0, got {first_retry_after}" + assert first_retry_after <= 120, f"Retry-After should be <= 120 (2 min window), got {first_retry_after}" + print(f"Phase 1a - First Retry-After: {first_retry_after}") + + time.sleep(5) + + response = php_server_get("/") + assert_response_code_is(response, 429) + assert "Retry-After" in response.headers, f"Retry-After header missing: {response.headers}" + second_retry_after = int(response.headers["Retry-After"]) + assert second_retry_after > 0, f"Retry-After should be > 0, got {second_retry_after}" + assert second_retry_after < first_retry_after, f"Retry-After should decrease: {second_retry_after} < {first_retry_after}" + print(f"Phase 1b - Second Retry-After: {second_retry_after} (decreased from {first_retry_after})") + + # --- Phase 2: multi-bucket eviction (catches CreatedAt bug) --- + # Wait for window to fully expire + time.sleep(120) + + # T=0: send 1 request into bucket 0 + response = php_server_get("/") + assert_response_code_is(response, 200) + + # Wait for bucket advance (~65s) + time.sleep(65) + + # T=65: send 1 request into bucket 1 (total=2, spread across 2 buckets) + response = php_server_get("/") + assert_response_code_is(response, 200) + + # Wait for next tick to evict bucket 0 (~65s) + # After eviction: bucket 0 (count=1) is dropped, total=1, window survives via bucket 1 + time.sleep(65) + + # T=130: send 2 requests to push total back to 3 (re-triggering rate limit) + response = php_server_get("/") + assert_response_code_is(response, 200) + response = php_server_get("/") + assert_response_code_is(response, 200) + + # Now total=3, next request should be rate limited + response = php_server_get("/") + assert_response_code_is(response, 429) + assert "Retry-After" in response.headers, f"Retry-After header missing: {response.headers}" + eviction_retry_after = int(response.headers["Retry-After"]) + print(f"Phase 2 - Retry-After after bucket eviction: {eviction_retry_after} (would be 1 without CreatedAt fix)") + assert eviction_retry_after > 10, f"Retry-After should be > 10 after eviction (would be 1 without fix), got {eviction_retry_after}" + assert eviction_retry_after <= 120, f"Retry-After should be <= 120 (2 min window), got {eviction_retry_after}" + + # --- Phase 3: full reset --- + time.sleep(120) + + for _ in range(3): + response = php_server_get("/") + assert_response_code_is(response, 200) + + time.sleep(3) + + for _ in range(3): + response = php_server_get("/") + + response = php_server_get("/") + assert_response_code_is(response, 429) + assert "Retry-After" in response.headers, f"Retry-After header missing: {response.headers}" + reset_retry_after = int(response.headers["Retry-After"]) + assert reset_retry_after > 100, f"Retry-After should reset near 120 after full window expiry, got {reset_retry_after}" + print(f"Phase 3 - Retry-After after full reset: {reset_retry_after}") + +if __name__ == "__main__": + load_test_args() + run_test() diff --git a/tests/testlib/testlib.py b/tests/testlib/testlib.py index 1c4f3fb2..761afdb5 100644 --- a/tests/testlib/testlib.py +++ b/tests/testlib/testlib.py @@ -16,7 +16,8 @@ from requests.adapters import HTTPAdapter, Retry s = requests.Session() retries = Retry(connect=10, - backoff_factor=1) + backoff_factor=1, + respect_retry_after_header=False) s.mount('http://', HTTPAdapter(max_retries=retries))