diff --git a/pkg/utils/apiutil/apiutil.go b/pkg/utils/apiutil/apiutil.go index 8aaa79485aa..c45046de029 100644 --- a/pkg/utils/apiutil/apiutil.go +++ b/pkg/utils/apiutil/apiutil.go @@ -484,7 +484,50 @@ func NewCustomReverseProxies(dialClient *http.Client, urls []url.URL) http.Handl return p } +// EnsureRewindableBody makes r's body safe to be retried by net/http +// Transport on connection loss. It is a no-op when the body is nil, +// http.NoBody, or r.GetBody is already set. Otherwise it drains the body +// into memory and wires up GetBody so the transport can rewind. +// +// Callers should ensure the body fits in memory; this helper buffers the +// entire payload. It guards against the +// "net/http: cannot rewind body after connection loss" error that can +// occur when a server-side request (GetBody == nil) is forwarded via +// http.Client.Do and the underlying keep-alive connection goes stale. +func EnsureRewindableBody(r *http.Request) error { + if r.Body == nil || r.Body == http.NoBody || r.GetBody != nil { + return nil + } + buf, err := io.ReadAll(r.Body) + _ = r.Body.Close() + if err != nil { + return err + } + // Restore NoBody semantics for empty payloads so that Transport's + // outgoingLength returns 0 and no body probing happens. + if len(buf) == 0 { + r.Body = http.NoBody + r.GetBody = func() (io.ReadCloser, error) { return http.NoBody, nil } + r.ContentLength = 0 + return nil + } + r.Body = io.NopCloser(bytes.NewReader(buf)) + r.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(buf)), nil + } + // We now know the exact length; set it so the transport can pick + // Content-Length framing over chunked encoding when forwarding. + r.ContentLength = int64(len(buf)) + return nil +} + func (p *customReverseProxies) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if err := EnsureRewindableBody(r); err != nil { + log.Error("failed to read request body", zap.Error(err)) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + for _, url := range p.urls { r.RequestURI = "" r.URL.Host = url.Host diff --git a/pkg/utils/apiutil/apiutil_test.go b/pkg/utils/apiutil/apiutil_test.go index 9a9139d43f8..62724280bf0 100644 --- a/pkg/utils/apiutil/apiutil_test.go +++ b/pkg/utils/apiutil/apiutil_test.go @@ -16,6 +16,7 @@ package apiutil import ( "bytes" + "errors" "io" "net/http" "net/http/httptest" @@ -210,6 +211,108 @@ func TestGetIPPortFromHTTPRequest(t *testing.T) { } } +type errReader struct{ err error } + +func (e *errReader) Read(_ []byte) (int, error) { return 0, e.err } +func (*errReader) Close() error { return nil } + +type closeTracker struct { + io.Reader + closed bool +} + +func (c *closeTracker) Close() error { + c.closed = true + return nil +} + +func TestEnsureRewindableBody(t *testing.T) { + t.Run("nil body is a no-op", func(t *testing.T) { + re := require.New(t) + r := &http.Request{} + re.NoError(EnsureRewindableBody(r)) + re.Nil(r.Body) + re.Nil(r.GetBody) + }) + + t.Run("http.NoBody is a no-op", func(t *testing.T) { + re := require.New(t) + r := &http.Request{Body: http.NoBody} + re.NoError(EnsureRewindableBody(r)) + re.Equal(http.NoBody, r.Body) + re.Nil(r.GetBody) + }) + + t.Run("existing GetBody is preserved", func(t *testing.T) { + re := require.New(t) + orig := io.NopCloser(bytes.NewBufferString("payload")) + called := false + getBody := func() (io.ReadCloser, error) { + called = true + return io.NopCloser(bytes.NewBufferString("payload")), nil + } + r := &http.Request{Body: orig, GetBody: getBody} + re.NoError(EnsureRewindableBody(r)) + // Body untouched, GetBody untouched (we only check it wasn't replaced + // by invoking it and confirming our own sentinel). + re.Equal(orig, r.Body) + _, err := r.GetBody() + re.NoError(err) + re.True(called) + }) + + t.Run("empty body is restored to NoBody", func(t *testing.T) { + re := require.New(t) + tracker := &closeTracker{Reader: bytes.NewReader(nil)} + r := &http.Request{Body: tracker} + re.NoError(EnsureRewindableBody(r)) + re.True(tracker.closed, "original body should be closed") + re.Equal(http.NoBody, r.Body) + re.EqualValues(0, r.ContentLength) + re.NotNil(r.GetBody) + rc, err := r.GetBody() + re.NoError(err) + re.Equal(http.NoBody, rc) + }) + + t.Run("non-empty body becomes rewindable", func(t *testing.T) { + re := require.New(t) + payload := []byte(`{"hello":"world"}`) + tracker := &closeTracker{Reader: bytes.NewReader(payload)} + r := &http.Request{Body: tracker, ContentLength: -1} + re.NoError(EnsureRewindableBody(r)) + re.True(tracker.closed, "original body should be closed") + re.EqualValues(len(payload), r.ContentLength) + re.NotNil(r.GetBody) + + // Draining r.Body once should yield the payload. + got, err := io.ReadAll(r.Body) + re.NoError(err) + re.Equal(payload, got) + re.NoError(r.Body.Close()) + + // GetBody should be invokable multiple times and each returned + // ReadCloser should independently yield the same payload -- this is + // the actual rewindability guarantee we care about. + for range 3 { + rc, err := r.GetBody() + re.NoError(err) + got, err := io.ReadAll(rc) + re.NoError(err) + re.Equal(payload, got) + re.NoError(rc.Close()) + } + }) + + t.Run("read error is propagated", func(t *testing.T) { + re := require.New(t) + wantErr := errors.New("boom") + r := &http.Request{Body: &errReader{err: wantErr}} + err := EnsureRewindableBody(r) + re.ErrorIs(err, wantErr) + }) +} + func TestParseHexKeys(t *testing.T) { re := require.New(t) // Test for hex format diff --git a/pkg/utils/requestutil/request_info.go b/pkg/utils/requestutil/request_info.go index cc5403f7232..b74100713c4 100644 --- a/pkg/utils/requestutil/request_info.go +++ b/pkg/utils/requestutil/request_info.go @@ -15,7 +15,6 @@ package requestutil import ( - "bytes" "encoding/json" "fmt" "io" @@ -69,13 +68,16 @@ func getURLParam(r *http.Request) string { } func getBodyParam(r *http.Request) string { - if r.Body == nil { + // Make the body rewindable so downstream forwarding can retry on + // connection loss, then read it back via GetBody for audit logging. + if err := apiutil.EnsureRewindableBody(r); err != nil || r.GetBody == nil { return "" } - // http request body is a io.Reader between bytes.Reader and strings.Reader, it only has EOF error - buf, _ := io.ReadAll(r.Body) - r.Body.Close() - bodyParam := string(buf) - r.Body = io.NopCloser(bytes.NewBuffer(buf)) - return bodyParam + rc, err := r.GetBody() + if err != nil { + return "" + } + defer rc.Close() + buf, _ := io.ReadAll(rc) + return string(buf) }