Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions pkg/runner/aggregate_sender.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,13 @@ func (as aggregateSender) send(fileName string, ts time.Time, duration time.Dura
if err != nil {
return fmt.Errorf("sendAggregateFile: unable to send request, elapsed time %s: %w", elapsedTime, err)
}
defer res.Body.Close()

bodyData, err := io.ReadAll(res.Body)
if err != nil {
return fmt.Errorf("sendAggregateFile: unable to read response body: %w", err)
}

err = res.Body.Close()
if err != nil {
return fmt.Errorf("sendAggregateFile: unable to close HTTP body: %w", err)
}

if res.StatusCode != http.StatusCreated {
as.log.Error(string(bodyData))
return fmt.Errorf("sendAggregateFile: unexpected status code: %d", res.StatusCode)
Expand Down
95 changes: 95 additions & 0 deletions pkg/runner/aggregate_sender_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package runner

import (
"crypto/ed25519"
"crypto/rand"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"net/url"
"os"
"testing"
"time"

"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
)

func TestAggregateSenderClosesBodyOnReadError(t *testing.T) {
closed := make(chan struct{})
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, err := io.Copy(io.Discard, r.Body); err != nil {
t.Errorf("draining request body: %s", err)
return
}

hj, ok := w.(http.Hijacker)
if !ok {
t.Error("ResponseWriter does not implement Hijacker")
return
}
conn, buf, err := hj.Hijack()
if err != nil {
t.Errorf("Hijack: %s", err)
return
}
_, _ = buf.WriteString("HTTP/1.1 201 Created\r\nContent-Length: 1000\r\nLocation: /stored\r\n\r\ntruncated")
_ = buf.Flush()
_ = conn.Close()
close(closed)
}))
t.Cleanup(server.Close)

file, err := os.CreateTemp(t.TempDir(), "aggregate-*.parquet")
if err != nil {
t.Fatalf("CreateTemp: %s", err)
}
if _, err := file.WriteString("payload"); err != nil {
t.Fatalf("write temp aggregate: %s", err)
}
if err := file.Close(); err != nil {
t.Fatalf("close temp aggregate: %s", err)
}

_, signingKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("GenerateKey: %s", err)
}
signingJWK, err := jwk.FromRaw(signingKey)
if err != nil {
t.Fatalf("FromRaw: %s", err)
}
if err := signingJWK.Set(jwk.AlgorithmKey, jwa.EdDSA); err != nil {
t.Fatalf("set Algorithm: %s", err)
}
if err := signingJWK.Set(jwk.KeyIDKey, "test-key"); err != nil {
t.Fatalf("set KeyID: %s", err)
}

aggrecURL, err := url.Parse(server.URL)
if err != nil {
t.Fatalf("parse server URL: %s", err)
}

edm := &dnstapMinimiser{
log: slog.New(slog.NewTextHandler(io.Discard, nil)),
httpClientCertStore: newCertStore(),
}
as, err := edm.newAggregateSender(aggrecURL, signingJWK, nil)
if err != nil {
t.Fatalf("newAggregateSender: %s", err)
}

start := time.Date(2026, 4, 29, 12, 34, 45, 0, time.UTC)
err = as.send(file.Name(), start, 45*time.Second)
if err == nil {
t.Fatal("expected error from send when response body is truncated")
}

select {
case <-closed:
case <-time.After(2 * time.Second):
t.Fatal("server handler did not run")
}
}