Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions go/core/cli/cmd/kagent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"os"
"strings"
"os/signal"
"syscall"
"time"
Expand Down Expand Up @@ -64,6 +65,10 @@ func main() {
_ = installCmd.RegisterFlagCompletionFunc("profile", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
return profiles.Profiles, cobra.ShellCompDirectiveNoFileComp
})
installCmd.Flags().StringVar(&installCfg.Provider, "provider", "", fmt.Sprintf("LLM provider to use (%s). Overrides KAGENT_DEFAULT_MODEL_PROVIDER.", strings.Join(cli.ValidProviders(), ", ")))
_ = installCmd.RegisterFlagCompletionFunc("provider", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
return cli.ValidProviders(), cobra.ShellCompDirectiveNoFileComp
})

uninstallCmd := &cobra.Command{
Use: "uninstall",
Expand Down
23 changes: 23 additions & 0 deletions go/core/cli/internal/cli/agent/const.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cli

import (
"fmt"
"os"
"strings"

Expand Down Expand Up @@ -63,3 +64,25 @@ func GetEnvVarWithDefault(envVar, defaultValue string) string {
}
return defaultValue
}

// ValidProviders returns the accepted --provider flag values (helm key format).
func ValidProviders() []string {
return []string{
GetModelProviderHelmValuesKey(v1alpha2.ModelProviderOpenAI),
GetModelProviderHelmValuesKey(v1alpha2.ModelProviderAnthropic),
GetModelProviderHelmValuesKey(v1alpha2.ModelProviderAzureOpenAI),
GetModelProviderHelmValuesKey(v1alpha2.ModelProviderOllama),
}
}

// applyProviderFlag validates the --provider value and sets KAGENT_DEFAULT_MODEL_PROVIDER so
// that GetModelProvider() picks it up. This lets users avoid setting the env var manually.
func applyProviderFlag(provider string) error {
valid := ValidProviders()
for _, v := range valid {
if provider == v {
return os.Setenv(env.KagentDefaultModelProvider.Name(), provider)
}
}
return fmt.Errorf("unknown provider %q: valid values: %s", provider, strings.Join(valid, ", "))
}
24 changes: 20 additions & 4 deletions go/core/cli/internal/cli/agent/install.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ import (
)

type InstallCfg struct {
Config *config.Config
Profile string
Config *config.Config
Profile string
Provider string
}

// installChart installs or upgrades a Helm chart with the given parameters
Expand Down Expand Up @@ -76,16 +77,28 @@ func InstallCmd(ctx context.Context, cfg *InstallCfg) *PortForward {
return nil
}

// --provider flag takes precedence over KAGENT_DEFAULT_MODEL_PROVIDER env var
if cfg.Provider != "" {
if err := applyProviderFlag(cfg.Provider); err != nil {
fmt.Fprintln(os.Stderr, err)
return nil
}
}
Comment thread
tjorourke marked this conversation as resolved.

// get model provider from KAGENT_DEFAULT_MODEL_PROVIDER environment variable or use DefaultModelProvider
modelProvider := GetModelProvider()

// If model provider is openai, check if the API key is set
// Check if the required API key is set for this provider
apiKeyName := GetProviderAPIKey(modelProvider)
apiKeyValue := os.Getenv(apiKeyName)

if apiKeyName != "" && apiKeyValue == "" {
fmt.Fprintf(os.Stderr, "%s is not set\n", apiKeyName)
fmt.Fprintf(os.Stderr, "Please set the %s environment variable\n", apiKeyName)
if cfg.Provider == "" && modelProvider == DefaultModelProvider && apiKeyName == env.OpenAIAPIKey.Name() {
fmt.Fprintf(os.Stderr, "Tip: use --provider to select a different LLM provider (e.g. --provider anthropic)\n")
fmt.Fprintf(os.Stderr, " or set %s=%s before running install\n", env.KagentDefaultModelProvider.Name(), GetModelProviderHelmValuesKey(v1alpha2.ModelProviderAnthropic))
}
return nil
}

Expand Down Expand Up @@ -120,13 +133,16 @@ func InteractiveInstallCmd(ctx context.Context, c *ishell.Context) *PortForward
// get model provider from KAGENT_DEFAULT_MODEL_PROVIDER environment variable or use DefaultModelProvider
modelProvider := GetModelProvider()

// if model provider is openai, check if the api key is set
// Check if the required API key is set for this provider
apiKeyName := GetProviderAPIKey(modelProvider)
apiKeyValue := os.Getenv(apiKeyName)

if apiKeyName != "" && apiKeyValue == "" {
fmt.Fprintf(os.Stderr, "%s is not set\n", apiKeyName)
fmt.Fprintf(os.Stderr, "Please set the %s environment variable\n", apiKeyName)
fmt.Fprintf(os.Stderr, "Tip: set %s to select a different provider (e.g. %s=%s)\n",
env.KagentDefaultModelProvider.Name(), env.KagentDefaultModelProvider.Name(),
GetModelProviderHelmValuesKey(v1alpha2.ModelProviderAnthropic))
return nil
}

Expand Down
70 changes: 68 additions & 2 deletions go/core/internal/utils/config_map.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,86 @@
package utils

import (
"bytes"
"compress/gzip"
"context"
"encoding/base64"
"fmt"
"io"

"github.com/klauspost/compress/zstd"
corev1 "k8s.io/api/core/v1"
"sigs.k8s.io/controller-runtime/pkg/client"
)

// GetConfigMapData fetches all data from a ConfigMap.
const (
// CompressionAnnotation specifies the compression algorithm used for ConfigMap
// values. Supported values: "gzip", "zstd". When set, all values in the
// ConfigMap are expected to be base64-encoded compressed data and will be
// transparently decompressed when read via GetConfigMapData.
CompressionAnnotation = "kagent.dev/compression"
)

// GetConfigMapData fetches all data from a ConfigMap. If the ConfigMap carries
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This data gets rewritten to a secret, so I don't think really makes sense unless we also compress that outgoing data as well

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok, i can fix that - but i think we should yes. Just compress it here too. Similar pattern to helm data in secret.

// the kagent.dev/compression annotation, values are transparently decompressed.
// Compressed values must be base64-encoded in the ConfigMap's Data field (not BinaryData).
func GetConfigMapData(ctx context.Context, c client.Client, ref client.ObjectKey) (map[string]string, error) {
configMap := &corev1.ConfigMap{}
if err := c.Get(ctx, ref, configMap); err != nil {
return nil, fmt.Errorf("failed to find ConfigMap %s: %v", ref.String(), err)
}
Comment thread
tjorourke marked this conversation as resolved.
return configMap.Data, nil

algo := configMap.Annotations[CompressionAnnotation]
if algo == "" {
return configMap.Data, nil
}
Comment thread
tjorourke marked this conversation as resolved.
Outdated

decompressed := make(map[string]string, len(configMap.Data))
for key, value := range configMap.Data {
plain, err := decompress(value, algo)
if err != nil {
return nil, fmt.Errorf("failed to decompress key %q in ConfigMap %s (algorithm=%s): %w", key, ref.String(), algo, err)
}
decompressed[key] = plain
}
return decompressed, nil
}

// decompress decodes base64 data and decompresses it with the given algorithm.
func decompress(encoded string, algo string) (string, error) {
raw, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return "", fmt.Errorf("base64 decode: %w", err)
}
Comment thread
tjorourke marked this conversation as resolved.
Outdated

switch algo {
case "gzip":
r, err := gzip.NewReader(bytes.NewReader(raw))
if err != nil {
return "", fmt.Errorf("gzip reader: %w", err)
}
defer r.Close()
out, err := io.ReadAll(r)
if err != nil {
return "", fmt.Errorf("gzip read: %w", err)
}
Comment thread
tjorourke marked this conversation as resolved.
Outdated
return string(out), nil

case "zstd":
r, err := zstd.NewReader(bytes.NewReader(raw))
if err != nil {
return "", fmt.Errorf("zstd reader: %w", err)
}
defer r.Close()
out, err := io.ReadAll(r)
if err != nil {
return "", fmt.Errorf("zstd read: %w", err)
}
Comment thread
tjorourke marked this conversation as resolved.
Outdated
return string(out), nil

default:
return "", fmt.Errorf("unsupported compression algorithm %q (supported: gzip, zstd)", algo)
}
}

// GetConfigMapValue fetches a value from a ConfigMap
Expand Down
79 changes: 79 additions & 0 deletions go/core/internal/utils/config_map_compression_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package utils

import (
"bytes"
"compress/gzip"
"encoding/base64"
"testing"

"github.com/klauspost/compress/zstd"
)

func compressGzip(t *testing.T, data string) string {
t.Helper()
var buf bytes.Buffer
w := gzip.NewWriter(&buf)
if _, err := w.Write([]byte(data)); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
return base64.StdEncoding.EncodeToString(buf.Bytes())
}

func compressZstd(t *testing.T, data string) string {
t.Helper()
var buf bytes.Buffer
w, err := zstd.NewWriter(&buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write([]byte(data)); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
return base64.StdEncoding.EncodeToString(buf.Bytes())
}

func TestDecompressGzip(t *testing.T) {
original := "Section 42 of the Children and Families Act 2014 imposes an absolute duty on the local authority to secure the provision specified in Section F."
encoded := compressGzip(t, original)

result, err := decompress(encoded, "gzip")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != original {
t.Errorf("got %q, want %q", result, original)
}
}

func TestDecompressZstd(t *testing.T) {
original := "Section 42 of the Children and Families Act 2014 imposes an absolute duty on the local authority to secure the provision specified in Section F."
encoded := compressZstd(t, original)

result, err := decompress(encoded, "zstd")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != original {
t.Errorf("got %q, want %q", result, original)
}
}

func TestDecompressUnsupportedAlgorithm(t *testing.T) {
_, err := decompress(base64.StdEncoding.EncodeToString([]byte("test")), "lz4")
if err == nil {
t.Fatal("expected error for unsupported algorithm")
}
}

func TestDecompressInvalidBase64(t *testing.T) {
_, err := decompress("not-valid-base64!!!", "gzip")
if err == nil {
t.Fatal("expected error for invalid base64")
}
}
Comment thread
tjorourke marked this conversation as resolved.
4 changes: 2 additions & 2 deletions go/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ require (
require (
github.com/aws/aws-sdk-go-v2 v1.41.5
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.4
github.com/google/jsonschema-go v0.4.2
github.com/jackc/pgx/v5 v5.9.1
github.com/klauspost/compress v1.18.5
github.com/ollama/ollama v0.20.5
github.com/testcontainers/testcontainers-go v0.42.0
github.com/testcontainers/testcontainers-go/modules/postgres v0.42.0
Expand Down Expand Up @@ -154,7 +156,6 @@ require (
github.com/google/cel-go v0.26.0 // indirect
github.com/google/gnostic-models v0.7.0 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/jsonschema-go v0.4.2 // indirect
github.com/google/s2a-go v0.1.9 // indirect
github.com/google/safehtml v0.1.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect
Expand All @@ -168,7 +169,6 @@ require (
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.5 // indirect
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
github.com/lestrrat-go/httprc v1.0.6 // indirect
Expand Down
Loading