diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index dd9a4e588f..1559290566 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -52,6 +52,11 @@ const ( ConnectionPoolIsolationAccountProxy = "account_proxy" ) +// DefaultUpstreamResponseReadMaxBytes 上游非流式响应体的默认读取上限。 +// 128 MB 可容纳 2-3 张 4K PNG(base64 膨胀 33%,单张 4K PNG 最坏约 67MB base64)。 +// 可通过 gateway.upstream_response_read_max_bytes 配置项覆盖。 +const DefaultUpstreamResponseReadMaxBytes int64 = 128 * 1024 * 1024 + type Config struct { Server ServerConfig `mapstructure:"server"` Log LogConfig `mapstructure:"log"` @@ -1407,7 +1412,7 @@ func setDefaults() { viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.antigravity_extra_retries", 10) viper.SetDefault("gateway.max_body_size", int64(256*1024*1024)) - viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024)) + viper.SetDefault("gateway.upstream_response_read_max_bytes", DefaultUpstreamResponseReadMaxBytes) viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024)) viper.SetDefault("gateway.gemini_debug_response_headers", false) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go index 1ddb8ae2dc..854dca548b 100644 --- a/backend/internal/handler/payment_handler.go +++ b/backend/internal/handler/payment_handler.go @@ -206,6 +206,10 @@ type CreateOrderRequest struct { PaymentType string `json:"payment_type" binding:"required"` OrderType string `json:"order_type"` PlanID int64 `json:"plan_id"` + // IsMobile lets the frontend declare its mobile status directly. When + // nil we fall back to User-Agent heuristics (which miss iPadOS / some + // embedded browsers that strip the "Mobile" keyword). + IsMobile *bool `json:"is_mobile,omitempty"` } // CreateOrder creates a new payment order. @@ -222,12 +226,16 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) { return } + mobile := isMobile(c) + if req.IsMobile != nil { + mobile = *req.IsMobile + } result, err := h.paymentService.CreateOrder(c.Request.Context(), service.CreateOrderRequest{ UserID: subject.UserID, Amount: req.Amount, PaymentType: req.PaymentType, ClientIP: c.ClientIP(), - IsMobile: isMobile(c), + IsMobile: mobile, SrcHost: c.Request.Host, SrcURL: c.Request.Referer(), OrderType: req.OrderType, diff --git a/backend/internal/payment/crypto.go b/backend/internal/payment/crypto.go index e39e957f4d..0581469d65 100644 --- a/backend/internal/payment/crypto.go +++ b/backend/internal/payment/crypto.go @@ -10,12 +10,20 @@ import ( "strings" ) +// AES256KeySize is the required key length (in bytes) for AES-256-GCM. +const AES256KeySize = 32 + // Encrypt encrypts plaintext using AES-256-GCM with the given 32-byte key. // The output format is "iv:authTag:ciphertext" where each component is base64-encoded, // matching the Node.js crypto.ts format for cross-compatibility. +// +// Deprecated: payment provider configs are now stored as plaintext JSON. +// This function is kept only for seeding legacy ciphertext in tests and for +// the transitional Decrypt fallback. Scheduled for removal after all live +// deployments complete migration by re-saving their configs. func Encrypt(plaintext string, key []byte) (string, error) { - if len(key) != 32 { - return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key)) + if len(key) != AES256KeySize { + return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key)) } block, err := aes.NewCipher(key) @@ -51,9 +59,14 @@ func Encrypt(plaintext string, key []byte) (string, error) { // Decrypt decrypts a ciphertext string produced by Encrypt. // The input format is "iv:authTag:ciphertext" where each component is base64-encoded. +// +// Deprecated: payment provider configs are now stored as plaintext JSON. +// This function remains only as a read-path fallback for pre-migration +// ciphertext records. Scheduled for removal once all deployments re-save +// their provider configs through the admin UI. func Decrypt(ciphertext string, key []byte) (string, error) { - if len(key) != 32 { - return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key)) + if len(key) != AES256KeySize { + return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key)) } parts := strings.SplitN(ciphertext, ":", 3) diff --git a/backend/internal/payment/load_balancer.go b/backend/internal/payment/load_balancer.go index f035317384..ec244cd676 100644 --- a/backend/internal/payment/load_balancer.go +++ b/backend/internal/payment/load_balancer.go @@ -261,6 +261,9 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns if err != nil { return nil, fmt.Errorf("decrypt instance %d config: %w", selected.ID, err) } + if config == nil { + config = map[string]string{} + } if selected.PaymentMode != "" { config["paymentMode"] = selected.PaymentMode @@ -275,16 +278,36 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns }, nil } -func (lb *DefaultLoadBalancer) decryptConfig(encrypted string) (map[string]string, error) { - plaintext, err := Decrypt(encrypted, lb.encryptionKey) - if err != nil { - return nil, err +// decryptConfig parses a stored provider config. +// New records are plaintext JSON; legacy records are AES-256-GCM ciphertext. +// Unreadable values (legacy ciphertext without a valid key, or malformed data) +// are treated as empty so the service keeps running while the admin re-enters +// the config via the UI. +// +// TODO(deprecated-legacy-ciphertext): The AES fallback branch below is a +// transitional compatibility shim for pre-plaintext records. Remove it (and +// the encryptionKey field + the Decrypt import) after a few releases once all +// live deployments have re-saved their provider configs through the UI. +func (lb *DefaultLoadBalancer) decryptConfig(stored string) (map[string]string, error) { + if stored == "" { + return nil, nil } var config map[string]string - if err := json.Unmarshal([]byte(plaintext), &config); err != nil { - return nil, fmt.Errorf("unmarshal config: %w", err) + if err := json.Unmarshal([]byte(stored), &config); err == nil { + return config, nil + } + // Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal. + if len(lb.encryptionKey) == AES256KeySize { + //nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal + if plaintext, err := Decrypt(stored, lb.encryptionKey); err == nil { + if err := json.Unmarshal([]byte(plaintext), &config); err == nil { + return config, nil + } + } } - return config, nil + slog.Warn("payment provider config unreadable, treating as empty for re-entry", + "stored_len", len(stored)) + return nil, nil } // GetInstanceDailyAmount returns the total completed order amount for an instance today. diff --git a/backend/internal/payment/load_balancer_test.go b/backend/internal/payment/load_balancer_test.go index 04b3c25b47..2bf4f6ac41 100644 --- a/backend/internal/payment/load_balancer_test.go +++ b/backend/internal/payment/load_balancer_test.go @@ -452,6 +452,103 @@ func TestStartOfDay(t *testing.T) { } } +func TestDecryptConfig_PlaintextAndLegacyCompat(t *testing.T) { + t.Parallel() + + key := make([]byte, AES256KeySize) + for i := range key { + key[i] = byte(i + 1) + } + wrongKey := make([]byte, AES256KeySize) + for i := range wrongKey { + wrongKey[i] = byte(0xFF - i) + } + + plaintextJSON := `{"appId":"app-123","secret":"sec-xyz"}` + + legacyEncrypted, err := Encrypt(plaintextJSON, key) + if err != nil { + t.Fatalf("seed Encrypt: %v", err) + } + + tests := []struct { + name string + stored string + key []byte + want map[string]string + }{ + { + name: "empty stored returns nil map", + stored: "", + key: key, + want: nil, + }, + { + name: "plaintext JSON parses directly", + stored: plaintextJSON, + key: nil, + want: map[string]string{"appId": "app-123", "secret": "sec-xyz"}, + }, + { + name: "plaintext JSON works even with key present", + stored: plaintextJSON, + key: key, + want: map[string]string{"appId": "app-123", "secret": "sec-xyz"}, + }, + { + name: "legacy ciphertext with correct key decrypts", + stored: legacyEncrypted, + key: key, + want: map[string]string{"appId": "app-123", "secret": "sec-xyz"}, + }, + { + name: "legacy ciphertext with no key treated as empty", + stored: legacyEncrypted, + key: nil, + want: nil, + }, + { + name: "legacy ciphertext with wrong key treated as empty", + stored: legacyEncrypted, + key: wrongKey, + want: nil, + }, + { + name: "garbage data treated as empty", + stored: "not-json-and-not-ciphertext", + key: key, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + lb := NewDefaultLoadBalancer(nil, tt.key) + got, err := lb.decryptConfig(tt.stored) + if err != nil { + t.Fatalf("decryptConfig unexpected error: %v", err) + } + if !stringMapEqual(got, tt.want) { + t.Fatalf("decryptConfig = %v, want %v", got, tt.want) + } + }) + } +} + +// stringMapEqual compares two map[string]string values; nil and empty are equal. +func stringMapEqual(a, b map[string]string) bool { + if len(a) != len(b) { + return false + } + for k, v := range a { + if bv, ok := b[k]; !ok || bv != v { + return false + } + } + return true +} + // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- diff --git a/backend/internal/payment/provider/alipay.go b/backend/internal/payment/provider/alipay.go index af8a90c681..fe8ea89c0a 100644 --- a/backend/internal/payment/provider/alipay.go +++ b/backend/internal/payment/provider/alipay.go @@ -15,8 +15,8 @@ import ( // Alipay product codes. const ( - alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY" alipayProductCodeWapPay = "QUICK_WAP_WAY" + alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY" ) // Alipay response constants. @@ -79,7 +79,12 @@ func (a *Alipay) SupportedTypes() []payment.PaymentType { return []payment.PaymentType{payment.TypeAlipay} } -// CreatePayment creates an Alipay payment page URL. +// CreatePayment creates an Alipay payment using redirect-only flow: +// - Mobile (H5): alipay.trade.wap.pay — returns a URL the browser jumps to. +// - PC: alipay.trade.page.pay — returns a gateway URL the browser opens in a +// new window; Alipay's own page then shows login/QR. We intentionally do +// NOT encode the URL into a QR on the client (it isn't a scannable payload +// and would produce an invalid scan result). func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { client, err := a.getClient() if err != nil { @@ -96,31 +101,31 @@ func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentReque } if req.IsMobile { - return a.createTrade(client, req, notifyURL, returnURL, true) + return a.createWapTrade(client, req, notifyURL, returnURL) } - return a.createTrade(client, req, notifyURL, returnURL, false) + return a.createPagePayTrade(client, req, notifyURL, returnURL) } -func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string, isMobile bool) (*payment.CreatePaymentResponse, error) { - if isMobile { - param := alipay.TradeWapPay{} - param.OutTradeNo = req.OrderID - param.TotalAmount = req.Amount - param.Subject = req.Subject - param.ProductCode = alipayProductCodeWapPay - param.NotifyURL = notifyURL - param.ReturnURL = returnURL - - payURL, err := client.TradeWapPay(param) - if err != nil { - return nil, fmt.Errorf("alipay TradeWapPay: %w", err) - } - return &payment.CreatePaymentResponse{ - TradeNo: req.OrderID, - PayURL: payURL.String(), - }, nil +func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) { + param := alipay.TradeWapPay{} + param.OutTradeNo = req.OrderID + param.TotalAmount = req.Amount + param.Subject = req.Subject + param.ProductCode = alipayProductCodeWapPay + param.NotifyURL = notifyURL + param.ReturnURL = returnURL + + payURL, err := client.TradeWapPay(param) + if err != nil { + return nil, fmt.Errorf("alipay TradeWapPay: %w", err) } + return &payment.CreatePaymentResponse{ + TradeNo: req.OrderID, + PayURL: payURL.String(), + }, nil +} +func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) { param := alipay.TradePagePay{} param.OutTradeNo = req.OrderID param.TotalAmount = req.Amount @@ -136,7 +141,6 @@ func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentReq return &payment.CreatePaymentResponse{ TradeNo: req.OrderID, PayURL: payURL.String(), - QRCode: payURL.String(), }, nil } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 7c26a47c62..701f3659ee 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -586,6 +586,15 @@ func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userI } func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) { + // 校验用户专属分组倍率:必须 > 0(nil 合法,表示清除专属倍率) + if input.GroupRates != nil { + for groupID, rate := range input.GroupRates { + if rate != nil && *rate <= 0 { + return nil, fmt.Errorf("rate_multiplier must be > 0 (group_id=%d)", groupID) + } + } + } + user, err := s.userRepo.GetByID(ctx, id) if err != nil { return nil, err @@ -811,6 +820,10 @@ func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, erro } func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) { + if input.RateMultiplier <= 0 { + return nil, errors.New("rate_multiplier must be > 0") + } + platform := input.Platform if platform == "" { platform = PlatformAnthropic @@ -1050,6 +1063,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd group.Platform = input.Platform } if input.RateMultiplier != nil { + if *input.RateMultiplier <= 0 { + return nil, errors.New("rate_multiplier must be > 0") + } group.RateMultiplier = *input.RateMultiplier } if input.IsExclusive != nil { @@ -1286,6 +1302,11 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro if s.userGroupRateRepo == nil { return nil } + for _, e := range entries { + if e.RateMultiplier <= 0 { + return fmt.Errorf("rate_multiplier must be > 0 (user_id=%d)", e.UserID) + } + } return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries) } diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index a4c6d0caba..41d2c26a0f 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -621,6 +621,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatfo _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformOpenAI, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -641,6 +642,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *t _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAnthropic, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeSubscription, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -695,6 +697,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t * _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAnthropic, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -713,6 +716,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) { _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAnthropic, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -733,6 +737,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAntigravity, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -750,6 +755,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing. group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAnthropic, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &zero, }) diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 32a54cbedf..c9f32b3b0c 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -448,8 +448,9 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown, }) } - if input.RateMultiplier <= 0 { - input.RateMultiplier = 1.0 + // 保存时强制 > 0;若仍有负数泄漏(缓存/迁移残留),按 0 处理避免按 1x 误扣。 + if input.RateMultiplier < 0 { + input.RateMultiplier = 0 } var breakdown *CostBreakdown @@ -493,8 +494,9 @@ func (s *BillingService) computeTokenBreakdown( rateMultiplier float64, serviceTier string, applyLongCtx bool, ) *CostBreakdown { - if rateMultiplier <= 0 { - rateMultiplier = 1.0 + // 保存时强制 > 0;若仍有负数泄漏,按 0 处理避免按 1x 误扣。 + if rateMultiplier < 0 { + rateMultiplier = 0 } inputPrice := pricing.InputPricePerToken @@ -831,9 +833,9 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag // 计算总费用 totalCost := unitPrice * float64(imageCount) - // 应用倍率 - if rateMultiplier <= 0 { - rateMultiplier = 1.0 + // 应用倍率(保存时强制 > 0;负数按 0 处理避免按 1x 误扣) + if rateMultiplier < 0 { + rateMultiplier = 0 } actualCost := totalCost * rateMultiplier diff --git a/backend/internal/service/billing_service_image_test.go b/backend/internal/service/billing_service_image_test.go index fa90f6bba5..8d3ca98778 100644 --- a/backend/internal/service/billing_service_image_test.go +++ b/backend/internal/service/billing_service_image_test.go @@ -90,13 +90,14 @@ func TestCalculateImageCost_NegativeCount(t *testing.T) { require.Equal(t, 0.0, cost.ActualCost) } -// TestCalculateImageCost_ZeroRateMultiplier 测试费率倍数为 0 时默认使用 1.0 +// TestCalculateImageCost_ZeroRateMultiplier 锁定新行为:倍率 0 直接按 0 计费 +// (保存时已强制 > 0;若仍有 0 泄漏到计费层,零消耗比历史的 1.0 更安全)。 func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) { svc := &BillingService{} cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0) require.InDelta(t, 0.201, cost.TotalCost, 0.0001) - require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理 + require.InDelta(t, 0.0, cost.ActualCost, 1e-10) } // TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格 diff --git a/backend/internal/service/billing_service_rate_multiplier_test.go b/backend/internal/service/billing_service_rate_multiplier_test.go new file mode 100644 index 0000000000..8378819611 --- /dev/null +++ b/backend/internal/service/billing_service_rate_multiplier_test.go @@ -0,0 +1,63 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestCalculateCost_RateMultiplier_NegativeClampedToZero 锁定负数倍率被 +// 钳制为 0(而非历史上的 1.0),避免配置异常导致静默按标准价扣费。 +func TestCalculateCost_RateMultiplier_NegativeClampedToZero(t *testing.T) { + svc := newTestBillingService() + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + + tests := []struct { + name string + multiplier float64 + wantRatio float64 // ActualCost / TotalCost + }{ + {"negative clamped to 0", -1.5, 0}, + {"zero passes through as 0 (defense in depth)", 0, 0}, + {"positive 2x applied", 2.0, 2.0}, + {"positive 0.5x applied", 0.5, 0.5}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, tt.multiplier) + require.NoError(t, err) + require.Greater(t, cost.TotalCost, 0.0, "TotalCost should be non-zero") + require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9) + }) + } +} + +// TestCalculateImageCost_RateMultiplier_NegativeClampedToZero 图片按次计费路径 +// 同样遵循"负数 → 0"语义。 +func TestCalculateImageCost_RateMultiplier_NegativeClampedToZero(t *testing.T) { + svc := newTestBillingService() + price := 0.04 + cfg := &ImagePriceConfig{Price1K: &price} + + tests := []struct { + name string + multiplier float64 + wantRatio float64 + }{ + {"negative clamped to 0", -0.5, 0}, + {"zero passes through", 0, 0}, + {"positive 3x applied", 3.0, 3.0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cost := svc.CalculateImageCost("imagen-3", "1K", 2, cfg, tt.multiplier) + require.NotNil(t, cost) + require.Greater(t, cost.TotalCost, 0.0) + require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9) + }) + } +} diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index 2cf134e2ba..fc8361c78d 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -71,34 +71,6 @@ func TestCalculateCost_RateMultiplier(t *testing.T) { require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10) } -func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) { - svc := newTestBillingService() - - tokens := UsageTokens{InputTokens: 1000} - - costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0) - require.NoError(t, err) - - costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) - require.NoError(t, err) - - require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10) -} - -func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) { - svc := newTestBillingService() - - tokens := UsageTokens{InputTokens: 1000} - - costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0) - require.NoError(t, err) - - costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) - require.NoError(t, err) - - require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10) -} - func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) { svc := newTestBillingService() diff --git a/backend/internal/service/billing_service_unified_test.go b/backend/internal/service/billing_service_unified_test.go index 694c338418..e6a92d1a8c 100644 --- a/backend/internal/service/billing_service_unified_test.go +++ b/backend/internal/service/billing_service_unified_test.go @@ -147,40 +147,35 @@ func TestCalculateCostUnified_ImageMode(t *testing.T) { require.Equal(t, string(BillingModeImage), cost.BillingMode) } -func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) { +// TestCalculateCostUnified_RateMultiplierZeroProducesZero 锁定新行为: +// 保存时强制 > 0;若 0 仍泄漏到计费层,按 0 计费(而非历史上的 1.0)。 +func TestCalculateCostUnified_RateMultiplierZeroProducesZero(t *testing.T) { bs := newTestBillingService() resolver := NewModelPricingResolver(nil, bs) tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} - costZero, err := bs.CalculateCostUnified(CostInput{ - Ctx: context.Background(), - Model: "claude-sonnet-4", - Tokens: tokens, - RateMultiplier: 0, // should default to 1.0 - Resolver: resolver, - }) - require.NoError(t, err) - - costOne, err := bs.CalculateCostUnified(CostInput{ + cost, err := bs.CalculateCostUnified(CostInput{ Ctx: context.Background(), Model: "claude-sonnet-4", Tokens: tokens, - RateMultiplier: 1.0, + RateMultiplier: 0, Resolver: resolver, }) require.NoError(t, err) - - require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10) + require.Greater(t, cost.TotalCost, 0.0) + require.InDelta(t, 0.0, cost.ActualCost, 1e-10) } -func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) { +// TestCalculateCostUnified_NegativeRateMultiplierClampedToZero 锁定新行为: +// 负数倍率按 0 计费,避免历史的 <=0 → 1.0 把配置异常静默按标准价扣费。 +func TestCalculateCostUnified_NegativeRateMultiplierClampedToZero(t *testing.T) { bs := newTestBillingService() resolver := NewModelPricingResolver(nil, bs) tokens := UsageTokens{InputTokens: 1000} - costNeg, err := bs.CalculateCostUnified(CostInput{ + cost, err := bs.CalculateCostUnified(CostInput{ Ctx: context.Background(), Model: "claude-sonnet-4", Tokens: tokens, @@ -188,17 +183,8 @@ func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) Resolver: resolver, }) require.NoError(t, err) - - costOne, err := bs.CalculateCostUnified(CostInput{ - Ctx: context.Background(), - Model: "claude-sonnet-4", - Tokens: tokens, - RateMultiplier: 1.0, - Resolver: resolver, - }) - require.NoError(t, err) - - require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10) + require.Greater(t, cost.TotalCost, 0.0) + require.InDelta(t, 0.0, cost.ActualCost, 1e-10) } func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) { diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 4b4fc0bf25..07a9e41c95 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -7317,8 +7317,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill cost := p.Cost if p.IsSubscriptionBill { - if cost.TotalCost > 0 { - if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil { + // Subscription usage tracked by ActualCost so group rate multiplier + // consumes the quota at the expected speed. + if cost.ActualCost > 0 { + if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.ActualCost); err != nil { slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err) } } @@ -7417,9 +7419,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage } } + // Record subscription / balance cost using ActualCost so the group (and any + // user-specific) rate multiplier consumes subscription quota at the expected + // speed. TotalCost remains the raw (pre-multiplier) value; downstream guards + // on "> 0" still correctly skip free subscriptions (RateMultiplier == 0). if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 { cmd.SubscriptionID = &p.Subscription.ID - cmd.SubscriptionCost = p.Cost.TotalCost + cmd.SubscriptionCost = p.Cost.ActualCost } else if p.Cost.ActualCost > 0 { cmd.BalanceCost = p.Cost.ActualCost } @@ -7478,8 +7484,8 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, resu } if p.IsSubscriptionBill { - if p.Cost.TotalCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil { - deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.TotalCost) + if p.Cost.ActualCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil { + deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.ActualCost) } } else if p.Cost.ActualCost > 0 && p.User != nil { deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost) diff --git a/backend/internal/service/gateway_service_subscription_billing_test.go b/backend/internal/service/gateway_service_subscription_billing_test.go new file mode 100644 index 0000000000..42a81035c0 --- /dev/null +++ b/backend/internal/service/gateway_service_subscription_billing_test.go @@ -0,0 +1,85 @@ +//go:build unit + +package service + +import ( + "testing" +) + +// TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier locks in the fix +// that subscription-mode billing honours the group (and any user-specific) rate +// multiplier — i.e. cmd.SubscriptionCost tracks ActualCost (= TotalCost * +// RateMultiplier), not raw TotalCost. +func TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier(t *testing.T) { + t.Parallel() + + groupID := int64(7) + subID := int64(42) + + tests := []struct { + name string + totalCost float64 + actualCost float64 + isSubscription bool + wantSub float64 + wantBalance float64 + }{ + { + name: "subscription with 2x multiplier consumes 2x quota", + totalCost: 1.0, + actualCost: 2.0, + isSubscription: true, + wantSub: 2.0, + wantBalance: 0, + }, + { + name: "subscription with 0.5x multiplier consumes 0.5x quota", + totalCost: 1.0, + actualCost: 0.5, + isSubscription: true, + wantSub: 0.5, + wantBalance: 0, + }, + { + name: "free subscription (multiplier 0) consumes no quota", + totalCost: 1.0, + actualCost: 0, + isSubscription: true, + wantSub: 0, + wantBalance: 0, + }, + { + name: "balance billing keeps using ActualCost (regression)", + totalCost: 1.0, + actualCost: 2.0, + isSubscription: false, + wantSub: 0, + wantBalance: 2.0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + p := &postUsageBillingParams{ + Cost: &CostBreakdown{TotalCost: tt.totalCost, ActualCost: tt.actualCost}, + User: &User{ID: 1}, + APIKey: &APIKey{ID: 2, GroupID: &groupID}, + Account: &Account{ID: 3}, + Subscription: &UserSubscription{ID: subID}, + IsSubscriptionBill: tt.isSubscription, + } + + cmd := buildUsageBillingCommand("req-1", nil, p) + if cmd == nil { + t.Fatal("buildUsageBillingCommand returned nil") + } + if cmd.SubscriptionCost != tt.wantSub { + t.Errorf("SubscriptionCost = %v, want %v", cmd.SubscriptionCost, tt.wantSub) + } + if cmd.BalanceCost != tt.wantBalance { + t.Errorf("BalanceCost = %v, want %v", cmd.BalanceCost, tt.wantBalance) + } + }) + } +} diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index 1226261357..64434ae1d6 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -76,10 +76,6 @@ func (g *Group) IsSubscriptionType() bool { return g.SubscriptionType == SubscriptionTypeSubscription } -func (g *Group) IsFreeSubscription() bool { - return g.IsSubscriptionType() && g.RateMultiplier == 0 -} - func (g *Group) HasDailyLimit() bool { return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0 } diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index e6fa94aa3b..6fa8a5bdbf 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -1031,7 +1031,7 @@ func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFiel Model: "gpt-5.1", Duration: time.Second, }, - APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription}}, + APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription, RateMultiplier: 1.0}}, User: &User{ID: 200}, Account: &Account{ID: 300}, Subscription: subscription, diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index 3c406b4534..3d1e4dc491 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "log/slog" "strconv" "strings" @@ -51,7 +52,7 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte AllowUserRefund: inst.AllowUserRefund, SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode, } - resp.Config, err = s.decryptAndMaskConfig(inst.Config) + resp.Config, err = s.decryptAndMaskConfig(inst.ProviderKey, inst.Config) if err != nil { return nil, fmt.Errorf("decrypt config for instance %d: %w", inst.ID, err) } @@ -60,8 +61,26 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte return result, nil } -func (s *PaymentConfigService) decryptAndMaskConfig(encrypted string) (map[string]string, error) { - return s.decryptConfig(encrypted) +// decryptAndMaskConfig returns the stored config with sensitive fields omitted. +// Admin UIs display masked placeholders for these; the raw values never leave +// the server. Callers that need the full config (e.g. payment runtime) must +// use decryptConfig directly. +func (s *PaymentConfigService) decryptAndMaskConfig(providerKey, encrypted string) (map[string]string, error) { + cfg, err := s.decryptConfig(encrypted) + if err != nil { + return nil, err + } + if cfg == nil { + return nil, nil + } + masked := make(map[string]string, len(cfg)) + for k, v := range cfg { + if isSensitiveProviderConfigField(providerKey, k) { + continue + } + masked[k] = v + } + return masked, nil } // pendingOrderStatuses are order statuses considered "in progress". @@ -71,16 +90,27 @@ var pendingOrderStatuses = []string{ payment.OrderStatusRecharging, } -var sensitiveConfigPatterns = []string{"key", "pkey", "secret", "private", "password"} +// providerSensitiveConfigFields is the authoritative list of config keys that +// are treated as secrets per provider. Must stay in sync with the frontend +// definition at frontend/src/components/payment/providerConfig.ts +// (PROVIDER_CONFIG_FIELDS, fields with sensitive: true). +// +// Key matching is case-insensitive. Non-listed keys (e.g. appId, notifyUrl, +// stripe publishableKey) are returned in plaintext by the admin GET API. +var providerSensitiveConfigFields = map[string]map[string]struct{}{ + payment.TypeEasyPay: {"pkey": {}}, + payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}}, + payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}}, + payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}}, +} -func isSensitiveConfigField(fieldName string) bool { - lower := strings.ToLower(fieldName) - for _, p := range sensitiveConfigPatterns { - if strings.Contains(lower, p) { - return true - } +func isSensitiveProviderConfigField(providerKey, fieldName string) bool { + fields, ok := providerSensitiveConfigFields[providerKey] + if !ok { + return false } - return false + _, found := fields[strings.ToLower(fieldName)] + return found } func (s *PaymentConfigService) countPendingOrders(ctx context.Context, providerInstanceID int64) (int, error) { @@ -136,10 +166,26 @@ func validateProviderRequest(providerKey, name, supportedTypes string) error { // NOTE: This function exceeds 30 lines due to per-field nil-check patch update // boilerplate and pending-order safety checks. func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) { + var cachedInst *dbent.PaymentProviderInstance + loadInst := func() (*dbent.PaymentProviderInstance, error) { + if cachedInst != nil { + return cachedInst, nil + } + inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) + if err != nil { + return nil, fmt.Errorf("load provider instance: %w", err) + } + cachedInst = inst + return inst, nil + } if req.Config != nil { + inst, err := loadInst() + if err != nil { + return nil, err + } hasSensitive := false - for k := range req.Config { - if isSensitiveConfigField(k) && req.Config[k] != "" { + for k, v := range req.Config { + if v != "" && isSensitiveProviderConfigField(inst.ProviderKey, k) { hasSensitive = true break } @@ -282,27 +328,48 @@ func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newCon return nil, fmt.Errorf("decrypt existing config for instance %d: %w", id, err) } if existing == nil { - return newConfig, nil + existing = map[string]string{} } for k, v := range newConfig { + // Preserve existing secrets when the client submits an empty value + // (admin UI omits the value to indicate "leave unchanged"). + if v == "" && isSensitiveProviderConfigField(inst.ProviderKey, k) { + continue + } existing[k] = v } return existing, nil } -func (s *PaymentConfigService) decryptConfig(encrypted string) (map[string]string, error) { - if encrypted == "" { +// decryptConfig parses a stored provider config. +// New records are plaintext JSON; legacy records are AES-256-GCM ciphertext +// ("iv:authTag:ciphertext"). Values that cannot be parsed as either — including +// legacy ciphertext with no/invalid TOTP_ENCRYPTION_KEY — are treated as empty, +// letting the admin re-enter the config via the UI to complete the migration. +// +// TODO(deprecated-legacy-ciphertext): The AES fallback branch is a transitional +// shim for pre-plaintext records. Remove it (and the encryptionKey field) after +// a few releases once all live deployments have re-saved their provider configs. +func (s *PaymentConfigService) decryptConfig(stored string) (map[string]string, error) { + if stored == "" { return nil, nil } - decrypted, err := payment.Decrypt(encrypted, s.encryptionKey) - if err != nil { - return nil, fmt.Errorf("decrypt config: %w", err) + var cfg map[string]string + if err := json.Unmarshal([]byte(stored), &cfg); err == nil { + return cfg, nil } - var raw map[string]string - if err := json.Unmarshal([]byte(decrypted), &raw); err != nil { - return nil, fmt.Errorf("unmarshal decrypted config: %w", err) + // Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal. + if len(s.encryptionKey) == payment.AES256KeySize { + //nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal + if plaintext, err := payment.Decrypt(stored, s.encryptionKey); err == nil { + if err := json.Unmarshal([]byte(plaintext), &cfg); err == nil { + return cfg, nil + } + } } - return raw, nil + slog.Warn("payment provider config unreadable, treating as empty for re-entry", + "stored_len", len(stored)) + return nil, nil } func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id int64) error { @@ -317,14 +384,13 @@ func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id in return s.entClient.PaymentProviderInstance.DeleteOneID(id).Exec(ctx) } +// encryptConfig serialises a provider config for storage. +// New records are written as plaintext JSON; the historical AES-GCM wrapping +// has been dropped but decryptConfig still accepts old ciphertext during migration. func (s *PaymentConfigService) encryptConfig(cfg map[string]string) (string, error) { data, err := json.Marshal(cfg) if err != nil { return "", fmt.Errorf("marshal config: %w", err) } - enc, err := payment.Encrypt(string(data), s.encryptionKey) - if err != nil { - return "", fmt.Errorf("encrypt config: %w", err) - } - return enc, nil + return string(data), nil } diff --git a/backend/internal/service/payment_config_providers_test.go b/backend/internal/service/payment_config_providers_test.go index 2aaa874f95..bc2a9b18be 100644 --- a/backend/internal/service/payment_config_providers_test.go +++ b/backend/internal/service/payment_config_providers_test.go @@ -97,41 +97,52 @@ func TestValidateProviderRequest(t *testing.T) { } } -func TestIsSensitiveConfigField(t *testing.T) { +func TestIsSensitiveProviderConfigField(t *testing.T) { t.Parallel() tests := []struct { - field string - wantSen bool + providerKey string + field string + wantSen bool }{ - // Sensitive fields (contain key/secret/private/password/pkey patterns) - {"secretKey", true}, - {"apiSecret", true}, - {"pkey", true}, - {"privateKey", true}, - {"apiPassword", true}, - {"appKey", true}, - {"SECRET_TOKEN", true}, - {"PrivateData", true}, - {"PASSWORD", true}, - {"mySecretValue", true}, - - // Non-sensitive fields - {"appId", false}, - {"mchId", false}, - {"apiBase", false}, - {"endpoint", false}, - {"merchantNo", false}, - {"paymentMode", false}, - {"notifyUrl", false}, + // Stripe: publishableKey is public, only secretKey/webhookSecret are secrets + {"stripe", "secretKey", true}, + {"stripe", "webhookSecret", true}, + {"stripe", "SecretKey", true}, // case-insensitive + {"stripe", "publishableKey", false}, + {"stripe", "appId", false}, + + // Alipay + {"alipay", "privateKey", true}, + {"alipay", "publicKey", true}, + {"alipay", "alipayPublicKey", true}, + {"alipay", "appId", false}, + {"alipay", "notifyUrl", false}, + + // Wxpay + {"wxpay", "privateKey", true}, + {"wxpay", "apiV3Key", true}, + {"wxpay", "publicKey", true}, + {"wxpay", "publicKeyId", false}, + {"wxpay", "certSerial", false}, + {"wxpay", "mchId", false}, + + // EasyPay + {"easypay", "pkey", true}, + {"easypay", "pid", false}, + {"easypay", "apiBase", false}, + + // Unknown provider: never sensitive + {"unknown", "secretKey", false}, } for _, tc := range tests { - t.Run(tc.field, func(t *testing.T) { + tc := tc + t.Run(tc.providerKey+"/"+tc.field, func(t *testing.T) { t.Parallel() - got := isSensitiveConfigField(tc.field) - assert.Equal(t, tc.wantSen, got, "isSensitiveConfigField(%q)", tc.field) + got := isSensitiveProviderConfigField(tc.providerKey, tc.field) + assert.Equal(t, tc.wantSen, got, "isSensitiveProviderConfigField(%q, %q)", tc.providerKey, tc.field) }) } } diff --git a/backend/internal/service/upstream_response_limit.go b/backend/internal/service/upstream_response_limit.go index a0444d5209..ddf0e81852 100644 --- a/backend/internal/service/upstream_response_limit.go +++ b/backend/internal/service/upstream_response_limit.go @@ -12,7 +12,9 @@ import ( var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large") -const defaultUpstreamResponseReadMaxBytes int64 = 8 * 1024 * 1024 +// defaultUpstreamResponseReadMaxBytes 源自 config.DefaultUpstreamResponseReadMaxBytes, +// 仅在 cfg 为 nil 时作为兜底(测试或极端场景)。 +const defaultUpstreamResponseReadMaxBytes = config.DefaultUpstreamResponseReadMaxBytes func resolveUpstreamResponseReadLimit(cfg *config.Config) int64 { if cfg != nil && cfg.Gateway.UpstreamResponseReadMaxBytes > 0 { diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 1da32e2c29..59ca0b9c52 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -52,6 +52,10 @@ v-model="editApiKey" type="password" class="input font-mono" + autocomplete="new-password" + data-1p-ignore + data-lpignore="true" + data-bwignore="true" :placeholder=" account.platform === 'openai' ? 'sk-proj-...' diff --git a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue index bf79bea200..41b2e63c09 100644 --- a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue +++ b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue @@ -166,7 +166,7 @@