Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions go/core/cli/cmd/kagent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"os"
"strings"
"os/signal"
"syscall"
"time"
Expand Down Expand Up @@ -64,6 +65,10 @@ func main() {
_ = installCmd.RegisterFlagCompletionFunc("profile", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
return profiles.Profiles, cobra.ShellCompDirectiveNoFileComp
})
installCmd.Flags().StringVar(&installCfg.Provider, "provider", "", fmt.Sprintf("LLM provider to use (%s). Overrides KAGENT_DEFAULT_MODEL_PROVIDER.", strings.Join(cli.ValidProviders(), ", ")))
_ = installCmd.RegisterFlagCompletionFunc("provider", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
return cli.ValidProviders(), cobra.ShellCompDirectiveNoFileComp
})

uninstallCmd := &cobra.Command{
Use: "uninstall",
Expand Down
23 changes: 23 additions & 0 deletions go/core/cli/internal/cli/agent/const.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cli

import (
"fmt"
"os"
"strings"

Expand Down Expand Up @@ -63,3 +64,25 @@ func GetEnvVarWithDefault(envVar, defaultValue string) string {
}
return defaultValue
}

// ValidProviders returns the accepted --provider flag values (helm key format).
func ValidProviders() []string {
return []string{
GetModelProviderHelmValuesKey(v1alpha2.ModelProviderOpenAI),
GetModelProviderHelmValuesKey(v1alpha2.ModelProviderAnthropic),
GetModelProviderHelmValuesKey(v1alpha2.ModelProviderAzureOpenAI),
GetModelProviderHelmValuesKey(v1alpha2.ModelProviderOllama),
}
}

// applyProviderFlag validates the --provider value and sets KAGENT_DEFAULT_MODEL_PROVIDER so
// that GetModelProvider() picks it up. This lets users avoid setting the env var manually.
func applyProviderFlag(provider string) error {
valid := ValidProviders()
for _, v := range valid {
if provider == v {
return os.Setenv(env.KagentDefaultModelProvider.Name(), provider)
}
}
return fmt.Errorf("unknown provider %q: valid values: %s", provider, strings.Join(valid, ", "))
}
24 changes: 20 additions & 4 deletions go/core/cli/internal/cli/agent/install.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ import (
)

type InstallCfg struct {
Config *config.Config
Profile string
Config *config.Config
Profile string
Provider string
}

// installChart installs or upgrades a Helm chart with the given parameters
Expand Down Expand Up @@ -76,16 +77,28 @@ func InstallCmd(ctx context.Context, cfg *InstallCfg) *PortForward {
return nil
}

// --provider flag takes precedence over KAGENT_DEFAULT_MODEL_PROVIDER env var
if cfg.Provider != "" {
if err := applyProviderFlag(cfg.Provider); err != nil {
fmt.Fprintln(os.Stderr, err)
return nil
}
}
Comment thread
tjorourke marked this conversation as resolved.

// get model provider from KAGENT_DEFAULT_MODEL_PROVIDER environment variable or use DefaultModelProvider
modelProvider := GetModelProvider()

// If model provider is openai, check if the API key is set
// Check if the required API key is set for this provider
apiKeyName := GetProviderAPIKey(modelProvider)
apiKeyValue := os.Getenv(apiKeyName)

if apiKeyName != "" && apiKeyValue == "" {
fmt.Fprintf(os.Stderr, "%s is not set\n", apiKeyName)
fmt.Fprintf(os.Stderr, "Please set the %s environment variable\n", apiKeyName)
if cfg.Provider == "" && modelProvider == DefaultModelProvider && apiKeyName == env.OpenAIAPIKey.Name() {
fmt.Fprintf(os.Stderr, "Tip: use --provider to select a different LLM provider (e.g. --provider anthropic)\n")
fmt.Fprintf(os.Stderr, " or set %s=%s before running install\n", env.KagentDefaultModelProvider.Name(), GetModelProviderHelmValuesKey(v1alpha2.ModelProviderAnthropic))
}
return nil
}

Expand Down Expand Up @@ -120,13 +133,16 @@ func InteractiveInstallCmd(ctx context.Context, c *ishell.Context) *PortForward
// get model provider from KAGENT_DEFAULT_MODEL_PROVIDER environment variable or use DefaultModelProvider
modelProvider := GetModelProvider()

// if model provider is openai, check if the api key is set
// Check if the required API key is set for this provider
apiKeyName := GetProviderAPIKey(modelProvider)
apiKeyValue := os.Getenv(apiKeyName)

if apiKeyName != "" && apiKeyValue == "" {
fmt.Fprintf(os.Stderr, "%s is not set\n", apiKeyName)
fmt.Fprintf(os.Stderr, "Please set the %s environment variable\n", apiKeyName)
fmt.Fprintf(os.Stderr, "Tip: set %s to select a different provider (e.g. %s=%s)\n",
env.KagentDefaultModelProvider.Name(), env.KagentDefaultModelProvider.Name(),
GetModelProviderHelmValuesKey(v1alpha2.ModelProviderAnthropic))
return nil
}

Expand Down
96 changes: 92 additions & 4 deletions go/core/internal/utils/config_map.go
Original file line number Diff line number Diff line change
@@ -1,28 +1,116 @@
package utils

import (
"bytes"
"compress/gzip"
"context"
"encoding/base64"
"fmt"
"io"
"strings"

"github.com/klauspost/compress/zstd"
corev1 "k8s.io/api/core/v1"
"sigs.k8s.io/controller-runtime/pkg/client"
)

// GetConfigMapData fetches all data from a ConfigMap.
const (
// CompressionAnnotation specifies the compression algorithm used for ConfigMap
// values. Supported values: "gzip", "zstd". When set, all values in the
// ConfigMap are expected to be base64-encoded compressed data and will be
// transparently decompressed when read via GetConfigMapData.
CompressionAnnotation = "kagent.dev/compression"

// maxDecompressedSize is the upper bound on decompressed output (10 MB).
// This prevents a small compressed payload from expanding into an
// arbitrarily large allocation that could OOM the controller.
maxDecompressedSize = 10 << 20 // 10 MiB
)

// GetConfigMapData fetches all data from a ConfigMap. If the ConfigMap carries
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This data gets rewritten to a secret, so I don't think really makes sense unless we also compress that outgoing data as well

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok, i can fix that - but i think we should yes. Just compress it here too. Similar pattern to helm data in secret.

// the kagent.dev/compression annotation, values are transparently decompressed.
// Compressed values must be base64-encoded in the ConfigMap's Data field (not BinaryData).
func GetConfigMapData(ctx context.Context, c client.Client, ref client.ObjectKey) (map[string]string, error) {
configMap := &corev1.ConfigMap{}
if err := c.Get(ctx, ref, configMap); err != nil {
return nil, fmt.Errorf("failed to find ConfigMap %s: %v", ref.String(), err)
return nil, fmt.Errorf("failed to find ConfigMap %s: %w", ref.String(), err)
}
Comment thread
tjorourke marked this conversation as resolved.

algo := strings.ToLower(strings.TrimSpace(configMap.Annotations[CompressionAnnotation]))
if algo == "" {
return configMap.Data, nil
}

decompressed := make(map[string]string, len(configMap.Data))
for key, value := range configMap.Data {
plain, err := decompress(value, algo)
if err != nil {
return nil, fmt.Errorf("failed to decompress key %q in ConfigMap %s (algorithm=%s): %w", key, ref.String(), algo, err)
}
decompressed[key] = plain
}
return decompressed, nil
}

// decompress decodes base64 data and decompresses it with the given algorithm.
// The encoded payload is whitespace-tolerant (newlines and spaces are stripped
// before decoding) and decompressed output is capped at maxDecompressedSize.
func decompress(encoded string, algo string) (string, error) {
// Strip whitespace/newlines that commonly appear in pasted base64
cleaned := strings.Map(func(r rune) rune {
if r == ' ' || r == '\n' || r == '\r' || r == '\t' {
return -1
}
return r
}, encoded)

raw, err := base64.StdEncoding.DecodeString(cleaned)
if err != nil {
return "", fmt.Errorf("base64 decode: %w", err)
}

switch algo {
case "gzip":
r, err := gzip.NewReader(bytes.NewReader(raw))
if err != nil {
return "", fmt.Errorf("gzip reader: %w", err)
}
defer r.Close()
out, err := io.ReadAll(io.LimitReader(r, maxDecompressedSize+1))
if err != nil {
return "", fmt.Errorf("gzip read: %w", err)
}
if len(out) > maxDecompressedSize {
return "", fmt.Errorf("decompressed output exceeds %d bytes limit", maxDecompressedSize)
}
return string(out), nil

case "zstd":
r, err := zstd.NewReader(bytes.NewReader(raw))
if err != nil {
return "", fmt.Errorf("zstd reader: %w", err)
}
defer r.Close()
out, err := io.ReadAll(io.LimitReader(r, maxDecompressedSize+1))
if err != nil {
return "", fmt.Errorf("zstd read: %w", err)
}
if len(out) > maxDecompressedSize {
return "", fmt.Errorf("decompressed output exceeds %d bytes limit", maxDecompressedSize)
}
return string(out), nil

default:
return "", fmt.Errorf("unsupported compression algorithm %q (supported: gzip, zstd)", algo)
}
return configMap.Data, nil
}

// GetConfigMapValue fetches a value from a ConfigMap
func GetConfigMapValue(ctx context.Context, c client.Client, ref client.ObjectKey, key string) (string, error) {
configMap := &corev1.ConfigMap{}
err := c.Get(ctx, ref, configMap)
if err != nil {
return "", fmt.Errorf("failed to find ConfigMap for %s: %v", ref.String(), err)
return "", fmt.Errorf("failed to find ConfigMap for %s: %w", ref.String(), err)
}

value, exists := configMap.Data[key]
Expand Down
124 changes: 124 additions & 0 deletions go/core/internal/utils/config_map_compression_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package utils

import (
"bytes"
"compress/gzip"
"encoding/base64"
"strings"
"testing"

"github.com/klauspost/compress/zstd"
)

func compressGzip(t *testing.T, data string) string {
t.Helper()
var buf bytes.Buffer
w := gzip.NewWriter(&buf)
if _, err := w.Write([]byte(data)); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
return base64.StdEncoding.EncodeToString(buf.Bytes())
}

func compressZstd(t *testing.T, data string) string {
t.Helper()
var buf bytes.Buffer
w, err := zstd.NewWriter(&buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write([]byte(data)); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
return base64.StdEncoding.EncodeToString(buf.Bytes())
}

func TestDecompressGzip(t *testing.T) {
original := "Section 42 of the Children and Families Act 2014 imposes an absolute duty on the local authority to secure the provision specified in Section F."
encoded := compressGzip(t, original)

result, err := decompress(encoded, "gzip")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != original {
t.Errorf("got %q, want %q", result, original)
}
}

func TestDecompressZstd(t *testing.T) {
original := "Section 42 of the Children and Families Act 2014 imposes an absolute duty on the local authority to secure the provision specified in Section F."
encoded := compressZstd(t, original)

result, err := decompress(encoded, "zstd")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != original {
t.Errorf("got %q, want %q", result, original)
}
}

func TestDecompressUnsupportedAlgorithm(t *testing.T) {
_, err := decompress(base64.StdEncoding.EncodeToString([]byte("test")), "lz4")
if err == nil {
t.Fatal("expected error for unsupported algorithm")
}
}

func TestDecompressInvalidBase64(t *testing.T) {
_, err := decompress("not-valid-base64!!!", "gzip")
if err == nil {
t.Fatal("expected error for invalid base64")
}
}
Comment thread
tjorourke marked this conversation as resolved.

func TestDecompressBase64WithWhitespace(t *testing.T) {
original := "Whitespace in base64 is common when users paste wrapped output."
clean := compressGzip(t, original)

// Insert newlines and spaces to simulate wrapped base64
wrapped := clean[:20] + "\n" + clean[20:40] + " " + clean[40:60] + "\r\n" + clean[60:]

result, err := decompress(wrapped, "gzip")
if err != nil {
t.Fatalf("unexpected error with whitespace in base64: %v", err)
}
if result != original {
t.Errorf("got %q, want %q", result, original)
}
}

func TestDecompressExceedsSizeLimit(t *testing.T) {
// Create data larger than maxDecompressedSize (10MB)
// zstd compresses repeated data extremely well, so a small input can exceed the limit
huge := make([]byte, maxDecompressedSize+1)
for i := range huge {
huge[i] = 'A'
}

var buf bytes.Buffer
w, err := zstd.NewWriter(&buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(huge); err != nil {
t.Fatal(err)
}
w.Close()
encoded := base64.StdEncoding.EncodeToString(buf.Bytes())

_, err = decompress(encoded, "zstd")
if err == nil {
t.Fatal("expected error for oversized decompressed output")
}
if !strings.Contains(err.Error(), "exceeds") {
t.Errorf("expected 'exceeds' in error message, got: %v", err)
}
}
4 changes: 2 additions & 2 deletions go/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ require (
require (
github.com/aws/aws-sdk-go-v2 v1.41.5
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.4
github.com/google/jsonschema-go v0.4.2
github.com/jackc/pgx/v5 v5.9.1
github.com/klauspost/compress v1.18.5
github.com/ollama/ollama v0.20.5
github.com/testcontainers/testcontainers-go v0.42.0
github.com/testcontainers/testcontainers-go/modules/postgres v0.42.0
Expand Down Expand Up @@ -154,7 +156,6 @@ require (
github.com/google/cel-go v0.26.0 // indirect
github.com/google/gnostic-models v0.7.0 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/jsonschema-go v0.4.2 // indirect
github.com/google/s2a-go v0.1.9 // indirect
github.com/google/safehtml v0.1.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect
Expand All @@ -168,7 +169,6 @@ require (
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.5 // indirect
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
github.com/lestrrat-go/httprc v1.0.6 // indirect
Expand Down