Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions docs/should_block_request.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,14 @@ class AikidoMiddleware implements MiddlewareInterface
else if ($decision->trigger == "group") {
$message = "Your group exceeded the rate limit for this endpoint!";
}
return new Response([
$response = new Response([
'message' => $message,
], 429);
// Set Retry-After header to inform the client how long to wait (in seconds)
if ($decision->retry_after_seconds > 0) {
$response = $response->withHeader('Retry-After', (string) $decision->retry_after_seconds);
}
return $response;
}

// Aikido decided to block but decision type is not implemented
Expand Down Expand Up @@ -145,15 +150,21 @@ class AikidoMiddleware
}
}
else if ($decision->type == "ratelimited") {
$message = '';
if ($decision->trigger == "user") {
return response('Your user exceeded the rate limit for this endpoint!', 429);
$message = 'Your user exceeded the rate limit for this endpoint!';
}
else if ($decision->trigger == "ip") {
return response("Your IP ({$decision->ip}) exceeded the rate limit for this endpoint!", 429);
$message = "Your IP ({$decision->ip}) exceeded the rate limit for this endpoint!";
}
else if ($decision->trigger == "group") {
return response("Your group exceeded the rate limit for this endpoint!", 429);
$message = "Your group exceeded the rate limit for this endpoint!";
}
$resp = response($message, 429);
if ($decision->retry_after_seconds > 0) {
$resp->header('Retry-After', $decision->retry_after_seconds);
}
return $resp;
}
}

Expand Down Expand Up @@ -266,10 +277,14 @@ class AikidoEventSubscriber implements EventSubscriberInterface
$message = "Your group exceeded the rate limit for this endpoint!";
}

$event->setResponse(new JsonResponse(
$response = new JsonResponse(
['message' => $message],
429
));
);
if ($decision->retry_after_seconds > 0) {
$response->headers->set('Retry-After', (string) $decision->retry_after_seconds);
}
$event->setResponse($response);
return;
}
}
Expand Down
13 changes: 9 additions & 4 deletions lib/agent/aikido_types/sliding_window.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package aikido_types

import "time"

type SuspiciousRequest struct {
Method string `json:"method"`
Url string `json:"url"`
Expand All @@ -8,15 +10,17 @@ type SuspiciousRequest struct {
// SlidingWindow represents a time-based sliding window counter.
// It maintains a queue of counts per time bucket and a running total.
type SlidingWindow struct {
Total int // Running total of all counts in the window
Queue Queue[int] // Queue of counts per time bucket
Samples []SuspiciousRequest // Sample requests collected for attack wave detection (max MaxSamplesPerIP)
Total int // Running total of all counts in the window
Queue Queue[int] // Queue of counts per time bucket
Samples []SuspiciousRequest // Sample requests collected for attack wave detection (max MaxSamplesPerIP)
CreatedAt time.Time // Timestamp when this window was first created
}

// NewSlidingWindow creates a new sliding window with the specified size.
func NewSlidingWindow() *SlidingWindow {
sw := &SlidingWindow{
Queue: NewQueue[int](0), // no max size, we handle it manually
Queue: NewQueue[int](0), // no max size, we handle it manually
CreatedAt: time.Now(),
}
// Ensure there is a current bucket
sw.Queue.Push(0)
Expand All @@ -32,6 +36,7 @@ func (sw *SlidingWindow) Advance(windowSize int) {
if sw.Total >= dropped { // safety check to avoid negative total
sw.Total -= dropped
}
sw.CreatedAt = sw.CreatedAt.Add(time.Minute)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment why we do this?

}
// Add a new bucket for the current time period
sw.Queue.Push(0)
Expand Down
44 changes: 44 additions & 0 deletions lib/agent/aikido_types/sliding_window_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package aikido_types

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -107,6 +108,49 @@ func TestSlidingWindowAdvance(t *testing.T) {
sw.Advance(2) // evict bucket with value 10, but total is only 5
assert.Equal(t, 5, sw.Total) // should not go negative
})

t.Run("does not advance CreatedAt when below capacity", func(t *testing.T) {
sw := NewSlidingWindow()
originalCreatedAt := sw.CreatedAt
sw.Increment()

sw.Advance(5) // queue length 1 < window size 5, no eviction
assert.Equal(t, originalCreatedAt, sw.CreatedAt)
})

t.Run("advances CreatedAt by one minute when evicting a bucket", func(t *testing.T) {
sw := NewSlidingWindow()
originalCreatedAt := sw.CreatedAt
sw.Increment()

sw.Advance(2) // queue grows to 2 (capacity)
sw.Increment()

sw.Advance(2) // evicts oldest bucket, CreatedAt should shift by 1 minute
assert.Equal(t, originalCreatedAt.Add(time.Minute), sw.CreatedAt)
})

t.Run("advances CreatedAt by one minute per eviction over multiple advances", func(t *testing.T) {
sw := NewSlidingWindow()
originalCreatedAt := sw.CreatedAt
sw.Increment()

// Fill up the window (size 3)
sw.Advance(3)
sw.Increment()
sw.Advance(3)
sw.Increment()
// Queue is now at capacity (3 buckets)

sw.Advance(3) // eviction 1
assert.Equal(t, originalCreatedAt.Add(1*time.Minute), sw.CreatedAt)

sw.Advance(3) // eviction 2
assert.Equal(t, originalCreatedAt.Add(2*time.Minute), sw.CreatedAt)

sw.Advance(3) // eviction 3
assert.Equal(t, originalCreatedAt.Add(3*time.Minute), sw.CreatedAt)
})
}

func TestSlidingWindowIsEmpty(t *testing.T) {
Expand Down
33 changes: 24 additions & 9 deletions lib/agent/grpc/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"main/utils"
"slices"
"strings"
"time"
)

func storeTotalStats(server *ServerData, rateLimited bool) {
Expand Down Expand Up @@ -304,20 +305,34 @@ func getRateLimitingDataForEndpoint(server *ServerData, method, route, routePars
return wildcardMatches[0]
}

func computeRetryAfterSeconds(sw *SlidingWindow, windowSizeInMinutes int) int32 {
windowSizeInSeconds := int32(windowSizeInMinutes * 60)
if sw == nil {
return windowSizeInSeconds
}
elapsed := int32(time.Since(sw.CreatedAt).Seconds())
retryAfter := windowSizeInSeconds - elapsed
if retryAfter < 1 {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment why?

retryAfter = 1
}
return retryAfter
}

func isRateLimitingThresholdExceededAndIncrement(rateLimitingDataMatch *RateLimitingValue,
countsMap map[string]*SlidingWindow,
key string,
) bool {
) (bool, int32) {
Comment thread
PopoviciMarian marked this conversation as resolved.
rateLimitingDataMatch.Mutex.Lock()
defer rateLimitingDataMatch.Mutex.Unlock()

if isRateLimitingThresholdExceeded(&rateLimitingDataMatch.Config, countsMap, key) {
return true
retryAfter := computeRetryAfterSeconds(countsMap[key], rateLimitingDataMatch.Config.WindowSizeInMinutes)
return true, retryAfter
}

incrementSlidingWindowEntry(countsMap, key)

return false
return false, 0
}

func getRateLimitingStatus(server *ServerData, method, route, routeParsed, user, ip, rateLimitGroup string) *protos.RateLimitingStatus {
Expand All @@ -333,21 +348,21 @@ func getRateLimitingStatus(server *ServerData, method, route, routeParsed, user,

if rateLimitGroup != "" {
// If the rate limit group exists, we only try to rate limit by rate limit group
if isRateLimitingThresholdExceededAndIncrement(rateLimitingDataMatch, rateLimitingDataMatch.RateLimitGroupCounts, rateLimitGroup) {
if exceeded, retryAfterSeconds := isRateLimitingThresholdExceededAndIncrement(rateLimitingDataMatch, rateLimitingDataMatch.RateLimitGroupCounts, rateLimitGroup); exceeded {
log.Infof(server.Logger, "Rate limited request for group %s - %s %s - %v", rateLimitGroup, method, routeParsed, rateLimitingDataMatch.RateLimitGroupCounts[rateLimitGroup])
return &protos.RateLimitingStatus{Block: true, Trigger: "group"}
return &protos.RateLimitingStatus{Block: true, Trigger: "group", RetryAfterSeconds: retryAfterSeconds}
}
} else if user != "" {
// Otherwise, if the user exists, we try to rate limit by user
if isRateLimitingThresholdExceededAndIncrement(rateLimitingDataMatch, rateLimitingDataMatch.UserCounts, user) {
if exceeded, retryAfterSeconds := isRateLimitingThresholdExceededAndIncrement(rateLimitingDataMatch, rateLimitingDataMatch.UserCounts, user); exceeded {
log.Infof(server.Logger, "Rate limited request for user %s - %s %s - %v", user, method, routeParsed, rateLimitingDataMatch.UserCounts[user])
return &protos.RateLimitingStatus{Block: true, Trigger: "user"}
return &protos.RateLimitingStatus{Block: true, Trigger: "user", RetryAfterSeconds: retryAfterSeconds}
}
} else {
// Otherwise, we try to rate limit by ip
if isRateLimitingThresholdExceededAndIncrement(rateLimitingDataMatch, rateLimitingDataMatch.IpCounts, ip) {
if exceeded, retryAfterSeconds := isRateLimitingThresholdExceededAndIncrement(rateLimitingDataMatch, rateLimitingDataMatch.IpCounts, ip); exceeded {
log.Infof(server.Logger, "Rate limited request for ip %s - %s %s - %v", ip, method, routeParsed, rateLimitingDataMatch.IpCounts[ip])
return &protos.RateLimitingStatus{Block: true, Trigger: "ip"}
return &protos.RateLimitingStatus{Block: true, Trigger: "ip", RetryAfterSeconds: retryAfterSeconds}
}
}

Expand Down
40 changes: 40 additions & 0 deletions lib/agent/grpc/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package grpc
import (
"sync"
"testing"
"time"

. "main/aikido_types"
"main/utils"
Expand Down Expand Up @@ -72,3 +73,42 @@ func TestAttackWaveThrottling(t *testing.T) {
assert.True(t, server.AttackWave.LastSent[ip] > 0, "LastSent should be set after event is sent")
})
}

func TestComputeRetryAfterSeconds(t *testing.T) {
t.Run("returns full window size when sliding window is nil", func(t *testing.T) {
result := computeRetryAfterSeconds(nil, 5)
assert.Equal(t, int32(300), result)
})

t.Run("returns full window size for a freshly created window", func(t *testing.T) {
sw := NewSlidingWindow()
result := computeRetryAfterSeconds(sw, 5)
assert.True(t, result >= 299 && result <= 300, "expected ~300, got %d", result)
})

t.Run("decreases over time", func(t *testing.T) {
sw := NewSlidingWindow()
sw.CreatedAt = time.Now().Add(-60 * time.Second)
result := computeRetryAfterSeconds(sw, 5)
assert.True(t, result >= 239 && result <= 241, "expected ~240, got %d", result)
})

t.Run("clamps to 1 when window has expired", func(t *testing.T) {
sw := NewSlidingWindow()
sw.CreatedAt = time.Now().Add(-600 * time.Second)
result := computeRetryAfterSeconds(sw, 5)
assert.Equal(t, int32(1), result)
})

t.Run("stays accurate after CreatedAt is advanced by eviction", func(t *testing.T) {
sw := NewSlidingWindow()
sw.CreatedAt = time.Now().Add(-30 * time.Second)

// Simulate one eviction advancing CreatedAt by 1 minute
sw.CreatedAt = sw.CreatedAt.Add(time.Minute)

result := computeRetryAfterSeconds(sw, 2) // 2 min window = 120s
// CreatedAt is now ~30 seconds in the future, so retryAfter ≈ 120 + 30 = 150
assert.True(t, result >= 149 && result <= 151, "expected ~150, got %d", result)
})
}
1 change: 1 addition & 0 deletions lib/ipc.proto
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ message CloudConfig {
message RateLimitingStatus {
bool block = 1;
string trigger = 2;
int32 retry_after_seconds = 3;
}

message User {
Expand Down
8 changes: 8 additions & 0 deletions lib/php-extension/Action.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ ACTION_STATUS Action::executeStore(json &event) {
if (trigger == "user-agent") {
userAgent = event["user-agent"];
}
if (event.contains("retryAfterSeconds") && event["retryAfterSeconds"].is_number()) {
retryAfterSeconds = event["retryAfterSeconds"].get<int>();
}
return CONTINUE;
}

Expand Down Expand Up @@ -102,6 +105,7 @@ void Action::Reset() {
description = "";
ip = "";
userAgent = "";
retryAfterSeconds = 0;
}

bool Action::Exit() {
Expand Down Expand Up @@ -135,3 +139,7 @@ char *Action::Ip() {
char *Action::UserAgent() {
return (char *)userAgent.c_str();
}

int Action::RetryAfterSeconds() {
return retryAfterSeconds;
}
2 changes: 2 additions & 0 deletions lib/php-extension/HandleShouldBlockRequest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ ZEND_FUNCTION(should_block_request) {
zend_update_property_string(blockingStatusClass, obj, "description", sizeof("description") - 1, action.Description());
zend_update_property_string(blockingStatusClass, obj, "ip", sizeof("ip") - 1, action.Ip());
zend_update_property_string(blockingStatusClass, obj, "user_agent", sizeof("user_agent") - 1, action.UserAgent());
zend_update_property_long(blockingStatusClass, obj, "retry_after_seconds", sizeof("retry_after_seconds") - 1, action.RetryAfterSeconds());
}

ZEND_FUNCTION(auto_block_request) {
Expand Down Expand Up @@ -145,6 +146,7 @@ void RegisterAikidoBlockRequestStatusClass() {
zend_declare_property_string(blockingStatusClass, "description", sizeof("description") - 1, "", ZEND_ACC_PUBLIC);
zend_declare_property_string(blockingStatusClass, "ip", sizeof("ip") - 1, "", ZEND_ACC_PUBLIC);
zend_declare_property_string(blockingStatusClass, "user_agent", sizeof("user_agent") - 1, "", ZEND_ACC_PUBLIC);
zend_declare_property_long(blockingStatusClass, "retry_after_seconds", sizeof("retry_after_seconds") - 1, 0, ZEND_ACC_PUBLIC);
}

void RegisterAikidoWhitelistRequestStatusClass() {
Expand Down
2 changes: 2 additions & 0 deletions lib/php-extension/include/Action.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Action {
std::string description;
std::string ip;
std::string userAgent;
int retryAfterSeconds = 0;

private:
ACTION_STATUS executeThrow(json &event);
Expand Down Expand Up @@ -49,4 +50,5 @@ class Action {
char* Description();
char* Ip();
char* UserAgent();
int RetryAfterSeconds();
};
20 changes: 19 additions & 1 deletion lib/request-processor/handle_blocking_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,24 @@ func GetAction(actionHandling, actionType, trigger, description, data string, re
return string(actionJson)
}

func GetRateLimitedAction(trigger, ip string, retryAfterSeconds int) string {
actionMap := map[string]interface{}{
"action": "store",
"type": "ratelimited",
"trigger": trigger,
"description": html.EscapeString("configured rate limit exceeded by current ip"),
"message": fmt.Sprintf("Your %s (%s) is blocked due to: %s!", trigger, ip, "configured rate limit exceeded by current ip"),
trigger: ip,
"response_code": 429,
"retryAfterSeconds": retryAfterSeconds,
}
actionJson, err := json.Marshal(actionMap)
if err != nil {
return ""
}
return string(actionJson)
}

func OnGetBlockingStatus(instance *instance.RequestProcessorInstance) string {
log.Debugf(instance, "OnGetBlockingStatus called!")

Expand Down Expand Up @@ -68,7 +86,7 @@ func OnGetBlockingStatus(instance *instance.RequestProcessorInstance) string {
if rateLimitingStatus != nil && rateLimitingStatus.Block {
context.ContextSetIsEndpointRateLimited(instance)
log.Infof(instance, "Request made from IP \"%s\" is ratelimited by \"%s\"!", ip, rateLimitingStatus.Trigger)
return GetAction("store", "ratelimited", rateLimitingStatus.Trigger, "configured rate limit exceeded by current ip", ip, 429)
return GetRateLimitedAction(rateLimitingStatus.Trigger, ip, int(rateLimitingStatus.RetryAfterSeconds))
}
}

Expand Down
5 changes: 5 additions & 0 deletions tests/server/test_rate_limiting_retry_after_header/env.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"AIKIDO_BLOCK": "1",
"AIKIDO_LOCALHOST_ALLOWED_BY_DEFAULT": "0",
"AIKIDO_FEATURE_COLLECT_API_SCHEMA": "1"
}
Loading
Loading