Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions cli/azd/extensions/azure.ai.agents/internal/cmd/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,20 +114,53 @@ type modelSelector struct {
environment *azdext.Environment
flags *initFlags

// supportedRegions is the hosted-agent region allowlist, populated once at
// construction by newModelSelector. Pass this to agentModelFilter for any
// "all regions" catalog query so init flows stay within supported regions
// instead of surfacing models from regions hosted agents can't run in.
supportedRegions []string

modelCatalog map[string]*azdext.AiModel
locationWarningShown bool
}

func (a *InitAction) getModelSelector() *modelSelector {
if a.models == nil {
a.models = &modelSelector{
azdClient: a.azdClient,
azureContext: a.azureContext,
environment: a.environment,
flags: a.flags,
}
// newModelSelector constructs a modelSelector, fetching the hosted-agent
// supported-regions allowlist eagerly so callers don't need to re-fetch it at
// every catalog-query site. The fetch is itself globally cached by
// supportedRegionsForInit, so repeated construction is cheap.
func newModelSelector(
ctx context.Context,
azdClient *azdext.AzdClient,
azureContext *azdext.AzureContext,
environment *azdext.Environment,
flags *initFlags,
) (*modelSelector, error) {
supportedRegions, err := supportedRegionsForInit(ctx)
if err != nil {
return nil, err
}

return &modelSelector{
azdClient: azdClient,
azureContext: azureContext,
environment: environment,
flags: flags,
supportedRegions: supportedRegions,
}, nil
}

func (a *InitAction) getModelSelector(ctx context.Context) (*modelSelector, error) {
if a.models != nil {
return a.models, nil
}
return a.models

ms, err := newModelSelector(ctx, a.azdClient, a.azureContext, a.environment, a.flags)
if err != nil {
return nil, err
}

a.models = ms
return ms, nil
}

// GitHubUrlInfo holds parsed information from a GitHub URL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,13 @@ func promptLocationForInit(
return locationResponse.Location.Name, nil
}

// agentModelFilter builds the ListModels/PromptAiModel filter used by init flows.
//
// Passing a nil or empty locations slice disables region filtering entirely, which
// will surface models from regions that are not supported for hosted agents. Init
// flows should pass the result of supportedRegionsForInit(ctx) (optionally further
// narrowed to the current scope location) so the catalog stays within the
// hosted-agent allowlist.
func agentModelFilter(locations []string, excludeModelNames []string) *azdext.AiModelFilterOptions {
filter := &azdext.AiModelFilterOptions{
Capabilities: []string{agentsV2ModelCapability},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -963,11 +963,9 @@ func (a *InitFromCodeAction) resolveSelectedModelDeployment(
return nil, exterrors.FromAiService(err, exterrors.CodeModelResolutionFailed)
}

selector := &modelSelector{
azdClient: a.azdClient,
azureContext: a.azureContext,
environment: a.environment,
flags: a.flags,
selector, err := newModelSelector(ctx, a.azdClient, a.azureContext, a.environment, a.flags)
if err != nil {
return nil, err
}

// allowSkip=false: in this recovery path the user already explicitly chose
Expand Down
24 changes: 16 additions & 8 deletions cli/azd/extensions/azure.ai.agents/internal/cmd/init_models.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (a *modelSelector) loadAiCatalog(ctx context.Context) error {

modelResp, err := a.azdClient.Ai().ListModels(ctx, &azdext.ListModelsRequest{
AzureContext: a.azureContext,
Filter: agentModelFilter(nil, nil),
Filter: agentModelFilter(a.supportedRegions, nil),
})
stopErr := spinner.Stop(ctx)
if err != nil {
Expand Down Expand Up @@ -383,7 +383,12 @@ func (a *InitAction) getModelDeploymentDetails(
}
}

modelDetails, err := a.getModelSelector().getModelDetails(ctx, model.Id, true)
selector, err := a.getModelSelector(ctx)
if err != nil {
return nil, false, err
}

modelDetails, err := selector.getModelDetails(ctx, model.Id, true)
if err != nil {
if errors.Is(err, errModelSkipped) {
// Propagate the sentinel unwrapped so ProcessModels can detect
Expand Down Expand Up @@ -683,7 +688,7 @@ func (a *modelSelector) promptForAlternativeModel(

regionChoices := []*azdext.SelectChoice{
{Label: fmt.Sprintf("Models available in my current region (%s)", a.azureContext.Scope.Location), Value: "region"},
{Label: "All available models", Value: "all"},
{Label: "All models supported for hosted agents", Value: "all"},
}

regionResp, err := a.azdClient.Prompt().Select(ctx, &azdext.SelectRequest{
Expand All @@ -697,9 +702,12 @@ func (a *modelSelector) promptForAlternativeModel(
return nil, fmt.Errorf("failed to prompt for region choice: %w", err)
}

// Default to the "all" branch: every model in the catalog, but restricted to
// the hosted-agent-supported regions so we don't surface models from regions
// that hosted agents can't run in.
promptReq := &azdext.PromptAiModelRequest{
AzureContext: a.azureContext,
Filter: agentModelFilter(nil, nil),
Filter: agentModelFilter(a.supportedRegions, nil),
SelectOptions: &azdext.SelectOptions{
Message: "Select a model",
},
Expand Down Expand Up @@ -761,7 +769,7 @@ func (a *modelSelector) promptForModelLocationMismatch(

choices := []*azdext.SelectChoice{
{Label: modelChoiceLabel, Value: "model"},
{Label: "Choose a different model (all regions)", Value: "model_all_regions"},
{Label: "Choose a different model (all supported regions)", Value: "model_all_regions"},
{Label: fmt.Sprintf("Choose a different location for %s", currentModel.Name), Value: "location"},
{Label: "Exit setup", Value: "exit"},
}
Expand Down Expand Up @@ -837,17 +845,17 @@ func (a *modelSelector) promptForModelLocationMismatch(
if selectedChoice == "model_all_regions" {
modelResp, err := a.azdClient.Prompt().PromptAiModel(ctx, &azdext.PromptAiModelRequest{
AzureContext: a.azureContext,
Filter: agentModelFilter(nil, []string{currentModel.Name}),
Filter: agentModelFilter(a.supportedRegions, []string{currentModel.Name}),
Quota: &azdext.QuotaCheckOptions{
MinRemainingCapacity: 1,
},
SelectOptions: &azdext.SelectOptions{
Message: "Select a model from all regions",
Message: "Select a model from all supported regions",
},
})
if err != nil {
if hasAiErrorReason(err, azdext.AiErrorReasonNoModelsMatch) {
message = "No alternative models were found across all regions."
message = "No alternative models were found across all supported regions."
continue
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
package cmd

import (
"azureaiagent/internal/project"
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"

"azureaiagent/internal/project"

"github.com/azure/azure-dev/cli/azd/pkg/azdext"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -317,3 +320,83 @@ func TestUpdateEnvLocation(t *testing.T) {
})
}
}

func TestNewModelSelector_PopulatesSupportedRegions(t *testing.T) {
resetRegionsCache(t, []string{"eastus2", "westus3"})

ms, err := newModelSelector(
t.Context(),
nil, // azdClient is unused by construction
&azdext.AzureContext{Scope: &azdext.AzureScope{Location: "eastus2"}},
&azdext.Environment{Name: "test-env"},
&initFlags{},
)
require.NoError(t, err)
require.NotNil(t, ms)
assert.Equal(t, []string{"eastus2", "westus3"}, ms.supportedRegions)
// Mutating the returned slice must not affect the global cache.
ms.supportedRegions[0] = "mutated"
again, err := supportedRegionsForInit(t.Context())
require.NoError(t, err)
assert.Equal(t, []string{"eastus2", "westus3"}, again)
}

func TestGetModelSelector_MemoizesAcrossCalls(t *testing.T) {
resetRegionsCache(t, []string{"eastus2"})

action := &InitAction{
azureContext: &azdext.AzureContext{Scope: &azdext.AzureScope{Location: "eastus2"}},
environment: &azdext.Environment{Name: "test-env"},
flags: &initFlags{},
}

first, err := action.getModelSelector(t.Context())
require.NoError(t, err)
require.NotNil(t, first)
assert.Equal(t, []string{"eastus2"}, first.supportedRegions)

// Mutate selector state so we can confirm the same instance is returned —
// memoization is what preserves modelCatalog/locationWarningShown across
// the per-model loop in ProcessModels.
first.locationWarningShown = true
first.modelCatalog = map[string]*azdext.AiModel{"gpt-4.1-mini": {Name: "gpt-4.1-mini"}}

second, err := action.getModelSelector(t.Context())
require.NoError(t, err)
assert.Same(t, first, second, "getModelSelector must return the cached instance")
assert.True(t, second.locationWarningShown)
assert.Contains(t, second.modelCatalog, "gpt-4.1-mini")
}

func TestNewModelSelector_PropagatesContextCancellation(t *testing.T) {
// Empty cache forces a fetch; canceled ctx makes the select in
// supportedRegionsForInit return ctx.Err without waiting on the background fetch.
resetRegionsCache(t, nil)

// Point the fetch at a server that hangs long enough that the canceled-ctx
// branch wins the select even on slow CI. The fetch goroutine itself uses
// context.WithoutCancel and will eventually time out via the fetch's own
// timeout; we don't wait for it.
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done()
}))
t.Cleanup(server.Close)

prev := hostedAgentRegionsURL
hostedAgentRegionsURL = server.URL
t.Cleanup(func() { hostedAgentRegionsURL = prev })

ctx, cancel := context.WithCancel(t.Context())
cancel()

ms, err := newModelSelector(
ctx,
nil,
&azdext.AzureContext{Scope: &azdext.AzureScope{Location: "eastus2"}},
&azdext.Environment{Name: "test-env"},
&initFlags{},
)
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)
assert.Nil(t, ms)
}
Loading