-
Notifications
You must be signed in to change notification settings - Fork 313
Cache Azure CLI delegated-auth tokens to fix memory exhaustion and slow model catalog loads #8458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Copilot
wants to merge
4
commits into
main
Choose a base branch
from
copilot/fix-azd-ai-agent-init-issue
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+342
−10
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
d843028
Initial plan
Copilot 6dbd7a8
Cache AzureCLICredential per tenant to avoid concurrent az subprocess…
Copilot bf43b46
Cache az CLI tokens in-memory to avoid per-request az subprocess spawns
JeffreyCA 5a4e4d6
Add test verifying caching credential does not cache errors
JeffreyCA File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| package auth | ||
|
|
||
| import ( | ||
| "context" | ||
| "strconv" | ||
| "strings" | ||
| "sync" | ||
| "time" | ||
|
|
||
| "github.com/Azure/azure-sdk-for-go/sdk/azcore" | ||
| "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" | ||
| "golang.org/x/sync/singleflight" | ||
| ) | ||
|
|
||
| // tokenRefreshOffset is how long before a cached token's expiry it is treated as stale and | ||
| // re-acquired. It mirrors the default refresh window used by the azcore bearer-token policy. | ||
| const tokenRefreshOffset = 5 * time.Minute | ||
|
|
||
| // cachingCredential wraps a TokenCredential with an in-memory token cache. | ||
| // | ||
| // It exists primarily for the Azure CLI delegated-auth path. AzureCLICredential performs no | ||
| // caching of its own and spawns an `az account get-access-token` subprocess on every GetToken | ||
| // call. Flows that fan out many concurrent Azure SDK clients (for example, listing the AI model | ||
| // catalog across every region) otherwise trigger one `az` subprocess per request, serialized | ||
| // behind the credential's internal mutex, which is slow. Caching collapses repeated requests for | ||
| // the same scope/tenant into a single subprocess invocation and reuses the token until it is near | ||
| // expiry. The cache is in-memory only and lives for the lifetime of the credential instance. | ||
| type cachingCredential struct { | ||
| inner azcore.TokenCredential | ||
|
|
||
| mu sync.RWMutex | ||
| cache map[string]azcore.AccessToken | ||
|
|
||
| // group deduplicates concurrent acquisitions for the same cache key so that only a single | ||
| // inner GetToken call (and thus a single `az` subprocess) runs at a time per key. | ||
| group singleflight.Group | ||
| } | ||
|
|
||
| // newCachingCredential wraps inner with an in-memory token cache. | ||
| func newCachingCredential(inner azcore.TokenCredential) *cachingCredential { | ||
| return &cachingCredential{ | ||
| inner: inner, | ||
| cache: map[string]azcore.AccessToken{}, | ||
| } | ||
| } | ||
|
|
||
| // GetToken returns a cached token for the requested options when one is available and not near | ||
| // expiry; otherwise it acquires a new token from the wrapped credential and caches it. | ||
| func (c *cachingCredential) GetToken( | ||
| ctx context.Context, | ||
| opts policy.TokenRequestOptions, | ||
| ) (azcore.AccessToken, error) { | ||
| key := tokenCacheKey(opts) | ||
|
|
||
| if tk, ok := c.cachedToken(key); ok { | ||
| return tk, nil | ||
| } | ||
|
|
||
| // singleflight shares the result of the first in-flight acquisition for this key with all | ||
| // concurrent callers, so N goroutines requesting the same scope share a single `az` call. | ||
| result, err, _ := c.group.Do(key, func() (any, error) { | ||
| // Another goroutine may have populated the cache while this call was queued. | ||
| if tk, ok := c.cachedToken(key); ok { | ||
| return tk, nil | ||
| } | ||
|
|
||
| tk, err := c.inner.GetToken(ctx, opts) | ||
| if err != nil { | ||
| return azcore.AccessToken{}, err | ||
| } | ||
|
|
||
| c.mu.Lock() | ||
| c.cache[key] = tk | ||
| c.mu.Unlock() | ||
|
|
||
| return tk, nil | ||
| }) | ||
| if err != nil { | ||
| return azcore.AccessToken{}, err | ||
| } | ||
|
|
||
| return result.(azcore.AccessToken), nil | ||
| } | ||
|
|
||
| // cachedToken returns the cached token for key when present and not within the refresh offset of | ||
| // its expiry. | ||
| func (c *cachingCredential) cachedToken(key string) (azcore.AccessToken, bool) { | ||
| c.mu.RLock() | ||
| defer c.mu.RUnlock() | ||
|
|
||
| tk, ok := c.cache[key] | ||
| if !ok { | ||
| return azcore.AccessToken{}, false | ||
| } | ||
|
|
||
| if time.Now().Add(tokenRefreshOffset).After(tk.ExpiresOn) { | ||
| return azcore.AccessToken{}, false | ||
| } | ||
|
|
||
| return tk, true | ||
| } | ||
|
|
||
| // tokenCacheKey derives a stable cache key from the token request options. Tokens differ by their | ||
| // requested scopes, tenant, CAE flag, and any claims challenge, so all are included in the key. | ||
| func tokenCacheKey(opts policy.TokenRequestOptions) string { | ||
| return strings.Join(opts.Scopes, " ") + "\n" + | ||
| opts.TenantID + "\n" + | ||
| opts.Claims + "\n" + | ||
| strconv.FormatBool(opts.EnableCAE) | ||
| } | ||
|
|
||
| var _ azcore.TokenCredential = (*cachingCredential)(nil) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| package auth | ||
|
|
||
| import ( | ||
| "context" | ||
| "errors" | ||
| "sync" | ||
| "sync/atomic" | ||
| "testing" | ||
| "time" | ||
|
|
||
| "github.com/Azure/azure-sdk-for-go/sdk/azcore" | ||
| "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" | ||
| "github.com/stretchr/testify/require" | ||
| ) | ||
|
|
||
| // countingCredential is a TokenCredential that records how many times GetToken is invoked and, | ||
| // optionally, blocks until released so concurrent behavior can be exercised deterministically. | ||
| type countingCredential struct { | ||
| calls atomic.Int32 | ||
| token azcore.AccessToken | ||
| err error | ||
| gate chan struct{} | ||
| gateReady chan struct{} | ||
| } | ||
|
|
||
| func (c *countingCredential) GetToken( | ||
| _ context.Context, | ||
| _ policy.TokenRequestOptions, | ||
| ) (azcore.AccessToken, error) { | ||
| c.calls.Add(1) | ||
| if c.gate != nil { | ||
| c.gateReady <- struct{}{} | ||
| <-c.gate | ||
| } | ||
| if c.err != nil { | ||
| return azcore.AccessToken{}, c.err | ||
| } | ||
| return c.token, nil | ||
| } | ||
|
|
||
| func TestCachingCredentialReusesToken(t *testing.T) { | ||
| inner := &countingCredential{ | ||
| token: azcore.AccessToken{Token: "abc", ExpiresOn: time.Now().Add(time.Hour)}, | ||
| } | ||
| cred := newCachingCredential(inner) | ||
|
|
||
| opts := policy.TokenRequestOptions{Scopes: []string{"https://management.azure.com/.default"}} | ||
|
|
||
| for range 5 { | ||
| tk, err := cred.GetToken(t.Context(), opts) | ||
| require.NoError(t, err) | ||
| require.Equal(t, "abc", tk.Token) | ||
| } | ||
|
|
||
| // All five requests for the same scope should have hit the underlying credential exactly once. | ||
| require.Equal(t, int32(1), inner.calls.Load()) | ||
| } | ||
|
|
||
| func TestCachingCredentialSeparatesByScope(t *testing.T) { | ||
| inner := &countingCredential{ | ||
| token: azcore.AccessToken{Token: "abc", ExpiresOn: time.Now().Add(time.Hour)}, | ||
| } | ||
| cred := newCachingCredential(inner) | ||
|
|
||
| _, err := cred.GetToken(t.Context(), policy.TokenRequestOptions{Scopes: []string{"scope-a"}}) | ||
| require.NoError(t, err) | ||
| _, err = cred.GetToken(t.Context(), policy.TokenRequestOptions{Scopes: []string{"scope-b"}}) | ||
| require.NoError(t, err) | ||
|
|
||
| // Distinct scopes are cached independently, so each triggers its own acquisition. | ||
| require.Equal(t, int32(2), inner.calls.Load()) | ||
| } | ||
|
|
||
| func TestCachingCredentialRefreshesNearExpiry(t *testing.T) { | ||
| inner := &countingCredential{ | ||
| // Token already within the refresh offset of expiry, so it must not be served from cache. | ||
| token: azcore.AccessToken{Token: "abc", ExpiresOn: time.Now().Add(time.Minute)}, | ||
| } | ||
| cred := newCachingCredential(inner) | ||
|
|
||
| opts := policy.TokenRequestOptions{Scopes: []string{"scope"}} | ||
|
|
||
| _, err := cred.GetToken(t.Context(), opts) | ||
| require.NoError(t, err) | ||
| _, err = cred.GetToken(t.Context(), opts) | ||
| require.NoError(t, err) | ||
|
|
||
| require.Equal(t, int32(2), inner.calls.Load()) | ||
| } | ||
|
|
||
| func TestCachingCredentialDoesNotCacheErrors(t *testing.T) { | ||
| inner := &countingCredential{err: errors.New("boom")} | ||
| cred := newCachingCredential(inner) | ||
|
|
||
| opts := policy.TokenRequestOptions{Scopes: []string{"scope"}} | ||
|
|
||
| // A failed acquisition must not be written to the cache, otherwise a single transient `az` | ||
| // failure would wedge auth for the rest of the command. | ||
| _, err := cred.GetToken(t.Context(), opts) | ||
| require.Error(t, err) | ||
|
|
||
| // A subsequent successful acquisition for the same key should reach the inner credential and | ||
| // return the fresh token rather than a cached error. | ||
| inner.err = nil | ||
| inner.token = azcore.AccessToken{Token: "ok", ExpiresOn: time.Now().Add(time.Hour)} | ||
|
|
||
| tk, err := cred.GetToken(t.Context(), opts) | ||
| require.NoError(t, err) | ||
| require.Equal(t, "ok", tk.Token) | ||
| require.Equal(t, int32(2), inner.calls.Load()) | ||
| } | ||
|
|
||
| func TestCachingCredentialSingleFlight(t *testing.T) { | ||
| inner := &countingCredential{ | ||
| token: azcore.AccessToken{Token: "abc", ExpiresOn: time.Now().Add(time.Hour)}, | ||
| gate: make(chan struct{}), | ||
| gateReady: make(chan struct{}), | ||
| } | ||
| cred := newCachingCredential(inner) | ||
|
|
||
| opts := policy.TokenRequestOptions{Scopes: []string{"scope"}} | ||
|
|
||
| const goroutines = 8 | ||
| var wg sync.WaitGroup | ||
| for range goroutines { | ||
| wg.Go(func() { | ||
| tk, err := cred.GetToken(t.Context(), opts) | ||
| require.NoError(t, err) | ||
| require.Equal(t, "abc", tk.Token) | ||
| }) | ||
| } | ||
|
|
||
| // Wait until the single in-flight acquisition is running, then release it. The remaining | ||
| // goroutines should resolve from the shared singleflight result rather than calling the inner | ||
| // credential again. | ||
| <-inner.gateReady | ||
| close(inner.gate) | ||
| wg.Wait() | ||
|
|
||
| require.Equal(t, int32(1), inner.calls.Load()) | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.