-
Notifications
You must be signed in to change notification settings - Fork 536
feat: transparent zstd/gzip decompression for ConfigMap prompt data sources #1696
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c84d87b
0c9b512
3ba8e3b
8921406
ff6416d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,28 +1,116 @@ | ||
| package utils | ||
|
|
||
| import ( | ||
| "bytes" | ||
| "compress/gzip" | ||
| "context" | ||
| "encoding/base64" | ||
| "fmt" | ||
| "io" | ||
| "strings" | ||
|
|
||
| "github.com/klauspost/compress/zstd" | ||
| corev1 "k8s.io/api/core/v1" | ||
| "sigs.k8s.io/controller-runtime/pkg/client" | ||
| ) | ||
|
|
||
| // GetConfigMapData fetches all data from a ConfigMap. | ||
| const ( | ||
| // CompressionAnnotation specifies the compression algorithm used for ConfigMap | ||
| // values. Supported values: "gzip", "zstd". When set, all values in the | ||
| // ConfigMap are expected to be base64-encoded compressed data and will be | ||
| // transparently decompressed when read via GetConfigMapData. | ||
| CompressionAnnotation = "kagent.dev/compression" | ||
|
|
||
| // maxDecompressedSize is the upper bound on decompressed output (10 MB). | ||
| // This prevents a small compressed payload from expanding into an | ||
| // arbitrarily large allocation that could OOM the controller. | ||
| maxDecompressedSize = 10 << 20 // 10 MiB | ||
| ) | ||
|
|
||
| // GetConfigMapData fetches all data from a ConfigMap. If the ConfigMap carries | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This data gets rewritten to a secret, so I don't think really makes sense unless we also compress that outgoing data as well
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah ok, i can fix that - but i think we should yes. Just compress it here too. Similar pattern to helm data in secret. |
||
| // the kagent.dev/compression annotation, values are transparently decompressed. | ||
| // Compressed values must be base64-encoded in the ConfigMap's Data field (not BinaryData). | ||
| func GetConfigMapData(ctx context.Context, c client.Client, ref client.ObjectKey) (map[string]string, error) { | ||
| configMap := &corev1.ConfigMap{} | ||
| if err := c.Get(ctx, ref, configMap); err != nil { | ||
| return nil, fmt.Errorf("failed to find ConfigMap %s: %v", ref.String(), err) | ||
| return nil, fmt.Errorf("failed to find ConfigMap %s: %w", ref.String(), err) | ||
| } | ||
|
tjorourke marked this conversation as resolved.
|
||
|
|
||
| algo := strings.ToLower(strings.TrimSpace(configMap.Annotations[CompressionAnnotation])) | ||
| if algo == "" { | ||
| return configMap.Data, nil | ||
| } | ||
|
|
||
| decompressed := make(map[string]string, len(configMap.Data)) | ||
| for key, value := range configMap.Data { | ||
| plain, err := decompress(value, algo) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("failed to decompress key %q in ConfigMap %s (algorithm=%s): %w", key, ref.String(), algo, err) | ||
| } | ||
| decompressed[key] = plain | ||
| } | ||
| return decompressed, nil | ||
| } | ||
|
|
||
| // decompress decodes base64 data and decompresses it with the given algorithm. | ||
| // The encoded payload is whitespace-tolerant (newlines and spaces are stripped | ||
| // before decoding) and decompressed output is capped at maxDecompressedSize. | ||
| func decompress(encoded string, algo string) (string, error) { | ||
| // Strip whitespace/newlines that commonly appear in pasted base64 | ||
| cleaned := strings.Map(func(r rune) rune { | ||
| if r == ' ' || r == '\n' || r == '\r' || r == '\t' { | ||
| return -1 | ||
| } | ||
| return r | ||
| }, encoded) | ||
|
|
||
| raw, err := base64.StdEncoding.DecodeString(cleaned) | ||
| if err != nil { | ||
| return "", fmt.Errorf("base64 decode: %w", err) | ||
| } | ||
|
|
||
| switch algo { | ||
| case "gzip": | ||
| r, err := gzip.NewReader(bytes.NewReader(raw)) | ||
| if err != nil { | ||
| return "", fmt.Errorf("gzip reader: %w", err) | ||
| } | ||
| defer r.Close() | ||
| out, err := io.ReadAll(io.LimitReader(r, maxDecompressedSize+1)) | ||
| if err != nil { | ||
| return "", fmt.Errorf("gzip read: %w", err) | ||
| } | ||
| if len(out) > maxDecompressedSize { | ||
| return "", fmt.Errorf("decompressed output exceeds %d bytes limit", maxDecompressedSize) | ||
| } | ||
| return string(out), nil | ||
|
|
||
| case "zstd": | ||
| r, err := zstd.NewReader(bytes.NewReader(raw)) | ||
| if err != nil { | ||
| return "", fmt.Errorf("zstd reader: %w", err) | ||
| } | ||
| defer r.Close() | ||
| out, err := io.ReadAll(io.LimitReader(r, maxDecompressedSize+1)) | ||
| if err != nil { | ||
| return "", fmt.Errorf("zstd read: %w", err) | ||
| } | ||
| if len(out) > maxDecompressedSize { | ||
| return "", fmt.Errorf("decompressed output exceeds %d bytes limit", maxDecompressedSize) | ||
| } | ||
| return string(out), nil | ||
|
|
||
| default: | ||
| return "", fmt.Errorf("unsupported compression algorithm %q (supported: gzip, zstd)", algo) | ||
| } | ||
| return configMap.Data, nil | ||
| } | ||
|
|
||
| // GetConfigMapValue fetches a value from a ConfigMap | ||
| func GetConfigMapValue(ctx context.Context, c client.Client, ref client.ObjectKey, key string) (string, error) { | ||
| configMap := &corev1.ConfigMap{} | ||
| err := c.Get(ctx, ref, configMap) | ||
| if err != nil { | ||
| return "", fmt.Errorf("failed to find ConfigMap for %s: %v", ref.String(), err) | ||
| return "", fmt.Errorf("failed to find ConfigMap for %s: %w", ref.String(), err) | ||
| } | ||
|
|
||
| value, exists := configMap.Data[key] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| package utils | ||
|
|
||
| import ( | ||
| "bytes" | ||
| "compress/gzip" | ||
| "encoding/base64" | ||
| "strings" | ||
| "testing" | ||
|
|
||
| "github.com/klauspost/compress/zstd" | ||
| ) | ||
|
|
||
| func compressGzip(t *testing.T, data string) string { | ||
| t.Helper() | ||
| var buf bytes.Buffer | ||
| w := gzip.NewWriter(&buf) | ||
| if _, err := w.Write([]byte(data)); err != nil { | ||
| t.Fatal(err) | ||
| } | ||
| if err := w.Close(); err != nil { | ||
| t.Fatal(err) | ||
| } | ||
| return base64.StdEncoding.EncodeToString(buf.Bytes()) | ||
| } | ||
|
|
||
| func compressZstd(t *testing.T, data string) string { | ||
| t.Helper() | ||
| var buf bytes.Buffer | ||
| w, err := zstd.NewWriter(&buf) | ||
| if err != nil { | ||
| t.Fatal(err) | ||
| } | ||
| if _, err := w.Write([]byte(data)); err != nil { | ||
| t.Fatal(err) | ||
| } | ||
| if err := w.Close(); err != nil { | ||
| t.Fatal(err) | ||
| } | ||
| return base64.StdEncoding.EncodeToString(buf.Bytes()) | ||
| } | ||
|
|
||
| func TestDecompressGzip(t *testing.T) { | ||
| original := "Section 42 of the Children and Families Act 2014 imposes an absolute duty on the local authority to secure the provision specified in Section F." | ||
| encoded := compressGzip(t, original) | ||
|
|
||
| result, err := decompress(encoded, "gzip") | ||
| if err != nil { | ||
| t.Fatalf("unexpected error: %v", err) | ||
| } | ||
| if result != original { | ||
| t.Errorf("got %q, want %q", result, original) | ||
| } | ||
| } | ||
|
|
||
| func TestDecompressZstd(t *testing.T) { | ||
| original := "Section 42 of the Children and Families Act 2014 imposes an absolute duty on the local authority to secure the provision specified in Section F." | ||
| encoded := compressZstd(t, original) | ||
|
|
||
| result, err := decompress(encoded, "zstd") | ||
| if err != nil { | ||
| t.Fatalf("unexpected error: %v", err) | ||
| } | ||
| if result != original { | ||
| t.Errorf("got %q, want %q", result, original) | ||
| } | ||
| } | ||
|
|
||
| func TestDecompressUnsupportedAlgorithm(t *testing.T) { | ||
| _, err := decompress(base64.StdEncoding.EncodeToString([]byte("test")), "lz4") | ||
| if err == nil { | ||
| t.Fatal("expected error for unsupported algorithm") | ||
| } | ||
| } | ||
|
|
||
| func TestDecompressInvalidBase64(t *testing.T) { | ||
| _, err := decompress("not-valid-base64!!!", "gzip") | ||
| if err == nil { | ||
| t.Fatal("expected error for invalid base64") | ||
| } | ||
| } | ||
|
tjorourke marked this conversation as resolved.
|
||
|
|
||
| func TestDecompressBase64WithWhitespace(t *testing.T) { | ||
| original := "Whitespace in base64 is common when users paste wrapped output." | ||
| clean := compressGzip(t, original) | ||
|
|
||
| // Insert newlines and spaces to simulate wrapped base64 | ||
| wrapped := clean[:20] + "\n" + clean[20:40] + " " + clean[40:60] + "\r\n" + clean[60:] | ||
|
|
||
| result, err := decompress(wrapped, "gzip") | ||
| if err != nil { | ||
| t.Fatalf("unexpected error with whitespace in base64: %v", err) | ||
| } | ||
| if result != original { | ||
| t.Errorf("got %q, want %q", result, original) | ||
| } | ||
| } | ||
|
|
||
| func TestDecompressExceedsSizeLimit(t *testing.T) { | ||
| // Create data larger than maxDecompressedSize (10MB) | ||
| // zstd compresses repeated data extremely well, so a small input can exceed the limit | ||
| huge := make([]byte, maxDecompressedSize+1) | ||
| for i := range huge { | ||
| huge[i] = 'A' | ||
| } | ||
|
|
||
| var buf bytes.Buffer | ||
| w, err := zstd.NewWriter(&buf) | ||
| if err != nil { | ||
| t.Fatal(err) | ||
| } | ||
| if _, err := w.Write(huge); err != nil { | ||
| t.Fatal(err) | ||
| } | ||
| w.Close() | ||
| encoded := base64.StdEncoding.EncodeToString(buf.Bytes()) | ||
|
|
||
| _, err = decompress(encoded, "zstd") | ||
| if err == nil { | ||
| t.Fatal("expected error for oversized decompressed output") | ||
| } | ||
| if !strings.Contains(err.Error(), "exceeds") { | ||
| t.Errorf("expected 'exceeds' in error message, got: %v", err) | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.