Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/communication/websocket_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ async def connect(
client_id: Optional[str] = None,
subprotocol: Optional[str] = None,
channel: Optional[str] = None,
limits_reserved: bool = False,
) -> bool:
"""
Accept new WebSocket connection.
Expand All @@ -300,6 +301,15 @@ async def connect(
``juniper_canopy_websocket_messages_total{channel, type}``.
Pass ``None`` (default) on the legacy ``/ws`` compat route
to skip metric emission and preserve closed-set discipline.
limits_reserved: True when the caller has already incremented the
per-IP/per-session counters via ``check_connection_limits``.
When the global cap rejects here, those reservations must be
released because ``disconnect()`` only decrements counters for
websockets that were added to ``active_connections``.

Returns:
True when the websocket was accepted and tracked; False when the
global cap rejected the connection.

Example:
await websocket_manager.connect(websocket, client_id='dashboard-1')
Expand All @@ -323,6 +333,9 @@ async def connect(
else:
close_ws = False
if close_ws:
if limits_reserved:
self._decrement_ip_count(websocket)
self._decrement_session_count(websocket)
await websocket.close(code=1013, reason="Max connections reached")
return False

Expand Down
25 changes: 25 additions & 0 deletions src/tests/unit/test_ws_connection_caps.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,31 @@ async def test_global_cap_rejects_n_plus_1_stackwide(self):
assert mgr.get_connection_count() == 3
assert ws4 not in mgr.active_connections

async def test_global_reject_after_reserved_limits_releases_counters(self):
"""A global-cap rejection must not leak per-IP or per-session slots.

FastAPI endpoints reserve per-IP/per-session capacity via
``check_connection_limits`` before calling ``connect()``, where the
global stack-wide cap is enforced. If that later global check rejects,
the already-reserved slots must be released or future legitimate
connections can be blocked by stale counters.
"""
mgr = WebSocketManager()
mgr.max_connections = 1

admitted = _make_ws(ip="10.0.0.1", session="sess-A")
assert mgr.check_connection_limits(admitted, max_per_ip=5, max_per_session=5) is True
assert await mgr.connect(admitted, limits_reserved=True) is True

rejected = _make_ws(ip="10.0.0.2", session="sess-B")
assert mgr.check_connection_limits(rejected, max_per_ip=5, max_per_session=5) is True
assert await mgr.connect(rejected, limits_reserved=True) is False

rejected.close.assert_awaited_once_with(code=1013, reason="Max connections reached")
assert rejected not in mgr.active_connections
assert mgr._per_ip_counts == {"10.0.0.1": 1}
assert mgr._per_session_counts == {"sess-A": 1}

async def test_global_cap_rejection_releases_reserved_session_slots(self):
"""A connect-time global-cap reject must not strand per-IP/session slots."""
mgr = WebSocketManager()
Expand Down
Loading