mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
fix(copilot): address third CodeRabbit review cycle on proxy
- Preserve multi-valued response headers (e.g. Set-Cookie) by using clean_response_headers -> CIMultiDict instead of dict(headers) - Use sock_connect + sock_read timeouts instead of total so long-lived SSE streaming responses aren't killed after 600s - Log the configured bind_host instead of hardcoded 127.0.0.1
This commit is contained in:
@@ -213,12 +213,23 @@ def clean_request_body_bytes(body_bytes: bytes) -> bytes:
|
||||
return json.dumps(payload, separators=(",", ":")).encode("utf-8")
|
||||
|
||||
|
||||
def _parse_connection_tokens(headers: dict[str, str]) -> set[str]:
|
||||
"""Extract hop-by-hop header names from the ``Connection`` field."""
|
||||
connection_header = next(
|
||||
(value for name, value in headers.items() if name.lower() == "connection"),
|
||||
"",
|
||||
)
|
||||
return {
|
||||
token.strip().lower() for token in connection_header.split(",") if token.strip()
|
||||
}
|
||||
|
||||
|
||||
def clean_request_headers(headers: dict[str, str]) -> dict[str, str]:
|
||||
"""Drop hop-by-hop headers and rewrite ``anthropic-beta`` to remove
|
||||
forbidden tokens. Returns a fresh dict the caller can pass through
|
||||
to the upstream client without further mutation.
|
||||
|
||||
Per RFC 7230 §6.1, intermediaries must drop the static hop-by-hop
|
||||
Per RFC 7230 section 6.1, intermediaries must drop the static hop-by-hop
|
||||
set above **and** every header name listed in the incoming
|
||||
``Connection`` field value (case-insensitive). The latter is how
|
||||
extension hop-by-hop headers are signalled per-connection.
|
||||
@@ -226,17 +237,7 @@ def clean_request_headers(headers: dict[str, str]) -> dict[str, str]:
|
||||
Callers should pass an already-materialised ``dict`` (e.g.
|
||||
``dict(request.headers)``) so this function stays simple.
|
||||
"""
|
||||
# Parse ``Connection: a, b, c`` into a lowercase token set so we
|
||||
# can drop any header the sender explicitly marked as hop-by-hop
|
||||
# on this connection. This is separate from the static set
|
||||
# above — extension headers can be anything.
|
||||
connection_header = next(
|
||||
(value for name, value in headers.items() if name.lower() == "connection"),
|
||||
"",
|
||||
)
|
||||
connection_tokens: set[str] = {
|
||||
token.strip().lower() for token in connection_header.split(",") if token.strip()
|
||||
}
|
||||
connection_tokens = _parse_connection_tokens(headers)
|
||||
|
||||
cleaned: dict[str, str] = {}
|
||||
for name, value in headers.items():
|
||||
@@ -253,6 +254,40 @@ def clean_request_headers(headers: dict[str, str]) -> dict[str, str]:
|
||||
return cleaned
|
||||
|
||||
|
||||
def clean_response_headers(
|
||||
headers: "Any",
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Like :func:`clean_request_headers` but preserves multi-valued
|
||||
headers (e.g. ``Set-Cookie``). Accepts any mapping-like object
|
||||
whose ``.items()`` yields ``(name, value)`` pairs — including
|
||||
aiohttp's ``CIMultiDictProxy`` which can have duplicate keys.
|
||||
|
||||
Returns a list of ``(name, value)`` tuples suitable for passing
|
||||
to ``web.StreamResponse(headers=...)`` via ``CIMultiDict``.
|
||||
"""
|
||||
connection_tokens: set[str] = set()
|
||||
for name, value in headers.items():
|
||||
if name.lower() == "connection":
|
||||
connection_tokens = {
|
||||
t.strip().lower() for t in value.split(",") if t.strip()
|
||||
}
|
||||
break
|
||||
|
||||
cleaned: list[tuple[str, str]] = []
|
||||
for name, value in headers.items():
|
||||
lower_name = name.lower()
|
||||
if lower_name in _HOP_BY_HOP_HEADERS or lower_name in connection_tokens:
|
||||
continue
|
||||
if lower_name == "anthropic-beta":
|
||||
stripped = strip_forbidden_anthropic_beta_header(value)
|
||||
if stripped is None:
|
||||
continue
|
||||
cleaned.append((name, stripped))
|
||||
continue
|
||||
cleaned.append((name, value))
|
||||
return cleaned
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# The proxy server
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -312,8 +347,16 @@ class OpenRouterCompatProxy:
|
||||
"""
|
||||
if self._runner is not None:
|
||||
return # already started
|
||||
# Use sock_connect + sock_read instead of total so long-lived
|
||||
# SSE / streaming responses aren't killed after request_timeout.
|
||||
# total=None means no cumulative limit; sock_read is the per-chunk
|
||||
# idle timeout (time between data arriving on the socket).
|
||||
client = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=self._request_timeout)
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=None,
|
||||
sock_connect=self._request_timeout,
|
||||
sock_read=self._request_timeout,
|
||||
)
|
||||
)
|
||||
app = web.Application()
|
||||
# Catch every method + path so we can also forward GETs
|
||||
@@ -358,7 +401,11 @@ class OpenRouterCompatProxy:
|
||||
# endpoint is anyway discoverable from the config the operator
|
||||
# already has access to. The detailed upstream is exposed via
|
||||
# the ``target_base_url`` property for callers that need it.
|
||||
logger.info("OpenRouter compat proxy listening on 127.0.0.1:%d", self._port)
|
||||
logger.info(
|
||||
"OpenRouter compat proxy listening on %s:%d",
|
||||
self._bind_host,
|
||||
self._port,
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop accepting connections and release the port."""
|
||||
@@ -433,10 +480,13 @@ class OpenRouterCompatProxy:
|
||||
return web.Response(status=502, text="upstream error")
|
||||
|
||||
# Stream the response back unchanged (apart from hop-by-hop
|
||||
# header filtering).
|
||||
# header filtering). Use clean_response_headers to preserve
|
||||
# multi-valued headers like Set-Cookie that dict() would drop.
|
||||
from multidict import CIMultiDict
|
||||
|
||||
downstream = web.StreamResponse(
|
||||
status=upstream_response.status,
|
||||
headers=clean_request_headers(dict(upstream_response.headers)),
|
||||
headers=CIMultiDict(clean_response_headers(upstream_response.headers)),
|
||||
)
|
||||
await downstream.prepare(request)
|
||||
# Track whether the stream terminated cleanly. A mid-stream
|
||||
|
||||
Reference in New Issue
Block a user