fix(copilot): address third CodeRabbit review cycle on proxy

- Preserve multi-valued response headers (e.g. Set-Cookie) by using
  clean_response_headers -> CIMultiDict instead of dict(headers)
- Use sock_connect + sock_read timeouts instead of total so long-lived
  SSE streaming responses aren't killed after 600s
- Log the configured bind_host instead of hardcoded 127.0.0.1
This commit is contained in:
majdyz
2026-04-12 07:43:12 +00:00
parent cc3bac13c5
commit 05477f2daa

View File

@@ -213,12 +213,23 @@ def clean_request_body_bytes(body_bytes: bytes) -> bytes:
return json.dumps(payload, separators=(",", ":")).encode("utf-8")
def _parse_connection_tokens(headers: dict[str, str]) -> set[str]:
"""Extract hop-by-hop header names from the ``Connection`` field."""
connection_header = next(
(value for name, value in headers.items() if name.lower() == "connection"),
"",
)
return {
token.strip().lower() for token in connection_header.split(",") if token.strip()
}
def clean_request_headers(headers: dict[str, str]) -> dict[str, str]:
"""Drop hop-by-hop headers and rewrite ``anthropic-beta`` to remove
forbidden tokens. Returns a fresh dict the caller can pass through
to the upstream client without further mutation.
Per RFC 7230 §6.1, intermediaries must drop the static hop-by-hop
Per RFC 7230 section 6.1, intermediaries must drop the static hop-by-hop
set above **and** every header name listed in the incoming
``Connection`` field value (case-insensitive). The latter is how
extension hop-by-hop headers are signalled per-connection.
@@ -226,17 +237,7 @@ def clean_request_headers(headers: dict[str, str]) -> dict[str, str]:
Callers should pass an already-materialised ``dict`` (e.g.
``dict(request.headers)``) so this function stays simple.
"""
# Parse ``Connection: a, b, c`` into a lowercase token set so we
# can drop any header the sender explicitly marked as hop-by-hop
# on this connection. This is separate from the static set
# above — extension headers can be anything.
connection_header = next(
(value for name, value in headers.items() if name.lower() == "connection"),
"",
)
connection_tokens: set[str] = {
token.strip().lower() for token in connection_header.split(",") if token.strip()
}
connection_tokens = _parse_connection_tokens(headers)
cleaned: dict[str, str] = {}
for name, value in headers.items():
@@ -253,6 +254,40 @@ def clean_request_headers(headers: dict[str, str]) -> dict[str, str]:
return cleaned
def clean_response_headers(
headers: "Any",
) -> list[tuple[str, str]]:
"""Like :func:`clean_request_headers` but preserves multi-valued
headers (e.g. ``Set-Cookie``). Accepts any mapping-like object
whose ``.items()`` yields ``(name, value)`` pairs — including
aiohttp's ``CIMultiDictProxy`` which can have duplicate keys.
Returns a list of ``(name, value)`` tuples suitable for passing
to ``web.StreamResponse(headers=...)`` via ``CIMultiDict``.
"""
connection_tokens: set[str] = set()
for name, value in headers.items():
if name.lower() == "connection":
connection_tokens = {
t.strip().lower() for t in value.split(",") if t.strip()
}
break
cleaned: list[tuple[str, str]] = []
for name, value in headers.items():
lower_name = name.lower()
if lower_name in _HOP_BY_HOP_HEADERS or lower_name in connection_tokens:
continue
if lower_name == "anthropic-beta":
stripped = strip_forbidden_anthropic_beta_header(value)
if stripped is None:
continue
cleaned.append((name, stripped))
continue
cleaned.append((name, value))
return cleaned
# ---------------------------------------------------------------------------
# The proxy server
# ---------------------------------------------------------------------------
@@ -312,8 +347,16 @@ class OpenRouterCompatProxy:
"""
if self._runner is not None:
return # already started
# Use sock_connect + sock_read instead of total so long-lived
# SSE / streaming responses aren't killed after request_timeout.
# total=None means no cumulative limit; sock_read is the per-chunk
# idle timeout (time between data arriving on the socket).
client = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=self._request_timeout)
timeout=aiohttp.ClientTimeout(
total=None,
sock_connect=self._request_timeout,
sock_read=self._request_timeout,
)
)
app = web.Application()
# Catch every method + path so we can also forward GETs
@@ -358,7 +401,11 @@ class OpenRouterCompatProxy:
# endpoint is anyway discoverable from the config the operator
# already has access to. The detailed upstream is exposed via
# the ``target_base_url`` property for callers that need it.
logger.info("OpenRouter compat proxy listening on 127.0.0.1:%d", self._port)
logger.info(
"OpenRouter compat proxy listening on %s:%d",
self._bind_host,
self._port,
)
async def stop(self) -> None:
"""Stop accepting connections and release the port."""
@@ -433,10 +480,13 @@ class OpenRouterCompatProxy:
return web.Response(status=502, text="upstream error")
# Stream the response back unchanged (apart from hop-by-hop
# header filtering).
# header filtering). Use clean_response_headers to preserve
# multi-valued headers like Set-Cookie that dict() would drop.
from multidict import CIMultiDict
downstream = web.StreamResponse(
status=upstream_response.status,
headers=clean_request_headers(dict(upstream_response.headers)),
headers=CIMultiDict(clean_response_headers(upstream_response.headers)),
)
await downstream.prepare(request)
# Track whether the stream terminated cleanly. A mid-stream