Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions cli/azd/pkg/auth/caching_credential.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package auth

import (
"context"
"strconv"
"strings"
"sync"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"golang.org/x/sync/singleflight"
)

// tokenRefreshOffset is how long before a cached token's expiry it is treated as stale and
// re-acquired. It mirrors the default refresh window used by the azcore bearer-token policy.
const tokenRefreshOffset = 5 * time.Minute

// cachingCredential wraps a TokenCredential with an in-memory token cache.
//
// It exists primarily for the Azure CLI delegated-auth path. AzureCLICredential performs no
// caching of its own and spawns an `az account get-access-token` subprocess on every GetToken
// call. Flows that fan out many concurrent Azure SDK clients (for example, listing the AI model
// catalog across every region) otherwise trigger one `az` subprocess per request, serialized
// behind the credential's internal mutex, which is slow. Caching collapses repeated requests for
// the same scope/tenant into a single subprocess invocation and reuses the token until it is near
// expiry. The cache is in-memory only and lives for the lifetime of the credential instance.
type cachingCredential struct {
inner azcore.TokenCredential

mu sync.RWMutex
cache map[string]azcore.AccessToken

// group deduplicates concurrent acquisitions for the same cache key so that only a single
// inner GetToken call (and thus a single `az` subprocess) runs at a time per key.
group singleflight.Group
}

// newCachingCredential wraps inner with an in-memory token cache.
func newCachingCredential(inner azcore.TokenCredential) *cachingCredential {
return &cachingCredential{
inner: inner,
cache: map[string]azcore.AccessToken{},
}
}

// GetToken returns a cached token for the requested options when one is available and not near
// expiry; otherwise it acquires a new token from the wrapped credential and caches it.
func (c *cachingCredential) GetToken(
ctx context.Context,
opts policy.TokenRequestOptions,
) (azcore.AccessToken, error) {
key := tokenCacheKey(opts)

if tk, ok := c.cachedToken(key); ok {
return tk, nil
}

// singleflight shares the result of the first in-flight acquisition for this key with all
// concurrent callers, so N goroutines requesting the same scope share a single `az` call.
result, err, _ := c.group.Do(key, func() (any, error) {
// Another goroutine may have populated the cache while this call was queued.
if tk, ok := c.cachedToken(key); ok {
return tk, nil
}

tk, err := c.inner.GetToken(ctx, opts)
if err != nil {
return azcore.AccessToken{}, err
}

c.mu.Lock()
c.cache[key] = tk
c.mu.Unlock()

return tk, nil
})
if err != nil {
return azcore.AccessToken{}, err
}

return result.(azcore.AccessToken), nil
}

// cachedToken returns the cached token for key when present and not within the refresh offset of
// its expiry.
func (c *cachingCredential) cachedToken(key string) (azcore.AccessToken, bool) {
c.mu.RLock()
defer c.mu.RUnlock()

tk, ok := c.cache[key]
if !ok {
return azcore.AccessToken{}, false
}

if time.Now().Add(tokenRefreshOffset).After(tk.ExpiresOn) {
return azcore.AccessToken{}, false
}

return tk, true
}

// tokenCacheKey derives a stable cache key from the token request options. Tokens differ by their
// requested scopes, tenant, CAE flag, and any claims challenge, so all are included in the key.
func tokenCacheKey(opts policy.TokenRequestOptions) string {
return strings.Join(opts.Scopes, " ") + "\n" +
opts.TenantID + "\n" +
opts.Claims + "\n" +
strconv.FormatBool(opts.EnableCAE)
}

var _ azcore.TokenCredential = (*cachingCredential)(nil)
144 changes: 144 additions & 0 deletions cli/azd/pkg/auth/caching_credential_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package auth

import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/stretchr/testify/require"
)

// countingCredential is a TokenCredential that records how many times GetToken is invoked and,
// optionally, blocks until released so concurrent behavior can be exercised deterministically.
type countingCredential struct {
calls atomic.Int32
token azcore.AccessToken
err error
gate chan struct{}
gateReady chan struct{}
}

func (c *countingCredential) GetToken(
_ context.Context,
_ policy.TokenRequestOptions,
) (azcore.AccessToken, error) {
c.calls.Add(1)
if c.gate != nil {
c.gateReady <- struct{}{}
<-c.gate
}
if c.err != nil {
return azcore.AccessToken{}, c.err
}
return c.token, nil
}

func TestCachingCredentialReusesToken(t *testing.T) {
inner := &countingCredential{
token: azcore.AccessToken{Token: "abc", ExpiresOn: time.Now().Add(time.Hour)},
}
cred := newCachingCredential(inner)

opts := policy.TokenRequestOptions{Scopes: []string{"https://management.azure.com/.default"}}

for range 5 {
tk, err := cred.GetToken(t.Context(), opts)
require.NoError(t, err)
require.Equal(t, "abc", tk.Token)
}

// All five requests for the same scope should have hit the underlying credential exactly once.
require.Equal(t, int32(1), inner.calls.Load())
}

func TestCachingCredentialSeparatesByScope(t *testing.T) {
inner := &countingCredential{
token: azcore.AccessToken{Token: "abc", ExpiresOn: time.Now().Add(time.Hour)},
}
cred := newCachingCredential(inner)

_, err := cred.GetToken(t.Context(), policy.TokenRequestOptions{Scopes: []string{"scope-a"}})
require.NoError(t, err)
_, err = cred.GetToken(t.Context(), policy.TokenRequestOptions{Scopes: []string{"scope-b"}})
require.NoError(t, err)

// Distinct scopes are cached independently, so each triggers its own acquisition.
require.Equal(t, int32(2), inner.calls.Load())
}

func TestCachingCredentialRefreshesNearExpiry(t *testing.T) {
inner := &countingCredential{
// Token already within the refresh offset of expiry, so it must not be served from cache.
token: azcore.AccessToken{Token: "abc", ExpiresOn: time.Now().Add(time.Minute)},
}
cred := newCachingCredential(inner)

opts := policy.TokenRequestOptions{Scopes: []string{"scope"}}

_, err := cred.GetToken(t.Context(), opts)
require.NoError(t, err)
_, err = cred.GetToken(t.Context(), opts)
require.NoError(t, err)

require.Equal(t, int32(2), inner.calls.Load())
}

func TestCachingCredentialDoesNotCacheErrors(t *testing.T) {
inner := &countingCredential{err: errors.New("boom")}
cred := newCachingCredential(inner)

opts := policy.TokenRequestOptions{Scopes: []string{"scope"}}

// A failed acquisition must not be written to the cache, otherwise a single transient `az`
// failure would wedge auth for the rest of the command.
_, err := cred.GetToken(t.Context(), opts)
require.Error(t, err)

// A subsequent successful acquisition for the same key should reach the inner credential and
// return the fresh token rather than a cached error.
inner.err = nil
inner.token = azcore.AccessToken{Token: "ok", ExpiresOn: time.Now().Add(time.Hour)}

tk, err := cred.GetToken(t.Context(), opts)
require.NoError(t, err)
require.Equal(t, "ok", tk.Token)
require.Equal(t, int32(2), inner.calls.Load())
}

func TestCachingCredentialSingleFlight(t *testing.T) {
inner := &countingCredential{
token: azcore.AccessToken{Token: "abc", ExpiresOn: time.Now().Add(time.Hour)},
gate: make(chan struct{}),
gateReady: make(chan struct{}),
}
cred := newCachingCredential(inner)

opts := policy.TokenRequestOptions{Scopes: []string{"scope"}}

const goroutines = 8
var wg sync.WaitGroup
for range goroutines {
wg.Go(func() {
tk, err := cred.GetToken(t.Context(), opts)
require.NoError(t, err)
require.Equal(t, "abc", tk.Token)
})
}

// Wait until the single in-flight acquisition is running, then release it. The remaining
// goroutines should resolve from the shared singleflight result rather than calling the inner
// credential again.
<-inner.gateReady
close(inner.gate)
wg.Wait()

require.Equal(t, int32(1), inner.calls.Load())
}
Comment thread
JeffreyCA marked this conversation as resolved.
49 changes: 42 additions & 7 deletions cli/azd/pkg/auth/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"path/filepath"
"strconv"
"strings"
"sync"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
azcloud "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
Expand Down Expand Up @@ -100,6 +101,13 @@ type Manager struct {
externalAuthCfg ExternalAuthConfiguration
azCli az.AzCli
userAgent string

// azCliCredentials caches az CLI credentials keyed by tenant ID when auth.useAzCliAuth is set.
// Each entry is a cachingCredential wrapping an AzureCLICredential. Sharing a single instance per
// tenant lets concurrent callers reuse one in-memory token instead of each spawning its own `az`
// subprocess, which avoids exhausting memory and serializing many `az` processes in parallel.
azCliCredentials map[string]azcore.TokenCredential
azCliCredentialsMu sync.Mutex
}

// UserAgent is a typed string for the application user-agent,
Expand Down Expand Up @@ -169,6 +177,7 @@ func NewManager(
externalAuthCfg: externalAuthCfg,
azCli: azCli,
userAgent: string(userAgent),
azCliCredentials: map[string]azcore.TokenCredential{},
}, nil
}

Expand Down Expand Up @@ -258,13 +267,7 @@ func (m *Manager) CredentialForCurrentUser(

if shouldUseLegacyAuth(userConfig) {
log.Printf("delegating auth to az since %s is set to true", useAzCliAuthKey)
cred, err := azidentity.NewAzureCLICredential(&azidentity.AzureCLICredentialOptions{
TenantID: options.TenantID,
})
if err != nil {
return nil, fmt.Errorf("failed to create credential: %w: %w", err, ErrNoCurrentUser)
}
return cred, nil
return m.azCliCredentialForTenant(options.TenantID)
}

authConfig, err := m.readAuthConfig()
Expand Down Expand Up @@ -396,6 +399,38 @@ func (m *Manager) CredentialForCurrentUser(

type ClaimsForCurrentUserOptions = CredentialForCurrentUserOptions

// azCliCredentialForTenant returns a cached az CLI credential for the given tenant, creating one if needed.
//
// The returned credential is a cachingCredential wrapping an AzureCLICredential. A single instance is
// shared per tenant so that concurrent callers reuse one in-memory token. AzureCLICredential performs no
// caching of its own and spawns an `az account get-access-token` subprocess on every GetToken call, so
// without this wrapper a fan-out of concurrent Azure SDK clients (e.g. listing the AI model catalog across
// every region) would spawn one `az` subprocess per caller, serialized behind the credential's internal
// mutex. That is both slow and, in constrained environments like Cloud Shell, memory-exhausting.
func (m *Manager) azCliCredentialForTenant(tenantID string) (azcore.TokenCredential, error) {
Comment thread
JeffreyCA marked this conversation as resolved.
m.azCliCredentialsMu.Lock()
defer m.azCliCredentialsMu.Unlock()

if cred, ok := m.azCliCredentials[tenantID]; ok {
return cred, nil
}

azCred, err := azidentity.NewAzureCLICredential(&azidentity.AzureCLICredentialOptions{
TenantID: tenantID,
})
if err != nil {
return nil, fmt.Errorf("failed to create credential: %w: %w", err, ErrNoCurrentUser)
}

cred := newCachingCredential(azCred)

if m.azCliCredentials == nil {
m.azCliCredentials = map[string]azcore.TokenCredential{}
}
m.azCliCredentials[tenantID] = cred
return cred, nil
}

// ClaimsForCurrentUser returns claims for the currently logged in user.
func (m *Manager) ClaimsForCurrentUser(ctx context.Context, options *ClaimsForCurrentUserOptions) (TokenClaims, error) {
if options == nil {
Expand Down
6 changes: 4 additions & 2 deletions cli/azd/pkg/auth/manager_coverage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ func TestCredentialForCurrentUser_LegacyAuth_Error(t *testing.T) {
// With default options (nil)
cred, err := m.CredentialForCurrentUser(t.Context(), nil)
require.NoError(t, err)
require.IsType(t, new(azidentity.AzureCLICredential), cred)
require.IsType(t, new(cachingCredential), cred)
require.IsType(t, new(azidentity.AzureCLICredential), cred.(*cachingCredential).inner)
}

func TestCredentialForCurrentUser_LegacyAuthWithTenant(t *testing.T) {
Expand All @@ -256,7 +257,8 @@ func TestCredentialForCurrentUser_LegacyAuthWithTenant(t *testing.T) {
TenantID: "my-tenant",
})
require.NoError(t, err)
require.IsType(t, new(azidentity.AzureCLICredential), cred)
require.IsType(t, new(cachingCredential), cred)
require.IsType(t, new(azidentity.AzureCLICredential), cred.(*cachingCredential).inner)
}

func TestCredentialForCurrentUser_ManagedIdentityNoClientID(t *testing.T) {
Expand Down
Loading
Loading