diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c51b3c075e..9e72f77675 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -16,10 +16,10 @@ on: type: boolean default: false -# 环境变量:合并 workflow_dispatch 输入和 repository variable -# tag push 触发时读取 vars.SIMPLE_RELEASE,workflow_dispatch 时使用输入参数 +# 环境变量:仅允许 workflow_dispatch 手动启用 simple release。 +# 正式 tag 发布必须始终生成完整的二进制资产,否则安装脚本无法工作。 env: - SIMPLE_RELEASE: ${{ github.event.inputs.simple_release == 'true' || vars.SIMPLE_RELEASE == 'true' }} + SIMPLE_RELEASE: ${{ github.event_name == 'workflow_dispatch' && (github.event.inputs.simple_release == 'true' || vars.SIMPLE_RELEASE == 'true') }} permissions: contents: write diff --git a/.gitignore b/.gitignore index 297c1d6f03..254afd993f 100644 --- a/.gitignore +++ b/.gitignore @@ -132,4 +132,5 @@ docs/* frontend/coverage/ aicodex output/ +release.sh diff --git a/.goreleaser.simple.yaml b/.goreleaser.simple.yaml index 14f67fd1cb..2fa4413b93 100644 --- a/.goreleaser.simple.yaml +++ b/.goreleaser.simple.yaml @@ -1,4 +1,5 @@ -# 简化版 GoReleaser 配置 - 仅发布 x86_64 GHCR 镜像 +# 简化版 GoReleaser 配置 - 仅用于手动触发的镜像发布。 +# 不生成安装脚本依赖的二进制资产,禁止用于正式 tag 发布。 version: 2 project_name: sub2api diff --git a/Dockerfile b/Dockerfile index a16eb958f2..3f8bd1d4a9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ # ============================================================================= ARG NODE_IMAGE=node:24-alpine -ARG GOLANG_IMAGE=golang:1.26.1-alpine +ARG GOLANG_IMAGE=golang:1.26.2-alpine ARG ALPINE_IMAGE=alpine:3.21 ARG POSTGRES_IMAGE=postgres:18-alpine ARG GOPROXY=https://goproxy.cn,direct @@ -84,9 +84,9 @@ FROM ${POSTGRES_IMAGE} AS pg-client FROM ${ALPINE_IMAGE} # Labels -LABEL maintainer="Wei-Shaw " +LABEL maintainer="kw0ngr " LABEL description="Sub2API - AI API Gateway Platform" -LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api" +LABEL org.opencontainers.image.source="https://github.com/kw0ngr/sub2api" # Install runtime dependencies RUN apk add --no-cache \ diff --git a/Dockerfile.goreleaser b/Dockerfile.goreleaser index f251d154c6..6f6d8d110a 100644 --- a/Dockerfile.goreleaser +++ b/Dockerfile.goreleaser @@ -12,9 +12,9 @@ FROM ${POSTGRES_IMAGE} AS pg-client FROM ${ALPINE_IMAGE} -LABEL maintainer="Wei-Shaw " +LABEL maintainer="kw0ngr " LABEL description="Sub2API - AI API Gateway Platform" -LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api" +LABEL org.opencontainers.image.source="https://github.com/kw0ngr/sub2api" # Install runtime dependencies RUN apk add --no-cache \ diff --git a/README.md b/README.md index 99753e4569..e5d31ffc87 100644 --- a/README.md +++ b/README.md @@ -106,7 +106,7 @@ One-click installation script that downloads pre-built binaries from GitHub Rele #### Installation Steps ```bash -curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash +curl -sSL https://raw.githubusercontent.com/kw0ngr/sub2api/main/deploy/install.sh | sudo bash ``` The script will: @@ -156,7 +156,7 @@ sudo journalctl -u sub2api -f sudo systemctl restart sub2api # Uninstall -curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash -s -- uninstall -y +curl -sSL https://raw.githubusercontent.com/kw0ngr/sub2api/main/deploy/install.sh | sudo bash -s -- uninstall -y ``` --- @@ -179,7 +179,7 @@ Use the automated deployment script for easy setup: mkdir -p sub2api-deploy && cd sub2api-deploy # Download and run deployment preparation script -curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash +curl -sSL https://raw.githubusercontent.com/kw0ngr/sub2api/main/deploy/docker-deploy.sh | bash # Start services docker compose up -d @@ -201,7 +201,7 @@ If you prefer manual setup: ```bash # 1. Clone the repository -git clone https://github.com/Wei-Shaw/sub2api.git +git clone https://github.com/kw0ngr/sub2api.git cd sub2api/deploy # 2. Copy environment configuration @@ -340,7 +340,7 @@ Build and run from source code for development or customization. ```bash # 1. Clone the repository -git clone https://github.com/Wei-Shaw/sub2api.git +git clone https://github.com/kw0ngr/sub2api.git cd sub2api # 2. Install pnpm (if not already installed) @@ -566,11 +566,11 @@ sub2api/ ## Star History - + - - - Star History Chart + + + Star History Chart diff --git a/README_CN.md b/README_CN.md index 8b6feaba0d..2fc1ec4b07 100644 --- a/README_CN.md +++ b/README_CN.md @@ -105,7 +105,7 @@ Nginx 默认会丢弃名称中含下划线的请求头(如 `session_id`), #### 安装步骤 ```bash -curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash +curl -sSL https://raw.githubusercontent.com/kw0ngr/sub2api/main/deploy/install.sh | sudo bash ``` 脚本会自动: @@ -155,7 +155,7 @@ sudo journalctl -u sub2api -f sudo systemctl restart sub2api # 卸载 -curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash -s -- uninstall -y +curl -sSL https://raw.githubusercontent.com/kw0ngr/sub2api/main/deploy/install.sh | sudo bash -s -- uninstall -y ``` --- @@ -178,7 +178,7 @@ curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install mkdir -p sub2api-deploy && cd sub2api-deploy # 下载并运行部署准备脚本 -curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash +curl -sSL https://raw.githubusercontent.com/kw0ngr/sub2api/main/deploy/docker-deploy.sh | bash # 启动服务 docker compose up -d @@ -200,7 +200,7 @@ docker compose logs -f sub2api ```bash # 1. 克隆仓库 -git clone https://github.com/Wei-Shaw/sub2api.git +git clone https://github.com/kw0ngr/sub2api.git cd sub2api/deploy # 2. 复制环境配置文件 @@ -351,7 +351,7 @@ rm -rf data/ postgres_data/ redis_data/ ```bash # 1. 克隆仓库 -git clone https://github.com/Wei-Shaw/sub2api.git +git clone https://github.com/kw0ngr/sub2api.git cd sub2api # 2. 安装 pnpm(如果还没有安装) @@ -627,11 +627,11 @@ sub2api/ ## Star History - + - - - Star History Chart + + + Star History Chart diff --git a/README_JA.md b/README_JA.md index 1266bd845c..6bdb4fe52d 100644 --- a/README_JA.md +++ b/README_JA.md @@ -106,7 +106,7 @@ GitHub Releases からビルド済みバイナリをダウンロードするワ #### インストール手順 ```bash -curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash +curl -sSL https://raw.githubusercontent.com/kw0ngr/sub2api/main/deploy/install.sh | sudo bash ``` スクリプトは以下を実行します: @@ -156,7 +156,7 @@ sudo journalctl -u sub2api -f sudo systemctl restart sub2api # アンインストール -curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash -s -- uninstall -y +curl -sSL https://raw.githubusercontent.com/kw0ngr/sub2api/main/deploy/install.sh | sudo bash -s -- uninstall -y ``` --- @@ -179,7 +179,7 @@ PostgreSQL と Redis のコンテナを含む Docker Compose でデプロイし mkdir -p sub2api-deploy && cd sub2api-deploy # デプロイ準備スクリプトをダウンロードして実行 -curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash +curl -sSL https://raw.githubusercontent.com/kw0ngr/sub2api/main/deploy/docker-deploy.sh | bash # サービスを起動 docker compose up -d @@ -201,7 +201,7 @@ docker compose logs -f sub2api ```bash # 1. リポジトリをクローン -git clone https://github.com/Wei-Shaw/sub2api.git +git clone https://github.com/kw0ngr/sub2api.git cd sub2api/deploy # 2. 環境設定ファイルをコピー @@ -340,7 +340,7 @@ rm -rf data/ postgres_data/ redis_data/ ```bash # 1. リポジトリをクローン -git clone https://github.com/Wei-Shaw/sub2api.git +git clone https://github.com/kw0ngr/sub2api.git cd sub2api # 2. pnpm をインストール(未インストールの場合) @@ -566,11 +566,11 @@ sub2api/ ## スター履歴 - + - - - Star History Chart + + + Star History Chart diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 9e3db2aa12..793d0e319f 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.106 +0.1.116 diff --git a/backend/ent/schema/error_passthrough_rule.go b/backend/ent/schema/error_passthrough_rule.go index 63a81230c2..3908e7718b 100644 --- a/backend/ent/schema/error_passthrough_rule.go +++ b/backend/ent/schema/error_passthrough_rule.go @@ -16,7 +16,7 @@ import ( // // 错误透传规则用于控制上游错误如何返回给客户端: // - 匹配条件:错误码 + 关键词组合 -// - 响应行为:透传原始信息 或 自定义错误信息 +// - 响应行为:返回安全默认文案或自定义错误信息(不再透传原始上游文本) // - 响应状态码:可指定返回给客户端的状态码 // - 平台范围:规则适用的平台(Anthropic、OpenAI、Gemini、Antigravity) type ErrorPassthroughRule struct { @@ -93,14 +93,13 @@ func (ErrorPassthroughRule) Fields() []ent.Field { Optional(). Nillable(), - // passthrough_body: 是否透传上游原始错误信息 - // true: 使用上游返回的错误信息 - // false: 使用 custom_message 指定的错误信息 + // passthrough_body: Deprecated. + // 该字段为历史兼容保留,运行时强制视为 false(不会透传上游原始错误信息)。 field.Bool("passthrough_body"). - Default(true), + Default(false), // custom_message: 自定义错误信息 - // 当 passthrough_body=false 时使用此错误信息 + // 为空时使用链路默认安全文案 field.Text("custom_message"). Optional(). Nillable(), diff --git a/backend/go.mod b/backend/go.mod index 135cbd3eaf..919edd8039 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -1,6 +1,6 @@ module github.com/Wei-Shaw/sub2api -go 1.26.1 +go 1.26.2 require ( entgo.io/ent v0.14.5 diff --git a/backend/internal/handler/admin/account_apikey.go b/backend/internal/handler/admin/account_apikey.go new file mode 100644 index 0000000000..68d526afff --- /dev/null +++ b/backend/internal/handler/admin/account_apikey.go @@ -0,0 +1,518 @@ +package admin + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "golang.org/x/sync/errgroup" + "net/http" + "strings" +) + +const rawAPIKeyImportPageSize = 500 + +type RawAPIKeyImportRequest struct { + RawText string `json:"raw_text" binding:"required"` + ValidateAfterImport bool `json:"validate_after_import"` + SkipDefaultGroupBind bool `json:"skip_default_group_bind"` +} + +type RawAPIKeyImportLineResult struct { + Line int `json:"line"` + KeyPreview string `json:"key_preview,omitempty"` + Platform string `json:"platform,omitempty"` + AccountID int64 `json:"account_id,omitempty"` + StatusCode int `json:"status_code,omitempty"` + Created bool `json:"created"` + Checked bool `json:"checked"` + Valid bool `json:"valid"` + InvalidDisabled bool `json:"invalid_disabled"` + Error string `json:"error,omitempty"` + Message string `json:"message,omitempty"` +} + +type RawAPIKeyImportResult struct { + TotalLines int `json:"total_lines"` + Created int `json:"created"` + Checked int `json:"checked"` + Valid int `json:"valid"` + InvalidDisabled int `json:"invalid_disabled"` + Failed int `json:"failed"` + Results []RawAPIKeyImportLineResult `json:"results"` +} + +type APIKeyHealthCheckRequest struct { + AccountIDs []int64 `json:"account_ids"` +} + +type APIKeyHealthCheckItem struct { + AccountID int64 `json:"account_id"` + Name string `json:"name"` + Platform string `json:"platform"` + StatusCode int `json:"status_code,omitempty"` + Valid bool `json:"valid"` + InvalidDisabled bool `json:"invalid_disabled"` + Error string `json:"error,omitempty"` + Message string `json:"message,omitempty"` +} + +type APIKeyHealthCheckResult struct { + Total int `json:"total"` + Checked int `json:"checked"` + Valid int `json:"valid"` + InvalidDisabled int `json:"invalid_disabled"` + Failed int `json:"failed"` + Results []APIKeyHealthCheckItem `json:"results"` +} + +type rawAPIKeyImportLine struct { + Line int + Key string + BaseURL string + Platform string +} + +func (h *AccountHandler) disableInvalidAPIKeyAccount(ctx context.Context, account *service.Account, message string) error { + if err := h.adminService.SetAccountError(ctx, account.ID, buildInvalidAPIKeyErrorMessage(account.Platform, message)); err != nil { + return err + } + if account != nil && account.Schedulable { + if _, err := h.adminService.SetAccountSchedulable(ctx, account.ID, false); err != nil { + return err + } + } + return nil +} + +func (h *AccountHandler) recoverValidAPIKeyAccount(ctx context.Context, account *service.Account) error { + if account == nil { + return nil + } + // If status is not active (error or disabled), restore it to active. + if !account.IsActive() { + if _, err := h.adminService.ClearAccountError(ctx, account.ID); err != nil { + return err + } + } + // Re-enable scheduling if it was turned off. + if !account.Schedulable { + if _, err := h.adminService.SetAccountSchedulable(ctx, account.ID, true); err != nil { + return err + } + } + return nil +} + +func (h *AccountHandler) ImportRawAPIKeys(c *gin.Context) { + var req RawAPIKeyImportRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if req.ValidateAfterImport && h.accountTestService == nil { + response.Error(c, http.StatusServiceUnavailable, "API key health check service is unavailable") + return + } + + totalLines, lines, parseResults, err := parseRawAPIKeyImportLines(req.RawText) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + if totalLines == 0 { + response.BadRequest(c, "No API key lines found") + return + } + + result := RawAPIKeyImportResult{ + TotalLines: totalLines, + Results: make([]RawAPIKeyImportLineResult, 0, len(parseResults)), + } + + result.Results = append(result.Results, parseResults...) + for _, item := range parseResults { + if item.Error != "" { + result.Failed++ + } + } + + existingByIdentity, err := h.loadExistingAPIKeyIndex(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + for _, line := range lines { + identity := buildAPIKeyIdentity(line.Platform, line.Key, line.BaseURL) + if existing, ok := existingByIdentity[identity]; ok && existing != nil { + result.Failed++ + result.Results = append(result.Results, RawAPIKeyImportLineResult{ + Line: line.Line, + KeyPreview: maskRawAPIKey(line.Key), + Platform: line.Platform, + AccountID: existing.ID, + Error: "duplicate key already exists", + }) + continue + } + + credentials := map[string]any{ + "api_key": line.Key, + } + if line.BaseURL != "" { + credentials["base_url"] = line.BaseURL + } else if defaultBaseURL := service.DefaultAPIKeyBaseURL(line.Platform); defaultBaseURL != "" && line.Platform != service.PlatformAnthropic { + credentials["base_url"] = defaultBaseURL + } + + account, createErr := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{ + Name: buildRawAPIKeyAccountName(line.Platform, line.Key), + Platform: line.Platform, + Type: service.AccountTypeAPIKey, + Credentials: credentials, + Concurrency: 3, + Priority: 50, + SkipDefaultGroupBind: req.SkipDefaultGroupBind, + }) + + item := RawAPIKeyImportLineResult{ + Line: line.Line, + KeyPreview: maskRawAPIKey(line.Key), + Platform: line.Platform, + } + + if createErr != nil { + item.Error = createErr.Error() + result.Failed++ + result.Results = append(result.Results, item) + continue + } + + item.Created = true + item.AccountID = account.ID + result.Created++ + existingByIdentity[identity] = account + + if req.ValidateAfterImport { + item.Checked = true + result.Checked++ + health, healthErr := h.accountTestService.CheckAPIKeyValidity(c.Request.Context(), account) + if healthErr != nil { + item.Error = healthErr.Error() + result.Failed++ + } else { + item.StatusCode = health.StatusCode + item.Message = health.Message + item.Valid = health.Valid + if health.Valid { + result.Valid++ + } + if health.Invalid { + item.InvalidDisabled = true + result.InvalidDisabled++ + if err := h.disableInvalidAPIKeyAccount(c.Request.Context(), account, health.Message); err != nil { + item.Error = err.Error() + item.InvalidDisabled = false + result.InvalidDisabled-- + result.Failed++ + } + } + } + } + + result.Results = append(result.Results, item) + } + + response.Success(c, result) +} + +func (h *AccountHandler) CheckAPIKeysHealth(c *gin.Context) { + var req APIKeyHealthCheckRequest + if err := c.ShouldBindJSON(&req); err != nil && err.Error() != "EOF" { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if h.accountTestService == nil { + response.Error(c, http.StatusServiceUnavailable, "API key health check service is unavailable") + return + } + + accounts, err := h.resolveAPIKeyHealthCheckAccounts(c.Request.Context(), req.AccountIDs) + if err != nil { + response.ErrorFrom(c, err) + return + } + + result := APIKeyHealthCheckResult{ + Total: len(accounts), + Results: make([]APIKeyHealthCheckItem, 0, len(accounts)), + } + + const maxConcurrency = 8 + g, ctx := errgroup.WithContext(c.Request.Context()) + g.SetLimit(maxConcurrency) + + type checkResult struct { + item APIKeyHealthCheckItem + // counters + valid bool + invalidDisabled bool + failed bool + } + resultCh := make(chan checkResult, len(accounts)) + + for _, account := range accounts { + acc := account + g.Go(func() error { + item := APIKeyHealthCheckItem{ + AccountID: acc.ID, + Name: acc.Name, + Platform: acc.Platform, + } + + health, healthErr := h.accountTestService.CheckAPIKeyValidity(ctx, acc) + + r := checkResult{item: item} + if healthErr != nil { + r.item.Error = healthErr.Error() + r.failed = true + resultCh <- r + return nil + } + + r.item.StatusCode = health.StatusCode + r.item.Message = health.Message + r.item.Valid = health.Valid + + if health.Valid { + r.valid = true + if !acc.IsActive() || !acc.Schedulable { + if err := h.recoverValidAPIKeyAccount(ctx, acc); err != nil { + r.item.Error = err.Error() + r.failed = true + } else if strings.TrimSpace(r.item.Message) == "" { + r.item.Message = "account re-enabled and scheduling restored after successful health check" + } else { + r.item.Message = r.item.Message + " | account re-enabled and scheduling restored after successful health check" + } + } + } + + if health.Invalid { + if err := h.disableInvalidAPIKeyAccount(ctx, acc, health.Message); err != nil { + r.item.Error = err.Error() + r.failed = true + } else { + r.item.InvalidDisabled = true + r.invalidDisabled = true + if acc.Schedulable { + if strings.TrimSpace(r.item.Message) == "" { + r.item.Message = "account marked invalid and scheduling disabled after health check" + } else { + r.item.Message += " | scheduling disabled after failed health check" + } + } + } + } + + resultCh <- r + return nil + }) + } + + if err := g.Wait(); err != nil { + response.ErrorFrom(c, err) + return + } + close(resultCh) + + for r := range resultCh { + result.Checked++ + if r.valid { + result.Valid++ + } + if r.invalidDisabled { + result.InvalidDisabled++ + } + if r.failed { + result.Failed++ + } + result.Results = append(result.Results, r.item) + } + + response.Success(c, result) +} + +func (h *AccountHandler) resolveAPIKeyHealthCheckAccounts(ctx context.Context, accountIDs []int64) ([]*service.Account, error) { + if len(accountIDs) > 0 { + accounts, err := h.adminService.GetAccountsByIDs(ctx, accountIDs) + if err != nil { + return nil, err + } + return filterSupportedAPIKeyAccounts(accounts), nil + } + + var allAccounts []*service.Account + page := 1 + for { + items, total, err := h.adminService.ListAccounts(ctx, page, rawAPIKeyImportPageSize, "", service.AccountTypeAPIKey, "", "", 0, "") + if err != nil { + return nil, err + } + for i := range items { + account := items[i] + accCopy := account + allAccounts = append(allAccounts, &accCopy) + } + if len(allAccounts) >= int(total) || len(items) == 0 { + break + } + page++ + } + + return filterSupportedAPIKeyAccounts(allAccounts), nil +} + +func filterSupportedAPIKeyAccounts(accounts []*service.Account) []*service.Account { + result := make([]*service.Account, 0, len(accounts)) + for _, account := range accounts { + if account == nil || account.Type != service.AccountTypeAPIKey { + continue + } + switch account.Platform { + case service.PlatformAnthropic, service.PlatformOpenAI, service.PlatformGemini: + result = append(result, account) + } + } + return result +} + +func (h *AccountHandler) loadExistingAPIKeyIndex(ctx context.Context) (map[string]*service.Account, error) { + index := make(map[string]*service.Account) + page := 1 + fetched := 0 + for { + items, total, err := h.adminService.ListAccounts(ctx, page, rawAPIKeyImportPageSize, "", service.AccountTypeAPIKey, "", "", 0, "") + if err != nil { + return nil, err + } + fetched += len(items) + for i := range items { + account := items[i] + if account.Type != service.AccountTypeAPIKey { + continue + } + switch account.Platform { + case service.PlatformAnthropic, service.PlatformOpenAI, service.PlatformGemini: + default: + continue + } + accCopy := account + identity := buildAPIKeyIdentity(account.Platform, account.GetCredential("api_key"), account.GetCredential("base_url")) + // Keep the earliest account (lowest ID) for each identity to preserve original. + if existing, ok := index[identity]; !ok || account.ID < existing.ID { + index[identity] = &accCopy + } + } + if fetched >= int(total) || len(items) == 0 { + break + } + page++ + } + return index, nil +} + +func parseRawAPIKeyImportLines(raw string) (int, []rawAPIKeyImportLine, []RawAPIKeyImportLineResult, error) { + normalized := strings.ReplaceAll(raw, "\r\n", "\n") + normalized = strings.ReplaceAll(normalized, ",", ",") + rawLines := strings.Split(normalized, "\n") + + lines := make([]rawAPIKeyImportLine, 0, len(rawLines)) + results := make([]RawAPIKeyImportLineResult, 0, len(rawLines)) + total := 0 + + for idx, rawLine := range rawLines { + lineNo := idx + 1 + line := strings.TrimSpace(rawLine) + if line == "" || strings.HasPrefix(line, "#") || strings.HasPrefix(line, "//") { + continue + } + total++ + + parts := strings.SplitN(line, ",", 3) + if len(parts) > 2 { + results = append(results, RawAPIKeyImportLineResult{ + Line: lineNo, + Error: "invalid line format, expected key or key,base_url", + }) + continue + } + + key := strings.TrimSpace(parts[0]) + baseURL := "" + if len(parts) == 2 { + baseURL = strings.TrimSpace(parts[1]) + } + if key == "" { + results = append(results, RawAPIKeyImportLineResult{ + Line: lineNo, + Error: "key cannot be empty", + }) + continue + } + + platform, ok := service.DetectAPIKeyPlatform(key) + if !ok { + results = append(results, RawAPIKeyImportLineResult{ + Line: lineNo, + KeyPreview: maskRawAPIKey(key), + Error: "unsupported key format, could not detect platform", + }) + continue + } + + lines = append(lines, rawAPIKeyImportLine{ + Line: lineNo, + Key: key, + BaseURL: baseURL, + Platform: platform, + }) + } + + return total, lines, results, nil +} + +func buildRawAPIKeyAccountName(platform, key string) string { + sum := sha256.Sum256([]byte(strings.TrimSpace(key))) + return fmt.Sprintf("%s-apikey-%s", platform, hex.EncodeToString(sum[:])[:10]) +} + +func buildAPIKeyIdentity(platform, key, baseURL string) string { + normalizedPlatform := strings.TrimSpace(platform) + normalizedKey := strings.TrimSpace(key) + normalizedBaseURL := strings.TrimSuffix(strings.TrimSpace(baseURL), "/") + if normalizedBaseURL == "" { + normalizedBaseURL = service.DefaultAPIKeyBaseURL(normalizedPlatform) + } + return normalizedPlatform + "|" + normalizedKey + "|" + normalizedBaseURL +} + +func maskRawAPIKey(key string) string { + key = strings.TrimSpace(key) + if len(key) <= 10 { + return key + } + return key[:6] + "..." + key[len(key)-4:] +} + +func buildInvalidAPIKeyErrorMessage(platform, message string) string { + prefix := fmt.Sprintf("API key auto-disabled after health check (%s)", platform) + if strings.TrimSpace(message) == "" { + return prefix + } + return prefix + ": " + strings.TrimSpace(message) +} diff --git a/backend/internal/handler/admin/account_apikey_test.go b/backend/internal/handler/admin/account_apikey_test.go new file mode 100644 index 0000000000..fd88f7eac7 --- /dev/null +++ b/backend/internal/handler/admin/account_apikey_test.go @@ -0,0 +1,103 @@ +package admin + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestParseRawAPIKeyImportLines(t *testing.T) { + total, lines, results, err := parseRawAPIKeyImportLines(` +# comment +sk-proj-123 +sk-ant-456,https://api.anthropic.com +AIzaSy789 +bad-key +`) + require.NoError(t, err) + require.Equal(t, 4, total) + require.Len(t, lines, 3) + require.Len(t, results, 1) + require.Equal(t, service.PlatformOpenAI, lines[0].Platform) + require.Equal(t, service.PlatformAnthropic, lines[1].Platform) + require.Equal(t, service.PlatformGemini, lines[2].Platform) + require.Contains(t, results[0].Error, "could not detect platform") +} + +func TestBuildAPIKeyIdentityUsesDefaultBaseURL(t *testing.T) { + a := buildAPIKeyIdentity(service.PlatformOpenAI, "sk-proj-1", "") + b := buildAPIKeyIdentity(service.PlatformOpenAI, "sk-proj-1", "https://api.openai.com/") + require.Equal(t, a, b) +} + +func TestDisableInvalidAPIKeyAccount_DisablesSchedulingWhenEnabled(t *testing.T) { + adminSvc := newStubAdminService() + handler := &AccountHandler{adminService: adminSvc} + account := &service.Account{ + ID: 42, + Platform: service.PlatformOpenAI, + Schedulable: true, + } + + err := handler.disableInvalidAPIKeyAccount(context.Background(), account, "invalid api key") + require.NoError(t, err) + require.Len(t, adminSvc.setAccountErrCalls, 1) + require.Equal(t, int64(42), adminSvc.setAccountErrCalls[0].id) + require.Len(t, adminSvc.setSchedulableCalls, 1) + require.Equal(t, int64(42), adminSvc.setSchedulableCalls[0].id) + require.False(t, adminSvc.setSchedulableCalls[0].schedulable) +} + +func TestDisableInvalidAPIKeyAccount_SkipsSchedulingUpdateWhenAlreadyDisabled(t *testing.T) { + adminSvc := newStubAdminService() + handler := &AccountHandler{adminService: adminSvc} + account := &service.Account{ + ID: 43, + Platform: service.PlatformAnthropic, + Schedulable: false, + } + + err := handler.disableInvalidAPIKeyAccount(context.Background(), account, "invalid x-api-key") + require.NoError(t, err) + require.Len(t, adminSvc.setAccountErrCalls, 1) + require.Equal(t, int64(43), adminSvc.setAccountErrCalls[0].id) + require.Empty(t, adminSvc.setSchedulableCalls) +} + +func TestRecoverValidAPIKeyAccount_ClearsErrorAndEnablesScheduling(t *testing.T) { + // When health check confirms the key is valid via real chat completions, + // status=error accounts should be fully recovered. + adminSvc := newStubAdminService() + handler := &AccountHandler{adminService: adminSvc} + account := &service.Account{ + ID: 44, + Status: service.StatusError, + Schedulable: false, + } + + err := handler.recoverValidAPIKeyAccount(context.Background(), account) + require.NoError(t, err) + require.Equal(t, []int64{44}, adminSvc.clearedAccountErrIDs) + require.Len(t, adminSvc.setSchedulableCalls, 1) + require.Equal(t, int64(44), adminSvc.setSchedulableCalls[0].id) + require.True(t, adminSvc.setSchedulableCalls[0].schedulable) +} + +func TestRecoverValidAPIKeyAccount_EnablesSchedulingWithoutClearingActiveAccount(t *testing.T) { + adminSvc := newStubAdminService() + handler := &AccountHandler{adminService: adminSvc} + account := &service.Account{ + ID: 45, + Status: service.StatusActive, + Schedulable: false, + } + + err := handler.recoverValidAPIKeyAccount(context.Background(), account) + require.NoError(t, err) + require.Empty(t, adminSvc.clearedAccountErrIDs) + require.Len(t, adminSvc.setSchedulableCalls, 1) + require.Equal(t, int64(45), adminSvc.setSchedulableCalls[0].id) + require.True(t, adminSvc.setSchedulableCalls[0].schedulable) +} diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 9759cef5c0..433defe388 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -26,7 +26,16 @@ type stubAdminService struct { updateAccountErr error bulkUpdateAccountErr error checkMixedErr error - lastMixedCheck struct { + clearedAccountErrIDs []int64 + setAccountErrCalls []struct { + id int64 + msg string + } + setSchedulableCalls []struct { + id int64 + schedulable bool + } + lastMixedCheck struct { accountID int64 platform string groupIDs []int64 @@ -234,15 +243,30 @@ func (s *stubAdminService) RefreshAccountCredentials(ctx context.Context, id int } func (s *stubAdminService) ClearAccountError(ctx context.Context, id int64) (*service.Account, error) { + s.mu.Lock() + s.clearedAccountErrIDs = append(s.clearedAccountErrIDs, id) + s.mu.Unlock() account := service.Account{ID: id, Name: "account", Status: service.StatusActive} return &account, nil } func (s *stubAdminService) SetAccountError(ctx context.Context, id int64, errorMsg string) error { + s.mu.Lock() + s.setAccountErrCalls = append(s.setAccountErrCalls, struct { + id int64 + msg string + }{id: id, msg: errorMsg}) + s.mu.Unlock() return nil } func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*service.Account, error) { + s.mu.Lock() + s.setSchedulableCalls = append(s.setSchedulableCalls, struct { + id int64 + schedulable bool + }{id: id, schedulable: schedulable}) + s.mu.Unlock() account := service.Account{ID: id, Name: "account", Status: service.StatusActive, Schedulable: schedulable} return &account, nil } diff --git a/backend/internal/handler/admin/error_passthrough_handler.go b/backend/internal/handler/admin/error_passthrough_handler.go index 25aaa5c72b..38ee31dfd7 100644 --- a/backend/internal/handler/admin/error_passthrough_handler.go +++ b/backend/internal/handler/admin/error_passthrough_handler.go @@ -30,7 +30,7 @@ type CreateErrorPassthroughRuleRequest struct { Platforms []string `json:"platforms"` PassthroughCode *bool `json:"passthrough_code"` ResponseCode *int `json:"response_code"` - PassthroughBody *bool `json:"passthrough_body"` + PassthroughBody *bool `json:"passthrough_body"` // Deprecated: ignored, always treated as false CustomMessage *string `json:"custom_message"` SkipMonitoring *bool `json:"skip_monitoring"` Description *string `json:"description"` @@ -47,7 +47,7 @@ type UpdateErrorPassthroughRuleRequest struct { Platforms []string `json:"platforms"` PassthroughCode *bool `json:"passthrough_code"` ResponseCode *int `json:"response_code"` - PassthroughBody *bool `json:"passthrough_body"` + PassthroughBody *bool `json:"passthrough_body"` // Deprecated: ignored, always treated as false CustomMessage *string `json:"custom_message"` SkipMonitoring *bool `json:"skip_monitoring"` Description *string `json:"description"` @@ -119,11 +119,8 @@ func (h *ErrorPassthroughHandler) Create(c *gin.Context) { } else { rule.PassthroughCode = true } - if req.PassthroughBody != nil { - rule.PassthroughBody = *req.PassthroughBody - } else { - rule.PassthroughBody = true - } + // Deprecated capability removed: never passthrough upstream raw error body. + rule.PassthroughBody = false if req.SkipMonitoring != nil { rule.SkipMonitoring = *req.SkipMonitoring } @@ -227,9 +224,8 @@ func (h *ErrorPassthroughHandler) Update(c *gin.Context) { if req.ResponseCode != nil { rule.ResponseCode = req.ResponseCode } - if req.PassthroughBody != nil { - rule.PassthroughBody = *req.PassthroughBody - } + // Deprecated capability removed: never passthrough upstream raw error body. + rule.PassthroughBody = false if req.CustomMessage != nil { rule.CustomMessage = req.CustomMessage } diff --git a/backend/internal/handler/failover_rule_sanitization_test.go b/backend/internal/handler/failover_rule_sanitization_test.go new file mode 100644 index 0000000000..4dd19f746f --- /dev/null +++ b/backend/internal/handler/failover_rule_sanitization_test.go @@ -0,0 +1,142 @@ +package handler + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type errorPassthroughRepoStub struct { + rules []*model.ErrorPassthroughRule +} + +func (s *errorPassthroughRepoStub) List(context.Context) ([]*model.ErrorPassthroughRule, error) { + return s.rules, nil +} +func (s *errorPassthroughRepoStub) GetByID(_ context.Context, id int64) (*model.ErrorPassthroughRule, error) { + for _, r := range s.rules { + if r.ID == id { + return r, nil + } + } + return nil, nil +} +func (s *errorPassthroughRepoStub) Create(_ context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + return rule, nil +} +func (s *errorPassthroughRepoStub) Update(_ context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + return rule, nil +} +func (s *errorPassthroughRepoStub) Delete(context.Context, int64) error { return nil } + +func newRuleServiceForHandlerTest(rule *model.ErrorPassthroughRule) *service.ErrorPassthroughService { + return service.NewErrorPassthroughService(&errorPassthroughRepoStub{ + rules: []*model.ErrorPassthroughRule{rule}, + }, nil) +} + +func TestOpenAIHandleFailoverExhausted_RuleNeverLeaksUpstreamBody(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + ruleSvc := newRuleServiceForHandlerTest(&model.ErrorPassthroughRule{ + ID: 1, + Name: "legacy-openai-rule", + Enabled: true, + Priority: 1, + ErrorCodes: []int{http.StatusBadRequest}, + MatchMode: model.MatchModeAny, + Platforms: []string{"openai"}, + PassthroughCode: true, + PassthroughBody: true, // Deprecated legacy value, should be ignored. + }) + + h := &OpenAIGatewayHandler{errorPassthroughService: ruleSvc} + h.handleFailoverExhausted(c, &service.UpstreamFailoverError{ + StatusCode: http.StatusBadRequest, + ResponseBody: []byte(`{"error":{"message":"SECRET_OPENAI_UPSTREAM"}}`), + }, false) + + require.Equal(t, http.StatusBadRequest, rec.Code) + var payload map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + errObj := payload["error"].(map[string]any) + require.Equal(t, "upstream_error", errObj["type"]) + require.Equal(t, "Upstream request failed", errObj["message"]) + require.NotContains(t, rec.Body.String(), "SECRET_OPENAI_UPSTREAM") +} + +func TestGatewayHandleFailoverExhausted_RuleNeverLeaksUpstreamBody(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + ruleSvc := newRuleServiceForHandlerTest(&model.ErrorPassthroughRule{ + ID: 2, + Name: "legacy-anthropic-rule", + Enabled: true, + Priority: 1, + ErrorCodes: []int{http.StatusForbidden}, + MatchMode: model.MatchModeAny, + Platforms: []string{service.PlatformAnthropic}, + PassthroughCode: true, + PassthroughBody: true, + }) + + h := &GatewayHandler{errorPassthroughService: ruleSvc} + h.handleFailoverExhausted(c, &service.UpstreamFailoverError{ + StatusCode: http.StatusForbidden, + ResponseBody: []byte(`{"error":{"message":"SECRET_ANTHROPIC_UPSTREAM"}}`), + }, service.PlatformAnthropic, false) + + require.Equal(t, http.StatusForbidden, rec.Code) + var payload map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + errObj := payload["error"].(map[string]any) + require.Equal(t, "upstream_error", errObj["type"]) + require.Equal(t, "Upstream access forbidden, please contact administrator", errObj["message"]) + require.NotContains(t, rec.Body.String(), "SECRET_ANTHROPIC_UPSTREAM") +} + +func TestGeminiHandleFailoverExhausted_RuleNeverLeaksUpstreamBody(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", nil) + + ruleSvc := newRuleServiceForHandlerTest(&model.ErrorPassthroughRule{ + ID: 3, + Name: "legacy-gemini-rule", + Enabled: true, + Priority: 1, + ErrorCodes: []int{http.StatusTooManyRequests}, + MatchMode: model.MatchModeAny, + Platforms: []string{service.PlatformGemini}, + PassthroughCode: true, + PassthroughBody: true, + }) + + h := &GatewayHandler{errorPassthroughService: ruleSvc} + h.handleGeminiFailoverExhausted(c, &service.UpstreamFailoverError{ + StatusCode: http.StatusTooManyRequests, + ResponseBody: []byte(`{"error":{"message":"SECRET_GEMINI_UPSTREAM"}}`), + }) + + require.Equal(t, http.StatusTooManyRequests, rec.Code) + var payload map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + errObj := payload["error"].(map[string]any) + require.Equal(t, float64(http.StatusTooManyRequests), errObj["code"]) + require.Equal(t, "Upstream rate limit exceeded, please retry later", errObj["message"]) + require.NotContains(t, rec.Body.String(), "SECRET_GEMINI_UPSTREAM") +} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index a0d8b2e9f5..4a37ff272b 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -1231,10 +1231,14 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se respCode = *rule.ResponseCode } - // 确定响应消息 - msg := service.ExtractUpstreamErrorMessage(responseBody) - if !rule.PassthroughBody && rule.CustomMessage != nil { - msg = *rule.CustomMessage + // 确定响应消息(不再透传上游原文) + _, _, mappedMsg := h.mapUpstreamError(statusCode) + msg := mappedMsg + if rule.CustomMessage != nil { + custom := strings.TrimSpace(*rule.CustomMessage) + if custom != "" { + msg = custom + } } if rule.SkipMonitoring { diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 524c6b6de4..6bc7d512e4 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -579,10 +579,14 @@ func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverE respCode = *rule.ResponseCode } - // 确定响应消息 - msg := service.ExtractUpstreamErrorMessage(responseBody) - if !rule.PassthroughBody && rule.CustomMessage != nil { - msg = *rule.CustomMessage + // 确定响应消息(不再透传上游原文) + _, mappedMsg := mapGeminiUpstreamError(statusCode) + msg := mappedMsg + if rule.CustomMessage != nil { + custom := strings.TrimSpace(*rule.CustomMessage) + if custom != "" { + msg = custom + } } if rule.SkipMonitoring { diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index ae70cee40e..3ecf3ed886 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -1434,10 +1434,14 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE respCode = *rule.ResponseCode } - // 确定响应消息 - msg := service.ExtractUpstreamErrorMessage(responseBody) - if !rule.PassthroughBody && rule.CustomMessage != nil { - msg = *rule.CustomMessage + // 确定响应消息(不再透传上游原文) + _, _, mappedMsg := h.mapUpstreamError(statusCode) + msg := mappedMsg + if rule.CustomMessage != nil { + custom := strings.TrimSpace(*rule.CustomMessage) + if custom != "" { + msg = custom + } } if rule.SkipMonitoring { diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go index fe035b6f7f..c00d457609 100644 --- a/backend/internal/handler/sora_client_handler_test.go +++ b/backend/internal/handler/sora_client_handler_test.go @@ -2064,6 +2064,9 @@ func (r *stubAccountRepoForHandler) GetByCRSAccountID(context.Context, string) ( func (r *stubAccountRepoForHandler) FindByExtraField(context.Context, string, any) ([]service.Account, error) { return nil, nil } +func (r *stubAccountRepoForHandler) FindByAPIKey(context.Context, string, string, string) (*service.Account, error) { + return nil, nil +} func (r *stubAccountRepoForHandler) ListCRSAccountIDs(context.Context) (map[string]int64, error) { return nil, nil } diff --git a/backend/internal/model/error_passthrough_rule.go b/backend/internal/model/error_passthrough_rule.go index 620736cd87..bb96ce951d 100644 --- a/backend/internal/model/error_passthrough_rule.go +++ b/backend/internal/model/error_passthrough_rule.go @@ -1,25 +1,30 @@ // Package model 定义服务层使用的数据模型。 package model -import "time" +import ( + "strings" + "time" +) // ErrorPassthroughRule 全局错误透传规则 // 用于控制上游错误如何返回给客户端 type ErrorPassthroughRule struct { - ID int64 `json:"id"` - Name string `json:"name"` // 规则名称 - Enabled bool `json:"enabled"` // 是否启用 - Priority int `json:"priority"` // 优先级(数字越小优先级越高) - ErrorCodes []int `json:"error_codes"` // 匹配的错误码列表(OR关系) - Keywords []string `json:"keywords"` // 匹配的关键词列表(OR关系) - MatchMode string `json:"match_mode"` // "any"(任一条件) 或 "all"(所有条件) - Platforms []string `json:"platforms"` // 适用平台列表 - PassthroughCode bool `json:"passthrough_code"` // 是否透传原始状态码 - ResponseCode *int `json:"response_code"` // 自定义状态码(passthrough_code=false 时使用) - PassthroughBody bool `json:"passthrough_body"` // 是否透传原始错误信息 - CustomMessage *string `json:"custom_message"` // 自定义错误信息(passthrough_body=false 时使用) - SkipMonitoring bool `json:"skip_monitoring"` // 是否跳过运维监控记录 - Description *string `json:"description"` // 规则描述 + ID int64 `json:"id"` + Name string `json:"name"` // 规则名称 + Enabled bool `json:"enabled"` // 是否启用 + Priority int `json:"priority"` // 优先级(数字越小优先级越高) + ErrorCodes []int `json:"error_codes"` // 匹配的错误码列表(OR关系) + Keywords []string `json:"keywords"` // 匹配的关键词列表(OR关系) + MatchMode string `json:"match_mode"` // "any"(任一条件) 或 "all"(所有条件) + Platforms []string `json:"platforms"` // 适用平台列表 + PassthroughCode bool `json:"passthrough_code"` // 是否透传原始状态码 + ResponseCode *int `json:"response_code"` // 自定义状态码(passthrough_code=false 时使用) + // Deprecated: passthrough_body 已废弃,规则不再下发上游原始错误文本。 + // 该字段仅为兼容历史存量配置保留,运行时会被强制视为 false。 + PassthroughBody bool `json:"passthrough_body"` + CustomMessage *string `json:"custom_message"` // 自定义错误信息(为空时使用链路默认错误文案) + SkipMonitoring bool `json:"skip_monitoring"` // 是否跳过运维监控记录 + Description *string `json:"description"` // 规则描述 CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } @@ -58,8 +63,8 @@ func (r *ErrorPassthroughRule) Validate() error { if !r.PassthroughCode && (r.ResponseCode == nil || *r.ResponseCode <= 0) { return &ValidationError{Field: "response_code", Message: "response_code is required when passthrough_code is false"} } - if !r.PassthroughBody && (r.CustomMessage == nil || *r.CustomMessage == "") { - return &ValidationError{Field: "custom_message", Message: "custom_message is required when passthrough_body is false"} + if r.CustomMessage != nil && strings.TrimSpace(*r.CustomMessage) == "" { + return &ValidationError{Field: "custom_message", Message: "custom_message cannot be empty"} } return nil } diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index dfca252f48..34d0be7edc 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -48,13 +48,13 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking var DefaultHeaders = map[string]string{ // Keep these in sync with recent Claude CLI traffic to reduce the chance // that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage. - "User-Agent": "claude-cli/2.1.22 (external, cli)", + "User-Agent": "claude-cli/2.1.88 (external, cli)", "X-Stainless-Lang": "js", - "X-Stainless-Package-Version": "0.70.0", - "X-Stainless-OS": "Linux", + "X-Stainless-Package-Version": "0.74.0", + "X-Stainless-Os": "Linux", "X-Stainless-Arch": "arm64", "X-Stainless-Runtime": "node", - "X-Stainless-Runtime-Version": "v24.13.0", + "X-Stainless-Runtime-Version": "v22.13.0", "X-Stainless-Retry-Count": "0", "X-Stainless-Timeout": "600", "X-App": "cli", diff --git a/backend/internal/pkg/googleapi/error.go b/backend/internal/pkg/googleapi/error.go index b6374e021e..f0d79ddbde 100644 --- a/backend/internal/pkg/googleapi/error.go +++ b/backend/internal/pkg/googleapi/error.go @@ -107,3 +107,38 @@ func IsServiceDisabledError(body string) bool { return false } + +// permanentDisableReasons contains ErrorInfo reasons that indicate a permanent, +// non-recoverable account or billing issue on the Google Cloud / Gemini platform. +var permanentDisableReasons = map[string]bool{ + "BILLING_DISABLED": true, + "CONSUMER_SUSPENDED": true, + "PROJECT_DISABLED": true, + "SERVICE_DISABLED": true, + "API_KEY_INVALID": true, +} + +// IsPermanentlyDisabledError checks if the error indicates a permanent account/billing +// issue that requires manual intervention (billing disabled, project suspended, etc.). +// This covers 403 PERMISSION_DENIED with known permanent-disable reasons. +func IsPermanentlyDisabledError(body string) bool { + var errResp ErrorResponse + if err := json.Unmarshal([]byte(body), &errResp); err != nil { + return false + } + + if errResp.Error.Code != 403 || errResp.Error.Status != "PERMISSION_DENIED" { + return false + } + + for _, detailRaw := range errResp.Error.Details { + var info ErrorDetailInfo + if err := json.Unmarshal(detailRaw, &info); err == nil { + if permanentDisableReasons[info.Reason] { + return true + } + } + } + + return false +} diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index d45e8a1297..05dd554e96 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -285,6 +285,39 @@ func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID return &accounts[0], nil } +// FindByAPIKey returns the first non-deleted account matching the given +// platform + credentials.api_key + credentials.base_url combination. +// Returns (nil, nil) when no match is found. +func (r *accountRepository) FindByAPIKey(ctx context.Context, platform, apiKey, baseURL string) (*service.Account, error) { + if platform == "" || apiKey == "" { + return nil, nil + } + m, err := r.client.Account.Query(). + Where( + dbaccount.PlatformEQ(platform), + dbaccount.TypeEQ("apikey"), + dbaccount.DeletedAtIsNil(), + func(s *entsql.Selector) { + s.Where(sqljson.ValueEQ(dbaccount.FieldCredentials, apiKey, sqljson.Path("api_key"))) + }, + ). + First(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, nil + } + return nil, err + } + // Verify base_url match in Go to avoid JSONB cast complexity for empty values. + account := accountEntityToService(m) + existingBaseURL := strings.TrimSuffix(strings.TrimSpace(account.GetCredential("base_url")), "/") + wantBaseURL := strings.TrimSuffix(strings.TrimSpace(baseURL), "/") + if existingBaseURL != wantBaseURL { + return nil, nil + } + return account, nil +} + func (r *accountRepository) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { rows, err := r.sql.QueryContext(ctx, ` SELECT id, extra->>'crs_account_id' diff --git a/backend/internal/repository/ent.go b/backend/internal/repository/ent.go index 64d321924d..a5bd86a739 100644 --- a/backend/internal/repository/ent.go +++ b/backend/internal/repository/ent.go @@ -63,6 +63,10 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) { _ = drv.Close() // 迁移失败时关闭驱动,避免资源泄露 return nil, nil, err } + if err := applyEmbeddedStartupSQLFixes(migrationCtx, drv.DB()); err != nil { + _ = drv.Close() + return nil, nil, err + } // 创建 Ent 客户端,绑定到已配置的数据库驱动。 client := ent.NewClient(ent.Driver(drv)) diff --git a/backend/internal/repository/error_passthrough_repo.go b/backend/internal/repository/error_passthrough_repo.go index ae989359fc..7e8c8b49b5 100644 --- a/backend/internal/repository/error_passthrough_repo.go +++ b/backend/internal/repository/error_passthrough_repo.go @@ -54,7 +54,7 @@ func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.Err SetPriority(rule.Priority). SetMatchMode(rule.MatchMode). SetPassthroughCode(rule.PassthroughCode). - SetPassthroughBody(rule.PassthroughBody). + SetPassthroughBody(false). SetSkipMonitoring(rule.SkipMonitoring) if len(rule.ErrorCodes) > 0 { @@ -91,7 +91,7 @@ func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.Err SetPriority(rule.Priority). SetMatchMode(rule.MatchMode). SetPassthroughCode(rule.PassthroughCode). - SetPassthroughBody(rule.PassthroughBody). + SetPassthroughBody(false). SetSkipMonitoring(rule.SkipMonitoring) // 处理可选字段 @@ -150,7 +150,7 @@ func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model MatchMode: e.MatchMode, Platforms: e.Platforms, PassthroughCode: e.PassthroughCode, - PassthroughBody: e.PassthroughBody, + PassthroughBody: false, SkipMonitoring: e.SkipMonitoring, CreatedAt: e.CreatedAt, UpdatedAt: e.UpdatedAt, diff --git a/backend/internal/repository/startup_sql_fixes.go b/backend/internal/repository/startup_sql_fixes.go new file mode 100644 index 0000000000..876aee1b6b --- /dev/null +++ b/backend/internal/repository/startup_sql_fixes.go @@ -0,0 +1,151 @@ +package repository + +import ( + "context" + "database/sql" + "fmt" + "io/fs" + "log" + "strings" + + "github.com/Wei-Shaw/sub2api/migrations" +) + +type startupSQLFix struct { + filename string + needsRun func(context.Context, *sql.DB) (bool, error) +} + +var startupSQLFixes = []startupSQLFix{ + { + filename: "083_disable_error_passthrough_body.sql", + needsRun: needsDisableErrorPassthroughBodyFix, + }, + { + filename: "084_migrate_openai_passthrough_key.sql", + needsRun: needsMigrateOpenAIPassthroughKeyFix, + }, +} + +func applyEmbeddedStartupSQLFixes(ctx context.Context, db *sql.DB) error { + if db == nil { + return fmt.Errorf("apply embedded startup sql fixes: nil sql db") + } + + for _, fix := range startupSQLFixes { + need, err := fix.needsRun(ctx, db) + if err != nil { + return fmt.Errorf("check startup sql fix %s: %w", fix.filename, err) + } + if !need { + continue + } + + sqlText, err := readEmbeddedMigrationSQL(fix.filename) + if err != nil { + return fmt.Errorf("load startup sql fix %s: %w", fix.filename, err) + } + if _, err := db.ExecContext(ctx, sqlText); err != nil { + return fmt.Errorf("apply startup sql fix %s: %w", fix.filename, err) + } + log.Printf("[DBInit] Applied startup SQL fix: %s", fix.filename) + } + + return nil +} + +func readEmbeddedMigrationSQL(filename string) (string, error) { + contentBytes, err := fs.ReadFile(migrations.FS, filename) + if err != nil { + return "", err + } + content := strings.TrimSpace(string(contentBytes)) + if content == "" { + return "", fmt.Errorf("empty sql content") + } + return content, nil +} + +func needsDisableErrorPassthroughBodyFix(ctx context.Context, db *sql.DB) (bool, error) { + hasTable, err := tableExists(ctx, db, "error_passthrough_rules") + if err != nil { + return false, fmt.Errorf("check error_passthrough_rules table: %w", err) + } + if !hasTable { + return false, nil + } + + hasColumn, err := columnExists(ctx, db, "error_passthrough_rules", "passthrough_body") + if err != nil { + return false, fmt.Errorf("check passthrough_body column: %w", err) + } + if !hasColumn { + return false, nil + } + + var columnDefault sql.NullString + if err := db.QueryRowContext(ctx, ` + SELECT column_default + FROM information_schema.columns + WHERE table_schema = 'public' AND table_name = $1 AND column_name = $2 + `, "error_passthrough_rules", "passthrough_body").Scan(&columnDefault); err != nil { + return false, fmt.Errorf("query passthrough_body default: %w", err) + } + + defaultExpr := strings.ToLower(strings.TrimSpace(columnDefault.String)) + defaultIsFalse := strings.Contains(defaultExpr, "false") + + var hasEnabledRows bool + if err := db.QueryRowContext( + ctx, + "SELECT EXISTS (SELECT 1 FROM error_passthrough_rules WHERE passthrough_body = true)", + ).Scan(&hasEnabledRows); err != nil { + return false, fmt.Errorf("query passthrough_body data: %w", err) + } + + return !defaultIsFalse || hasEnabledRows, nil +} + +func needsMigrateOpenAIPassthroughKeyFix(ctx context.Context, db *sql.DB) (bool, error) { + hasTable, err := tableExists(ctx, db, "accounts") + if err != nil { + return false, fmt.Errorf("check accounts table: %w", err) + } + if !hasTable { + return false, nil + } + + hasExtraColumn, err := columnExists(ctx, db, "accounts", "extra") + if err != nil { + return false, fmt.Errorf("check accounts.extra column: %w", err) + } + if !hasExtraColumn { + return false, nil + } + + var hasLegacyKeys bool + if err := db.QueryRowContext(ctx, ` + SELECT EXISTS ( + SELECT 1 + FROM accounts + WHERE platform = 'openai' + AND extra IS NOT NULL + AND (extra ? 'openai_passthrough' OR extra ? 'openai_oauth_passthrough') + ) + `).Scan(&hasLegacyKeys); err != nil { + return false, fmt.Errorf("query openai passthrough legacy keys: %w", err) + } + return hasLegacyKeys, nil +} + +func columnExists(ctx context.Context, db *sql.DB, tableName, columnName string) (bool, error) { + var exists bool + err := db.QueryRowContext(ctx, ` + SELECT EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = 'public' AND table_name = $1 AND column_name = $2 + ) + `, tableName, columnName).Scan(&exists) + return exists, err +} diff --git a/backend/internal/repository/startup_sql_fixes_test.go b/backend/internal/repository/startup_sql_fixes_test.go new file mode 100644 index 0000000000..880817372d --- /dev/null +++ b/backend/internal/repository/startup_sql_fixes_test.go @@ -0,0 +1,75 @@ +package repository + +import ( + "context" + "testing" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/require" +) + +func TestApplyEmbeddedStartupSQLFixes_AppliesWhenNeeded(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("error_passthrough_rules"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("error_passthrough_rules", "passthrough_body"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT column_default\\s+FROM information_schema.columns"). + WithArgs("error_passthrough_rules", "passthrough_body"). + WillReturnRows(sqlmock.NewRows([]string{"column_default"}).AddRow("true")) + mock.ExpectQuery("SELECT EXISTS \\(SELECT 1 FROM error_passthrough_rules WHERE passthrough_body = true\\)"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + mock.ExpectExec("ALTER TABLE error_passthrough_rules"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("accounts"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("accounts", "extra"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\(\\s*SELECT 1\\s*FROM accounts"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectExec("UPDATE accounts"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + err = applyEmbeddedStartupSQLFixes(context.Background(), db) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyEmbeddedStartupSQLFixes_SkipsWhenAlreadySatisfied(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("error_passthrough_rules"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("error_passthrough_rules", "passthrough_body"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT column_default\\s+FROM information_schema.columns"). + WithArgs("error_passthrough_rules", "passthrough_body"). + WillReturnRows(sqlmock.NewRows([]string{"column_default"}).AddRow("'false'::boolean")) + mock.ExpectQuery("SELECT EXISTS \\(SELECT 1 FROM error_passthrough_rules WHERE passthrough_body = true\\)"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("accounts"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("accounts", "extra"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\(\\s*SELECT 1\\s*FROM accounts"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + + err = applyEmbeddedStartupSQLFixes(context.Background(), db) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index e04dae8521..e52b586c44 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -276,6 +276,8 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.POST("/batch", h.Admin.Account.BatchCreate) accounts.GET("/data", h.Admin.Account.ExportData) accounts.POST("/data", h.Admin.Account.ImportData) + accounts.POST("/raw-import", h.Admin.Account.ImportRawAPIKeys) + accounts.POST("/apikey-health-check", h.Admin.Account.CheckAPIKeysHealth) accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials) accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier) accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate) diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 512195e334..aec7df056c 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -971,13 +971,18 @@ func (a *Account) IsOveragesEnabled() bool { // IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。 // -// 新字段:accounts.extra.openai_passthrough。 -// 兼容字段:accounts.extra.openai_oauth_passthrough(历史 OAuth 开关)。 +// 新字段:accounts.extra.forward_passthrough_only。 +// 兼容字段: +// - accounts.extra.openai_passthrough +// - accounts.extra.openai_oauth_passthrough(历史 OAuth 开关) // 字段缺失或类型不正确时,按 false(关闭)处理。 func (a *Account) IsOpenAIPassthroughEnabled() bool { if a == nil || !a.IsOpenAI() || a.Extra == nil { return false } + if enabled, ok := a.Extra["forward_passthrough_only"].(bool); ok { + return enabled + } if enabled, ok := a.Extra["openai_passthrough"].(bool); ok { return enabled } diff --git a/backend/internal/service/account_openai_passthrough_test.go b/backend/internal/service/account_openai_passthrough_test.go index 50c2b7cb86..0353fcad7e 100644 --- a/backend/internal/service/account_openai_passthrough_test.go +++ b/backend/internal/service/account_openai_passthrough_test.go @@ -11,6 +11,29 @@ func TestAccount_IsOpenAIPassthroughEnabled(t *testing.T) { account := &Account{ Platform: PlatformOpenAI, Type: AccountTypeAPIKey, + Extra: map[string]any{ + "forward_passthrough_only": true, + }, + } + require.True(t, account.IsOpenAIPassthroughEnabled()) + }) + + t.Run("新字段优先于兼容字段", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "forward_passthrough_only": true, + "openai_passthrough": false, + }, + } + require.True(t, account.IsOpenAIPassthroughEnabled()) + }) + + t.Run("兼容旧字段 openai_passthrough", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, Extra: map[string]any{ "openai_passthrough": true, }, @@ -18,7 +41,7 @@ func TestAccount_IsOpenAIPassthroughEnabled(t *testing.T) { require.True(t, account.IsOpenAIPassthroughEnabled()) }) - t.Run("兼容旧字段", func(t *testing.T) { + t.Run("兼容旧字段 openai_oauth_passthrough", func(t *testing.T) { account := &Account{ Platform: PlatformOpenAI, Type: AccountTypeOAuth, diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 328790a87f..6f01b092df 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -28,6 +28,10 @@ type AccountRepository interface { // GetByCRSAccountID finds an account previously synced from CRS. // Returns (nil, nil) if not found. GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) + // FindByAPIKey returns the first non-deleted account that matches the given + // platform + api_key + base_url combination. + // Returns (nil, nil) when no match is found. + FindByAPIKey(ctx context.Context, platform, apiKey, baseURL string) (*Account, error) // FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora') // 用于查找通过 linked_openai_account_id 关联的 Sora 账号 FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index 81169a029b..60a22f413b 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -54,6 +54,10 @@ func (s *accountRepoStub) GetByCRSAccountID(ctx context.Context, crsAccountID st panic("unexpected GetByCRSAccountID call") } +func (s *accountRepoStub) FindByAPIKey(ctx context.Context, platform, apiKey, baseURL string) (*Account, error) { + panic("unexpected FindByAPIKey call") +} + func (s *accountRepoStub) FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) { panic("unexpected FindByExtraField call") } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 8218c2db0e..e3c6fd3d45 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -314,16 +314,15 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account body, _ := io.ReadAll(resp.Body) errMsg := fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)) - // 403 表示账号被上游封禁,标记为 error 状态 - if resp.StatusCode == http.StatusForbidden { - _ = s.accountRepo.SetError(ctx, account.ID, errMsg) + if account.Type == AccountTypeAPIKey && s.accountRepo != nil { + applyTestConnectionAction(ctx, s.accountRepo, account, resp.StatusCode, resp.Header, body) } return s.sendErrorAndEnd(c, errMsg) } // Process SSE stream - return s.processClaudeStream(c, resp.Body) + return s.processClaudeStream(c, ctx, account, resp.Body) } // testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke @@ -481,7 +480,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error())) } - apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses" + apiURL = buildOpenAIResponsesURL(normalizedBaseURL) } else { return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) } @@ -551,16 +550,21 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account account.RateLimitResetAt = resetAt } } - // 401 Unauthorized: 标记账号为永久错误 - if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil { + if account.Type == AccountTypeAPIKey && s.accountRepo != nil { + applyTestConnectionAction(ctx, s.accountRepo, account, resp.StatusCode, resp.Header, body) + } else if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil { + // OAuth 401: mark account as error and disable scheduling errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body)) _ = s.accountRepo.SetError(ctx, account.ID, errMsg) + if account.Schedulable { + _ = s.accountRepo.SetSchedulable(ctx, account.ID, false) + } } return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) } // Process SSE stream - return s.processOpenAIStream(c, resp.Body) + return s.processOpenAIStream(c, ctx, account, resp.Body) } // testGeminiAccountConnection tests a Gemini account's connection @@ -627,6 +631,11 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) + + if account.Type == AccountTypeAPIKey && s.accountRepo != nil { + applyTestConnectionAction(ctx, s.accountRepo, account, resp.StatusCode, resp.Header, body) + } + return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) } @@ -1643,7 +1652,7 @@ func createOpenAITestPayload(modelID string, isOAuth bool) map[string]any { } // processClaudeStream processes the SSE stream from Claude API -func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader) error { +func (s *AccountTestService) processClaudeStream(c *gin.Context, ctx context.Context, account *Account, body io.Reader) error { reader := bufio.NewReader(body) for { @@ -1686,65 +1695,214 @@ func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader) return nil case "error": errorMsg := "Unknown error" + errorCode := "" if errData, ok := data["error"].(map[string]any); ok { if msg, ok := errData["message"].(string); ok { errorMsg = msg } + if t, ok := errData["type"].(string); ok { + errorCode = t + } + } + // For API Key accounts, classify the in-stream error and mark account state. + if account != nil && account.Type == AccountTypeAPIKey && s.accountRepo != nil { + syntheticBody := []byte(`{"error":{"message":` + fmt.Sprintf("%q", errorMsg) + `,"type":` + fmt.Sprintf("%q", errorCode) + `}}`) + action := ClassifyAPIKeyStatusAction(account, http.StatusForbidden, syntheticBody) + switch action { + case APIKeyStatusActionPermanentDisable: + msg := buildAPIKeyRuntimeErrorMessage(http.StatusForbidden, syntheticBody, "API key permanently disabled after test connection (stream error)") + _ = s.accountRepo.SetError(ctx, account.ID, msg) + if account.Schedulable { + _ = s.accountRepo.SetSchedulable(ctx, account.ID, false) + } + case APIKeyStatusActionTemporaryCooldown: + reason := buildAPIKeyRuntimeErrorMessage(http.StatusForbidden, syntheticBody, "API key temporary cooldown after test connection (stream error)") + until := time.Now().Add(apiKeyProbeCooldown) + _ = s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason) + } } return s.sendErrorAndEnd(c, errorMsg) } } } -// processOpenAIStream processes the SSE stream from OpenAI Responses API -func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error { +// classifyOpenAIStreamTextAsAccountError checks if accumulated response text from a +// 200 OK SSE stream indicates an account-level error returned as plain text content. +// Some third-party OpenAI-compatible APIs return account errors as text deltas instead +// of structured error events, causing the system to miss account disabling. +// Wraps the text as a synthetic JSON error body and delegates to ClassifyAPIKeyStatusAction +// so there is a single source of truth for all keyword/code matching. +// applyTestConnectionAction classifies a non-200 response from a test connection and writes +// account state accordingly. For APIKey accounts only. +func applyTestConnectionAction(ctx context.Context, repo AccountRepository, account *Account, statusCode int, headers http.Header, body []byte) { + switch ClassifyAPIKeyStatusAction(account, statusCode, body) { + case APIKeyStatusActionPermanentDisable: + msg := buildAPIKeyRuntimeErrorMessage(statusCode, body, "API key permanently disabled after test connection") + _ = repo.SetError(ctx, account.ID, msg) + if account.Schedulable { + _ = repo.SetSchedulable(ctx, account.ID, false) + } + case APIKeyStatusActionTemporaryCooldown: + if statusCode == http.StatusTooManyRequests { + resetAt := (&RateLimitService{}).calculateOpenAI429ResetTime(headers) + if resetAt == nil { + t := time.Now().Add(apiKey429Cooldown) + resetAt = &t + } + _ = repo.SetRateLimited(ctx, account.ID, *resetAt) + } else { + reason := buildAPIKeyRuntimeErrorMessage(statusCode, body, "API key temporary cooldown after test connection") + until := time.Now().Add(apiKeyProbeCooldown) + _ = repo.SetTempUnschedulable(ctx, account.ID, until, reason) + } + } +} + +// knownStreamErrorPhrases are verbatim phrases that third-party OpenAI-compatible +// APIs return as plain text deltas (HTTP 200) instead of structured error events. +// Rules for adding entries: +// 1. Must be a complete, verbatim phrase from an actual upstream error response. +// 2. Must NOT be a phrase that could appear in a normal AI-generated reply. +// 3. Prefer the shortest unique prefix that still uniquely identifies the error. +var knownStreamErrorPhrases = []string{ + "your account is not active", + "account is not active", + "your account has been suspended", + "account has been suspended", + "your account has been deactivated", + "account has been deactivated", + "organization has been disabled", + "workspace has been deactivated", + "workspace has been disabled", + "api key has been disabled", + "api key is disabled", + "key has been revoked", +} + +// isStreamOnlyErrorText returns true only when the accumulated stream text is a +// verbatim upstream error message, NOT normal AI content. +// We require ALL of: +// 1. No response.completed event was seen (a completed stream is a successful response) +// 2. Only a single delta was received (error messages come as one chunk, not multi-turn) +// 3. Text exactly matches a known error phrase (prefix match, case-insensitive) +func isStreamOnlyErrorText(text string, deltaCount int, completedSeen bool) bool { + if completedSeen || deltaCount != 1 { + return false + } + lower := strings.ToLower(strings.TrimSpace(text)) + for _, phrase := range knownStreamErrorPhrases { + if strings.HasPrefix(lower, phrase) { + return true + } + } + return false +} + +// processOpenAIStream processes the SSE stream from OpenAI Responses API. +// account may be nil for OAuth accounts where state marking is not needed. +func (s *AccountTestService) processOpenAIStream(c *gin.Context, ctx context.Context, account *Account, body io.Reader) error { reader := bufio.NewReader(body) + var ( + accumulatedText strings.Builder + deltaCount int + completedSeen bool + ) + + applyStreamErrorText := func() (bool, error) { + txt := accumulatedText.String() + if !isStreamOnlyErrorText(txt, deltaCount, completedSeen) { + return false, nil + } + if account != nil && account.Type == AccountTypeAPIKey && s.accountRepo != nil { + syntheticBody := []byte(`{"error":{"message":` + fmt.Sprintf("%q", txt) + `}}`) + action := ClassifyAPIKeyStatusAction(account, http.StatusForbidden, syntheticBody) + switch action { + case APIKeyStatusActionPermanentDisable: + msg := buildAPIKeyRuntimeErrorMessage(http.StatusForbidden, syntheticBody, "API key permanently disabled after test connection (stream error)") + _ = s.accountRepo.SetError(ctx, account.ID, msg) + if account.Schedulable { + _ = s.accountRepo.SetSchedulable(ctx, account.ID, false) + } + case APIKeyStatusActionTemporaryCooldown: + reason := buildAPIKeyRuntimeErrorMessage(http.StatusForbidden, syntheticBody, "API key temporary cooldown after test connection (stream error)") + until := time.Now().Add(apiKeyProbeCooldown) + _ = s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason) + } + } + return true, s.sendErrorAndEnd(c, txt) + } + for { line, err := reader.ReadString('\n') - if err != nil { - if err == io.EOF { - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil - } + isEOF := err == io.EOF + if err != nil && !isEOF { return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error())) } line = strings.TrimSpace(line) - if line == "" || !sseDataPrefix.MatchString(line) { - continue - } - - jsonStr := sseDataPrefix.ReplaceAllString(line, "") - if jsonStr == "[DONE]" { - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil - } - var data map[string]any - if err := json.Unmarshal([]byte(jsonStr), &data); err != nil { - continue + if line != "" && sseDataPrefix.MatchString(line) { + jsonStr := sseDataPrefix.ReplaceAllString(line, "") + + if jsonStr != "" && jsonStr != "[DONE]" { + var data map[string]any + if jsonErr := json.Unmarshal([]byte(jsonStr), &data); jsonErr == nil { + eventType, _ := data["type"].(string) + switch eventType { + case "response.output_text.delta": + if delta, ok := data["delta"].(string); ok && delta != "" { + accumulatedText.WriteString(delta) + deltaCount++ + s.sendEvent(c, TestEvent{Type: "content", Text: delta}) + } + if !isEOF { + continue + } + case "response.completed": + completedSeen = true + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + case "error": + errorMsg := "Unknown error" + errorCode := "" + if errData, ok := data["error"].(map[string]any); ok { + if msg, ok := errData["message"].(string); ok { + errorMsg = msg + } + if code, ok := errData["code"].(string); ok { + errorCode = code + } + } + // Structured error event: classify and mark account state. + if account != nil && account.Type == AccountTypeAPIKey && s.accountRepo != nil { + syntheticBody := []byte(`{"error":{"message":` + fmt.Sprintf("%q", errorMsg) + `,"code":` + fmt.Sprintf("%q", errorCode) + `}}`) + action := ClassifyAPIKeyStatusAction(account, http.StatusForbidden, syntheticBody) + switch action { + case APIKeyStatusActionPermanentDisable: + msg := buildAPIKeyRuntimeErrorMessage(http.StatusForbidden, syntheticBody, "API key permanently disabled after test connection (stream error)") + _ = s.accountRepo.SetError(ctx, account.ID, msg) + if account.Schedulable { + _ = s.accountRepo.SetSchedulable(ctx, account.ID, false) + } + case APIKeyStatusActionTemporaryCooldown: + reason := buildAPIKeyRuntimeErrorMessage(http.StatusForbidden, syntheticBody, "API key temporary cooldown after test connection (stream error)") + until := time.Now().Add(apiKeyProbeCooldown) + _ = s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason) + } + } + return s.sendErrorAndEnd(c, errorMsg) + } + } + } } - eventType, _ := data["type"].(string) - - switch eventType { - case "response.output_text.delta": - // OpenAI Responses API uses "delta" field for text content - if delta, ok := data["delta"].(string); ok && delta != "" { - s.sendEvent(c, TestEvent{Type: "content", Text: delta}) + if isEOF { + if triggered, errResult := applyStreamErrorText(); triggered { + return errResult } - case "response.completed": s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) return nil - case "error": - errorMsg := "Unknown error" - if errData, ok := data["error"].(map[string]any); ok { - if msg, ok := errData["message"].(string); ok { - errorMsg = msg - } - } - return s.sendErrorAndEnd(c, errorMsg) } } } diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go index efa6f7da78..a08655f135 100644 --- a/backend/internal/service/account_test_service_openai_test.go +++ b/backend/internal/service/account_test_service_openai_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -100,3 +101,184 @@ func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) require.WithinDuration(t, *repo.rateLimitedAt, *account.RateLimitResetAt, time.Second) } } + +func TestAccountTestService_OpenAIApiKeyUsesV1ResponsesEndpoint(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, recorder := newSoraTestContext() + + resp := newJSONResponse(http.StatusOK, "") + resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"} + +`)) + + upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}} + svc := &AccountTestService{ + httpUpstream: upstream, + cfg: &config.Config{}, + } + account := &Account{ + ID: 90, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{"api_key": "sk-test", "base_url": "https://api.openai.com"}, + } + + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + require.NoError(t, err) + require.Len(t, upstream.requests, 1) + require.Equal(t, "https://api.openai.com/v1/responses", upstream.requests[0].URL.String()) + require.Contains(t, recorder.Body.String(), "test_complete") +} + +// openAIStreamTextErrorRepo tracks SetError and SetSchedulable calls. +type openAIStreamTextErrorRepo struct { + mockAccountRepoForGemini + setErrorCalls int + lastErrorMsg string + setSchedulableCalls int + lastSchedulable bool +} + +func (r *openAIStreamTextErrorRepo) SetError(_ context.Context, _ int64, errorMsg string) error { + r.setErrorCalls++ + r.lastErrorMsg = errorMsg + return nil +} + +func (r *openAIStreamTextErrorRepo) SetSchedulable(_ context.Context, _ int64, schedulable bool) error { + r.setSchedulableCalls++ + r.lastSchedulable = schedulable + return nil +} + +func TestAccountTestService_OpenAIApiKey_StreamBehavior(t *testing.T) { + gin.SetMode(gin.TestMode) + + cases := []struct { + name string + sseBody string + wantDisable bool // expect SetError+SetSchedulable(false) + wantSuccess bool // expect test_complete in output + }{ + { + // Single delta matching known error phrase, no response.completed → disable + name: "single_delta_error_phrase_no_completed", + sseBody: "data: {\"type\":\"response.output_text.delta\",\"delta\":\"Your account is not active, please check your billing details on our website.\"}\n\n" + + "data: [DONE]\n\n", + wantDisable: true, + wantSuccess: false, + }, + { + // Same error phrase but stream ends with response.completed → normal reply, no disable + name: "error_phrase_with_completed_event_not_flagged", + sseBody: "data: {\"type\":\"response.output_text.delta\",\"delta\":\"Your account is not active, please check your billing details on our website.\"}\n\n" + + "data: {\"type\":\"response.completed\"}\n\n", + wantDisable: false, + wantSuccess: true, + }, + { + // Multiple deltas → normal multi-token reply, no disable even if one delta matches + name: "multiple_deltas_not_flagged", + sseBody: "data: {\"type\":\"response.output_text.delta\",\"delta\":\"Your account is not active\"}\n\n" + + "data: {\"type\":\"response.output_text.delta\",\"delta\":\", please check your billing details.\"}\n\n" + + "data: [DONE]\n\n", + wantDisable: false, + wantSuccess: true, + }, + { + // Normal AI reply: no disable + name: "normal_response_not_flagged", + sseBody: "data: {\"type\":\"response.output_text.delta\",\"delta\":\"Hello! How can I help you?\"}\n\n" + + "data: {\"type\":\"response.completed\"}\n\n", + wantDisable: false, + wantSuccess: true, + }, + { + // The original bug: "Hi — what can I help with?" must never trigger disable + name: "hi_response_not_flagged", + sseBody: "data: {\"type\":\"response.output_text.delta\",\"delta\":\"Hi \\u2014 what can I help with?\"}\n\n" + + "data: {\"type\":\"response.completed\"}\n\n", + wantDisable: false, + wantSuccess: true, + }, + { + // Single delta, no completed, but text does not match any known error phrase → success + name: "single_delta_unknown_text_not_flagged", + sseBody: "data: {\"type\":\"response.output_text.delta\",\"delta\":\"Sure, happy to help!\"}\n\n" + + "data: [DONE]\n\n", + wantDisable: false, + wantSuccess: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctx, recorder := newSoraTestContext() + + resp := newJSONResponse(http.StatusOK, "") + resp.Body = io.NopCloser(strings.NewReader(tc.sseBody)) + + repo := &openAIStreamTextErrorRepo{} + upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}} + svc := &AccountTestService{ + httpUpstream: upstream, + accountRepo: repo, + cfg: &config.Config{}, + } + account := &Account{ + ID: 91, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{"api_key": "sk-test", "base_url": "https://api.openai.com"}, + } + + _ = svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + + if tc.wantDisable { + require.Equal(t, 1, repo.setErrorCalls, "SetError should be called") + require.Equal(t, 1, repo.setSchedulableCalls, "SetSchedulable should be called") + require.False(t, repo.lastSchedulable) + require.NotContains(t, recorder.Body.String(), "test_complete") + } else { + require.Equal(t, 0, repo.setErrorCalls, "SetError should NOT be called") + require.Equal(t, 0, repo.setSchedulableCalls, "SetSchedulable should NOT be called") + if tc.wantSuccess { + require.Contains(t, recorder.Body.String(), "test_complete") + } + } + }) + } +} + +func TestIsStreamOnlyErrorText(t *testing.T) { + cases := []struct { + text string + deltaCount int + completedSeen bool + want bool + }{ + // Known error phrase, single delta, no completed → detect + {"Your account is not active, please check your billing details on our website.", 1, false, true}, + {"account is not active", 1, false, true}, + // response.completed seen → never detect + {"Your account is not active, please check your billing details on our website.", 1, true, false}, + // Multiple deltas → normal reply + {"Your account is not active", 2, false, false}, + // Text doesn't match any phrase + {"Hello! How can I help you?", 1, false, false}, + {"Hi — what can I help with?", 1, false, false}, + {"Sure, happy to help!", 1, false, false}, + {"", 1, false, false}, + } + + for _, tc := range cases { + got := isStreamOnlyErrorText(tc.text, tc.deltaCount, tc.completedSeen) + if got != tc.want { + t.Errorf("text=%q deltaCount=%d completedSeen=%v: want %v, got %v", + tc.text, tc.deltaCount, tc.completedSeen, tc.want, got) + } + } +} diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index b6d7d634df..332902dbee 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -1578,6 +1578,21 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou } } + // Deduplicate API Key accounts: reject creation if the same key already exists. + if input.Type == AccountTypeAPIKey { + apiKey, _ := input.Credentials["api_key"].(string) + baseURL, _ := input.Credentials["base_url"].(string) + if strings.TrimSpace(apiKey) != "" { + existing, err := s.accountRepo.FindByAPIKey(ctx, input.Platform, strings.TrimSpace(apiKey), strings.TrimSpace(baseURL)) + if err != nil { + return nil, fmt.Errorf("duplicate check failed: %w", err) + } + if existing != nil { + return nil, fmt.Errorf("duplicate api key: account %d (%s) already uses this key", existing.ID, existing.Name) + } + } + } + account := &Account{ Name: input.Name, Notes: normalizeAccountNotes(input.Notes), @@ -2221,6 +2236,10 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR Message: err.Error(), UpdatedAt: time.Now(), }) + if proxy.Status == StatusActive { + proxy.Status = "inactive" + _ = s.proxyRepo.Update(ctx, proxy) + } return &ProxyTestResult{ Success: false, Message: err.Error(), @@ -2239,6 +2258,10 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR City: exitInfo.City, UpdatedAt: time.Now(), }) + if proxy.Status != StatusActive { + proxy.Status = StatusActive + _ = s.proxyRepo.Update(ctx, proxy) + } return &ProxyTestResult{ Success: true, Message: "Proxy is accessible", @@ -2265,6 +2288,21 @@ func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*Pr Items: make([]ProxyQualityCheckItem, 0, len(proxyQualityTargets)+1), } + finalize := func(exitInfo *ProxyExitInfo) (*ProxyQualityCheckResult, error) { + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, exitInfo) + if result.FailedCount > 0 { + if proxy.Status == StatusActive { + proxy.Status = "inactive" + _ = s.proxyRepo.Update(ctx, proxy) + } + } else if proxy.Status != StatusActive { + proxy.Status = StatusActive + _ = s.proxyRepo.Update(ctx, proxy) + } + return result, nil + } + proxyURL := proxy.URL() if s.proxyProber == nil { result.Items = append(result.Items, ProxyQualityCheckItem{ @@ -2273,9 +2311,7 @@ func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*Pr Message: "代理探测服务未配置", }) result.FailedCount++ - finalizeProxyQualityResult(result) - s.saveProxyQualitySnapshot(ctx, id, result, nil) - return result, nil + return finalize(nil) } exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL) @@ -2287,9 +2323,7 @@ func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*Pr Message: err.Error(), }) result.FailedCount++ - finalizeProxyQualityResult(result) - s.saveProxyQualitySnapshot(ctx, id, result, nil) - return result, nil + return finalize(nil) } result.ExitIP = exitInfo.IP @@ -2316,9 +2350,7 @@ func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*Pr Message: fmt.Sprintf("创建检测客户端失败: %v", err), }) result.FailedCount++ - finalizeProxyQualityResult(result) - s.saveProxyQualitySnapshot(ctx, id, result, exitInfo) - return result, nil + return finalize(exitInfo) } for _, target := range proxyQualityTargets { @@ -2336,9 +2368,7 @@ func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*Pr } } - finalizeProxyQualityResult(result) - s.saveProxyQualitySnapshot(ctx, id, result, exitInfo) - return result, nil + return finalize(exitInfo) } func runProxyQualityTarget(ctx context.Context, client *http.Client, target proxyQualityTarget) ProxyQualityCheckItem { diff --git a/backend/internal/service/apikey_health.go b/backend/internal/service/apikey_health.go new file mode 100644 index 0000000000..4187eadeb6 --- /dev/null +++ b/backend/internal/service/apikey_health.go @@ -0,0 +1,432 @@ +package service + +import ( + "context" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" +) + +const apiKeyProbeCooldown = 60 * time.Minute + +type APIKeyHealthCheckResult struct { + Platform string `json:"platform"` + StatusCode int `json:"status_code"` + Valid bool `json:"valid"` + Invalid bool `json:"invalid"` + Message string `json:"message,omitempty"` +} + +type APIKeyStatusAction int + +const ( + APIKeyStatusActionIgnore APIKeyStatusAction = iota + APIKeyStatusActionValid + APIKeyStatusActionPermanentDisable + APIKeyStatusActionTemporaryCooldown +) + +func DetectAPIKeyPlatform(rawKey string) (string, bool) { + key := strings.TrimSpace(rawKey) + switch { + case strings.HasPrefix(key, "sk-ant-"): + return PlatformAnthropic, true + case strings.HasPrefix(key, "AIza"): + return PlatformGemini, true + case strings.HasPrefix(strings.ToLower(key), "sk-"): + return PlatformOpenAI, true + default: + return "", false + } +} + +func DefaultAPIKeyBaseURL(platform string) string { + switch strings.TrimSpace(platform) { + case PlatformAnthropic: + return "https://api.anthropic.com" + case PlatformOpenAI: + return "https://api.openai.com" + case PlatformGemini: + return "https://generativelanguage.googleapis.com" + default: + return "" + } +} + +func ShouldDisableAPIKeyStatus(account *Account, statusCode int, responseBody []byte) bool { + return ClassifyAPIKeyStatusAction(account, statusCode, responseBody) == APIKeyStatusActionPermanentDisable +} + +func ClassifyAPIKeyStatusAction(account *Account, statusCode int, responseBody []byte) APIKeyStatusAction { + if account == nil || account.Type != AccountTypeAPIKey { + return APIKeyStatusActionIgnore + } + if statusCode == http.StatusOK { + return APIKeyStatusActionValid + } + + msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(responseBody))) + code := strings.ToLower(strings.TrimSpace(extractUpstreamErrorCode(responseBody))) + bodyUpper := strings.ToUpper(string(responseBody)) + + // 5xx and 529 are always temporary cooldowns regardless of platform + switch statusCode { + case 529, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: + return APIKeyStatusActionTemporaryCooldown + } + + switch account.Platform { + case PlatformOpenAI: + switch statusCode { + case http.StatusUnauthorized, http.StatusPaymentRequired: + return APIKeyStatusActionPermanentDisable + case http.StatusTooManyRequests: + // insufficient_quota is permanent billing exhaustion, not a temporary rate limit + if code == "insufficient_quota" || containsAny(msg, "exceeded your current quota", "insufficient_quota") { + return APIKeyStatusActionPermanentDisable + } + return APIKeyStatusActionTemporaryCooldown + case http.StatusBadRequest: + // Prefer structured error code: high-precision, no false positives + if containsAny(code, + "account_deactivated", + "deactivated_workspace", + "billing_not_active", + "account_inactive", + "billing_hard_limit_reached", + "invalid_api_key", + ) { + return APIKeyStatusActionPermanentDisable + } + // Message text fallback: use precise phrases that cannot appear in normal API errors + if containsAny(msg, + "organization has been disabled", + "project has been disabled", + "workspace has been deactivated", + "workspace has been disabled", + "account has been deactivated", + "account has been suspended", + "account has been blocked", + "key is disabled", + "api key disabled", + "account is not active", + "billing_hard_limit_reached", + "billing hard limit reached", + ) { + return APIKeyStatusActionPermanentDisable + } + // Unrecognized 400: could be a parameter issue or an unknown account error. + // Treat as temporary cooldown to avoid hammering a potentially disabled key. + return APIKeyStatusActionTemporaryCooldown + case http.StatusForbidden: + // Prefer structured error code + if containsAny(code, + "invalid_api_key", + "token_invalidated", + "token_revoked", + "account_deactivated", + "deactivated_workspace", + "billing_not_active", + "account_inactive", + ) { + return APIKeyStatusActionPermanentDisable + } + // Message text fallback: precise phrases only + if containsAny(msg, + "invalid api key", + "incorrect api key", + "no api key provided", + "token invalidated", + "token revoked", + "account has been deactivated", + "workspace has been deactivated", + "organization has been disabled", + "project has been disabled", + "key is disabled", + "api key disabled", + "account is not active", + "account has been suspended", + "account has been blocked", + ) { + return APIKeyStatusActionPermanentDisable + } + // Unrecognized 403: treat as temporary cooldown. + return APIKeyStatusActionTemporaryCooldown + } + case PlatformAnthropic: + switch statusCode { + case http.StatusUnauthorized: + // 401 is always a permanent key/auth failure for Anthropic + return APIKeyStatusActionPermanentDisable + case http.StatusForbidden: + // Anthropic 403: check for known account-level error types first. + // Some 403s are model-level permission issues (e.g. no access to claude-opus), + // not key invalidation. Use structured type field when available. + errType := strings.ToLower(strings.TrimSpace(extractUpstreamErrorType(responseBody))) + if containsAny(errType, + "authentication_error", + "permission_error", + ) { + return APIKeyStatusActionPermanentDisable + } + // Fallback: precise message phrases that only appear for account-level issues + if containsAny(msg, + "invalid api key", + "api key is invalid", + "account has been disabled", + "organization has been disabled", + "account has been deactivated", + ) { + return APIKeyStatusActionPermanentDisable + } + // Unknown 403: treat as temporary, not permanent — model access restriction + return APIKeyStatusActionTemporaryCooldown + case http.StatusPaymentRequired: + // 402 is temporary billing issue (payment needed), not permanent key invalidation + return APIKeyStatusActionTemporaryCooldown + case http.StatusTooManyRequests: + return APIKeyStatusActionTemporaryCooldown + case http.StatusBadRequest: + // Anthropic returns 400 for credit balance exhaustion (not 402/429) + if containsAny(msg, + "credit balance is too low", + "your credit balance is", + "insufficient credits", + "account has been disabled", + "organization has been disabled", + "account has been deactivated", + ) { + return APIKeyStatusActionPermanentDisable + } + return APIKeyStatusActionTemporaryCooldown + } + case PlatformGemini: + switch statusCode { + case http.StatusTooManyRequests: + return APIKeyStatusActionTemporaryCooldown + case http.StatusUnauthorized: + return APIKeyStatusActionPermanentDisable + case http.StatusForbidden: + // Use structured reason check first (covers BILLING_DISABLED, CONSUMER_SUSPENDED, PROJECT_DISABLED, SERVICE_DISABLED) + if googleapi.IsPermanentlyDisabledError(string(responseBody)) { + return APIKeyStatusActionPermanentDisable + } + // Match known permanent-disable message patterns. + // Avoid catch-all: some 403s indicate model-level permission issues (not account problems). + if containsAny(msg, + "billing is disabled", + "billing disabled", + "consumer suspended", + "project disabled", + "project has been suspended", + ) { + return APIKeyStatusActionPermanentDisable + } + // Unknown 403: treat as temporary cooldown rather than ignore. + // A model-level permission 403 is transient for this key/model combo; + // a temporary cooldown avoids hammering a key that may be account-level suspended. + return APIKeyStatusActionTemporaryCooldown + case http.StatusBadRequest: + if strings.Contains(bodyUpper, "API_KEY_INVALID") || googleapi.IsServiceDisabledError(string(responseBody)) { + return APIKeyStatusActionPermanentDisable + } + // FAILED_PRECONDITION with billing/free-tier messages: permanent disable. + // Bare FAILED_PRECONDITION without billing context may be a request issue, not key failure. + if strings.Contains(bodyUpper, "FAILED_PRECONDITION") && containsAny(msg, + "free tier is not available", + "enable billing", + "billing account", + "requires a billing", + ) { + return APIKeyStatusActionPermanentDisable + } + if containsAny(msg, + "api key not valid", + "invalid api key", + "api_key_invalid", + "api key is invalid", + "before or it is disabled", + "service disabled", + "api has not been used in project", + "unregistered callers", + "caller not registered", + "free tier is not available", + "enable billing", + ) { + return APIKeyStatusActionPermanentDisable + } + return APIKeyStatusActionTemporaryCooldown + } + } + + // All other non-200 status codes (404, 405, 422, etc.) that are not explicitly handled above: + // treat as temporary cooldown so the key is not scheduled again immediately. + // This covers endpoint-not-found, method-not-allowed, and any future unknown error codes. + return APIKeyStatusActionTemporaryCooldown +} + +func ShouldDisableAPIKeyAuthFailure(account *Account, statusCode int, responseBody []byte) bool { + return ShouldDisableAPIKeyStatus(account, statusCode, responseBody) +} + +// ClassifyAPIKeyProbeResponse classifies a probe response into (valid, invalid, cooldown, message). +// valid=true: key works. invalid=true: key is permanently disabled. cooldown=true: key needs temp cooldown. +func ClassifyAPIKeyProbeResponse(account *Account, statusCode int, responseBody []byte) (valid bool, invalid bool, cooldown bool, message string) { + if account == nil || account.Type != AccountTypeAPIKey { + return false, false, false, "unsupported account type" + } + + message = strings.TrimSpace(extractUpstreamErrorMessage(responseBody)) + if message == "" { + message = http.StatusText(statusCode) + } + message = sanitizeUpstreamErrorMessage(message) + + switch account.Platform { + case PlatformAnthropic, PlatformOpenAI, PlatformGemini: + switch ClassifyAPIKeyStatusAction(account, statusCode, responseBody) { + case APIKeyStatusActionValid: + return true, false, false, message + case APIKeyStatusActionPermanentDisable: + return false, true, false, message + case APIKeyStatusActionTemporaryCooldown: + return false, false, true, message + default: + return false, false, false, message + } + default: + return false, false, false, message + } +} + +// CheckAPIKeyValidity tests an API key account using a real chat completions request, +// identical to the single-account "test connection" flow. This ensures health check +// results are authoritative and consistent with manual test results. +func (s *AccountTestService) CheckAPIKeyValidity(ctx context.Context, account *Account) (*APIKeyHealthCheckResult, error) { + if account == nil { + return nil, fmt.Errorf("account is required") + } + if account.Type != AccountTypeAPIKey { + return nil, fmt.Errorf("account %d is not an apikey account", account.ID) + } + if s == nil || s.httpUpstream == nil { + return nil, fmt.Errorf("account test service is not configured") + } + + // Run the same real chat completions test used by single-account "test connection". + // Account state (SetError, SetSchedulable, SetTempUnschedulable, etc.) is written + // inside the platform-specific test functions, so no additional state writes are needed here. + result, err := s.RunTestBackground(ctx, account.ID, "") + if err != nil { + return nil, err + } + + valid := result.Status == "success" + invalid := false + message := result.ResponseText + if !valid { + message = result.ErrorMessage + } + + return &APIKeyHealthCheckResult{ + Platform: account.Platform, + Valid: valid, + Invalid: invalid, + Message: message, + }, nil +} + +func buildAPIKeyProbeErrorMessage(statusCode int, upstreamMsg string) string { + msg := strings.TrimSpace(upstreamMsg) + if msg == "" { + msg = http.StatusText(statusCode) + } + return fmt.Sprintf("API key permanently disabled after probe (%d): %s", statusCode, msg) +} + +func (s *AccountTestService) buildAPIKeyProbeRequest(ctx context.Context, account *Account) (*http.Request, error) { + switch account.Platform { + case PlatformAnthropic: + return s.buildAnthropicAPIKeyProbeRequest(ctx, account) + case PlatformOpenAI: + return s.buildOpenAIAPIKeyProbeRequest(ctx, account) + case PlatformGemini: + return s.buildGeminiAPIKeyProbeRequest(ctx, account) + default: + return nil, fmt.Errorf("unsupported apikey platform: %s", account.Platform) + } +} + +func (s *AccountTestService) buildAnthropicAPIKeyProbeRequest(ctx context.Context, account *Account) (*http.Request, error) { + baseURL := strings.TrimSpace(account.GetBaseURL()) + if baseURL == "" { + baseURL = DefaultAPIKeyBaseURL(account.Platform) + } + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, fmt.Errorf("invalid anthropic base url: %w", err) + } + + // Use GET /v1/models for probe - no token consumption, pure auth check. + req, err := http.NewRequestWithContext(ctx, http.MethodGet, + strings.TrimSuffix(normalizedBaseURL, "/")+"/v1/models", + nil) + if err != nil { + return nil, err + } + req.Header.Set("anthropic-version", "2023-06-01") + req.Header.Set("x-api-key", account.GetCredential("api_key")) + req.Header.Set("User-Agent", proxyQualityClientUserAgent) + return req, nil +} + +func (s *AccountTestService) buildOpenAIAPIKeyProbeRequest(ctx context.Context, account *Account) (*http.Request, error) { + baseURL := strings.TrimSpace(account.GetOpenAIBaseURL()) + if baseURL == "" { + baseURL = DefaultAPIKeyBaseURL(account.Platform) + } + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, fmt.Errorf("invalid openai base url: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, strings.TrimSuffix(normalizedBaseURL, "/")+"/v1/models", nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+account.GetCredential("api_key")) + req.Header.Set("User-Agent", proxyQualityClientUserAgent) + return req, nil +} + +func (s *AccountTestService) buildGeminiAPIKeyProbeRequest(ctx context.Context, account *Account) (*http.Request, error) { + baseURL := strings.TrimSpace(account.GetBaseURL()) + if baseURL == "" { + baseURL = DefaultAPIKeyBaseURL(account.Platform) + } + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, fmt.Errorf("invalid gemini base url: %w", err) + } + + endpoint := strings.TrimSuffix(normalizedBaseURL, "/") + "/v1beta/models?key=" + url.QueryEscape(account.GetCredential("api_key")) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + req.Header.Set("User-Agent", proxyQualityClientUserAgent) + return req, nil +} + +func containsAny(haystack string, needles ...string) bool { + for _, needle := range needles { + if needle != "" && strings.Contains(haystack, needle) { + return true + } + } + return false +} diff --git a/backend/internal/service/apikey_health_test.go b/backend/internal/service/apikey_health_test.go new file mode 100644 index 0000000000..f0f2de7408 --- /dev/null +++ b/backend/internal/service/apikey_health_test.go @@ -0,0 +1,80 @@ +//go:build unit + +package service + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDetectAPIKeyPlatform(t *testing.T) { + tests := []struct { + key string + platform string + ok bool + }{ + {key: "sk-ant-api03-abc", platform: PlatformAnthropic, ok: true}, + {key: "AIzaSyD-example", platform: PlatformGemini, ok: true}, + {key: "sk-proj-123", platform: PlatformOpenAI, ok: true}, + {key: "unknown-key", platform: "", ok: false}, + } + + for _, tt := range tests { + platform, ok := DetectAPIKeyPlatform(tt.key) + require.Equal(t, tt.platform, platform) + require.Equal(t, tt.ok, ok) + } +} + +func TestClassifyAPIKeyStatusAction(t *testing.T) { + openAI := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + anthropic := &Account{Platform: PlatformAnthropic, Type: AccountTypeAPIKey} + gemini := &Account{Platform: PlatformGemini, Type: AccountTypeAPIKey} + + require.Equal(t, APIKeyStatusActionValid, ClassifyAPIKeyStatusAction(openAI, http.StatusOK, []byte(`{}`))) + require.Equal(t, APIKeyStatusActionPermanentDisable, ClassifyAPIKeyStatusAction(openAI, http.StatusForbidden, []byte(`{"error":{"message":"organization has been disabled","code":"account_deactivated"}}`))) + require.Equal(t, APIKeyStatusActionIgnore, ClassifyAPIKeyStatusAction(openAI, http.StatusForbidden, []byte(`{"error":{"message":"model not allowed for this project","code":"forbidden"}}`))) + require.Equal(t, APIKeyStatusActionIgnore, ClassifyAPIKeyStatusAction(anthropic, http.StatusMethodNotAllowed, []byte(`method not allowed`))) + require.Equal(t, APIKeyStatusActionTemporaryCooldown, ClassifyAPIKeyStatusAction(gemini, http.StatusTooManyRequests, []byte(`{"error":{"message":"quota exceeded"}}`))) + require.Equal(t, APIKeyStatusActionPermanentDisable, ClassifyAPIKeyStatusAction(gemini, http.StatusBadRequest, []byte(`{"error":{"message":"API key not valid. Please pass a valid API key.","status":"API_KEY_INVALID"}}`))) +} + +func TestClassifyAPIKeyProbeResponse(t *testing.T) { + openAIAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + geminiAccount := &Account{Platform: PlatformGemini, Type: AccountTypeAPIKey} + anthropicAccount := &Account{Platform: PlatformAnthropic, Type: AccountTypeAPIKey} + + valid, invalid, cooldown, _ := ClassifyAPIKeyProbeResponse(openAIAccount, http.StatusOK, []byte(`{}`)) + require.True(t, valid) + require.False(t, invalid) + require.False(t, cooldown) + + valid, invalid, cooldown, _ = ClassifyAPIKeyProbeResponse(openAIAccount, http.StatusPaymentRequired, []byte(`{"error":{"message":"insufficient balance"}}`)) + require.False(t, valid) + require.True(t, invalid) + require.False(t, cooldown) + + valid, invalid, cooldown, _ = ClassifyAPIKeyProbeResponse(geminiAccount, http.StatusBadRequest, []byte(`{"error":{"message":"API key not valid. Please pass a valid API key.","status":"API_KEY_INVALID"}}`)) + require.False(t, valid) + require.True(t, invalid) + require.False(t, cooldown) + + // OpenAI 403 unrecognized → cooldown (not permanent disable, not valid) + valid, invalid, cooldown, _ = ClassifyAPIKeyProbeResponse(openAIAccount, http.StatusForbidden, []byte(`{"error":{"message":"model not allowed for this project","code":"forbidden"}}`)) + require.False(t, valid) + require.False(t, invalid) + require.True(t, cooldown) + + // Unknown status code (405) → cooldown + valid, invalid, cooldown, _ = ClassifyAPIKeyProbeResponse(anthropicAccount, http.StatusMethodNotAllowed, []byte(`method not allowed`)) + require.False(t, valid) + require.False(t, invalid) + require.True(t, cooldown) + + valid, invalid, cooldown, _ = ClassifyAPIKeyProbeResponse(openAIAccount, http.StatusTooManyRequests, []byte(`{"error":{"message":"rate limited"}}`)) + require.False(t, valid) + require.False(t, invalid) + require.True(t, cooldown) +} diff --git a/backend/internal/service/error_passthrough_runtime.go b/backend/internal/service/error_passthrough_runtime.go index 011c3ce4d5..c152b00a68 100644 --- a/backend/internal/service/error_passthrough_runtime.go +++ b/backend/internal/service/error_passthrough_runtime.go @@ -1,6 +1,10 @@ package service -import "github.com/gin-gonic/gin" +import ( + "strings" + + "github.com/gin-gonic/gin" +) const errorPassthroughServiceContextKey = "error_passthrough_service" @@ -56,9 +60,12 @@ func applyErrorPassthroughRule( status = *rule.ResponseCode } - errMsg = ExtractUpstreamErrorMessage(responseBody) - if !rule.PassthroughBody && rule.CustomMessage != nil { - errMsg = *rule.CustomMessage + errMsg = strings.TrimSpace(defaultErrMsg) + if rule.CustomMessage != nil { + custom := strings.TrimSpace(*rule.CustomMessage) + if custom != "" { + errMsg = custom + } } // 命中 skip_monitoring 时在 context 中标记,供 ops_error_logger 跳过记录。 diff --git a/backend/internal/service/error_passthrough_runtime_test.go b/backend/internal/service/error_passthrough_runtime_test.go index 7032d15b95..39bae1fd0e 100644 --- a/backend/internal/service/error_passthrough_runtime_test.go +++ b/backend/internal/service/error_passthrough_runtime_test.go @@ -251,6 +251,42 @@ func TestApplyErrorPassthroughRule_NoSkipMonitoringDoesNotSetContextKey(t *testi assert.False(t, exists, "OpsSkipPassthroughKey should NOT be set when skip_monitoring=false") } +func TestApplyErrorPassthroughRule_DeprecatedPassthroughBodyDoesNotLeakUpstreamMessage(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + rule := &model.ErrorPassthroughRule{ + ID: 2, + Name: "deprecated-passthrough-body", + Enabled: true, + Priority: 1, + ErrorCodes: []int{http.StatusBadRequest}, + MatchMode: model.MatchModeAny, + PassthroughCode: true, + PassthroughBody: true, // Deprecated legacy value + } + ruleSvc := &ErrorPassthroughService{} + ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{rule}) + BindErrorPassthroughService(c, ruleSvc) + + status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + PlatformAnthropic, + http.StatusBadRequest, + []byte(`{"error":{"message":"SECRET_UPSTREAM_TEXT"}}`), + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ) + + require.True(t, matched) + require.Equal(t, http.StatusBadRequest, status) + require.Equal(t, "upstream_error", errType) + require.Equal(t, "Upstream request failed", errMsg) + require.NotContains(t, errMsg, "SECRET_UPSTREAM_TEXT") +} + func newNonFailoverPassthroughRule(statusCode int, keyword string, respCode int, customMessage string) *model.ErrorPassthroughRule { return &model.ErrorPassthroughRule{ ID: 1, diff --git a/backend/internal/service/error_passthrough_service.go b/backend/internal/service/error_passthrough_service.go index 26fdf9a7dd..fdee254f3d 100644 --- a/backend/internal/service/error_passthrough_service.go +++ b/backend/internal/service/error_passthrough_service.go @@ -92,16 +92,29 @@ func NewErrorPassthroughService( // List 获取所有规则 func (s *ErrorPassthroughService) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) { - return s.repo.List(ctx) + rules, err := s.repo.List(ctx) + if err != nil { + return nil, err + } + for _, rule := range rules { + normalizeErrorPassthroughRule(rule) + } + return rules, nil } // GetByID 根据 ID 获取规则 func (s *ErrorPassthroughService) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) { - return s.repo.GetByID(ctx, id) + rule, err := s.repo.GetByID(ctx, id) + if err != nil || rule == nil { + return rule, err + } + normalizeErrorPassthroughRule(rule) + return rule, nil } // Create 创建规则 func (s *ErrorPassthroughService) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + normalizeErrorPassthroughRule(rule) if err := rule.Validate(); err != nil { return nil, err } @@ -121,6 +134,7 @@ func (s *ErrorPassthroughService) Create(ctx context.Context, rule *model.ErrorP // Update 更新规则 func (s *ErrorPassthroughService) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + normalizeErrorPassthroughRule(rule) if err := rule.Validate(); err != nil { return nil, err } @@ -239,6 +253,7 @@ func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error { func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) { cached := make([]*cachedPassthroughRule, len(rules)) for i, r := range rules { + normalizeErrorPassthroughRule(r) cr := &cachedPassthroughRule{ErrorPassthroughRule: r} if len(r.Keywords) > 0 { cr.lowerKeywords = make([]string, len(r.Keywords)) @@ -271,6 +286,24 @@ func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughR s.localCacheMu.Unlock() } +func normalizeErrorPassthroughRule(rule *model.ErrorPassthroughRule) { + if rule == nil { + return + } + // Deprecated capability removal: never expose upstream raw body text via rules. + rule.PassthroughBody = false + if rule.CustomMessage != nil { + trimmed := strings.TrimSpace(*rule.CustomMessage) + if trimmed == "" { + rule.CustomMessage = nil + return + } + if trimmed != *rule.CustomMessage { + rule.CustomMessage = &trimmed + } + } +} + // clearLocalCache 清空本地缓存,避免刷新失败时继续命中陈旧规则。 func (s *ErrorPassthroughService) clearLocalCache() { s.localCacheMu.Lock() diff --git a/backend/internal/service/error_passthrough_service_test.go b/backend/internal/service/error_passthrough_service_test.go index 96ddd6377e..d7c0afde8c 100644 --- a/backend/internal/service/error_passthrough_service_test.go +++ b/backend/internal/service/error_passthrough_service_test.go @@ -618,12 +618,12 @@ func TestMatchRule_CaseInsensitiveKeyword(t *testing.T) { // 测试真实场景 // ============================================================================= -func TestMatchRule_RealWorldScenario_ContextLimitPassthrough(t *testing.T) { - // 场景:上游返回 422 + "context limit has been reached",需要透传给客户端 +func TestMatchRule_RealWorldScenario_ContextLimitRule(t *testing.T) { + // 场景:上游返回 422 + "context limit has been reached",命中规则后返回统一错误契约 rules := []*model.ErrorPassthroughRule{ { ID: 1, - Name: "Context Limit Passthrough", + Name: "Context Limit Rule", Enabled: true, Priority: 1, ErrorCodes: []int{422}, @@ -631,7 +631,7 @@ func TestMatchRule_RealWorldScenario_ContextLimitPassthrough(t *testing.T) { MatchMode: model.MatchModeAll, // 必须同时满足 Platforms: []string{"anthropic", "antigravity"}, PassthroughCode: true, - PassthroughBody: true, + PassthroughBody: true, // Deprecated: 运行时会被强制关闭 }, } @@ -643,7 +643,7 @@ func TestMatchRule_RealWorldScenario_ContextLimitPassthrough(t *testing.T) { matched := svc.MatchRule("anthropic", 422, body) require.NotNil(t, matched) assert.True(t, matched.PassthroughCode) - assert.True(t, matched.PassthroughBody) + assert.False(t, matched.PassthroughBody, "passthrough_body 已废弃,运行时应强制关闭") }) // 测试 Antigravity 平台 @@ -716,7 +716,7 @@ func TestErrorPassthroughRule_Validate(t *testing.T) { errorField string }{ { - name: "有效规则 - 透传模式(含错误码)", + name: "有效规则 - 兼容透传字段(含错误码)", rule: &model.ErrorPassthroughRule{ Name: "Valid Rule", MatchMode: model.MatchModeAny, @@ -727,7 +727,7 @@ func TestErrorPassthroughRule_Validate(t *testing.T) { expectError: false, }, { - name: "有效规则 - 透传模式(含关键词)", + name: "有效规则 - 兼容透传字段(含关键词)", rule: &model.ErrorPassthroughRule{ Name: "Valid Rule", MatchMode: model.MatchModeAny, @@ -815,7 +815,7 @@ func TestErrorPassthroughRule_Validate(t *testing.T) { errorField: "response_code", }, { - name: "自定义消息但未提供值", + name: "自定义消息可省略(走链路默认文案)", rule: &model.ErrorPassthroughRule{ Name: "Missing Message", MatchMode: model.MatchModeAny, @@ -824,8 +824,7 @@ func TestErrorPassthroughRule_Validate(t *testing.T) { PassthroughBody: false, CustomMessage: nil, }, - expectError: true, - errorField: "custom_message", + expectError: false, }, { name: "自定义消息为空字符串", @@ -857,6 +856,30 @@ func TestErrorPassthroughRule_Validate(t *testing.T) { } } +func TestErrorPassthroughService_Create_NormalizesDeprecatedPassthroughBody(t *testing.T) { + ctx := context.Background() + repo := &mockErrorPassthroughRepo{} + cache := newMockErrorPassthroughCache(nil, false) + svc := &ErrorPassthroughService{repo: repo, cache: cache} + + rule := &model.ErrorPassthroughRule{ + Name: "legacy-rule", + Enabled: true, + Priority: 1, + ErrorCodes: []int{422}, + MatchMode: model.MatchModeAny, + PassthroughCode: true, + PassthroughBody: true, + } + + created, err := svc.Create(ctx, rule) + require.NoError(t, err) + require.NotNil(t, created) + require.False(t, created.PassthroughBody) + require.NotEmpty(t, repo.rules) + require.False(t, repo.rules[0].PassthroughBody) +} + // ============================================================================= // 测试写路径缓存刷新(Create/Update/Delete) // ============================================================================= diff --git a/backend/internal/service/gateway_claude_beta_dynamic_test.go b/backend/internal/service/gateway_claude_beta_dynamic_test.go new file mode 100644 index 0000000000..36fd412bdc --- /dev/null +++ b/backend/internal/service/gateway_claude_beta_dynamic_test.go @@ -0,0 +1,101 @@ +package service + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestBuildAuth2APIDynamicBetaTokens_NonHaiku(t *testing.T) { + got := buildAuth2APIDynamicBetaTokens("claude-sonnet-4-6", false) + require.Equal(t, []string{ + claude.BetaClaudeCode, + claude.BetaOAuth, + claude.BetaInterleavedThinking, + betaRedactThinking, + betaContextManagement, + betaPromptCachingScope, + betaAdvancedToolUse, + betaEffort, + }, got) +} + +func TestBuildAuth2APIDynamicBetaTokens_HaikuStructured(t *testing.T) { + got := buildAuth2APIDynamicBetaTokens("claude-haiku-4-5-20251001", true) + require.Equal(t, []string{ + claude.BetaOAuth, + claude.BetaInterleavedThinking, + betaRedactThinking, + betaContextManagement, + betaPromptCachingScope, + betaStructuredOutputs, + }, got) +} + +func TestGatewayService_GetBetaHeader_ClientProvided(t *testing.T) { + svc := &GatewayService{} + + withoutOAuth := svc.getBetaHeader("claude-sonnet-4-6", []byte(`{}`), "interleaved-thinking-2025-05-14,foo-beta") + require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", withoutOAuth) + + withOAuth := svc.getBetaHeader("claude-sonnet-4-6", []byte(`{}`), "oauth-2025-04-20,foo-beta") + require.Equal(t, "oauth-2025-04-20,foo-beta", withOAuth) +} + +func TestGatewayService_GetBetaHeader_DynamicStructured(t *testing.T) { + svc := &GatewayService{} + body := []byte(`{"output_config":{"type":"json_schema","json_schema":{"name":"x","schema":{"type":"object"}}}}`) + got := svc.getBetaHeader("claude-sonnet-4-6", body, "") + require.Equal(t, strings.Join([]string{ + claude.BetaClaudeCode, + claude.BetaOAuth, + claude.BetaInterleavedThinking, + betaRedactThinking, + betaContextManagement, + betaPromptCachingScope, + betaAdvancedToolUse, + betaEffort, + betaStructuredOutputs, + }, ","), got) +} + +func TestGatewayService_BuildUpstreamRequest_OAuthMimicUsesDynamicBetas(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{} + account := &Account{ + ID: 1001, + Name: "oauth-dynamic-beta", + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Concurrency: 1, + } + + req, err := svc.buildUpstreamRequest( + context.Background(), + c, + account, + []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`), + "oauth-token", + "oauth", + "claude-sonnet-4-6", + false, + true, + ) + require.NoError(t, err) + beta := getHeaderRaw(req.Header, "anthropic-beta") + require.Contains(t, beta, claude.BetaClaudeCode) + require.Contains(t, beta, claude.BetaOAuth) + require.Contains(t, beta, betaRedactThinking) + require.Contains(t, beta, betaAdvancedToolUse) + require.Contains(t, beta, betaEffort) +} diff --git a/backend/internal/service/gateway_claude_cloaking_test.go b/backend/internal/service/gateway_claude_cloaking_test.go new file mode 100644 index 0000000000..e174b5e83a --- /dev/null +++ b/backend/internal/service/gateway_claude_cloaking_test.go @@ -0,0 +1,138 @@ +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestComputeClaudeCodeBillingFingerprint(t *testing.T) { + message := "01234567890123456789012345" + got := computeClaudeCodeBillingFingerprint(message, "2.1.88") + require.Equal(t, "d4e", got) +} + +func TestGenerateClaudeCodeBillingHeader(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"01234567890123456789012345"}]}]}`) + got := generateClaudeCodeBillingHeader(body, "2.1.88", "cli", "") + require.Equal(t, "x-anthropic-billing-header: cc_version=2.1.88.d4e; cc_entrypoint=cli;", got) +} + +func TestEnsureClaudeOAuthSystemCloaking_InsertsBillingAndPrefix(t *testing.T) { + body := []byte(`{"model":"claude-sonnet-4-6","system":[{"type":"text","text":"custom system"}],"messages":[{"role":"user","content":[{"type":"text","text":"01234567890123456789012345"}]}]}`) + next, changed := ensureClaudeOAuthSystemCloaking(body, "2.1.88", "cli") + require.True(t, changed) + + system := gjson.GetBytes(next, "system") + require.True(t, system.Exists()) + require.True(t, system.IsArray()) + require.GreaterOrEqual(t, len(system.Array()), 3) + + first := system.Array()[0].Get("text").String() + require.Contains(t, first, "x-anthropic-billing-header:") + require.Contains(t, first, "cc_version=2.1.88.d4e") + require.Contains(t, first, "cc_entrypoint=cli") + + second := system.Array()[1].Get("text").String() + require.Equal(t, claudeCodeSystemPrompt, strings.TrimSpace(second)) + require.Equal(t, "custom system", system.Array()[2].Get("text").String()) +} + +func TestEnsureClaudeOAuthSystemCloaking_PreservesExistingBillingBlock(t *testing.T) { + existingBilling := "x-anthropic-billing-header: cc_version=2.0.0.abc; cc_entrypoint=cli;" + body := []byte(`{"model":"claude-sonnet-4-6","system":[{"type":"text","text":"` + existingBilling + `"},{"type":"text","text":"You are Claude Code, Anthropic's official CLI for Claude."},{"type":"text","text":"custom"}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`) + next, changed := ensureClaudeOAuthSystemCloaking(body, "2.1.88", "cli") + require.True(t, changed) + + system := gjson.GetBytes(next, "system") + require.True(t, system.IsArray()) + items := system.Array() + require.GreaterOrEqual(t, len(items), 3) + require.Equal(t, existingBilling, items[0].Get("text").String()) + require.Equal(t, "You are Claude Code, Anthropic's official CLI for Claude.", items[1].Get("text").String()) + require.Equal(t, "custom", items[2].Get("text").String()) + require.Equal(t, 1, strings.Count(system.Raw, "x-anthropic-billing-header")) +} + +func TestBuildUpstreamRequest_OAuth_ForcesJSONMetadataAndSessionHeader(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + req.Header.Set("x-api-key", "test-key-session-metadata") + c.Request = req + + svc := &GatewayService{} + account := &Account{ + ID: 4242, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "account_uuid": "acc-uuid-42", + }, + } + + body := []byte(`{"model":"claude-sonnet-4-6","metadata":{"user_id":"legacy-user-id"},"messages":[{"role":"user","content":[{"type":"text","text":"01234567890123456789012345"}]}]}`) + upstreamReq, err := svc.buildUpstreamRequest(context.Background(), c, account, body, "oauth-token", "oauth", "claude-sonnet-4-6", false, false) + require.NoError(t, err) + + rawBody, err := io.ReadAll(upstreamReq.Body) + require.NoError(t, err) + uidRaw := gjson.GetBytes(rawBody, "metadata.user_id").String() + require.NotEmpty(t, uidRaw) + + parsed := ParseMetadataUserID(uidRaw) + require.NotNil(t, parsed) + require.True(t, parsed.IsNewFormat) + require.NotEmpty(t, parsed.DeviceID) + require.Equal(t, "acc-uuid-42", parsed.AccountUUID) + require.NotEmpty(t, parsed.SessionID) + + sessionHeader := getHeaderRaw(upstreamReq.Header, "X-Claude-Code-Session-Id") + require.NotEmpty(t, sessionHeader) + require.Equal(t, parsed.SessionID, sessionHeader) +} + +func TestBuildUpstreamRequest_OAuth_SessionStablePerAPIKey(t *testing.T) { + gin.SetMode(gin.TestMode) + account := &Account{ + ID: 5001, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "account_uuid": "acc-uuid-sticky", + }, + } + svc := &GatewayService{} + body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`) + + makeReq := func(apiKey string) *http.Request { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + r := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + r.Header.Set("x-api-key", apiKey) + c.Request = r + req, err := svc.buildUpstreamRequest(context.Background(), c, account, body, "oauth-token", "oauth", "claude-sonnet-4-6", false, false) + require.NoError(t, err) + return req + } + + reqA1 := makeReq("stable-api-key-a") + reqA2 := makeReq("stable-api-key-a") + reqB := makeReq("stable-api-key-b") + + sessionA1 := getHeaderRaw(reqA1.Header, "X-Claude-Code-Session-Id") + sessionA2 := getHeaderRaw(reqA2.Header, "X-Claude-Code-Session-Id") + sessionB := getHeaderRaw(reqB.Header, "X-Claude-Code-Session-Id") + + require.NotEmpty(t, sessionA1) + require.Equal(t, sessionA1, sessionA2) + require.NotEqual(t, sessionA1, sessionB) +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 2d16ad9429..6d9815a9da 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -82,6 +82,10 @@ func (m *mockAccountRepoForPlatform) FindByExtraField(ctx context.Context, key s return nil, nil } +func (m *mockAccountRepoForPlatform) FindByAPIKey(ctx context.Context, platform, apiKey, baseURL string) (*Account, error) { + return nil, nil +} + func (m *mockAccountRepoForPlatform) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { return nil, nil } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 94e04d286d..4e801df2b3 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "crypto/sha256" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -16,9 +17,11 @@ import ( "os" "path/filepath" "regexp" + "runtime" "sort" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -54,6 +57,17 @@ const ( defaultModelsListCacheTTL = 15 * time.Second postUsageBillingTimeout = 15 * time.Second debugGatewayBodyEnv = "SUB2API_DEBUG_GATEWAY_BODY" + // Claude Code billing header fingerprint salt (mirrors Claude Code validator requirement). + claudeCodeBillingFingerprintSalt = "59cf53e54c78" + defaultClaudeCodeEntrypoint = "cli" + betaRedactThinking = "redact-thinking-2026-02-12" + betaContextManagement = "context-management-2025-06-27" + betaPromptCachingScope = "prompt-caching-scope-2026-01-05" + betaAdvancedToolUse = "advanced-tool-use-2025-11-20" + betaEffort = "effort-2025-11-24" + betaStructuredOutputs = "structured-outputs-2025-12-15" + claudeCodeSessionTTLMin = 30 * time.Minute + claudeCodeSessionTTLMax = 300 * time.Minute ) const ( @@ -88,8 +102,17 @@ var ( modelsListCacheHitTotal atomic.Int64 modelsListCacheMissTotal atomic.Int64 modelsListCacheStoreTotal atomic.Int64 + + claudeCodeSessionMu sync.Mutex + claudeCodeSessionCache = map[string]claudeCodeSessionEntry{} ) +type claudeCodeSessionEntry struct { + id string + lastUsed time.Time + ttl time.Duration +} + func GatewayWindowCostPrefetchStats() (cacheHit, cacheMiss, batchSQL, fallback, errCount int64) { return windowCostPrefetchCacheHitTotal.Load(), windowCostPrefetchCacheMissTotal.Load(), @@ -215,6 +238,111 @@ func safeHeaderValueForLog(key string, v string) string { } } +func extractDownstreamAPIKey(headers http.Header) string { + if headers == nil { + return "" + } + auth := strings.TrimSpace(headers.Get("authorization")) + if strings.HasPrefix(strings.ToLower(auth), "bearer ") { + return strings.TrimSpace(auth[7:]) + } + if k := strings.TrimSpace(headers.Get("x-api-key")); k != "" { + return k + } + if k := strings.TrimSpace(headers.Get("x-goog-api-key")); k != "" { + return k + } + return "" +} + +func hashAPIKeyForSession(apiKey string) string { + if strings.TrimSpace(apiKey) == "" { + return "default" + } + sum := sha256.Sum256([]byte(apiKey)) + return hex.EncodeToString(sum[:]) +} + +func randomClaudeCodeSessionTTL() time.Duration { + if claudeCodeSessionTTLMax <= claudeCodeSessionTTLMin { + return claudeCodeSessionTTLMin + } + jitter := mathrand.Int63n(int64(claudeCodeSessionTTLMax - claudeCodeSessionTTLMin)) + return claudeCodeSessionTTLMin + time.Duration(jitter) +} + +func getOrCreateClaudeCodeSessionID(apiKeyHash string) string { + if strings.TrimSpace(apiKeyHash) == "" { + apiKeyHash = "default" + } + now := time.Now() + claudeCodeSessionMu.Lock() + defer claudeCodeSessionMu.Unlock() + + if entry, ok := claudeCodeSessionCache[apiKeyHash]; ok { + if now.Sub(entry.lastUsed) < entry.ttl { + entry.lastUsed = now + claudeCodeSessionCache[apiKeyHash] = entry + return entry.id + } + } + + for key, entry := range claudeCodeSessionCache { + if now.Sub(entry.lastUsed) >= entry.ttl { + delete(claudeCodeSessionCache, key) + } + } + + id := generateRandomUUID() + claudeCodeSessionCache[apiKeyHash] = claudeCodeSessionEntry{ + id: id, + lastUsed: now, + ttl: randomClaudeCodeSessionTTL(), + } + return id +} + +func formatClaudeOAuthMetadataUserID(deviceID, accountUUID, sessionID string) string { + if strings.TrimSpace(deviceID) == "" || strings.TrimSpace(sessionID) == "" { + return "" + } + raw, err := json.Marshal(jsonUserID{ + DeviceID: strings.TrimSpace(deviceID), + AccountUUID: strings.TrimSpace(accountUUID), + SessionID: strings.TrimSpace(sessionID), + }) + if err != nil { + return "" + } + return string(raw) +} + +func forceClaudeOAuthMetadataUserID(body []byte, userID string) ([]byte, bool) { + if len(body) == 0 || strings.TrimSpace(userID) == "" { + return body, false + } + + metadata := gjson.GetBytes(body, "metadata") + if !metadata.Exists() || metadata.Type == gjson.Null { + raw, err := marshalAnthropicMetadata(userID) + if err != nil { + return body, false + } + return setJSONRawBytes(body, "metadata", raw) + } + + trimmedRaw := strings.TrimSpace(metadata.Raw) + if strings.HasPrefix(trimmedRaw, "{") { + return setJSONValueBytes(body, "metadata.user_id", userID) + } + + raw, err := marshalAnthropicMetadata(userID) + if err != nil { + return body, false + } + return setJSONRawBytes(body, "metadata", raw) +} + func extractSystemPreviewFromBody(body []byte) string { if len(body) == 0 { return "" @@ -521,6 +649,14 @@ func (e *UpstreamFailoverError) Error() string { return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode) } +// streamSSEError is returned by processSSEEvent when an SSE "event: error" frame is received. +// It carries the raw data line so the caller can classify the error and update account state. +type streamSSEError struct { + body []byte +} + +func (e *streamSSEError) Error() string { return "have error in stream" } + // TempUnscheduleRetryableError 对 RetryableOnSameAccount 类型的 failover 错误触发临时封禁。 // 由 handler 层在同账号重试全部用尽、切换账号时调用。 func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *UpstreamFailoverError) { @@ -883,6 +1019,9 @@ type claudeOAuthNormalizeOptions struct { injectMetadata bool metadataUserID string stripSystemCacheControl bool + applySystemCloaking bool + cloakingCLIVersion string + cloakingEntrypoint string } // sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present). @@ -1045,6 +1184,257 @@ func ensureClaudeOAuthMetadataUserID(body []byte, userID string) ([]byte, bool) return setJSONRawBytes(body, "metadata", raw) } +func resolveClaudeCodeCLIVersion(ua string) string { + if v := ExtractCLIVersion(strings.TrimSpace(ua)); v != "" { + return v + } + if v := ExtractCLIVersion(claude.DefaultHeaders["User-Agent"]); v != "" { + return v + } + return "2.1.88" +} + +func extractFirstAnthropicUserMessageText(body []byte) string { + messages := gjson.GetBytes(body, "messages") + if !messages.Exists() || !messages.IsArray() { + return "" + } + + text := "" + messages.ForEach(func(_, item gjson.Result) bool { + if item.Get("role").String() != "user" { + return true + } + content := item.Get("content") + if !content.Exists() { + return false + } + if content.Type == gjson.String { + text = content.String() + return false + } + if content.IsArray() { + content.ForEach(func(_, part gjson.Result) bool { + if part.Get("type").String() != "text" { + return true + } + t := part.Get("text") + if t.Exists() && t.Type == gjson.String { + text = t.String() + return false + } + return true + }) + } + return false + }) + return text +} + +func computeClaudeCodeBillingFingerprint(messageText, version string) string { + runes := []rune(messageText) + indices := []int{4, 7, 20} + chars := make([]rune, 0, len(indices)) + for _, idx := range indices { + if idx >= 0 && idx < len(runes) { + chars = append(chars, runes[idx]) + continue + } + chars = append(chars, '0') + } + + input := claudeCodeBillingFingerprintSalt + string(chars) + version + sum := sha256.Sum256([]byte(input)) + return hex.EncodeToString(sum[:])[:3] +} + +func generateClaudeCodeBillingHeader(body []byte, cliVersion, entrypoint, workload string) string { + msgText := extractFirstAnthropicUserMessageText(body) + fp := computeClaudeCodeBillingFingerprint(msgText, cliVersion) + workloadPair := "" + if strings.TrimSpace(workload) != "" { + workloadPair = " cc_workload=" + strings.TrimSpace(workload) + ";" + } + return "x-anthropic-billing-header: cc_version=" + cliVersion + "." + fp + "; cc_entrypoint=" + entrypoint + ";" + workloadPair +} + +func isBillingHeaderSystemText(text string) bool { + return strings.Contains(text, "x-anthropic-billing-header") +} + +func isClaudeCodePrefixSystemText(text string) bool { + return strings.Contains(text, "You are Claude Code") +} + +func ensureClaudeOAuthSystemCloaking(body []byte, cliVersion, entrypoint string) ([]byte, bool) { + if len(body) == 0 { + return body, false + } + if cliVersion == "" { + cliVersion = resolveClaudeCodeCLIVersion("") + } + if entrypoint == "" { + entrypoint = defaultClaudeCodeEntrypoint + } + + sys := gjson.GetBytes(body, "system") + var remaining [][]byte + var billingRaw []byte + var prefixRaw []byte + + collectTextBlock := func(raw []byte, text string) { + if billingRaw == nil && isBillingHeaderSystemText(text) { + billingRaw = raw + return + } + if prefixRaw == nil && isClaudeCodePrefixSystemText(text) { + prefixRaw = raw + return + } + remaining = append(remaining, raw) + } + + switch { + case !sys.Exists() || sys.Type == gjson.Null: + // no-op: we will synthesize billing + prefix blocks below. + case sys.Type == gjson.String: + text := sys.String() + raw, err := marshalAnthropicSystemTextBlock(text, false) + if err == nil { + collectTextBlock(raw, text) + } + case sys.IsArray(): + sys.ForEach(func(_, item gjson.Result) bool { + itemText := "" + if item.Type == gjson.String { + itemText = item.String() + raw, err := marshalAnthropicSystemTextBlock(itemText, false) + if err == nil { + collectTextBlock(raw, itemText) + return true + } + } + if item.Get("type").String() == "text" { + textResult := item.Get("text") + if textResult.Exists() && textResult.Type == gjson.String { + itemText = textResult.String() + } + } + raw := []byte(item.Raw) + if itemText != "" { + collectTextBlock(raw, itemText) + return true + } + remaining = append(remaining, raw) + return true + }) + default: + remaining = append(remaining, []byte(sys.Raw)) + } + + if billingRaw == nil { + billingHeader := generateClaudeCodeBillingHeader(body, cliVersion, entrypoint, "") + raw, err := marshalAnthropicSystemTextBlock(billingHeader, false) + if err == nil { + billingRaw = raw + } + } + if prefixRaw == nil { + raw, err := marshalAnthropicSystemTextBlock(claudeCodeSystemPrompt, true) + if err == nil { + prefixRaw = raw + } + } + if billingRaw == nil || prefixRaw == nil { + return body, false + } + + items := make([][]byte, 0, len(remaining)+2) + items = append(items, billingRaw, prefixRaw) + items = append(items, remaining...) + next, ok := setJSONRawBytes(body, "system", buildJSONArrayRaw(items)) + if !ok { + return body, false + } + return next, true +} + +func isStructuredOutputRequest(body []byte) bool { + outputType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(body, "output_config.type").String())) + if outputType == "json_object" || outputType == "json_schema" { + return true + } + if gjson.GetBytes(body, "output_config.json_schema").Exists() { + return true + } + // Keep compatibility with OpenAI-style payloads that may pass through unified gateways. + responseFormatType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(body, "response_format.type").String())) + if responseFormatType == "json_object" || responseFormatType == "json_schema" { + return true + } + return false +} + +func buildAuth2APIDynamicBetaTokens(modelID string, structured bool) []string { + isHaiku := strings.Contains(strings.ToLower(modelID), "haiku") + if isHaiku { + if structured { + return []string{ + claude.BetaOAuth, + claude.BetaInterleavedThinking, + betaRedactThinking, + betaContextManagement, + betaPromptCachingScope, + betaStructuredOutputs, + } + } + return []string{ + claude.BetaOAuth, + claude.BetaInterleavedThinking, + betaRedactThinking, + betaContextManagement, + betaPromptCachingScope, + claude.BetaClaudeCode, + } + } + + if structured { + return []string{ + claude.BetaClaudeCode, + claude.BetaOAuth, + claude.BetaInterleavedThinking, + betaRedactThinking, + betaContextManagement, + betaPromptCachingScope, + betaAdvancedToolUse, + betaEffort, + betaStructuredOutputs, + } + } + return []string{ + claude.BetaClaudeCode, + claude.BetaOAuth, + claude.BetaInterleavedThinking, + betaRedactThinking, + betaContextManagement, + betaPromptCachingScope, + betaAdvancedToolUse, + betaEffort, + } +} + +func joinBetaTokens(tokens []string) string { + filtered := make([]string, 0, len(tokens)) + for _, token := range tokens { + token = strings.TrimSpace(token) + if token == "" { + continue + } + filtered = append(filtered, token) + } + return strings.Join(filtered, ",") +} + func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string) { if len(body) == 0 { return body, modelID @@ -1053,6 +1443,13 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu out := body modified := false + if opts.applySystemCloaking { + if next, changed := ensureClaudeOAuthSystemCloaking(out, opts.cloakingCLIVersion, opts.cloakingEntrypoint); changed { + out = next + modified = true + } + } + if next, changed := normalizeClaudeOAuthSystemBody(out, opts); changed { out = next modified = true @@ -4157,29 +4554,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode if shouldMimicClaudeCode { - // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要) - // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词 - if !strings.Contains(strings.ToLower(reqModel), "haiku") && - !systemIncludesClaudeCodePrompt(parsed.System) { - body = injectClaudeCodePrompt(body, parsed.System) - } - - normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} - if s.identityService != nil { - fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) - if err == nil && fp != nil { - // metadata 透传开启时跳过 metadata 注入 - _, mimicMPT := s.settingService.GetGatewayForwardingSettings(ctx) - if !mimicMPT { - if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" { - normalizeOpts.injectMetadata = true - normalizeOpts.metadataUserID = metadataUserID - } - } - } - } - - body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) + body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, claudeOAuthNormalizeOptions{ + stripSystemCacheControl: true, + }) } // 强制执行 cache_control 块数量限制(最多 4 个) @@ -4546,6 +4923,36 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), } } + // API Key 账号:重试耗尽后,对永久性错误(402等)也触发 failover + if account.Type == AccountTypeAPIKey { + respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if readErr != nil { + return s.handleRetryExhaustedError(ctx, resp, c, account) + } + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + action := ClassifyAPIKeyStatusAction(account, resp.StatusCode, respBody) + if action == APIKeyStatusActionPermanentDisable || action == APIKeyStatusActionTemporaryCooldown { + kind := "apikey_permanent_retry_exhausted_failover" + if action == APIKeyStatusActionTemporaryCooldown { + kind = "apikey_temporary_retry_exhausted_failover" + } + logger.LegacyPrintf("service.gateway", "[APIKey] Account %d: error %d after retries (%s), triggering failover", + account.ID, resp.StatusCode, kind) + s.handleRetryExhaustedSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: kind, + Message: sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(respBody))), + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + } return s.handleRetryExhaustedError(ctx, resp, c, account) } @@ -4581,6 +4988,37 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } } if resp.StatusCode >= 400 { + // API Key 账号的 400 错误:永久禁用直接 failover,临时冷却也 failover 切换让用户无感。 + if resp.StatusCode == 400 && account.Type == AccountTypeAPIKey { + respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if readErr != nil { + return s.handleErrorResponse(ctx, resp, c, account) + } + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + action := ClassifyAPIKeyStatusAction(account, resp.StatusCode, respBody) + if action == APIKeyStatusActionPermanentDisable || action == APIKeyStatusActionTemporaryCooldown { + kind := "apikey_permanent_400_failover" + if action == APIKeyStatusActionTemporaryCooldown { + kind = "apikey_temporary_400_failover" + } + logger.LegacyPrintf("service.gateway", "[APIKey] Account %d: 400 error (%s), triggering failover: %s", + account.ID, kind, truncateString(string(respBody), 500)) + s.handleFailoverSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: kind, + Message: sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(respBody))), + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + } + // 可选:对部分 400 触发 failover(默认关闭以保持语义) if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 { respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) @@ -4642,9 +5080,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A if reqStream { streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, shouldMimicClaudeCode) if err != nil { - if err.Error() == "have error in stream" { + var sseErr *streamSSEError + if errors.As(err, &sseErr) { + // Classify and persist account state before failing over. + s.rateLimitService.HandleUpstreamError(ctx, account, http.StatusForbidden, resp.Header, sseErr.body) return nil, &UpstreamFailoverError{ - StatusCode: 403, + StatusCode: http.StatusForbidden, + ResponseBody: sseErr.body, } } return nil, err @@ -4851,6 +5293,32 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput( RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), } } + // API Key 账号:重试耗尽后,对永久性错误(402等)也触发 failover + if account.Type == AccountTypeAPIKey { + respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if readErr != nil { + return s.handleRetryExhaustedError(ctx, resp, c, account) + } + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + if ShouldDisableAPIKeyStatus(account, resp.StatusCode, respBody) { + logger.LegacyPrintf("service.gateway", "[Anthropic Passthrough][APIKey] Account %d: permanent error %d after retries, triggering failover", + account.ID, resp.StatusCode) + s.handleRetryExhaustedSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "apikey_permanent_retry_exhausted_failover", + Message: sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(respBody))), + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + } return s.handleRetryExhaustedError(ctx, resp, c, account) } @@ -4886,6 +5354,33 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput( } } + // API Key 账号的永久性 400(余额不足、账号禁用等)触发 failover + if resp.StatusCode == 400 && account.Type == AccountTypeAPIKey { + respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if readErr != nil { + return s.handleErrorResponse(ctx, resp, c, account) + } + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + if ShouldDisableAPIKeyStatus(account, resp.StatusCode, respBody) { + logger.LegacyPrintf("service.gateway", "[Anthropic Passthrough][APIKey] Account %d: permanent 400 error, triggering failover: %s", + account.ID, truncateString(string(respBody), 500)) + s.handleFailoverSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "apikey_permanent_400_failover", + Message: sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(respBody))), + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + } + if resp.StatusCode >= 400 { return s.handleErrorResponse(ctx, resp, c, account) } @@ -5713,35 +6208,50 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex clientHeaders = c.Request.Header } - // OAuth账号:应用统一指纹和metadata重写(受设置开关控制) + downstreamAPIKeyHash := hashAPIKeyForSession(extractDownstreamAPIKey(clientHeaders)) + sessionID := strings.TrimSpace(getHeaderRaw(clientHeaders, "x-claude-code-session-id")) + if sessionID == "" { + sessionID = getOrCreateClaudeCodeSessionID(downstreamAPIKeyHash) + } + + // OAuth账号:对齐 Claude Code 请求身份(指纹 + system cloaking + metadata.user_id)。 var fingerprint *Fingerprint enableFP, enableMPT := true, false if s.settingService != nil { enableFP, enableMPT = s.settingService.GetGatewayForwardingSettings(ctx) } - if account.IsOAuth() && s.identityService != nil { - // 1. 获取或创建指纹(包含随机生成的ClientID) - fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) - if err != nil { - logger.LegacyPrintf("service.gateway", "Warning: failed to get fingerprint for account %d: %v", account.ID, err) - // 失败时降级为透传原始headers - } else { - if enableFP { + if account.IsOAuth() { + cliVersion := resolveClaudeCodeCLIVersion("") + deviceID := "" + if s.identityService != nil { + fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) + if err != nil { + logger.LegacyPrintf("service.gateway", "Warning: failed to get fingerprint for account %d: %v", account.ID, err) + } else if fp != nil { fingerprint = fp - } - - // 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid) - // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 - // 当 metadata 透传开启时跳过重写 - if !enableMPT { - accountUUID := account.GetExtraString("account_uuid") - if accountUUID != "" && fp.ClientID != "" { - if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 { - body = newBody - } + if strings.TrimSpace(fp.ClientID) == "" { + fp.ClientID = generateClientID() } + deviceID = strings.TrimSpace(fp.ClientID) + cliVersion = resolveClaudeCodeCLIVersion(fp.UserAgent) } } + if deviceID == "" { + deviceID = generateClientID() + } + + if next, changed := ensureClaudeOAuthSystemCloaking(body, cliVersion, defaultClaudeCodeEntrypoint); changed { + body = next + } + + metadataUserID := formatClaudeOAuthMetadataUserID( + deviceID, + strings.TrimSpace(account.GetExtraString("account_uuid")), + sessionID, + ) + if next, changed := forceClaudeOAuthMetadataUserID(body, metadataUserID); changed { + body = next + } } req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) @@ -5768,7 +6278,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } // OAuth账号:应用缓存的指纹到请求头(覆盖白名单透传的头) - if fingerprint != nil { + if fingerprint != nil && s.identityService != nil { s.identityService.ApplyFingerprint(req, fingerprint) } @@ -5780,33 +6290,24 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex setHeaderRaw(req.Header, "anthropic-version", "2023-06-01") } if tokenType == "oauth" { - applyClaudeOAuthHeaderDefaults(req) + applyClaudeOAuthHeaderDefaults(req, reqStream) + if sessionID != "" { + setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", sessionID) + } } // Build effective drop set: merge static defaults with dynamic beta policy filter rules policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account) effectiveDropSet := mergeDropSets(policyFilterSet) - effectiveDropWithClaudeCodeSet := mergeDropSets(policyFilterSet, claude.BetaClaudeCode) // 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta) if tokenType == "oauth" { if mimicClaudeCode { - // 非 Claude Code 客户端:按 opencode 的策略处理: - // - 强制 Claude Code 指纹相关请求头(尤其是 user-agent/x-stainless/x-app) - // - 保留 incoming beta 的同时,确保 OAuth 所需 beta 存在 + // 非 Claude Code 客户端:强制 Claude Code 风格请求头。 applyClaudeCodeMimicHeaders(req, reqStream) - - incomingBeta := getHeaderRaw(req.Header, "anthropic-beta") - // Match real Claude CLI traffic (per mitmproxy reports): - // messages requests typically use only oauth + interleaved-thinking. - // Also drop claude-code beta if a downstream client added it. - requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} - setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet)) - } else { - // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta - clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta") - setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), effectiveDropSet)) } + clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta") + setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, body, clientBetaHeader), effectiveDropSet)) } else { // API-key accounts: apply beta policy filter to strip controlled tokens if existingBeta := getHeaderRaw(req.Header, "anthropic-beta"); existingBeta != "" { @@ -5821,15 +6322,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } } - // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 - if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { - if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { - if parsed := ParseMetadataUserID(uid); parsed != nil { - setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) - } - } - } - // === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 === s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{ "url": req.URL.String(), @@ -5852,51 +6344,18 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex return req, nil } -// getBetaHeader 处理anthropic-beta header -// 对于OAuth账号,需要确保包含oauth-2025-04-20 -func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string { - // 如果客户端传了anthropic-beta +// getBetaHeader 处理 anthropic-beta header(按 auth2api 动态策略) +// 1) 客户端提供 anthropic-beta:保持原值并补齐 oauth beta(在首位) +// 2) 客户端未提供:按模型 + structured 输出动态生成默认值 +func (s *GatewayService) getBetaHeader(modelID string, body []byte, clientBetaHeader string) string { + clientBetaHeader = strings.TrimSpace(clientBetaHeader) if clientBetaHeader != "" { - // 已包含oauth beta则直接返回 if strings.Contains(clientBetaHeader, claude.BetaOAuth) { return clientBetaHeader } - - // 需要添加oauth beta - parts := strings.Split(clientBetaHeader, ",") - for i, p := range parts { - parts[i] = strings.TrimSpace(p) - } - - // 在claude-code-20250219后面插入oauth beta - claudeCodeIdx := -1 - for i, p := range parts { - if p == claude.BetaClaudeCode { - claudeCodeIdx = i - break - } - } - - if claudeCodeIdx >= 0 { - // 在claude-code后面插入 - newParts := make([]string, 0, len(parts)+1) - newParts = append(newParts, parts[:claudeCodeIdx+1]...) - newParts = append(newParts, claude.BetaOAuth) - newParts = append(newParts, parts[claudeCodeIdx+1:]...) - return strings.Join(newParts, ",") - } - - // 没有claude-code,放在第一位 return claude.BetaOAuth + "," + clientBetaHeader } - - // 客户端没传,根据模型生成 - // haiku 模型不需要 claude-code beta - if strings.Contains(strings.ToLower(modelID), "haiku") { - return claude.HaikuBetaHeader - } - - return claude.DefaultBetaHeader + return joinBetaTokens(buildAuth2APIDynamicBetaTokens(modelID, isStructuredOutputRequest(body))) } func requestNeedsBetaFeatures(body []byte) bool { @@ -5919,12 +6378,16 @@ func defaultAPIKeyBetaHeader(body []byte) string { return claude.APIKeyBetaHeader } -func applyClaudeOAuthHeaderDefaults(req *http.Request) { +func applyClaudeOAuthHeaderDefaults(req *http.Request, isStream bool) { if req == nil { return } if getHeaderRaw(req.Header, "Accept") == "" { - setHeaderRaw(req.Header, "Accept", "application/json") + if isStream { + setHeaderRaw(req.Header, "Accept", "text/event-stream") + } else { + setHeaderRaw(req.Header, "Accept", "application/json") + } } for key, value := range claude.DefaultHeaders { if value == "" { @@ -5934,6 +6397,39 @@ func applyClaudeOAuthHeaderDefaults(req *http.Request) { setHeaderRaw(req.Header, resolveWireCasing(key), value) } } + if getHeaderRaw(req.Header, "X-Stainless-Arch") == "" { + setHeaderRaw(req.Header, "X-Stainless-Arch", resolveStainlessArch()) + } + if getHeaderRaw(req.Header, "X-Stainless-Os") == "" { + setHeaderRaw(req.Header, "X-Stainless-Os", resolveStainlessOS()) + } + if getHeaderRaw(req.Header, "x-client-request-id") == "" { + setHeaderRaw(req.Header, "x-client-request-id", uuid.NewString()) + } +} + +func resolveStainlessArch() string { + switch runtime.GOARCH { + case "arm64": + return "arm64" + case "amd64": + return "x64" + default: + return "x86" + } +} + +func resolveStainlessOS() string { + switch runtime.GOOS { + case "darwin": + return "MacOS" + case "windows": + return "Windows" + case "freebsd": + return "FreeBSD" + default: + return "Linux" + } } func mergeAnthropicBeta(required []string, incoming string) string { @@ -6228,7 +6724,7 @@ func applyClaudeCodeMimicHeaders(req *http.Request, isStream bool) { return } // Start with the standard defaults (fill missing). - applyClaudeOAuthHeaderDefaults(req) + applyClaudeOAuthHeaderDefaults(req, isStream) // Then force key headers to match Claude Code fingerprint regardless of what the client sent. // 使用 resolveWireCasing 确保 key 与真实 wire format 一致(如 "x-app" 而非 "X-App") for key, value := range claude.DefaultHeaders { @@ -6237,10 +6733,12 @@ func applyClaudeCodeMimicHeaders(req *http.Request, isStream bool) { } setHeaderRaw(req.Header, resolveWireCasing(key), value) } - // Real Claude CLI uses Accept: application/json (even for streaming). - setHeaderRaw(req.Header, "Accept", "application/json") + setHeaderRaw(req.Header, "X-Stainless-Arch", resolveStainlessArch()) + setHeaderRaw(req.Header, "X-Stainless-Os", resolveStainlessOS()) if isStream { - setHeaderRaw(req.Header, "x-stainless-helper-method", "stream") + setHeaderRaw(req.Header, "Accept", "text/event-stream") + } else { + setHeaderRaw(req.Header, "Accept", "application/json") } } @@ -6429,6 +6927,12 @@ func extractUpstreamErrorCode(body []byte) string { return "" } +// extractUpstreamErrorType extracts the error type field from Anthropic-style responses: +// {"type":"error","error":{"type":"authentication_error","message":"..."}} +func extractUpstreamErrorType(body []byte) string { + return strings.TrimSpace(gjson.GetBytes(body, "error.type").String()) +} + func isCountTokensUnsupported404(statusCode int, body []byte) bool { if statusCode != http.StatusNotFound { return false @@ -6600,10 +7104,17 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re if account.IsOAuth() && statusCode == 403 { s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body) logger.LegacyPrintf("service.gateway", "Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode) - } else { - // API Key 未配置错误码:不标记账号状态 - logger.LegacyPrintf("service.gateway", "Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts) + return + } + + // API Key 账号:标记账号状态(永久禁用或临时冷却) + if account.Type == AccountTypeAPIKey { + s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body) + logger.LegacyPrintf("service.gateway", "Account %d: apikey error %d after %d retries, status marked", account.ID, statusCode, maxRetryAttempts) + return } + + logger.LegacyPrintf("service.gateway", "Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts) } func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { @@ -6851,7 +7362,12 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } if eventName == "error" { - return nil, dataLine, nil, errors.New("have error in stream") + // Wrap dataLine as a synthetic error body so callers can classify the error. + syntheticBody := []byte(dataLine) + if len(syntheticBody) == 0 { + syntheticBody = []byte(`{"error":{"type":"stream_error","message":"error event in stream"}}`) + } + return nil, dataLine, nil, &streamSSEError{body: syntheticBody} } if dataLine == "" { @@ -8451,25 +8967,19 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con clientHeaders = c.Request.Header } - // OAuth 账号:应用统一指纹和重写 userID(受设置开关控制) - // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 - ctEnableFP, ctEnableMPT := true, false - if s.settingService != nil { - ctEnableFP, ctEnableMPT = s.settingService.GetGatewayForwardingSettings(ctx) + downstreamAPIKeyHash := hashAPIKeyForSession(extractDownstreamAPIKey(clientHeaders)) + sessionID := strings.TrimSpace(getHeaderRaw(clientHeaders, "x-claude-code-session-id")) + if sessionID == "" { + sessionID = getOrCreateClaudeCodeSessionID(downstreamAPIKeyHash) } + var ctFingerprint *Fingerprint if account.IsOAuth() && s.identityService != nil { fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) - if err == nil { + if err == nil && fp != nil { ctFingerprint = fp - if !ctEnableMPT { - accountUUID := account.GetExtraString("account_uuid") - if accountUUID != "" && fp.ClientID != "" { - if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 { - body = newBody - } - } - } + } else if err != nil { + logger.LegacyPrintf("service.gateway", "Warning: failed to get fingerprint for count_tokens account %d: %v", account.ID, err) } } @@ -8496,8 +9006,8 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } } - // OAuth 账号:应用指纹到请求头(受设置开关控制) - if ctEnableFP && ctFingerprint != nil { + // OAuth 账号:应用指纹到请求头 + if ctFingerprint != nil && s.identityService != nil { s.identityService.ApplyFingerprint(req, ctFingerprint) } @@ -8509,7 +9019,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con setHeaderRaw(req.Header, "anthropic-version", "2023-06-01") } if tokenType == "oauth" { - applyClaudeOAuthHeaderDefaults(req) + applyClaudeOAuthHeaderDefaults(req, false) + if sessionID != "" { + setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", sessionID) + } } // Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules @@ -8519,22 +9032,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con if tokenType == "oauth" { if mimicClaudeCode { applyClaudeCodeMimicHeaders(req, false) - - incomingBeta := getHeaderRaw(req.Header, "anthropic-beta") - requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting} - setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet)) - } else { - clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta") - if clientBetaHeader == "" { - setHeaderRaw(req.Header, "anthropic-beta", claude.CountTokensBetaHeader) - } else { - beta := s.getBetaHeader(modelID, clientBetaHeader) - if !strings.Contains(beta, claude.BetaTokenCounting) { - beta = beta + "," + claude.BetaTokenCounting - } - setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(beta, ctEffectiveDropSet)) - } } + clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta") + beta := s.getBetaHeader(modelID, body, clientBetaHeader) + setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(beta, ctEffectiveDropSet)) } else { // API-key accounts: apply beta policy filter to strip controlled tokens if existingBeta := getHeaderRaw(req.Header, "anthropic-beta"); existingBeta != "" { @@ -8549,15 +9050,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } } - // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 - if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { - if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { - if parsed := ParseMetadataUserID(uid); parsed != nil { - setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) - } - } - } - if c != nil && tokenType == "oauth" { c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) } diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 5b1abc119f..57557df841 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -917,6 +917,36 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex // ErrorPolicyNone → 原有逻辑 s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + if account.Type == AccountTypeAPIKey { + action := ClassifyAPIKeyStatusAction(account, resp.StatusCode, respBody) + if action == APIKeyStatusActionPermanentDisable || action == APIKeyStatusActionTemporaryCooldown { + upstreamReqID := resp.Header.Get(requestIDHeader) + if upstreamReqID == "" { + upstreamReqID = resp.Header.Get("x-goog-request-id") + } + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: upstreamReqID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + } // 精确匹配服务端配置类 400 错误,触发 failover + 临时封禁 if resp.StatusCode == http.StatusBadRequest { msg400 := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) @@ -1403,6 +1433,33 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. // ErrorPolicyNone → 原有逻辑 s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + if account.Type == AccountTypeAPIKey { + action := ClassifyAPIKeyStatusAction(account, resp.StatusCode, respBody) + if action == APIKeyStatusActionPermanentDisable || action == APIKeyStatusActionTemporaryCooldown { + evBody := unwrapIfNeeded(isOAuth, respBody) + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(evBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: requestID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody} + } + } // 精确匹配服务端配置类 400 错误,触发 failover + 临时封禁 if resp.StatusCode == http.StatusBadRequest { msg400 := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) @@ -2725,6 +2782,13 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont if !account.ShouldHandleErrorCode(statusCode) { return } + if s.rateLimitService != nil && account.Type == AccountTypeAPIKey { + action := ClassifyAPIKeyStatusAction(account, statusCode, body) + if action == APIKeyStatusActionPermanentDisable || action == APIKeyStatusActionTemporaryCooldown { + s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body) + return + } + } if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) { s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body) return diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 5e09b95af2..f3d1dcff0a 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -71,6 +71,10 @@ func (m *mockAccountRepoForGemini) FindByExtraField(ctx context.Context, key str return nil, nil } +func (m *mockAccountRepoForGemini) FindByAPIKey(ctx context.Context, platform, apiKey, baseURL string) (*Account, error) { + return nil, nil +} + func (m *mockAccountRepoForGemini) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { return nil, nil } diff --git a/backend/internal/service/header_util.go b/backend/internal/service/header_util.go index 1091070df0..8001ae422e 100644 --- a/backend/internal/service/header_util.go +++ b/backend/internal/service/header_util.go @@ -20,7 +20,7 @@ var headerWireCasing = map[string]string{ "x-stainless-timeout": "X-Stainless-Timeout", "x-stainless-lang": "X-Stainless-Lang", "x-stainless-package-version": "X-Stainless-Package-Version", - "x-stainless-os": "X-Stainless-OS", + "x-stainless-os": "X-Stainless-Os", "x-stainless-arch": "X-Stainless-Arch", "x-stainless-runtime": "X-Stainless-Runtime", "x-stainless-runtime-version": "X-Stainless-Runtime-Version", @@ -51,7 +51,7 @@ var headerWireOrder = []string{ "X-Stainless-Timeout", "X-Stainless-Lang", "X-Stainless-Package-Version", - "X-Stainless-OS", + "X-Stainless-Os", "X-Stainless-Arch", "X-Stainless-Runtime", "X-Stainless-Runtime-Version", @@ -81,7 +81,7 @@ func init() { } } -// resolveWireCasing 将 Go canonical key(如 X-Stainless-Os)映射为真实 wire casing(如 X-Stainless-OS)。 +// resolveWireCasing 将 Go canonical key(如 X-Stainless-Os)映射为真实 wire casing(如 X-Stainless-Os)。 // 如果 map 中没有对应条目,返回原始 key 不变。 func resolveWireCasing(key string) string { if wk, ok := headerWireCasing[strings.ToLower(key)]; ok { diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index 3d7065083e..e3bb1221ab 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -26,13 +26,13 @@ var ( // 默认指纹值(当客户端未提供时使用) var defaultFingerprint = Fingerprint{ - UserAgent: "claude-cli/2.1.22 (external, cli)", + UserAgent: "claude-cli/2.1.88 (external, cli)", StainlessLang: "js", - StainlessPackageVersion: "0.70.0", - StainlessOS: "Linux", - StainlessArch: "arm64", + StainlessPackageVersion: "0.74.0", + StainlessOS: resolveStainlessOS(), + StainlessArch: resolveStainlessArch(), StainlessRuntime: "node", - StainlessRuntimeVersion: "v24.13.0", + StainlessRuntimeVersion: "v22.13.0", } // Fingerprint represents account fingerprint data @@ -132,7 +132,7 @@ func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fin // 获取x-stainless-*头,如果没有则使用默认值 fp.StainlessLang = getHeaderOrDefault(headers, "X-Stainless-Lang", defaultFingerprint.StainlessLang) fp.StainlessPackageVersion = getHeaderOrDefault(headers, "X-Stainless-Package-Version", defaultFingerprint.StainlessPackageVersion) - fp.StainlessOS = getHeaderOrDefault(headers, "X-Stainless-OS", defaultFingerprint.StainlessOS) + fp.StainlessOS = getHeaderOrDefault(headers, "X-Stainless-Os", defaultFingerprint.StainlessOS) fp.StainlessArch = getHeaderOrDefault(headers, "X-Stainless-Arch", defaultFingerprint.StainlessArch) fp.StainlessRuntime = getHeaderOrDefault(headers, "X-Stainless-Runtime", defaultFingerprint.StainlessRuntime) fp.StainlessRuntimeVersion = getHeaderOrDefault(headers, "X-Stainless-Runtime-Version", defaultFingerprint.StainlessRuntimeVersion) @@ -152,7 +152,7 @@ func mergeHeadersIntoFingerprint(fp *Fingerprint, headers http.Header) { // X-Stainless-* 头:仅在请求中实际携带时才更新,否则保留缓存值 mergeHeader(headers, "X-Stainless-Lang", &fp.StainlessLang) mergeHeader(headers, "X-Stainless-Package-Version", &fp.StainlessPackageVersion) - mergeHeader(headers, "X-Stainless-OS", &fp.StainlessOS) + mergeHeader(headers, "X-Stainless-Os", &fp.StainlessOS) mergeHeader(headers, "X-Stainless-Arch", &fp.StainlessArch) mergeHeader(headers, "X-Stainless-Runtime", &fp.StainlessRuntime) mergeHeader(headers, "X-Stainless-Runtime-Version", &fp.StainlessRuntimeVersion) @@ -174,7 +174,7 @@ func getHeaderOrDefault(headers http.Header, key, defaultValue string) string { } // ApplyFingerprint 将指纹应用到请求头(覆盖原有的x-stainless-*头) -// 使用 setHeaderRaw 保持原始大小写(如 X-Stainless-OS 而非 X-Stainless-Os) +// 使用 setHeaderRaw 保持原始大小写(如 X-Stainless-Os) func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) { if fp == nil { return @@ -193,7 +193,7 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) { setHeaderRaw(req.Header, "X-Stainless-Package-Version", fp.StainlessPackageVersion) } if fp.StainlessOS != "" { - setHeaderRaw(req.Header, "X-Stainless-OS", fp.StainlessOS) + setHeaderRaw(req.Header, "X-Stainless-Os", fp.StainlessOS) } if fp.StainlessArch != "" { setHeaderRaw(req.Header, "X-Stainless-Arch", fp.StainlessArch) diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 21b4874eb3..4ec038e068 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -275,6 +275,13 @@ func normalizeCodexModel(model string) string { return "gpt-5.1" } +func normalizeOpenAIModelForUpstream(account *Account, model string) string { + if account == nil || account.Type == AccountTypeOAuth { + return normalizeCodexModel(model) + } + return strings.TrimSpace(model) +} + func SupportsVerbosity(model string) bool { if !strings.HasPrefix(model, "gpt-") { return true diff --git a/backend/internal/service/openai_content_session_seed.go b/backend/internal/service/openai_content_session_seed.go new file mode 100644 index 0000000000..ab0413289f --- /dev/null +++ b/backend/internal/service/openai_content_session_seed.go @@ -0,0 +1,101 @@ +package service + +import ( + "encoding/json" + "strings" + + "github.com/tidwall/gjson" +) + +// contentSessionSeedPrefix prevents collisions between content-derived seeds +// and explicit session IDs like sess_* or compat_cc_*. +const contentSessionSeedPrefix = "compat_cs_" + +// deriveOpenAIContentSessionSeed builds a stable fallback seed from fields that +// should stay constant across turns for non-Codex clients. +func deriveOpenAIContentSessionSeed(body []byte) string { + if len(body) == 0 { + return "" + } + + var b strings.Builder + + if model := gjson.GetBytes(body, "model").String(); model != "" { + b.WriteString("model=") + b.WriteString(model) + } + + if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() && tools.Raw != "[]" { + b.WriteString("|tools=") + b.WriteString(normalizeCompatSeedJSON(json.RawMessage(tools.Raw))) + } + + if funcs := gjson.GetBytes(body, "functions"); funcs.Exists() && funcs.IsArray() && funcs.Raw != "[]" { + b.WriteString("|functions=") + b.WriteString(normalizeCompatSeedJSON(json.RawMessage(funcs.Raw))) + } + + if instructions := gjson.GetBytes(body, "instructions").String(); instructions != "" { + b.WriteString("|instructions=") + b.WriteString(instructions) + } + + firstUserCaptured := false + msgs := gjson.GetBytes(body, "messages") + if msgs.Exists() && msgs.IsArray() { + msgs.ForEach(func(_, msg gjson.Result) bool { + switch msg.Get("role").String() { + case "system", "developer": + b.WriteString("|system=") + if c := msg.Get("content"); c.Exists() { + b.WriteString(normalizeCompatSeedJSON(json.RawMessage(c.Raw))) + } + case "user": + if !firstUserCaptured { + b.WriteString("|first_user=") + if c := msg.Get("content"); c.Exists() { + b.WriteString(normalizeCompatSeedJSON(json.RawMessage(c.Raw))) + } + firstUserCaptured = true + } + } + return true + }) + } else if input := gjson.GetBytes(body, "input"); input.Exists() { + if input.Type == gjson.String { + b.WriteString("|input=") + b.WriteString(input.String()) + } else if input.IsArray() { + input.ForEach(func(_, item gjson.Result) bool { + switch item.Get("role").String() { + case "system", "developer": + b.WriteString("|system=") + if c := item.Get("content"); c.Exists() { + b.WriteString(normalizeCompatSeedJSON(json.RawMessage(c.Raw))) + } + case "user": + if !firstUserCaptured { + b.WriteString("|first_user=") + if c := item.Get("content"); c.Exists() { + b.WriteString(normalizeCompatSeedJSON(json.RawMessage(c.Raw))) + } + firstUserCaptured = true + } + } + if !firstUserCaptured && item.Get("type").String() == "input_text" { + b.WriteString("|first_user=") + if text := item.Get("text").String(); text != "" { + b.WriteString(text) + } + firstUserCaptured = true + } + return true + }) + } + } + + if b.Len() == 0 { + return "" + } + return contentSessionSeedPrefix + b.String() +} diff --git a/backend/internal/service/openai_content_session_seed_test.go b/backend/internal/service/openai_content_session_seed_test.go new file mode 100644 index 0000000000..1379f8e204 --- /dev/null +++ b/backend/internal/service/openai_content_session_seed_test.go @@ -0,0 +1,41 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDeriveOpenAIContentSessionSeed_EmptyInputs(t *testing.T) { + require.Empty(t, deriveOpenAIContentSessionSeed(nil)) + require.Empty(t, deriveOpenAIContentSessionSeed([]byte{})) + require.Empty(t, deriveOpenAIContentSessionSeed([]byte(`{}`))) +} + +func TestDeriveOpenAIContentSessionSeed_StableAcrossLaterTurns(t *testing.T) { + turn1 := []byte(`{"model":"gpt-5.4","messages":[{"role":"system","content":"You are helpful."},{"role":"user","content":"Hello"}]}`) + turn2 := []byte(`{"model":"gpt-5.4","messages":[{"role":"system","content":"You are helpful."},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi"},{"role":"user","content":"How are you?"}]}`) + require.Equal(t, deriveOpenAIContentSessionSeed(turn1), deriveOpenAIContentSessionSeed(turn2)) +} + +func TestDeriveOpenAIContentSessionSeed_UsesToolsAndFunctions(t *testing.T) { + withTools := []byte(`{"model":"gpt-5.4","tools":[{"type":"function","function":{"name":"get_weather"}}],"messages":[{"role":"user","content":"Hello"}]}`) + withoutTools := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"Hello"}]}`) + require.NotEqual(t, deriveOpenAIContentSessionSeed(withTools), deriveOpenAIContentSessionSeed(withoutTools)) + + withFunctions := []byte(`{"model":"gpt-5.4","functions":[{"name":"get_weather","parameters":{}}],"messages":[{"role":"user","content":"Hello"}]}`) + require.Contains(t, deriveOpenAIContentSessionSeed(withFunctions), "|functions=") +} + +func TestDeriveOpenAIContentSessionSeed_ResponsesInputAndInstructions(t *testing.T) { + body := []byte(`{"model":"gpt-5.4","instructions":"You are a coding assistant.","input":"Write hello world"}`) + seed := deriveOpenAIContentSessionSeed(body) + require.Contains(t, seed, "|instructions=You are a coding assistant.") + require.Contains(t, seed, "|input=Write hello world") +} + +func TestDeriveOpenAIContentSessionSeed_JSONCanonicalisation(t *testing.T) { + compact := []byte(`{"model":"gpt-5.4","tools":[{"type":"function","function":{"name":"get_weather","description":"Get weather"}}],"messages":[{"role":"user","content":"Hi"}]}`) + spaced := []byte(`{"model":"gpt-5.4","tools":[{"type":"function","function":{"description":"Get weather","name":"get_weather"}}],"messages":[{"role":"user","content":"Hi"}]}`) + require.Equal(t, deriveOpenAIContentSessionSeed(compact), deriveOpenAIContentSessionSeed(spaced)) +} diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index 1d5bf0d0a4..dd8078de1a 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -46,7 +46,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( // 2. Resolve model mapping early so compat prompt_cache_key injection can // derive a stable seed from the final upstream model family. billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) - upstreamModel := resolveOpenAIUpstreamModel(billingModel) + upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel) promptCacheKey = strings.TrimSpace(promptCacheKey) compatPromptCacheInjected := false @@ -219,7 +219,13 @@ func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse( c *gin.Context, account *Account, ) (*OpenAIForwardResult, error) { - return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError) + return s.handleCompatErrorResponse( + resp, + c, + account, + writeChatCompletionsError, + UpstreamFaultFormatOpenAIChatCompletions, + ) } // handleChatBufferedStreamingResponse reads all Responses SSE events from the diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 8c389556f2..58b076b9c2 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -62,7 +62,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( // 3. Model mapping billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel) - upstreamModel := resolveOpenAIUpstreamModel(billingModel) + upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel) responsesReq.Model = upstreamModel logger.L().Debug("openai messages: model mapping applied", @@ -224,7 +224,13 @@ func (s *OpenAIGatewayService) handleAnthropicErrorResponse( c *gin.Context, account *Account, ) (*OpenAIForwardResult, error) { - return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError) + return s.handleCompatErrorResponse( + resp, + c, + account, + writeAnthropicError, + UpstreamFaultFormatAnthropicMessages, + ) } // handleAnthropicBufferedStreamingResponse reads all Responses SSE events from diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index e85f0705aa..249782e280 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -322,6 +322,7 @@ type OpenAIGatewayService struct { openAITokenProvider *OpenAITokenProvider toolCorrector *CodexToolCorrector openaiWSResolver OpenAIWSProtocolResolver + upstreamFaultMapper *UpstreamFaultMapper openaiWSPoolOnce sync.Once openaiWSStateStoreOnce sync.Once @@ -384,6 +385,7 @@ func NewOpenAIGatewayService( openAITokenProvider: openAITokenProvider, toolCorrector: NewCodexToolCorrector(), openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + upstreamFaultMapper: NewUpstreamFaultMapper(), responseHeaderFilter: compileResponseHeaderFilter(cfg), codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), } @@ -398,6 +400,13 @@ func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle return defaultOpenAICodexSnapshotPersistThrottle } +func (s *OpenAIGatewayService) getUpstreamFaultMapper() *UpstreamFaultMapper { + if s != nil && s.upstreamFaultMapper != nil { + return s.upstreamFaultMapper + } + return defaultUpstreamFaultMapper +} + func (s *OpenAIGatewayService) billingDeps() *billingDeps { return &billingDeps{ accountRepo: s.accountRepo, @@ -1044,6 +1053,7 @@ func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) str // 1. Header: session_id // 2. Header: conversation_id // 3. Body: prompt_cache_key (opencode) +// 4. Body: content-based fallback (model + system + tools + first user message) func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte) string { if c == nil { return "" @@ -1056,6 +1066,9 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte) if sessionID == "" && len(body) > 0 { sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) } + if sessionID == "" && len(body) > 0 { + sessionID = deriveOpenAIContentSessionSeed(body) + } if sessionID == "" { return "" } @@ -2430,7 +2443,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( defer func() { _ = resp.Body.Close() }() if resp.StatusCode >= 400 { - // 透传模式不做 failover(避免改变原始上游语义),按上游原样返回错误响应。 + // 透传模式仅保留“请求透传”,错误出口统一走网关语义(含 failover)。 return nil, s.handleErrorResponsePassthrough(ctx, resp, c, account, body) } @@ -2624,46 +2637,46 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough( upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 - } - upstreamDetail = truncateString(string(body), maxBytes) - } - setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) - if s.rateLimitService != nil { - // Passthrough mode preserves the raw upstream error response, but runtime - // account state still needs to be updated so sticky routing can stop - // reusing a freshly rate-limited account. - _ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) - } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Passthrough: true, - Kind: "http_error", - Message: upstreamMsg, - Detail: upstreamDetail, - UpstreamResponseBody: upstreamDetail, - }) - writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) - contentType := resp.Header.Get("Content-Type") - if contentType == "" { - contentType = "application/json" + // 统一 failover 判定:透传模式也应走账号切换,而不是原样回包上游故障。 + if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, body) { + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(body), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + if s.rateLimitService != nil { + _ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) + } + return &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: body, + RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, body)), + } } - c.Data(resp.StatusCode, contentType, body) - if upstreamMsg == "" { - return fmt.Errorf("upstream error: %d", resp.StatusCode) + // 非 failover 的错误统一复用网关错误映射逻辑,避免原样透传上游故障文本。 + resp.Body = io.NopCloser(bytes.NewReader(body)) + _, err := s.handleErrorResponse(ctx, resp, c, account, requestBody) + if err != nil { + return err } - return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) + return fmt.Errorf("upstream error: %d", resp.StatusCode) } func isOpenAIPassthroughAllowedRequestHeader(lowerKey string, allowTimeoutHeaders bool) bool { @@ -3110,37 +3123,17 @@ func (s *OpenAIGatewayService) handleErrorResponse( } } - // Return appropriate error response - var errType, errMsg string - var statusCode int - - switch resp.StatusCode { - case 401: - statusCode = http.StatusBadGateway - errType = "upstream_error" - errMsg = "Upstream authentication failed, please contact administrator" - case 402: - statusCode = http.StatusBadGateway - errType = "upstream_error" - errMsg = "Upstream payment required: insufficient balance or billing issue" - case 403: - statusCode = http.StatusBadGateway - errType = "upstream_error" - errMsg = "Upstream access forbidden, please contact administrator" - case 429: - statusCode = http.StatusTooManyRequests - errType = "rate_limit_error" - errMsg = "Upstream rate limit exceeded, please retry later" - default: - statusCode = http.StatusBadGateway - errType = "upstream_error" - errMsg = "Upstream request failed" - } - - c.JSON(statusCode, gin.H{ + mapped := s.getUpstreamFaultMapper().Map( + resp.StatusCode, + extractUpstreamErrorCode(body), + extractUpstreamErrorType(body), + body, + UpstreamFaultFormatOpenAIResponses, + ) + c.JSON(mapped.StatusCode, gin.H{ "error": gin.H{ - "type": errType, - "message": errMsg, + "type": mapped.ErrorType, + "message": mapped.Message, }, }) @@ -3164,6 +3157,7 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse( c *gin.Context, account *Account, writeError compatErrorWriter, + faultFormat UpstreamFaultFormat, ) (*OpenAIForwardResult, error) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) @@ -3247,20 +3241,14 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse( } } - // Map status code to error type and write response - errType := "api_error" - switch { - case resp.StatusCode == 400: - errType = "invalid_request_error" - case resp.StatusCode == 404: - errType = "not_found_error" - case resp.StatusCode == 429: - errType = "rate_limit_error" - case resp.StatusCode >= 500: - errType = "api_error" - } - - writeError(c, resp.StatusCode, errType, upstreamMsg) + mapped := s.getUpstreamFaultMapper().Map( + resp.StatusCode, + extractUpstreamErrorCode(body), + extractUpstreamErrorType(body), + body, + faultFormat, + ) + writeError(c, mapped.StatusCode, mapped.ErrorType, mapped.Message) return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg) } diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 9e2f33f22a..d8f62a8f78 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -182,10 +182,19 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) { t.Fatalf("expected different hashes for different keys") } - // 4) empty when no signals - h4 := svc.GenerateSessionHash(c, []byte(`{}`)) - if h4 != "" { - t.Fatalf("expected empty hash when no signals") + // 4) content-based fallback when no explicit signals + h4 := svc.GenerateSessionHash(c, []byte(`{"model":"gpt-5.4","messages":[{"role":"system","content":"You are helpful."},{"role":"user","content":"Hello"}]}`)) + if h4 == "" { + t.Fatalf("expected non-empty hash when content fallback is available") + } + if h3 == h4 { + t.Fatalf("expected content fallback hash to differ from prompt_cache_key hash") + } + + // 5) still empty when body also lacks stable content + h5 := svc.GenerateSessionHash(c, []byte(`{}`)) + if h5 != "" { + t.Fatalf("expected empty hash when no signals and no content seed") } } diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go index 42f58b3741..0e4c8ef863 100644 --- a/backend/internal/service/openai_model_mapping_test.go +++ b/backend/internal/service/openai_model_mapping_test.go @@ -101,3 +101,17 @@ func TestResolveOpenAIUpstreamModel(t *testing.T) { } } } + +func TestNormalizeOpenAIModelForUpstream_PreservesAPIKeyModels(t *testing.T) { + account := &Account{Type: AccountTypeAPIKey} + if got := normalizeOpenAIModelForUpstream(account, "gpt-5.3"); got != "gpt-5.3" { + t.Fatalf("normalizeOpenAIModelForUpstream(api key) = %q, want %q", got, "gpt-5.3") + } +} + +func TestNormalizeOpenAIModelForUpstream_NormalizesOAuthModels(t *testing.T) { + account := &Account{Type: AccountTypeOAuth} + if got := normalizeOpenAIModelForUpstream(account, "gpt-5.3"); got != "gpt-5.3-codex" { + t.Fatalf("normalizeOpenAIModelForUpstream(oauth) = %q, want %q", got, "gpt-5.3-codex") + } +} diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index 97fa218d92..c4aa86f78e 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -528,13 +528,17 @@ func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughF _, err := svc.Forward(context.Background(), c, account, originalBody) require.Error(t, err) - // should append an upstream error event with passthrough=true + // Non-failover passthrough errors now use unified gateway error mapping (no raw passthrough body). + require.Equal(t, http.StatusBadGateway, rec.Code) + require.NotContains(t, rec.Body.String(), `"message":"bad"`) + + // unified path currently records non-passthrough ops events. v, ok := c.Get(OpsUpstreamErrorsKey) require.True(t, ok) arr, ok := v.([]*OpsUpstreamErrorEvent) require.True(t, ok) require.NotEmpty(t, arr) - require.True(t, arr[len(arr)-1].Passthrough) + require.False(t, arr[len(arr)-1].Passthrough) } func TestOpenAIGatewayService_OAuthPassthrough_429PersistsRateLimit(t *testing.T) { @@ -579,9 +583,10 @@ func TestOpenAIGatewayService_OAuthPassthrough_429PersistsRateLimit(t *testing.T } _, err := svc.Forward(context.Background(), c, account, originalBody) - require.Error(t, err) - require.Equal(t, http.StatusTooManyRequests, rec.Code) - require.Contains(t, rec.Body.String(), "usage_limit_reached") + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode) + require.Empty(t, rec.Body.String(), "service layer should not passthrough upstream fault body") require.Len(t, repo.rateLimitCalls, 1) require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second) } diff --git a/backend/internal/service/openai_upstream_fault_mapper_test.go b/backend/internal/service/openai_upstream_fault_mapper_test.go new file mode 100644 index 0000000000..dc35d4eaed --- /dev/null +++ b/backend/internal/service/openai_upstream_fault_mapper_test.go @@ -0,0 +1,108 @@ +package service + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestUpstreamFaultMapper_OpenAIResponsesContract(t *testing.T) { + t.Parallel() + + mapper := NewUpstreamFaultMapper() + mapped := mapper.Map( + http.StatusUnauthorized, + "invalid_api_key", + "authentication_error", + []byte(`{"error":{"message":"SECRET"}}`), + UpstreamFaultFormatOpenAIResponses, + ) + + require.Equal(t, UpstreamFaultCodeAuth, mapped.Code) + require.Equal(t, http.StatusBadGateway, mapped.StatusCode) + require.Equal(t, "upstream_error", mapped.ErrorType) + require.Equal(t, "Upstream authentication failed, please contact administrator", mapped.Message) + require.NotContains(t, mapped.Message, "SECRET") +} + +func TestOpenAIGatewayService_ForwardAsChatCompletions_UsesMappedFault(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.2","messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusBadRequest, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_chat_fault"}}, + Body: io.NopCloser(bytes.NewReader([]byte(`{"error":{"message":"SECRET_BILLING_TEXT"}}`))), + }} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + httpUpstream: upstream, + upstreamFaultMapper: NewUpstreamFaultMapper(), + } + account := &Account{ + ID: 1, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{"api_key": "sk-test"}, + } + + _, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "") + require.Error(t, err) + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), `"type":"invalid_request_error"`) + require.Contains(t, rec.Body.String(), `"message":"Upstream request was rejected"`) + require.NotContains(t, rec.Body.String(), "SECRET_BILLING_TEXT") +} + +func TestOpenAIGatewayService_ForwardAsAnthropic_UsesMappedFault(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.2","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusNotFound, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_msg_fault"}}, + Body: io.NopCloser(bytes.NewReader([]byte(`{"error":{"message":"SECRET_UPSTREAM_NOT_FOUND"}}`))), + }} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + httpUpstream: upstream, + upstreamFaultMapper: NewUpstreamFaultMapper(), + } + account := &Account{ + ID: 2, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{"api_key": "sk-test"}, + } + + _, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "") + require.Error(t, err) + require.Equal(t, http.StatusNotFound, rec.Code) + require.Contains(t, rec.Body.String(), `"type":"error"`) + require.Contains(t, rec.Body.String(), `"type":"not_found_error"`) + require.Contains(t, rec.Body.String(), `"message":"Upstream resource not found"`) + require.NotContains(t, rec.Body.String(), "SECRET_UPSTREAM_NOT_FOUND") +} diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 1ebe554236..83849bf35c 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -2515,7 +2515,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } normalized = next } - upstreamModel := resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel)) + upstreamModel := normalizeOpenAIModelForUpstream(account, account.GetMappedModel(originalModel)) if upstreamModel != originalModel { next, setErr := applyPayloadMutation(normalized, "model", upstreamModel) if setErr != nil { @@ -2773,7 +2773,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( mappedModel := "" var mappedModelBytes []byte if originalModel != "" { - mappedModel = resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel)) + mappedModel = normalizeOpenAIModelForUpstream(account, account.GetMappedModel(originalModel)) needModelReplace = mappedModel != "" && mappedModel != originalModel if needModelReplace { mappedModelBytes = []byte(mappedModel) diff --git a/backend/internal/service/openai_ws_protocol_resolver_test.go b/backend/internal/service/openai_ws_protocol_resolver_test.go index 4d5dc5f18b..4fa3c9273f 100644 --- a/backend/internal/service/openai_ws_protocol_resolver_test.go +++ b/backend/internal/service/openai_ws_protocol_resolver_test.go @@ -43,7 +43,7 @@ func TestOpenAIWSProtocolResolver_Resolve(t *testing.T) { account := *openAIOAuthEnabled account.Extra = map[string]any{ "openai_oauth_responses_websockets_v2_enabled": true, - "openai_passthrough": true, + "forward_passthrough_only": true, } decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account) require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) diff --git a/backend/internal/service/openai_ws_ratelimit_signal_test.go b/backend/internal/service/openai_ws_ratelimit_signal_test.go index ffe7915262..8c8dfa64a6 100644 --- a/backend/internal/service/openai_ws_ratelimit_signal_test.go +++ b/backend/internal/service/openai_ws_ratelimit_signal_test.go @@ -34,6 +34,10 @@ type openAICodexExtraListRepo struct { rateLimitCh chan time.Time } +func (r *openAIWSRateLimitSignalRepo) SetError(_ context.Context, _ int64, _ string) error { + return nil +} + func (r *openAIWSRateLimitSignalRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error { r.rateLimitCalls = append(r.rateLimitCalls, resetAt) return nil @@ -55,6 +59,10 @@ func (r *openAICodexSnapshotAsyncRepo) SetRateLimited(_ context.Context, _ int64 return nil } +func (r *openAICodexSnapshotAsyncRepo) SetError(_ context.Context, _ int64, _ string) error { + return nil +} + func (r *openAICodexSnapshotAsyncRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { if r.updateExtraCh != nil { copied := make(map[string]any, len(updates)) @@ -73,6 +81,10 @@ func (r *openAICodexExtraListRepo) SetRateLimited(_ context.Context, _ int64, re return nil } +func (r *openAICodexExtraListRepo) SetError(_ context.Context, _ int64, _ string) error { + return nil +} + func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { _ = platform _ = accountType @@ -159,13 +171,15 @@ func TestOpenAIGatewayService_Forward_WSv2ErrorEventUsageLimitPersistsRateLimit( } body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + before := time.Now() result, err := svc.Forward(context.Background(), c, &account, body) + after := time.Now() require.Error(t, err) require.Nil(t, result) require.Equal(t, http.StatusTooManyRequests, rec.Code) require.Nil(t, upstream.lastReq, "WS 限流 error event 不应回退到同账号 HTTP") require.Len(t, repo.rateLimitCalls, 1) - require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second) + require.WithinDuration(t, before.Add(apiKey429Cooldown), repo.rateLimitCalls[0], after.Sub(before)+time.Second) } func TestOpenAIGatewayService_Forward_WSv2Handshake429PersistsRateLimit(t *testing.T) { @@ -229,12 +243,15 @@ func TestOpenAIGatewayService_Forward_WSv2Handshake429PersistsRateLimit(t *testi } body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + before := time.Now() result, err := svc.Forward(context.Background(), c, &account, body) + after := time.Now() require.Error(t, err) require.Nil(t, result) require.Equal(t, http.StatusTooManyRequests, rec.Code) require.Nil(t, upstream.lastReq, "WS 握手 429 不应回退到同账号 HTTP") require.Len(t, repo.rateLimitCalls, 1) + require.WithinDuration(t, before.Add(apiKey429Cooldown), repo.rateLimitCalls[0], after.Sub(before)+time.Second) require.NotEmpty(t, repo.updateExtra, "握手 429 的 x-codex 头应立即落库") require.Contains(t, repo.updateExtra[0], "codex_usage_updated_at") } @@ -335,11 +352,13 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ErrorEventUsageL cancelWrite() require.NoError(t, err) + before := time.Now() select { case serverErr := <-serverErrCh: + after := time.Now() require.Error(t, serverErr) require.Len(t, repo.rateLimitCalls, 1) - require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second) + require.WithinDuration(t, before.Add(apiKey429Cooldown), repo.rateLimitCalls[0], after.Sub(before)+time.Second) case <-time.After(5 * time.Second): t.Fatal("等待 ingress websocket 结束超时") } diff --git a/backend/internal/service/ops_service.go b/backend/internal/service/ops_service.go index 29f0aa8b50..cef0a0b47f 100644 --- a/backend/internal/service/ops_service.go +++ b/backend/internal/service/ops_service.go @@ -301,7 +301,7 @@ func sanitizeOpsUpstreamErrors(entry *OpsInsertErrorLogInput) error { if out.UpstreamRequestBody != "" { // Reuse the same sanitization/trimming strategy as request body storage. // Keep it small so it is safe to persist in ops_error_logs JSON. - sanitizedBody, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024) + sanitizedBody, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), opsMaxStoredRequestBodyBytes) if sanitizedBody != "" { out.UpstreamRequestBody = sanitizedBody if truncated { diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 4f5b57cc97..efbff0b184 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -3,6 +3,7 @@ package service import ( "context" "encoding/json" + "fmt" "log/slog" "net/http" "strconv" @@ -52,6 +53,12 @@ type geminiUsageTotalsBatchProvider interface { const geminiPrecheckCacheTTL = time.Minute +const ( + apiKey429Cooldown = 60 * time.Minute + apiKey529Cooldown = 120 * time.Minute + apiKeyServerErrorCooldown = 60 * time.Minute +) + // NewRateLimitService 创建RateLimitService实例 func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService, tempUnschedCache TempUnschedCache) *RateLimitService { return &RateLimitService{ @@ -111,6 +118,29 @@ func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Accoun // HandleUpstreamError 处理上游错误响应,标记账号状态 // 返回是否应该停止该账号的调度 func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) { + if account != nil && account.Type == AccountTypeAPIKey { + // Pool mode without custom error codes: upstream manages key state, skip local marking. + if account.IsPoolMode() && !account.IsCustomErrorCodesEnabled() { + slog.Info("pool_mode_error_skipped", "account_id", account.ID, "status_code", statusCode) + return false + } + action := ClassifyAPIKeyStatusAction(account, statusCode, responseBody) + if s == nil || s.accountRepo == nil { + return action == APIKeyStatusActionPermanentDisable || action == APIKeyStatusActionTemporaryCooldown + } + switch action { + case APIKeyStatusActionPermanentDisable: + s.handleAPIKeyPermanentDisable(ctx, account, statusCode, responseBody) + return true + case APIKeyStatusActionTemporaryCooldown: + s.handleAPIKeyTemporaryCooldown(ctx, account, statusCode, headers, responseBody) + return true + default: + // Ignore or Valid: do not fall through to OAuth-oriented logic below. + return false + } + } + customErrorCodesEnabled := account.IsCustomErrorCodesEnabled() // 池模式默认不标记本地账号状态;仅当用户显式配置自定义错误码时按本地策略处理。 @@ -147,6 +177,10 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc msg := "Organization disabled (400): " + upstreamMsg s.handleAuthError(ctx, account, msg) shouldDisable = true + } else if strings.Contains(strings.ToLower(upstreamMsg), "identity verification is required") { + msg := "Identity verification required (400): " + upstreamMsg + s.handleAuthError(ctx, account, msg) + shouldDisable = true } // 其他 400 错误(如参数问题)不处理,不禁用账号 case 401: @@ -636,15 +670,81 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account) return s.geminiQuotaService.CooldownForAccount(ctx, account) } -// handleAuthError 处理认证类错误(401/403),停止账号调度 +// handleAuthError 处理认证类错误(401/403),停止账号调度并关闭调度开关 func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) { if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { slog.Warn("account_set_error_failed", "account_id", account.ID, "error", err) return } + if account.Schedulable { + if err := s.accountRepo.SetSchedulable(ctx, account.ID, false); err != nil { + slog.Warn("account_set_schedulable_false_failed", "account_id", account.ID, "error", err) + } + } slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg) } +func buildAPIKeyRuntimeErrorMessage(statusCode int, responseBody []byte, prefix string) string { + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if upstreamMsg == "" { + upstreamMsg = http.StatusText(statusCode) + } + if strings.TrimSpace(prefix) == "" { + return upstreamMsg + } + return fmt.Sprintf("%s (%d): %s", prefix, statusCode, upstreamMsg) +} + +func (s *RateLimitService) handleAPIKeyPermanentDisable(ctx context.Context, account *Account, statusCode int, responseBody []byte) { + msg := buildAPIKeyRuntimeErrorMessage(statusCode, responseBody, "API key permanently disabled after runtime detection") + s.handleAuthError(ctx, account, msg) +} + +func (s *RateLimitService) handleAPIKeyTemporaryCooldown(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) { + switch statusCode { + case http.StatusTooManyRequests: + if account.Platform == PlatformOpenAI { + s.persistOpenAICodexSnapshot(ctx, account, headers) + } + cooldown := apiKey429Cooldown + if account.Platform == PlatformGemini { + cooldown = s.GeminiCooldown(ctx, account) + } + resetAt := time.Now().Add(cooldown) + if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { + slog.Warn("apikey_rate_limit_set_failed", "account_id", account.ID, "status_code", statusCode, "error", err) + return + } + slog.Info("apikey_rate_limited", "account_id", account.ID, "status_code", statusCode, "reset_at", resetAt) + case 529: + until := time.Now().Add(apiKey529Cooldown) + if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil { + slog.Warn("apikey_overload_set_failed", "account_id", account.ID, "status_code", statusCode, "error", err) + return + } + slog.Info("apikey_overloaded", "account_id", account.ID, "status_code", statusCode, "until", until) + case http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: + until := time.Now().Add(apiKeyServerErrorCooldown) + reason := buildAPIKeyRuntimeErrorMessage(statusCode, responseBody, "API key temporary cooldown after upstream server error") + if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { + slog.Warn("apikey_temp_unsched_set_failed", "account_id", account.ID, "status_code", statusCode, "error", err) + return + } + slog.Info("apikey_temp_unschedulable", "account_id", account.ID, "status_code", statusCode, "until", until) + default: + // 400/402/403 等其他被分类为 TemporaryCooldown 的状态码(如模型权限不足、临时欠费等) + // 执行临时封禁,冷却后自动恢复调度。 + until := time.Now().Add(apiKeyServerErrorCooldown) + reason := buildAPIKeyRuntimeErrorMessage(statusCode, responseBody, "API key temporary cooldown") + if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { + slog.Warn("apikey_temp_unsched_set_failed", "account_id", account.ID, "status_code", statusCode, "error", err) + return + } + slog.Info("apikey_temp_unschedulable", "account_id", account.ID, "status_code", statusCode, "until", until) + } +} + // handle403 处理 403 Forbidden 错误 // Antigravity 平台区分 validation/violation/generic 三种类型,均 SetError 永久禁用; // 其他平台保持原有 SetError 行为。 @@ -652,7 +752,25 @@ func (s *RateLimitService) handle403(ctx context.Context, account *Account, upst if account.Platform == PlatformAntigravity { return s.handleAntigravity403(ctx, account, upstreamMsg, responseBody) } - // 非 Antigravity 平台:保持原有行为 + if account.Type == AccountTypeAPIKey { + action := ClassifyAPIKeyStatusAction(account, http.StatusForbidden, responseBody) + switch action { + case APIKeyStatusActionPermanentDisable: + msg := "Access forbidden (403): account permanently disabled" + if upstreamMsg != "" { + msg = "Access forbidden (403): " + upstreamMsg + } + s.handleAuthError(ctx, account, msg) + return true + case APIKeyStatusActionTemporaryCooldown: + s.handleAPIKeyTemporaryCooldown(ctx, account, http.StatusForbidden, nil, responseBody) + return true + default: + slog.Info("apikey_403_not_disabled", "account_id", account.ID, "platform", account.Platform) + return false + } + } + // 非 Antigravity 平台、非 APIKey:保持原有行为 msg := "Access forbidden (403): account may be suspended or lack permissions" if upstreamMsg != "" { msg = "Access forbidden (403): " + upstreamMsg diff --git a/backend/internal/service/ratelimit_service_401_test.go b/backend/internal/service/ratelimit_service_401_test.go index 67b22e5212..526165c761 100644 --- a/backend/internal/service/ratelimit_service_401_test.go +++ b/backend/internal/service/ratelimit_service_401_test.go @@ -17,9 +17,14 @@ type rateLimitAccountRepoStub struct { mockAccountRepoForGemini setErrorCalls int tempCalls int + rateLimitedCalls int + overloadedCalls int updateCredentialsCalls int lastCredentials map[string]any lastErrorMsg string + lastTempUntil *time.Time + lastRateLimitResetAt *time.Time + lastOverloadedUntil *time.Time } func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error { @@ -30,6 +35,19 @@ func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, error func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { r.tempCalls++ + r.lastTempUntil = &until + return nil +} + +func (r *rateLimitAccountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { + r.rateLimitedCalls++ + r.lastRateLimitResetAt = &resetAt + return nil +} + +func (r *rateLimitAccountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error { + r.overloadedCalls++ + r.lastOverloadedUntil = &until return nil } diff --git a/backend/internal/service/ratelimit_service_apikey_test.go b/backend/internal/service/ratelimit_service_apikey_test.go new file mode 100644 index 0000000000..6da1d480ad --- /dev/null +++ b/backend/internal/service/ratelimit_service_apikey_test.go @@ -0,0 +1,391 @@ +//go:build unit + +package service + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// rateLimitAccountRepoStubWithSchedulable extends rateLimitAccountRepoStub +// to track SetSchedulable calls. +type rateLimitAccountRepoStubWithSchedulable struct { + rateLimitAccountRepoStub + setSchedulableCalls int + lastSchedulable bool +} + +func (r *rateLimitAccountRepoStubWithSchedulable) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { + r.setSchedulableCalls++ + r.lastSchedulable = schedulable + return nil +} + +func TestRateLimitService_HandleUpstreamError_OpenAIAPIKey403ModelAccessIgnored(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 104, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + } + + shouldDisable := service.HandleUpstreamError( + context.Background(), + account, + http.StatusForbidden, + http.Header{}, + []byte(`{"error":{"message":"model not allowed for this project","code":"forbidden"}}`), + ) + + require.False(t, shouldDisable) + require.Equal(t, 0, repo.setErrorCalls) +} + +func TestRateLimitService_HandleUpstreamError_GeminiAPIKey400InvalidDisables(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 105, + Platform: PlatformGemini, + Type: AccountTypeAPIKey, + } + + shouldDisable := service.HandleUpstreamError( + context.Background(), + account, + http.StatusBadRequest, + http.Header{}, + []byte(`{"error":{"message":"API key not valid. Please pass a valid API key.","status":"API_KEY_INVALID"}}`), + ) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.setErrorCalls) +} + +func TestRateLimitService_HandleUpstreamError_APIKey429UsesTemporaryCooldown(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 106, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + } + + before := time.Now() + shouldDisable := service.HandleUpstreamError( + context.Background(), + account, + http.StatusTooManyRequests, + http.Header{}, + []byte(`{"error":{"message":"rate limited"}}`), + ) + after := time.Now() + + require.True(t, shouldDisable) + require.Equal(t, 0, repo.setErrorCalls) + require.Equal(t, 1, repo.rateLimitedCalls) + require.Equal(t, 0, repo.overloadedCalls) + require.Equal(t, 0, repo.tempCalls) + require.NotNil(t, repo.lastRateLimitResetAt) + require.WithinDuration(t, before.Add(apiKey429Cooldown), *repo.lastRateLimitResetAt, after.Sub(before)+time.Second) +} + +func TestRateLimitService_HandleUpstreamError_APIKey529UsesTemporaryCooldown(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 108, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + } + + before := time.Now() + shouldDisable := service.HandleUpstreamError( + context.Background(), + account, + 529, + http.Header{}, + []byte(`{"error":{"message":"overloaded"}}`), + ) + after := time.Now() + + require.True(t, shouldDisable) + require.Equal(t, 0, repo.setErrorCalls) + require.Equal(t, 0, repo.rateLimitedCalls) + require.Equal(t, 1, repo.overloadedCalls) + require.Equal(t, 0, repo.tempCalls) + require.NotNil(t, repo.lastOverloadedUntil) + require.WithinDuration(t, before.Add(apiKey529Cooldown), *repo.lastOverloadedUntil, after.Sub(before)+time.Second) +} + +func TestRateLimitService_HandleUpstreamError_APIKey503UsesTemporaryCooldown(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 107, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + } + + before := time.Now() + shouldDisable := service.HandleUpstreamError( + context.Background(), + account, + http.StatusServiceUnavailable, + http.Header{}, + []byte(`{"error":{"message":"service temporarily unavailable"}}`), + ) + after := time.Now() + + require.True(t, shouldDisable) + require.Equal(t, 0, repo.setErrorCalls) + require.Equal(t, 0, repo.rateLimitedCalls) + require.Equal(t, 0, repo.overloadedCalls) + require.Equal(t, 1, repo.tempCalls) + require.NotNil(t, repo.lastTempUntil) + require.WithinDuration(t, before.Add(apiKeyServerErrorCooldown), *repo.lastTempUntil, after.Sub(before)+time.Second) +} + +// TestRateLimitService_HandleUpstreamError_OpenAIAPIKey402AccountNotActivePermanentDisable +// 验证 OpenAI API key 账单欠费/账号未激活(402)触发永久禁用 +func TestRateLimitService_HandleUpstreamError_OpenAIAPIKey402AccountNotActivePermanentDisable(t *testing.T) { + repo := &rateLimitAccountRepoStubWithSchedulable{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 201, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Schedulable: true, + } + + shouldDisable := svc.HandleUpstreamError( + context.Background(), + account, + http.StatusPaymentRequired, + http.Header{}, + []byte(`{"error":{"message":"Your account is not active, please check your billing details on our website.","type":"invalid_request_error","code":"account_inactive"}}`), + ) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.setErrorCalls) +} + +// TestRateLimitService_HandleAuthError_ClosesSchedulingSwitch +// 验证 handleAuthError 在永久禁用 key 时同步关闭调度开关 +func TestRateLimitService_HandleAuthError_ClosesSchedulingSwitch(t *testing.T) { + repo := &rateLimitAccountRepoStubWithSchedulable{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 202, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Schedulable: true, + } + + svc.handleAuthError(context.Background(), account, "API key permanently disabled") + + require.Equal(t, 1, repo.setErrorCalls) + require.Equal(t, 1, repo.setSchedulableCalls) + require.False(t, repo.lastSchedulable) +} + +// TestRateLimitService_HandleAuthError_SkipsSchedulableIfAlreadyFalse +// 验证调度开关已关闭时不重复调用 SetSchedulable +func TestRateLimitService_HandleAuthError_SkipsSchedulableIfAlreadyFalse(t *testing.T) { + repo := &rateLimitAccountRepoStubWithSchedulable{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 203, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Schedulable: false, + } + + svc.handleAuthError(context.Background(), account, "already disabled") + + require.Equal(t, 1, repo.setErrorCalls) + require.Equal(t, 0, repo.setSchedulableCalls) +} + +// TestClassifyAPIKeyStatusAction_OpenAIAccountNotActive +// 验证 "account is not active" 消息被正确识别为永久禁用 +func TestClassifyAPIKeyStatusAction_OpenAIAccountNotActive(t *testing.T) { + tests := []struct { + name string + statusCode int + body []byte + expected APIKeyStatusAction + }{ + { + name: "403 account is not active", + statusCode: http.StatusForbidden, + body: []byte(`{"error":{"message":"Your account is not active, please check your billing details on our website.","type":"invalid_request_error"}}`), + expected: APIKeyStatusActionPermanentDisable, + }, + { + name: "400 account is not active", + statusCode: http.StatusBadRequest, + body: []byte(`{"error":{"message":"Your account is not active, please check your billing details.","type":"invalid_request_error"}}`), + expected: APIKeyStatusActionPermanentDisable, + }, + { + name: "402 payment required", + statusCode: http.StatusPaymentRequired, + body: []byte(`{"error":{"message":"You exceeded your current quota","type":"insufficient_quota"}}`), + expected: APIKeyStatusActionPermanentDisable, + }, + { + name: "403 billing_not_active code", + statusCode: http.StatusForbidden, + body: []byte(`{"error":{"message":"Billing not active","code":"billing_not_active"}}`), + expected: APIKeyStatusActionPermanentDisable, + }, + { + name: "403 account suspended", + statusCode: http.StatusForbidden, + body: []byte(`{"error":{"message":"Your account has been suspended","type":"invalid_request_error"}}`), + expected: APIKeyStatusActionPermanentDisable, + }, + { + name: "403 model access forbidden should be ignored", + statusCode: http.StatusForbidden, + body: []byte(`{"error":{"message":"model not allowed for this project","code":"forbidden"}}`), + expected: APIKeyStatusActionIgnore, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + ID: 999, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + } + result := ClassifyAPIKeyStatusAction(account, tt.statusCode, tt.body) + require.Equal(t, tt.expected, result) + }) + } +} + +// TestClassifyAPIKeyStatusAction_OpenAI429InsufficientQuota +// 验证 OpenAI 429 insufficient_quota 被识别为永久禁用(余额耗尽)而非临时限速 +func TestClassifyAPIKeyStatusAction_OpenAI429InsufficientQuota(t *testing.T) { + account := &Account{ID: 1, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + tests := []struct { + name string + body []byte + expected APIKeyStatusAction + }{ + { + name: "insufficient_quota code", + body: []byte(`{"error":{"message":"You exceeded your current quota, please check your plan and billing details.","type":"insufficient_quota","code":"insufficient_quota"}}`), + expected: APIKeyStatusActionPermanentDisable, + }, + { + name: "regular rate limit should be cooldown", + body: []byte(`{"error":{"message":"Rate limit reached for model","type":"requests","code":"rate_limit_exceeded"}}`), + expected: APIKeyStatusActionTemporaryCooldown, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ClassifyAPIKeyStatusAction(account, http.StatusTooManyRequests, tt.body) + require.Equal(t, tt.expected, result) + }) + } +} + +// TestClassifyAPIKeyStatusAction_AnthropicCreditBalance +// 验证 Anthropic 400 余额不足被正确识别为永久禁用 +func TestClassifyAPIKeyStatusAction_AnthropicCreditBalance(t *testing.T) { + account := &Account{ID: 2, Platform: PlatformAnthropic, Type: AccountTypeAPIKey} + + tests := []struct { + name string + body []byte + expected APIKeyStatusAction + }{ + { + name: "credit balance too low", + body: []byte(`{"type":"error","error":{"type":"invalid_request_error","message":"Your credit balance is too low to access the Anthropic API. Please go to Plans \u0026 Billing to upgrade or purchase credits."}}`), + expected: APIKeyStatusActionPermanentDisable, + }, + { + name: "401 invalid key", + body: []byte(`{"type":"error","error":{"type":"authentication_error","message":"Invalid API key"}}`), + expected: APIKeyStatusActionPermanentDisable, + }, + { + name: "400 unrelated bad request should be ignored", + body: []byte(`{"type":"error","error":{"type":"invalid_request_error","message":"max_tokens is required"}}`), + expected: APIKeyStatusActionIgnore, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + statusCode := http.StatusBadRequest + if tt.name == "401 invalid key" { + statusCode = http.StatusUnauthorized + } + result := ClassifyAPIKeyStatusAction(account, statusCode, tt.body) + require.Equal(t, tt.expected, result) + }) + } +} + +// TestClassifyAPIKeyStatusAction_GeminiBillingDisabled +// 验证 Gemini 403 BILLING_DISABLED / CONSUMER_SUSPENDED 被识别为永久禁用 +func TestClassifyAPIKeyStatusAction_GeminiBillingDisabled(t *testing.T) { + account := &Account{ID: 3, Platform: PlatformGemini, Type: AccountTypeAPIKey} + + billingDisabledBody := []byte(`{ + "error": { + "code": 403, + "message": "Billing is disabled for this project.", + "status": "PERMISSION_DENIED", + "details": [{"@type": "type.googleapis.com/google.rpc.ErrorInfo","reason": "BILLING_DISABLED","domain": "googleapis.com"}] + } + }`) + + consumerSuspendedBody := []byte(`{ + "error": { + "code": 403, + "message": "The caller does not have permission", + "status": "PERMISSION_DENIED", + "details": [{"@type": "type.googleapis.com/google.rpc.ErrorInfo","reason": "CONSUMER_SUSPENDED","domain": "googleapis.com"}] + } + }`) + + failedPreconditionBody := []byte(`{ + "error": { + "code": 400, + "message": "Gemini API free tier is not available in your country. Please enable billing on your project in Google AI Studio.", + "status": "FAILED_PRECONDITION" + } + }`) + + tests := []struct { + name string + statusCode int + body []byte + expected APIKeyStatusAction + }{ + {"403 BILLING_DISABLED", http.StatusForbidden, billingDisabledBody, APIKeyStatusActionPermanentDisable}, + {"403 CONSUMER_SUSPENDED", http.StatusForbidden, consumerSuspendedBody, APIKeyStatusActionPermanentDisable}, + {"400 FAILED_PRECONDITION free tier", http.StatusBadRequest, failedPreconditionBody, APIKeyStatusActionPermanentDisable}, + {"429 rate limit is cooldown", http.StatusTooManyRequests, []byte(`{"error":{"code":429,"message":"Resource exhausted","status":"RESOURCE_EXHAUSTED"}}`), APIKeyStatusActionTemporaryCooldown}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ClassifyAPIKeyStatusAction(account, tt.statusCode, tt.body) + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/backend/internal/service/ratelimit_session_window_test.go b/backend/internal/service/ratelimit_session_window_test.go index 7796a85e76..010b939d63 100644 --- a/backend/internal/service/ratelimit_session_window_test.go +++ b/backend/internal/service/ratelimit_session_window_test.go @@ -73,6 +73,9 @@ func (m *sessionWindowMockRepo) GetByCRSAccountID(context.Context, string) (*Acc func (m *sessionWindowMockRepo) FindByExtraField(context.Context, string, any) ([]Account, error) { panic("unexpected") } +func (m *sessionWindowMockRepo) FindByAPIKey(context.Context, string, string, string) (*Account, error) { + panic("unexpected") +} func (m *sessionWindowMockRepo) ListCRSAccountIDs(context.Context) (map[string]int64, error) { panic("unexpected") } diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 4f3f2e8f95..1983d8c583 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -303,6 +303,14 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc "error", setErr, ) } + if account.Schedulable { + if setErr := s.accountRepo.SetSchedulable(ctx, account.ID, false); setErr != nil { + slog.Error("token_refresh.set_schedulable_false_failed", + "account_id", account.ID, + "error", setErr, + ) + } + } // 刷新失败但 access_token 可能仍有效,尝试设置隐私 s.ensureOpenAIPrivacy(ctx, account) s.ensureAntigravityPrivacy(ctx, account) diff --git a/backend/internal/service/upstream_fault_mapper.go b/backend/internal/service/upstream_fault_mapper.go new file mode 100644 index 0000000000..78806b51b2 --- /dev/null +++ b/backend/internal/service/upstream_fault_mapper.go @@ -0,0 +1,189 @@ +package service + +import ( + "net/http" + "strings" +) + +// UpstreamFaultFormat identifies which downstream API contract to render. +type UpstreamFaultFormat string + +const ( + UpstreamFaultFormatOpenAIResponses UpstreamFaultFormat = "openai_responses" + UpstreamFaultFormatOpenAIChatCompletions UpstreamFaultFormat = "openai_chat_completions" + UpstreamFaultFormatAnthropicMessages UpstreamFaultFormat = "anthropic_messages" +) + +// UpstreamFaultCode is a provider-neutral internal fault code. +type UpstreamFaultCode string + +const ( + UpstreamFaultCodeUnknown UpstreamFaultCode = "upstream_unknown_error" + UpstreamFaultCodeAuth UpstreamFaultCode = "upstream_auth_error" + UpstreamFaultCodePayment UpstreamFaultCode = "upstream_payment_error" + UpstreamFaultCodeForbidden UpstreamFaultCode = "upstream_forbidden_error" + UpstreamFaultCodeRateLimited UpstreamFaultCode = "upstream_rate_limited" + UpstreamFaultCodeInvalid UpstreamFaultCode = "upstream_invalid_request" + UpstreamFaultCodeNotFound UpstreamFaultCode = "upstream_not_found" +) + +// UpstreamMappedFault is the normalized output contract used by handlers. +type UpstreamMappedFault struct { + Code UpstreamFaultCode + StatusCode int + ErrorType string + Message string +} + +// UpstreamFaultMapper maps upstream status/code/type/body into unified client-facing errors. +// It must never expose raw upstream message text to end users. +type UpstreamFaultMapper struct{} + +var defaultUpstreamFaultMapper = &UpstreamFaultMapper{} + +func NewUpstreamFaultMapper() *UpstreamFaultMapper { + return &UpstreamFaultMapper{} +} + +func (m *UpstreamFaultMapper) Map( + statusCode int, + upstreamCode string, + upstreamType string, + responseBody []byte, + format UpstreamFaultFormat, +) UpstreamMappedFault { + normalizedCode := strings.ToLower(strings.TrimSpace(upstreamCode)) + normalizedType := strings.ToLower(strings.TrimSpace(upstreamType)) + normalizedBody := strings.ToLower(string(responseBody)) + quotaExhausted := statusCode == http.StatusTooManyRequests && + (normalizedCode == "insufficient_quota" || + normalizedType == "insufficient_quota" || + strings.Contains(normalizedBody, "insufficient_quota")) + + switch format { + case UpstreamFaultFormatOpenAIResponses: + return mapOpenAIResponsesFault(statusCode, quotaExhausted) + case UpstreamFaultFormatAnthropicMessages: + return mapAnthropicMessagesFault(statusCode, quotaExhausted) + default: + return mapOpenAIChatCompletionsFault(statusCode, quotaExhausted) + } +} + +func mapOpenAIResponsesFault(statusCode int, quotaExhausted bool) UpstreamMappedFault { + switch statusCode { + case http.StatusUnauthorized: + return UpstreamMappedFault{ + Code: UpstreamFaultCodeAuth, + StatusCode: http.StatusBadGateway, + ErrorType: "upstream_error", + Message: "Upstream authentication failed, please contact administrator", + } + case http.StatusPaymentRequired: + return UpstreamMappedFault{ + Code: UpstreamFaultCodePayment, + StatusCode: http.StatusBadGateway, + ErrorType: "upstream_error", + Message: "Upstream payment required: insufficient balance or billing issue", + } + case http.StatusForbidden: + return UpstreamMappedFault{ + Code: UpstreamFaultCodeForbidden, + StatusCode: http.StatusBadGateway, + ErrorType: "upstream_error", + Message: "Upstream access forbidden, please contact administrator", + } + case http.StatusTooManyRequests: + internalCode := UpstreamFaultCodeRateLimited + if quotaExhausted { + internalCode = UpstreamFaultCodePayment + } + return UpstreamMappedFault{ + Code: internalCode, + StatusCode: http.StatusTooManyRequests, + ErrorType: "rate_limit_error", + Message: "Upstream rate limit exceeded, please retry later", + } + default: + return UpstreamMappedFault{ + Code: UpstreamFaultCodeUnknown, + StatusCode: http.StatusBadGateway, + ErrorType: "upstream_error", + Message: "Upstream request failed", + } + } +} + +func mapOpenAIChatCompletionsFault(statusCode int, quotaExhausted bool) UpstreamMappedFault { + switch statusCode { + case http.StatusBadRequest: + return UpstreamMappedFault{ + Code: UpstreamFaultCodeInvalid, + StatusCode: http.StatusBadRequest, + ErrorType: "invalid_request_error", + Message: "Upstream request was rejected", + } + case http.StatusNotFound: + return UpstreamMappedFault{ + Code: UpstreamFaultCodeNotFound, + StatusCode: http.StatusNotFound, + ErrorType: "not_found_error", + Message: "Upstream resource not found", + } + case http.StatusTooManyRequests: + internalCode := UpstreamFaultCodeRateLimited + if quotaExhausted { + internalCode = UpstreamFaultCodePayment + } + return UpstreamMappedFault{ + Code: internalCode, + StatusCode: http.StatusTooManyRequests, + ErrorType: "rate_limit_error", + Message: "Upstream rate limit exceeded, please retry later", + } + default: + return UpstreamMappedFault{ + Code: UpstreamFaultCodeUnknown, + StatusCode: statusCode, + ErrorType: "api_error", + Message: "Upstream request failed", + } + } +} + +func mapAnthropicMessagesFault(statusCode int, quotaExhausted bool) UpstreamMappedFault { + switch statusCode { + case http.StatusBadRequest: + return UpstreamMappedFault{ + Code: UpstreamFaultCodeInvalid, + StatusCode: http.StatusBadRequest, + ErrorType: "invalid_request_error", + Message: "Upstream request was rejected", + } + case http.StatusNotFound: + return UpstreamMappedFault{ + Code: UpstreamFaultCodeNotFound, + StatusCode: http.StatusNotFound, + ErrorType: "not_found_error", + Message: "Upstream resource not found", + } + case http.StatusTooManyRequests: + internalCode := UpstreamFaultCodeRateLimited + if quotaExhausted { + internalCode = UpstreamFaultCodePayment + } + return UpstreamMappedFault{ + Code: internalCode, + StatusCode: http.StatusTooManyRequests, + ErrorType: "rate_limit_error", + Message: "Upstream rate limit exceeded, please retry later", + } + default: + return UpstreamMappedFault{ + Code: UpstreamFaultCodeUnknown, + StatusCode: statusCode, + ErrorType: "api_error", + Message: "Upstream request failed", + } + } +} diff --git a/backend/migrations/082_add_apikey_unique_constraint_notx.sql b/backend/migrations/082_add_apikey_unique_constraint_notx.sql new file mode 100644 index 0000000000..2becffca39 --- /dev/null +++ b/backend/migrations/082_add_apikey_unique_constraint_notx.sql @@ -0,0 +1,21 @@ +-- 082_add_apikey_unique_constraint.sql +-- Add a partial unique index on accounts to prevent duplicate API key entries. +-- +-- Uniqueness is defined as: (platform, api_key, normalised base_url) must be +-- unique among non-deleted apikey-type accounts. +-- +-- base_url is normalised by stripping the trailing slash so that +-- "https://api.openai.com" and "https://api.openai.com/" are treated as equal. +-- Empty base_url is stored as '' and compared as ''. +-- +-- The index is CONCURRENT (no-transaction) to avoid locking a busy table. +-- It uses IF NOT EXISTS for idempotency. + +CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS accounts_apikey_unique_active + ON accounts ( + platform, + (credentials->>'api_key'), + (COALESCE(NULLIF(TRIM(TRAILING '/' FROM credentials->>'base_url'), ''), '')) + ) + WHERE type = 'apikey' + AND deleted_at IS NULL; diff --git a/backend/migrations/083_disable_error_passthrough_body.sql b/backend/migrations/083_disable_error_passthrough_body.sql new file mode 100644 index 0000000000..db136d1181 --- /dev/null +++ b/backend/migrations/083_disable_error_passthrough_body.sql @@ -0,0 +1,9 @@ +-- P2: disable error_passthrough_rules.passthrough_body capability. +-- Runtime no longer allows forwarding upstream raw fault text to clients. + +ALTER TABLE error_passthrough_rules +ALTER COLUMN passthrough_body SET DEFAULT false; + +UPDATE error_passthrough_rules +SET passthrough_body = false +WHERE passthrough_body = true; diff --git a/backend/migrations/084_migrate_openai_passthrough_key.sql b/backend/migrations/084_migrate_openai_passthrough_key.sql new file mode 100644 index 0000000000..9b410139f0 --- /dev/null +++ b/backend/migrations/084_migrate_openai_passthrough_key.sql @@ -0,0 +1,27 @@ +-- P2: migrate OpenAI account extra key from openai_passthrough/openai_oauth_passthrough +-- to forward_passthrough_only. + +UPDATE accounts +SET extra = CASE + WHEN extra ? 'forward_passthrough_only' THEN + extra - 'openai_passthrough' - 'openai_oauth_passthrough' + WHEN extra ? 'openai_passthrough' THEN + jsonb_set( + extra - 'openai_passthrough' - 'openai_oauth_passthrough', + '{forward_passthrough_only}', + extra->'openai_passthrough', + true + ) + WHEN extra ? 'openai_oauth_passthrough' THEN + jsonb_set( + extra - 'openai_oauth_passthrough', + '{forward_passthrough_only}', + extra->'openai_oauth_passthrough', + true + ) + ELSE + extra +END +WHERE platform = 'openai' + AND extra IS NOT NULL + AND (extra ? 'openai_passthrough' OR extra ? 'openai_oauth_passthrough'); diff --git a/deploy/DOCKER.md b/deploy/DOCKER.md index 156b6c97a9..3794f6dc23 100644 --- a/deploy/DOCKER.md +++ b/deploy/DOCKER.md @@ -10,7 +10,7 @@ docker run -d \ -p 8080:8080 \ -e DATABASE_URL="postgres://user:pass@host:5432/sub2api" \ -e REDIS_URL="redis://host:6379" \ - weishaw/sub2api:latest + ghcr.io/kw0ngr/sub2api:latest ``` ## Docker Compose @@ -20,7 +20,7 @@ version: '3.8' services: sub2api: - image: weishaw/sub2api:latest + image: ghcr.io/kw0ngr/sub2api:latest ports: - "8080:8080" environment: @@ -72,5 +72,5 @@ volumes: ## Links -- [GitHub Repository](https://github.com/weishaw/sub2api) -- [Documentation](https://github.com/weishaw/sub2api#readme) +- [GitHub Repository](https://github.com/kw0ngr/sub2api) +- [Documentation](https://github.com/kw0ngr/sub2api#readme) diff --git a/deploy/Dockerfile b/deploy/Dockerfile index 7caa5ca63e..68b51d1638 100644 --- a/deploy/Dockerfile +++ b/deploy/Dockerfile @@ -7,7 +7,7 @@ # ============================================================================= ARG NODE_IMAGE=node:24-alpine -ARG GOLANG_IMAGE=golang:1.26.1-alpine +ARG GOLANG_IMAGE=golang:1.26.2-alpine ARG ALPINE_IMAGE=alpine:3.20 ARG GOPROXY=https://goproxy.cn,direct ARG GOSUMDB=sum.golang.google.cn @@ -73,9 +73,9 @@ RUN CGO_ENABLED=0 GOOS=linux go build \ FROM ${ALPINE_IMAGE} # Labels -LABEL maintainer="Wei-Shaw " +LABEL maintainer="kw0ngr " LABEL description="Sub2API - AI API Gateway Platform" -LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api" +LABEL org.opencontainers.image.source="https://github.com/kw0ngr/sub2api" # Install runtime dependencies RUN apk add --no-cache \ diff --git a/deploy/README.md b/deploy/README.md index dd311721d9..cdf5248203 100644 --- a/deploy/README.md +++ b/deploy/README.md @@ -35,10 +35,10 @@ Use the automated preparation script for the easiest setup: ```bash # Download and run the preparation script -curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash +curl -sSL https://raw.githubusercontent.com/kw0ngr/sub2api/main/deploy/docker-deploy.sh | bash # Or download first, then run -curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh -o docker-deploy.sh +curl -sSL https://raw.githubusercontent.com/kw0ngr/sub2api/main/deploy/docker-deploy.sh -o docker-deploy.sh chmod +x docker-deploy.sh ./docker-deploy.sh ``` @@ -71,7 +71,7 @@ If you prefer manual control: ```bash # Clone repository -git clone https://github.com/Wei-Shaw/sub2api.git +git clone https://github.com/kw0ngr/sub2api.git cd sub2api/deploy # Configure environment @@ -353,12 +353,12 @@ For production servers using systemd. ### One-Line Installation ```bash -curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash +curl -sSL https://raw.githubusercontent.com/kw0ngr/sub2api/main/deploy/install.sh | sudo bash ``` ### Manual Installation -1. Download the latest release from [GitHub Releases](https://github.com/Wei-Shaw/sub2api/releases) +1. Download the latest release from [GitHub Releases](https://github.com/kw0ngr/sub2api/releases) 2. Extract and copy the binary to `/opt/sub2api/` 3. Copy `sub2api.service` to `/etc/systemd/system/` 4. Run: diff --git a/deploy/docker-deploy.sh b/deploy/docker-deploy.sh index a07f4f417a..9048e4c96c 100644 --- a/deploy/docker-deploy.sh +++ b/deploy/docker-deploy.sh @@ -21,7 +21,7 @@ BLUE='\033[0;34m' NC='\033[0m' # No Color # GitHub raw content base URL -GITHUB_RAW_URL="https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy" +GITHUB_RAW_URL="https://raw.githubusercontent.com/kw0ngr/sub2api/main/deploy" # Print colored message print_info() { diff --git a/deploy/install.sh b/deploy/install.sh index 6dcf41238e..2facc73090 100644 --- a/deploy/install.sh +++ b/deploy/install.sh @@ -2,7 +2,7 @@ # # Sub2API Installation Script # Sub2API 安装脚本 -# Usage: curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | bash +# Usage: curl -sSL https://raw.githubusercontent.com/kw0ngr/sub2api/main/deploy/install.sh | bash # set -e @@ -16,7 +16,7 @@ CYAN='\033[0;36m' NC='\033[0m' # No Color # Configuration -GITHUB_REPO="Wei-Shaw/sub2api" +GITHUB_REPO="kw0ngr/sub2api" INSTALL_DIR="/opt/sub2api" SERVICE_NAME="sub2api" SERVICE_USER="sub2api" @@ -112,6 +112,9 @@ declare -A MSG_ZH=( ["fetching_versions"]="正在获取可用版本..." ["not_installed"]="Sub2API 尚未安装,请先执行全新安装" ["fresh_install_hint"]="用法" + ["fork_no_releases"]="Fork 仓库尚未发布任何 GitHub Releases" + ["fork_only_release_source"]="当前安装脚本仅从该 Fork 仓库下载发布包" + ["fork_publish_release_hint"]="请先在 GitHub 上创建如 vX.Y.Z 的 tag,并等待 release workflow 发布资产后重试" # Uninstall ["uninstall_confirm"]="这将从系统中移除 Sub2API。" @@ -237,6 +240,9 @@ declare -A MSG_EN=( ["fetching_versions"]="Fetching available versions..." ["not_installed"]="Sub2API is not installed. Please run a fresh install first" ["fresh_install_hint"]="Usage" + ["fork_no_releases"]="No GitHub Releases have been published for this fork yet" + ["fork_only_release_source"]="This installer only downloads release artifacts from this fork" + ["fork_publish_release_hint"]="Create a tag such as vX.Y.Z on GitHub and wait for the release workflow to publish assets before retrying" # Uninstall ["uninstall_confirm"]="This will remove Sub2API from your system." @@ -465,9 +471,37 @@ check_dependencies() { fi } +get_releases_json() { + curl -s --connect-timeout 10 --max-time 30 "https://api.github.com/repos/${GITHUB_REPO}/releases?per_page=20" 2>/dev/null +} + +print_no_releases_guidance() { + print_error "$(msg 'fork_no_releases'): ${GITHUB_REPO}" + print_info "$(msg 'fork_only_release_source')" + print_info "$(msg 'fork_publish_release_hint')" +} + +ensure_releases_available() { + local releases_json="$1" + + if echo "$releases_json" | grep -q '"message"[[:space:]]*:[[:space:]]*"Not Found"'; then + print_no_releases_guidance + exit 1 + fi + + if ! echo "$releases_json" | grep -q '"tag_name"'; then + print_no_releases_guidance + exit 1 + fi +} + # Get latest release version get_latest_version() { print_info "$(msg 'fetching_version')" + local releases_json + releases_json=$(get_releases_json) + ensure_releases_available "$releases_json" + LATEST_VERSION=$(curl -s --connect-timeout 10 --max-time 30 "https://api.github.com/repos/${GITHUB_REPO}/releases/latest" 2>/dev/null | grep '"tag_name"' | sed -E 's/.*"([^"]+)".*/\1/') if [ -z "$LATEST_VERSION" ]; then @@ -484,7 +518,10 @@ list_versions() { print_info "$(msg 'fetching_versions')" local versions - versions=$(curl -s --connect-timeout 10 --max-time 30 "https://api.github.com/repos/${GITHUB_REPO}/releases" 2>/dev/null | grep '"tag_name"' | sed -E 's/.*"([^"]+)".*/\1/' | head -20) + local releases_json + releases_json=$(get_releases_json) + ensure_releases_available "$releases_json" + versions=$(echo "$releases_json" | grep '"tag_name"' | sed -E 's/.*"([^"]+)".*/\1/' | head -20) if [ -z "$versions" ]; then print_error "$(msg 'failed_get_version')" @@ -517,6 +554,17 @@ validate_version() { version="v$version" fi + local releases_json + releases_json=$(get_releases_json) + if echo "$releases_json" | grep -q '"message"[[:space:]]*:[[:space:]]*"Not Found"'; then + print_no_releases_guidance >&2 + exit 1 + fi + if ! echo "$releases_json" | grep -q '"tag_name"'; then + print_no_releases_guidance >&2 + exit 1 + fi + print_info "$(msg 'validating_version') $version" >&2 # Check if the release exists @@ -655,7 +703,7 @@ install_service() { cat > /etc/systemd/system/sub2api.service << EOF [Unit] Description=Sub2API - AI API Gateway Platform -Documentation=https://github.com/Wei-Shaw/sub2api +Documentation=https://github.com/kw0ngr/sub2api After=network.target postgresql.service redis.service Wants=postgresql.service redis.service diff --git a/deploy/sub2api.service b/deploy/sub2api.service index 1a59ad032c..2444f77828 100644 --- a/deploy/sub2api.service +++ b/deploy/sub2api.service @@ -1,6 +1,6 @@ [Unit] Description=Sub2API - AI API Gateway Platform -Documentation=https://github.com/Wei-Shaw/sub2api +Documentation=https://github.com/kw0ngr/sub2api After=network.target postgresql.service redis.service Wants=postgresql.service redis.service diff --git a/frontend/package.json b/frontend/package.json index d2a6deded7..49868cd903 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -18,7 +18,7 @@ "@lobehub/icons": "^4.0.2", "@tanstack/vue-virtual": "^3.13.23", "@vueuse/core": "^10.7.0", - "axios": "^1.13.5", + "axios": "^1.15.0", "chart.js": "^4.4.1", "dompurify": "^3.3.1", "driver.js": "^1.4.0", diff --git a/frontend/pnpm-lock.yaml b/frontend/pnpm-lock.yaml index 505b72f388..fe147c7f7e 100644 --- a/frontend/pnpm-lock.yaml +++ b/frontend/pnpm-lock.yaml @@ -18,8 +18,8 @@ importers: specifier: ^10.7.0 version: 10.11.1(vue@3.5.26(typescript@5.6.3)) axios: - specifier: ^1.13.5 - version: 1.13.5 + specifier: ^1.15.0 + version: 1.15.0 chart.js: specifier: ^4.4.1 version: 4.5.1 @@ -1260,67 +1260,56 @@ packages: resolution: {integrity: sha512-EHMUcDwhtdRGlXZsGSIuXSYwD5kOT9NVnx9sqzYiwAc91wfYOE1g1djOEDseZJKKqtHAHGwnGPQu3kytmfaXLQ==} cpu: [arm] os: [linux] - libc: [glibc] '@rollup/rollup-linux-arm-musleabihf@4.54.0': resolution: {integrity: sha512-+pBrqEjaakN2ySv5RVrj/qLytYhPKEUwk+e3SFU5jTLHIcAtqh2rLrd/OkbNuHJpsBgxsD8ccJt5ga/SeG0JmA==} cpu: [arm] os: [linux] - libc: [musl] '@rollup/rollup-linux-arm64-gnu@4.54.0': resolution: {integrity: sha512-NSqc7rE9wuUaRBsBp5ckQ5CVz5aIRKCwsoa6WMF7G01sX3/qHUw/z4pv+D+ahL1EIKy6Enpcnz1RY8pf7bjwng==} cpu: [arm64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-arm64-musl@4.54.0': resolution: {integrity: sha512-gr5vDbg3Bakga5kbdpqx81m2n9IX8M6gIMlQQIXiLTNeQW6CucvuInJ91EuCJ/JYvc+rcLLsDFcfAD1K7fMofg==} cpu: [arm64] os: [linux] - libc: [musl] '@rollup/rollup-linux-loong64-gnu@4.54.0': resolution: {integrity: sha512-gsrtB1NA3ZYj2vq0Rzkylo9ylCtW/PhpLEivlgWe0bpgtX5+9j9EZa0wtZiCjgu6zmSeZWyI/e2YRX1URozpIw==} cpu: [loong64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-ppc64-gnu@4.54.0': resolution: {integrity: sha512-y3qNOfTBStmFNq+t4s7Tmc9hW2ENtPg8FeUD/VShI7rKxNW7O4fFeaYbMsd3tpFlIg1Q8IapFgy7Q9i2BqeBvA==} cpu: [ppc64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-riscv64-gnu@4.54.0': resolution: {integrity: sha512-89sepv7h2lIVPsFma8iwmccN7Yjjtgz0Rj/Ou6fEqg3HDhpCa+Et+YSufy27i6b0Wav69Qv4WBNl3Rs6pwhebQ==} cpu: [riscv64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-riscv64-musl@4.54.0': resolution: {integrity: sha512-ZcU77ieh0M2Q8Ur7D5X7KvK+UxbXeDHwiOt/CPSBTI1fBmeDMivW0dPkdqkT4rOgDjrDDBUed9x4EgraIKoR2A==} cpu: [riscv64] os: [linux] - libc: [musl] '@rollup/rollup-linux-s390x-gnu@4.54.0': resolution: {integrity: sha512-2AdWy5RdDF5+4YfG/YesGDDtbyJlC9LHmL6rZw6FurBJ5n4vFGupsOBGfwMRjBYH7qRQowT8D/U4LoSvVwOhSQ==} cpu: [s390x] os: [linux] - libc: [glibc] '@rollup/rollup-linux-x64-gnu@4.54.0': resolution: {integrity: sha512-WGt5J8Ij/rvyqpFexxk3ffKqqbLf9AqrTBbWDk7ApGUzaIs6V+s2s84kAxklFwmMF/vBNGrVdYgbblCOFFezMQ==} cpu: [x64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-x64-musl@4.54.0': resolution: {integrity: sha512-JzQmb38ATzHjxlPHuTH6tE7ojnMKM2kYNzt44LO/jJi8BpceEC8QuXYA908n8r3CNuG/B3BV8VR3Hi1rYtmPiw==} cpu: [x64] os: [linux] - libc: [musl] '@rollup/rollup-openharmony-arm64@4.54.0': resolution: {integrity: sha512-huT3fd0iC7jigGh7n3q/+lfPcXxBi+om/Rs3yiFxjvSxbSB6aohDFXbWvlspaqjeOh+hx7DDHS+5Es5qRkWkZg==} @@ -1827,8 +1816,8 @@ packages: peerDependencies: postcss: ^8.1.0 - axios@1.13.5: - resolution: {integrity: sha512-cz4ur7Vb0xS4/KUN0tPWe44eqxrIu31me+fbang3ijiNscE129POzipJJA6zniq2C/Z6sJCjMimjS8Lc/GAs8Q==} + axios@1.15.0: + resolution: {integrity: sha512-wWyJDlAatxk30ZJer+GeCWS209sA42X+N5jU2jy6oHTp7ufw8uzUTVFBX9+wTfAlhiJXGS0Bq7X6efruWjuK9Q==} babel-plugin-macros@3.1.0: resolution: {integrity: sha512-Cg7TFGpIr01vOQNODXOOaGz2NpCU5gl8x1qJFbb6hbZxR7XrcE2vtbAsTAbJ7/xwJtUuJEw8K8Zr/AE0LHlesg==} @@ -3503,8 +3492,9 @@ packages: proto-list@1.2.4: resolution: {integrity: sha512-vtK/94akxsTMhe0/cbfpR+syPuszcuwhqVjJq26CuNDgFGj682oRBXOP5MJpv2r7JtE8MsiepGIqvvOTBwn2vA==} - proxy-from-env@1.1.0: - resolution: {integrity: sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==} + proxy-from-env@2.1.0: + resolution: {integrity: sha512-cJ+oHTW1VAEa8cJslgmUZrc+sjRKgAKl3Zyse6+PV38hZe/V6Z14TbCuXcan9F9ghlz4QrFr2c92TNF82UkYHA==} + engines: {node: '>=10'} psl@1.15.0: resolution: {integrity: sha512-JZd3gMVBAVQkSs6HdNZo9Sdo0LNcQeMNP3CozBJb3JYC/QUYZTnKxP+f8oWRX4rHP5EurWxqAHTSwUCjlNKa1w==} @@ -6416,11 +6406,11 @@ snapshots: postcss: 8.5.6 postcss-value-parser: 4.2.0 - axios@1.13.5: + axios@1.15.0: dependencies: follow-redirects: 1.15.11 form-data: 4.0.5 - proxy-from-env: 1.1.0 + proxy-from-env: 2.1.0 transitivePeerDependencies: - debug @@ -8530,7 +8520,7 @@ snapshots: proto-list@1.2.4: {} - proxy-from-env@1.1.0: {} + proxy-from-env@2.1.0: {} psl@1.15.0: dependencies: diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index fd93fe7ef0..034e92fca1 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -402,6 +402,50 @@ export interface BatchTodayStatsResponse { stats: Record } +export interface RawAPIKeyImportLineResult { + line: number + key_preview?: string + platform?: string + account_id?: number + created: boolean + checked: boolean + valid: boolean + invalid_disabled: boolean + error?: string + message?: string + status_code?: number +} + +export interface RawAPIKeyImportResult { + total_lines: number + created: number + checked: number + valid: number + invalid_disabled: number + failed: number + results: RawAPIKeyImportLineResult[] +} + +export interface APIKeyHealthCheckItem { + account_id: number + name: string + platform: string + status_code?: number + valid: boolean + invalid_disabled: boolean + error?: string + message?: string +} + +export interface APIKeyHealthCheckResult { + total: number + checked: number + valid: number + invalid_disabled: number + failed: number + results: APIKeyHealthCheckItem[] +} + /** * 批量获取多个账号的今日统计 * @param accountIds - 账号 ID 列表 @@ -532,6 +576,25 @@ export async function importData(payload: { return data } +export async function importRawAPIKeys(payload: { + raw_text: string + validate_after_import?: boolean + skip_default_group_bind?: boolean +}): Promise { + const { data } = await apiClient.post('/admin/accounts/raw-import', payload, { + timeout: 120000 + }) + return data +} + +export async function checkAPIKeysHealth(accountIds?: number[]): Promise { + const payload = accountIds && accountIds.length > 0 ? { account_ids: accountIds } : {} + const { data } = await apiClient.post('/admin/accounts/apikey-health-check', payload, { + timeout: 120000 + }) + return data +} + /** * Get Antigravity default model mapping from backend * @returns Default model mapping (from -> to) @@ -671,6 +734,8 @@ export const accountsAPI = { syncFromCrs, exportData, importData, + importRawAPIKeys, + checkAPIKeysHealth, getAntigravityDefaultModelMapping, batchClearError, batchRefresh, diff --git a/frontend/src/components/admin/account/RawKeyImportModal.vue b/frontend/src/components/admin/account/RawKeyImportModal.vue new file mode 100644 index 0000000000..5bb4dfbd87 --- /dev/null +++ b/frontend/src/components/admin/account/RawKeyImportModal.vue @@ -0,0 +1,172 @@ +