diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go index 03a521f0986..c38feaf56ad 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go @@ -114,20 +114,53 @@ type modelSelector struct { environment *azdext.Environment flags *initFlags + // supportedRegions is the hosted-agent region allowlist, populated once at + // construction by newModelSelector. Pass this to agentModelFilter for any + // "all regions" catalog query so init flows stay within supported regions + // instead of surfacing models from regions hosted agents can't run in. + supportedRegions []string + modelCatalog map[string]*azdext.AiModel locationWarningShown bool } -func (a *InitAction) getModelSelector() *modelSelector { - if a.models == nil { - a.models = &modelSelector{ - azdClient: a.azdClient, - azureContext: a.azureContext, - environment: a.environment, - flags: a.flags, - } +// newModelSelector constructs a modelSelector, fetching the hosted-agent +// supported-regions allowlist eagerly so callers don't need to re-fetch it at +// every catalog-query site. The fetch is itself globally cached by +// supportedRegionsForInit, so repeated construction is cheap. +func newModelSelector( + ctx context.Context, + azdClient *azdext.AzdClient, + azureContext *azdext.AzureContext, + environment *azdext.Environment, + flags *initFlags, +) (*modelSelector, error) { + supportedRegions, err := supportedRegionsForInit(ctx) + if err != nil { + return nil, err + } + + return &modelSelector{ + azdClient: azdClient, + azureContext: azureContext, + environment: environment, + flags: flags, + supportedRegions: supportedRegions, + }, nil +} + +func (a *InitAction) getModelSelector(ctx context.Context) (*modelSelector, error) { + if a.models != nil { + return a.models, nil } - return a.models + + ms, err := newModelSelector(ctx, a.azdClient, a.azureContext, a.environment, a.flags) + if err != nil { + return nil, err + } + + a.models = ms + return ms, nil } // GitHubUrlInfo holds parsed information from a GitHub URL diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_foundry_resources_helpers.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_foundry_resources_helpers.go index 7f09764f9dc..7755fb615fb 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_foundry_resources_helpers.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_foundry_resources_helpers.go @@ -1007,6 +1007,13 @@ func promptLocationForInit( return locationResponse.Location.Name, nil } +// agentModelFilter builds the ListModels/PromptAiModel filter used by init flows. +// +// Passing a nil or empty locations slice disables region filtering entirely, which +// will surface models from regions that are not supported for hosted agents. Init +// flows should pass the result of supportedRegionsForInit(ctx) (optionally further +// narrowed to the current scope location) so the catalog stays within the +// hosted-agent allowlist. func agentModelFilter(locations []string, excludeModelNames []string) *azdext.AiModelFilterOptions { filter := &azdext.AiModelFilterOptions{ Capabilities: []string{agentsV2ModelCapability}, diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_from_code.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_from_code.go index 822f1048749..ac3167fe77d 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_from_code.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_from_code.go @@ -963,11 +963,9 @@ func (a *InitFromCodeAction) resolveSelectedModelDeployment( return nil, exterrors.FromAiService(err, exterrors.CodeModelResolutionFailed) } - selector := &modelSelector{ - azdClient: a.azdClient, - azureContext: a.azureContext, - environment: a.environment, - flags: a.flags, + selector, err := newModelSelector(ctx, a.azdClient, a.azureContext, a.environment, a.flags) + if err != nil { + return nil, err } // allowSkip=false: in this recovery path the user already explicitly chose diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_models.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_models.go index 9e6178a2e20..7bc53c6be94 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_models.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_models.go @@ -50,7 +50,7 @@ func (a *modelSelector) loadAiCatalog(ctx context.Context) error { modelResp, err := a.azdClient.Ai().ListModels(ctx, &azdext.ListModelsRequest{ AzureContext: a.azureContext, - Filter: agentModelFilter(nil, nil), + Filter: agentModelFilter(a.supportedRegions, nil), }) stopErr := spinner.Stop(ctx) if err != nil { @@ -383,7 +383,12 @@ func (a *InitAction) getModelDeploymentDetails( } } - modelDetails, err := a.getModelSelector().getModelDetails(ctx, model.Id, true) + selector, err := a.getModelSelector(ctx) + if err != nil { + return nil, false, err + } + + modelDetails, err := selector.getModelDetails(ctx, model.Id, true) if err != nil { if errors.Is(err, errModelSkipped) { // Propagate the sentinel unwrapped so ProcessModels can detect @@ -683,7 +688,7 @@ func (a *modelSelector) promptForAlternativeModel( regionChoices := []*azdext.SelectChoice{ {Label: fmt.Sprintf("Models available in my current region (%s)", a.azureContext.Scope.Location), Value: "region"}, - {Label: "All available models", Value: "all"}, + {Label: "All models supported for hosted agents", Value: "all"}, } regionResp, err := a.azdClient.Prompt().Select(ctx, &azdext.SelectRequest{ @@ -697,9 +702,12 @@ func (a *modelSelector) promptForAlternativeModel( return nil, fmt.Errorf("failed to prompt for region choice: %w", err) } + // Default to the "all" branch: every model in the catalog, but restricted to + // the hosted-agent-supported regions so we don't surface models from regions + // that hosted agents can't run in. promptReq := &azdext.PromptAiModelRequest{ AzureContext: a.azureContext, - Filter: agentModelFilter(nil, nil), + Filter: agentModelFilter(a.supportedRegions, nil), SelectOptions: &azdext.SelectOptions{ Message: "Select a model", }, @@ -761,7 +769,7 @@ func (a *modelSelector) promptForModelLocationMismatch( choices := []*azdext.SelectChoice{ {Label: modelChoiceLabel, Value: "model"}, - {Label: "Choose a different model (all regions)", Value: "model_all_regions"}, + {Label: "Choose a different model (all supported regions)", Value: "model_all_regions"}, {Label: fmt.Sprintf("Choose a different location for %s", currentModel.Name), Value: "location"}, {Label: "Exit setup", Value: "exit"}, } @@ -837,17 +845,17 @@ func (a *modelSelector) promptForModelLocationMismatch( if selectedChoice == "model_all_regions" { modelResp, err := a.azdClient.Prompt().PromptAiModel(ctx, &azdext.PromptAiModelRequest{ AzureContext: a.azureContext, - Filter: agentModelFilter(nil, []string{currentModel.Name}), + Filter: agentModelFilter(a.supportedRegions, []string{currentModel.Name}), Quota: &azdext.QuotaCheckOptions{ MinRemainingCapacity: 1, }, SelectOptions: &azdext.SelectOptions{ - Message: "Select a model from all regions", + Message: "Select a model from all supported regions", }, }) if err != nil { if hasAiErrorReason(err, azdext.AiErrorReasonNoModelsMatch) { - message = "No alternative models were found across all regions." + message = "No alternative models were found across all supported regions." continue } diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_models_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_models_test.go index 02f3894c3eb..38bb939428b 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_models_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_models_test.go @@ -4,11 +4,14 @@ package cmd import ( - "azureaiagent/internal/project" "context" "errors" + "net/http" + "net/http/httptest" "testing" + "azureaiagent/internal/project" + "github.com/azure/azure-dev/cli/azd/pkg/azdext" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -317,3 +320,83 @@ func TestUpdateEnvLocation(t *testing.T) { }) } } + +func TestNewModelSelector_PopulatesSupportedRegions(t *testing.T) { + resetRegionsCache(t, []string{"eastus2", "westus3"}) + + ms, err := newModelSelector( + t.Context(), + nil, // azdClient is unused by construction + &azdext.AzureContext{Scope: &azdext.AzureScope{Location: "eastus2"}}, + &azdext.Environment{Name: "test-env"}, + &initFlags{}, + ) + require.NoError(t, err) + require.NotNil(t, ms) + assert.Equal(t, []string{"eastus2", "westus3"}, ms.supportedRegions) + // Mutating the returned slice must not affect the global cache. + ms.supportedRegions[0] = "mutated" + again, err := supportedRegionsForInit(t.Context()) + require.NoError(t, err) + assert.Equal(t, []string{"eastus2", "westus3"}, again) +} + +func TestGetModelSelector_MemoizesAcrossCalls(t *testing.T) { + resetRegionsCache(t, []string{"eastus2"}) + + action := &InitAction{ + azureContext: &azdext.AzureContext{Scope: &azdext.AzureScope{Location: "eastus2"}}, + environment: &azdext.Environment{Name: "test-env"}, + flags: &initFlags{}, + } + + first, err := action.getModelSelector(t.Context()) + require.NoError(t, err) + require.NotNil(t, first) + assert.Equal(t, []string{"eastus2"}, first.supportedRegions) + + // Mutate selector state so we can confirm the same instance is returned — + // memoization is what preserves modelCatalog/locationWarningShown across + // the per-model loop in ProcessModels. + first.locationWarningShown = true + first.modelCatalog = map[string]*azdext.AiModel{"gpt-4.1-mini": {Name: "gpt-4.1-mini"}} + + second, err := action.getModelSelector(t.Context()) + require.NoError(t, err) + assert.Same(t, first, second, "getModelSelector must return the cached instance") + assert.True(t, second.locationWarningShown) + assert.Contains(t, second.modelCatalog, "gpt-4.1-mini") +} + +func TestNewModelSelector_PropagatesContextCancellation(t *testing.T) { + // Empty cache forces a fetch; canceled ctx makes the select in + // supportedRegionsForInit return ctx.Err without waiting on the background fetch. + resetRegionsCache(t, nil) + + // Point the fetch at a server that hangs long enough that the canceled-ctx + // branch wins the select even on slow CI. The fetch goroutine itself uses + // context.WithoutCancel and will eventually time out via the fetch's own + // timeout; we don't wait for it. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-r.Context().Done() + })) + t.Cleanup(server.Close) + + prev := hostedAgentRegionsURL + hostedAgentRegionsURL = server.URL + t.Cleanup(func() { hostedAgentRegionsURL = prev }) + + ctx, cancel := context.WithCancel(t.Context()) + cancel() + + ms, err := newModelSelector( + ctx, + nil, + &azdext.AzureContext{Scope: &azdext.AzureScope{Location: "eastus2"}}, + &azdext.Environment{Name: "test-env"}, + &initFlags{}, + ) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + assert.Nil(t, ms) +}