Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 37 additions & 28 deletions pkg/authz/response_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,35 +187,44 @@ func (rfw *ResponseFilteringWriter) processSSEResponse(rawResponse []byte) error
var written bool
if data, ok := bytes.CutPrefix(line, []byte("data:")); ok {
message, err := jsonrpc2.DecodeMessage(data)
if err != nil {
rfw.ResponseWriter.WriteHeader(rfw.statusCode)
_, err := rfw.ResponseWriter.Write(rawResponse)
return err
}

response, ok := message.(*jsonrpc2.Response)
if !ok {
rfw.ResponseWriter.WriteHeader(rfw.statusCode)
_, err := rfw.ResponseWriter.Write(rawResponse)
return err
}

filteredResponse, err := rfw.filterListResponse(response)
if err != nil {
return rfw.writeErrorResponse(response.ID, err)
}

filteredData, err := jsonrpc2.EncodeMessage(filteredResponse)
if err != nil {
return rfw.writeErrorResponse(response.ID, err)
switch {
case err != nil:
// Pass this line through unfiltered. Earlier revisions wrote
// rawResponse and returned here, which leaked every subsequent
// data line on the stream past the filter (issue #5257). The
// WARN fires for every filtered method (tools/list,
// prompts/list, resources/list, find_tool) because the bypass
// applies equally to all of them.
slog.Warn("SSE data line could not be decoded as JSON-RPC; passing through unfiltered",
"method", rfw.method, "error", err)
default:
if response, ok := message.(*jsonrpc2.Response); ok {
filteredResponse, err := rfw.filterListResponse(response)
if err != nil {
return rfw.writeErrorResponse(response.ID, err)
}

filteredData, err := jsonrpc2.EncodeMessage(filteredResponse)
if err != nil {
return rfw.writeErrorResponse(response.ID, err)
}

_, err = rfw.ResponseWriter.Write([]byte("data: " + string(filteredData) + "\n"))
if err != nil {
return fmt.Errorf("%w: %w", errBug, err)
}

written = true
} else {
// Non-Response message (e.g. a notifications/* frame
// interleaved on the stream). Pass through unfiltered for
// this line only; the next data line may still be the real
// response and must reach the filter. Logs at WARN for
// every filtered method, not just tools/list.
slog.Warn("SSE data line was not a JSON-RPC Response; passing through unfiltered",
"method", rfw.method)
}
}

_, err = rfw.ResponseWriter.Write([]byte("data: " + string(filteredData) + "\n"))
if err != nil {
return fmt.Errorf("%w: %w", errBug, err)
}

written = true
}

if !written {
Expand Down
197 changes: 197 additions & 0 deletions pkg/authz/response_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -863,3 +863,200 @@ func TestOptimizerPassThroughToolsInResponseFilter(t *testing.T) {
"admin_tool has no permit policy and is not a pass-through tool")
})
}

// TestResponseFilteringWriter_SSE_PerLineFallthrough is a regression test for
// issue #5257: when an SSE upstream interleaves a non-Response data line (e.g.
// an MCP notification) or an undecodable data line with a real list response,
// the filter previously wrote the entire raw upstream payload and returned,
// leaking the unfiltered list past Cedar. It must instead pass only the
// offending line through and continue filtering the rest of the stream.
//
// The same code path runs for every method covered by
// requiresResponseFiltering, so each of tools/list, prompts/list, and
// resources/list is exercised below.
func TestResponseFilteringWriter_SSE_PerLineFallthrough(t *testing.T) {
t.Parallel()

authorizer, err := cedar.NewCedarAuthorizer(cedar.ConfigOptions{
Policies: []string{
`permit(principal, action == Action::"call_tool", resource == Tool::"weather");`,
`permit(principal, action == Action::"get_prompt", resource == Prompt::"greeting");`,
`permit(principal, action == Action::"read_resource", resource == Resource::"data");`,
},
EntitiesJSON: `[]`,
}, "")
require.NoError(t, err)

identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{
Subject: "user1",
Claims: map[string]interface{}{"sub": "user1"},
}}

// encodeListResponse marshals a list result type into a JSON-RPC Response
// data line.
encodeListResponse := func(t *testing.T, result interface{}) string {
t.Helper()
resultJSON, err := json.Marshal(result)
require.NoError(t, err)
encoded, err := jsonrpc2.EncodeMessage(&jsonrpc2.Response{
ID: jsonrpc2.Int64ID(1),
Result: json.RawMessage(resultJSON),
})
require.NoError(t, err)
return "data: " + string(encoded)
}

// methodCase describes how to build a filterable response for one MCP
// list method and how to read the filtered names out of the wire output.
type methodCase struct {
name string
method string
respLine string
authorizedName string
unauthorizedName string
extractNames func(t *testing.T, result json.RawMessage) []string
}

methodCases := []methodCase{
{
name: "tools/list",
method: string(mcp.MethodToolsList),
respLine: encodeListResponse(t, mcp.ListToolsResult{
Tools: []mcp.Tool{
{Name: "weather", Description: "Get weather information"},
{Name: "admin_tool", Description: "Sensitive admin operations"},
},
}),
authorizedName: "weather",
unauthorizedName: "admin_tool",
extractNames: func(t *testing.T, result json.RawMessage) []string {
t.Helper()
var r mcp.ListToolsResult
require.NoError(t, json.Unmarshal(result, &r))
names := make([]string, len(r.Tools))
for i, tool := range r.Tools {
names[i] = tool.Name
}
return names
},
},
{
name: "prompts/list",
method: string(mcp.MethodPromptsList),
respLine: encodeListResponse(t, mcp.ListPromptsResult{
Prompts: []mcp.Prompt{
{Name: "greeting", Description: "Generate greetings"},
{Name: "admin_prompt", Description: "Sensitive admin prompt"},
},
}),
authorizedName: "greeting",
unauthorizedName: "admin_prompt",
extractNames: func(t *testing.T, result json.RawMessage) []string {
t.Helper()
var r mcp.ListPromptsResult
require.NoError(t, json.Unmarshal(result, &r))
names := make([]string, len(r.Prompts))
for i, p := range r.Prompts {
names[i] = p.Name
}
return names
},
},
{
name: "resources/list",
method: string(mcp.MethodResourcesList),
respLine: encodeListResponse(t, mcp.ListResourcesResult{
Resources: []mcp.Resource{
{URI: "data", Name: "Data Resource"},
{URI: "secret", Name: "Sensitive Resource"},
},
}),
authorizedName: "data",
unauthorizedName: "secret",
extractNames: func(t *testing.T, result json.RawMessage) []string {
t.Helper()
var r mcp.ListResourcesResult
require.NoError(t, json.Unmarshal(result, &r))
names := make([]string, len(r.Resources))
for i, res := range r.Resources {
names[i] = res.URI
}
return names
},
},
}

precedingLineCases := []struct {
name string
line string
}{
{
name: "non-response data line",
// A notifications/* frame is a valid JSON-RPC notification
// (no id), so jsonrpc2.DecodeMessage returns a non-Response
// message. The buggy path treated this as a signal to dump
// rawResponse and return.
line: `data: {"jsonrpc":"2.0","method":"notifications/message","params":{"level":"info","data":"warming up"}}`,
},
{
name: "undecodable data line",
line: `data: this is not json at all`,
},
}

for _, mc := range methodCases {
for _, plc := range precedingLineCases {
mc, plc := mc, plc
t.Run(mc.name+"/"+plc.name, func(t *testing.T) {
t.Parallel()

req, err := http.NewRequest(http.MethodPost, "/messages", nil)
require.NoError(t, err)
req = req.WithContext(auth.WithIdentity(req.Context(), identity))

rr := httptest.NewRecorder()
rfw := NewResponseFilteringWriter(rr, authorizer, req, mc.method, nil, nil)
rfw.ResponseWriter.Header().Set("Content-Type", "text/event-stream")

body := strings.Join([]string{plc.line, mc.respLine, ""}, "\n")
_, err = rfw.Write([]byte(body))
require.NoError(t, err)

require.NoError(t, rfw.FlushAndFilter())

out := rr.Body.String()

// The preceding line must still appear verbatim; pass-through
// is the whole point of the fix.
assert.Contains(t, out, plc.line,
"non-response/undecodable preceding line must pass through unchanged")

// The real list response must have been filtered. Pull the
// last JSON-RPC Response data line out and decode it.
var filteredLine string
for _, line := range strings.Split(out, "\n") {
if strings.HasPrefix(line, "data: {\"jsonrpc\"") && strings.Contains(line, `"result"`) {
filteredLine = line
}
}
require.NotEmpty(t, filteredLine, "no JSON-RPC Response data line found in output")

payload := strings.TrimPrefix(filteredLine, "data: ")
msg, err := jsonrpc2.DecodeMessage([]byte(payload))
require.NoError(t, err)
resp, ok := msg.(*jsonrpc2.Response)
require.True(t, ok)

names := mc.extractNames(t, resp.Result)
assert.Contains(t, names, mc.authorizedName, "authorized entry must be retained")
assert.NotContains(t, names, mc.unauthorizedName,
"unauthorized entry must be filtered; presence indicates the cedar bypass from #5257 is back")

// And the raw unfiltered payload (the bug used to dump it)
// must not appear in the wire output.
assert.NotContains(t, out, `"`+mc.unauthorizedName+`"`,
"unfiltered list payload leaked into SSE output")
})
}
}
}
Loading