diff --git a/cli/azd/pkg/auth/caching_credential.go b/cli/azd/pkg/auth/caching_credential.go new file mode 100644 index 00000000000..81690e45f06 --- /dev/null +++ b/cli/azd/pkg/auth/caching_credential.go @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package auth + +import ( + "context" + "strconv" + "strings" + "sync" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "golang.org/x/sync/singleflight" +) + +// tokenRefreshOffset is how long before a cached token's expiry it is treated as stale and +// re-acquired. It mirrors the default refresh window used by the azcore bearer-token policy. +const tokenRefreshOffset = 5 * time.Minute + +// cachingCredential wraps a TokenCredential with an in-memory token cache. +// +// It exists primarily for the Azure CLI delegated-auth path. AzureCLICredential performs no +// caching of its own and spawns an `az account get-access-token` subprocess on every GetToken +// call. Flows that fan out many concurrent Azure SDK clients (for example, listing the AI model +// catalog across every region) otherwise trigger one `az` subprocess per request, serialized +// behind the credential's internal mutex, which is slow. Caching collapses repeated requests for +// the same scope/tenant into a single subprocess invocation and reuses the token until it is near +// expiry. The cache is in-memory only and lives for the lifetime of the credential instance. +type cachingCredential struct { + inner azcore.TokenCredential + + mu sync.RWMutex + cache map[string]azcore.AccessToken + + // group deduplicates concurrent acquisitions for the same cache key so that only a single + // inner GetToken call (and thus a single `az` subprocess) runs at a time per key. + group singleflight.Group +} + +// newCachingCredential wraps inner with an in-memory token cache. +func newCachingCredential(inner azcore.TokenCredential) *cachingCredential { + return &cachingCredential{ + inner: inner, + cache: map[string]azcore.AccessToken{}, + } +} + +// GetToken returns a cached token for the requested options when one is available and not near +// expiry; otherwise it acquires a new token from the wrapped credential and caches it. +func (c *cachingCredential) GetToken( + ctx context.Context, + opts policy.TokenRequestOptions, +) (azcore.AccessToken, error) { + key := tokenCacheKey(opts) + + if tk, ok := c.cachedToken(key); ok { + return tk, nil + } + + // singleflight shares the result of the first in-flight acquisition for this key with all + // concurrent callers, so N goroutines requesting the same scope share a single `az` call. + result, err, _ := c.group.Do(key, func() (any, error) { + // Another goroutine may have populated the cache while this call was queued. + if tk, ok := c.cachedToken(key); ok { + return tk, nil + } + + tk, err := c.inner.GetToken(ctx, opts) + if err != nil { + return azcore.AccessToken{}, err + } + + c.mu.Lock() + c.cache[key] = tk + c.mu.Unlock() + + return tk, nil + }) + if err != nil { + return azcore.AccessToken{}, err + } + + return result.(azcore.AccessToken), nil +} + +// cachedToken returns the cached token for key when present and not within the refresh offset of +// its expiry. +func (c *cachingCredential) cachedToken(key string) (azcore.AccessToken, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + tk, ok := c.cache[key] + if !ok { + return azcore.AccessToken{}, false + } + + if time.Now().Add(tokenRefreshOffset).After(tk.ExpiresOn) { + return azcore.AccessToken{}, false + } + + return tk, true +} + +// tokenCacheKey derives a stable cache key from the token request options. Tokens differ by their +// requested scopes, tenant, CAE flag, and any claims challenge, so all are included in the key. +func tokenCacheKey(opts policy.TokenRequestOptions) string { + return strings.Join(opts.Scopes, " ") + "\n" + + opts.TenantID + "\n" + + opts.Claims + "\n" + + strconv.FormatBool(opts.EnableCAE) +} + +var _ azcore.TokenCredential = (*cachingCredential)(nil) diff --git a/cli/azd/pkg/auth/caching_credential_test.go b/cli/azd/pkg/auth/caching_credential_test.go new file mode 100644 index 00000000000..11dcec1a12e --- /dev/null +++ b/cli/azd/pkg/auth/caching_credential_test.go @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package auth + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/stretchr/testify/require" +) + +// countingCredential is a TokenCredential that records how many times GetToken is invoked and, +// optionally, blocks until released so concurrent behavior can be exercised deterministically. +type countingCredential struct { + calls atomic.Int32 + token azcore.AccessToken + err error + gate chan struct{} + gateReady chan struct{} +} + +func (c *countingCredential) GetToken( + _ context.Context, + _ policy.TokenRequestOptions, +) (azcore.AccessToken, error) { + c.calls.Add(1) + if c.gate != nil { + c.gateReady <- struct{}{} + <-c.gate + } + if c.err != nil { + return azcore.AccessToken{}, c.err + } + return c.token, nil +} + +func TestCachingCredentialReusesToken(t *testing.T) { + inner := &countingCredential{ + token: azcore.AccessToken{Token: "abc", ExpiresOn: time.Now().Add(time.Hour)}, + } + cred := newCachingCredential(inner) + + opts := policy.TokenRequestOptions{Scopes: []string{"https://management.azure.com/.default"}} + + for range 5 { + tk, err := cred.GetToken(t.Context(), opts) + require.NoError(t, err) + require.Equal(t, "abc", tk.Token) + } + + // All five requests for the same scope should have hit the underlying credential exactly once. + require.Equal(t, int32(1), inner.calls.Load()) +} + +func TestCachingCredentialSeparatesByScope(t *testing.T) { + inner := &countingCredential{ + token: azcore.AccessToken{Token: "abc", ExpiresOn: time.Now().Add(time.Hour)}, + } + cred := newCachingCredential(inner) + + _, err := cred.GetToken(t.Context(), policy.TokenRequestOptions{Scopes: []string{"scope-a"}}) + require.NoError(t, err) + _, err = cred.GetToken(t.Context(), policy.TokenRequestOptions{Scopes: []string{"scope-b"}}) + require.NoError(t, err) + + // Distinct scopes are cached independently, so each triggers its own acquisition. + require.Equal(t, int32(2), inner.calls.Load()) +} + +func TestCachingCredentialRefreshesNearExpiry(t *testing.T) { + inner := &countingCredential{ + // Token already within the refresh offset of expiry, so it must not be served from cache. + token: azcore.AccessToken{Token: "abc", ExpiresOn: time.Now().Add(time.Minute)}, + } + cred := newCachingCredential(inner) + + opts := policy.TokenRequestOptions{Scopes: []string{"scope"}} + + _, err := cred.GetToken(t.Context(), opts) + require.NoError(t, err) + _, err = cred.GetToken(t.Context(), opts) + require.NoError(t, err) + + require.Equal(t, int32(2), inner.calls.Load()) +} + +func TestCachingCredentialDoesNotCacheErrors(t *testing.T) { + inner := &countingCredential{err: errors.New("boom")} + cred := newCachingCredential(inner) + + opts := policy.TokenRequestOptions{Scopes: []string{"scope"}} + + // A failed acquisition must not be written to the cache, otherwise a single transient `az` + // failure would wedge auth for the rest of the command. + _, err := cred.GetToken(t.Context(), opts) + require.Error(t, err) + + // A subsequent successful acquisition for the same key should reach the inner credential and + // return the fresh token rather than a cached error. + inner.err = nil + inner.token = azcore.AccessToken{Token: "ok", ExpiresOn: time.Now().Add(time.Hour)} + + tk, err := cred.GetToken(t.Context(), opts) + require.NoError(t, err) + require.Equal(t, "ok", tk.Token) + require.Equal(t, int32(2), inner.calls.Load()) +} + +func TestCachingCredentialSingleFlight(t *testing.T) { + inner := &countingCredential{ + token: azcore.AccessToken{Token: "abc", ExpiresOn: time.Now().Add(time.Hour)}, + gate: make(chan struct{}), + gateReady: make(chan struct{}), + } + cred := newCachingCredential(inner) + + opts := policy.TokenRequestOptions{Scopes: []string{"scope"}} + + const goroutines = 8 + var wg sync.WaitGroup + for range goroutines { + wg.Go(func() { + tk, err := cred.GetToken(t.Context(), opts) + require.NoError(t, err) + require.Equal(t, "abc", tk.Token) + }) + } + + // Wait until the single in-flight acquisition is running, then release it. The remaining + // goroutines should resolve from the shared singleflight result rather than calling the inner + // credential again. + <-inner.gateReady + close(inner.gate) + wg.Wait() + + require.Equal(t, int32(1), inner.calls.Load()) +} diff --git a/cli/azd/pkg/auth/manager.go b/cli/azd/pkg/auth/manager.go index f91566c8c3c..cb73e6ab342 100644 --- a/cli/azd/pkg/auth/manager.go +++ b/cli/azd/pkg/auth/manager.go @@ -16,6 +16,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "github.com/Azure/azure-sdk-for-go/sdk/azcore" azcloud "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" @@ -100,6 +101,13 @@ type Manager struct { externalAuthCfg ExternalAuthConfiguration azCli az.AzCli userAgent string + + // azCliCredentials caches az CLI credentials keyed by tenant ID when auth.useAzCliAuth is set. + // Each entry is a cachingCredential wrapping an AzureCLICredential. Sharing a single instance per + // tenant lets concurrent callers reuse one in-memory token instead of each spawning its own `az` + // subprocess, which avoids exhausting memory and serializing many `az` processes in parallel. + azCliCredentials map[string]azcore.TokenCredential + azCliCredentialsMu sync.Mutex } // UserAgent is a typed string for the application user-agent, @@ -169,6 +177,7 @@ func NewManager( externalAuthCfg: externalAuthCfg, azCli: azCli, userAgent: string(userAgent), + azCliCredentials: map[string]azcore.TokenCredential{}, }, nil } @@ -258,13 +267,7 @@ func (m *Manager) CredentialForCurrentUser( if shouldUseLegacyAuth(userConfig) { log.Printf("delegating auth to az since %s is set to true", useAzCliAuthKey) - cred, err := azidentity.NewAzureCLICredential(&azidentity.AzureCLICredentialOptions{ - TenantID: options.TenantID, - }) - if err != nil { - return nil, fmt.Errorf("failed to create credential: %w: %w", err, ErrNoCurrentUser) - } - return cred, nil + return m.azCliCredentialForTenant(options.TenantID) } authConfig, err := m.readAuthConfig() @@ -396,6 +399,38 @@ func (m *Manager) CredentialForCurrentUser( type ClaimsForCurrentUserOptions = CredentialForCurrentUserOptions +// azCliCredentialForTenant returns a cached az CLI credential for the given tenant, creating one if needed. +// +// The returned credential is a cachingCredential wrapping an AzureCLICredential. A single instance is +// shared per tenant so that concurrent callers reuse one in-memory token. AzureCLICredential performs no +// caching of its own and spawns an `az account get-access-token` subprocess on every GetToken call, so +// without this wrapper a fan-out of concurrent Azure SDK clients (e.g. listing the AI model catalog across +// every region) would spawn one `az` subprocess per caller, serialized behind the credential's internal +// mutex. That is both slow and, in constrained environments like Cloud Shell, memory-exhausting. +func (m *Manager) azCliCredentialForTenant(tenantID string) (azcore.TokenCredential, error) { + m.azCliCredentialsMu.Lock() + defer m.azCliCredentialsMu.Unlock() + + if cred, ok := m.azCliCredentials[tenantID]; ok { + return cred, nil + } + + azCred, err := azidentity.NewAzureCLICredential(&azidentity.AzureCLICredentialOptions{ + TenantID: tenantID, + }) + if err != nil { + return nil, fmt.Errorf("failed to create credential: %w: %w", err, ErrNoCurrentUser) + } + + cred := newCachingCredential(azCred) + + if m.azCliCredentials == nil { + m.azCliCredentials = map[string]azcore.TokenCredential{} + } + m.azCliCredentials[tenantID] = cred + return cred, nil +} + // ClaimsForCurrentUser returns claims for the currently logged in user. func (m *Manager) ClaimsForCurrentUser(ctx context.Context, options *ClaimsForCurrentUserOptions) (TokenClaims, error) { if options == nil { diff --git a/cli/azd/pkg/auth/manager_coverage_test.go b/cli/azd/pkg/auth/manager_coverage_test.go index ea7700ff8cd..97811003dab 100644 --- a/cli/azd/pkg/auth/manager_coverage_test.go +++ b/cli/azd/pkg/auth/manager_coverage_test.go @@ -238,7 +238,8 @@ func TestCredentialForCurrentUser_LegacyAuth_Error(t *testing.T) { // With default options (nil) cred, err := m.CredentialForCurrentUser(t.Context(), nil) require.NoError(t, err) - require.IsType(t, new(azidentity.AzureCLICredential), cred) + require.IsType(t, new(cachingCredential), cred) + require.IsType(t, new(azidentity.AzureCLICredential), cred.(*cachingCredential).inner) } func TestCredentialForCurrentUser_LegacyAuthWithTenant(t *testing.T) { @@ -256,7 +257,8 @@ func TestCredentialForCurrentUser_LegacyAuthWithTenant(t *testing.T) { TenantID: "my-tenant", }) require.NoError(t, err) - require.IsType(t, new(azidentity.AzureCLICredential), cred) + require.IsType(t, new(cachingCredential), cred) + require.IsType(t, new(azidentity.AzureCLICredential), cred.(*cachingCredential).inner) } func TestCredentialForCurrentUser_ManagedIdentityNoClientID(t *testing.T) { diff --git a/cli/azd/pkg/auth/manager_test.go b/cli/azd/pkg/auth/manager_test.go index d29ca303d65..668f063c917 100644 --- a/cli/azd/pkg/auth/manager_test.go +++ b/cli/azd/pkg/auth/manager_test.go @@ -187,7 +187,43 @@ func TestLegacyAzCliCredentialSupport(t *testing.T) { cred, err := m.CredentialForCurrentUser(t.Context(), nil) require.NoError(t, err) - require.IsType(t, new(azidentity.AzureCLICredential), cred) + // The credential is wrapped in a cachingCredential that reuses tokens across concurrent callers, + // backed by an AzureCLICredential. + require.IsType(t, new(cachingCredential), cred) + require.IsType(t, new(azidentity.AzureCLICredential), cred.(*cachingCredential).inner) +} + +func TestLegacyAzCliCredentialIsCached(t *testing.T) { + mgr := newMemoryUserConfigManager() + + cfg, err := mgr.Load() + require.NoError(t, err) + + err = cfg.Set(useAzCliAuthKey, "true") + require.NoError(t, err) + + err = mgr.Save(cfg) + require.NoError(t, err) + + m := Manager{ + userConfigManager: mgr, + } + + // The same credential instance should be returned on subsequent calls for the same tenant so that the + // azidentity SDK can collapse concurrent token requests into a single `az` subprocess. + first, err := m.CredentialForCurrentUser(t.Context(), nil) + require.NoError(t, err) + + second, err := m.CredentialForCurrentUser(t.Context(), nil) + require.NoError(t, err) + + require.Same(t, first, second) + + // A different tenant should yield a distinct credential instance. + other, err := m.CredentialForCurrentUser(t.Context(), &CredentialForCurrentUserOptions{TenantID: "other-tenant"}) + require.NoError(t, err) + + require.NotSame(t, first, other) } func TestCloudShellCredentialSupport(t *testing.T) {