diff --git a/autogpt_platform/backend/backend/copilot/sdk/openrouter_compat_proxy.py b/autogpt_platform/backend/backend/copilot/sdk/openrouter_compat_proxy.py index 5001c575a5..d940a83f73 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/openrouter_compat_proxy.py +++ b/autogpt_platform/backend/backend/copilot/sdk/openrouter_compat_proxy.py @@ -213,12 +213,23 @@ def clean_request_body_bytes(body_bytes: bytes) -> bytes: return json.dumps(payload, separators=(",", ":")).encode("utf-8") +def _parse_connection_tokens(headers: dict[str, str]) -> set[str]: + """Extract hop-by-hop header names from the ``Connection`` field.""" + connection_header = next( + (value for name, value in headers.items() if name.lower() == "connection"), + "", + ) + return { + token.strip().lower() for token in connection_header.split(",") if token.strip() + } + + def clean_request_headers(headers: dict[str, str]) -> dict[str, str]: """Drop hop-by-hop headers and rewrite ``anthropic-beta`` to remove forbidden tokens. Returns a fresh dict the caller can pass through to the upstream client without further mutation. - Per RFC 7230 §6.1, intermediaries must drop the static hop-by-hop + Per RFC 7230 section 6.1, intermediaries must drop the static hop-by-hop set above **and** every header name listed in the incoming ``Connection`` field value (case-insensitive). The latter is how extension hop-by-hop headers are signalled per-connection. @@ -226,17 +237,7 @@ def clean_request_headers(headers: dict[str, str]) -> dict[str, str]: Callers should pass an already-materialised ``dict`` (e.g. ``dict(request.headers)``) so this function stays simple. """ - # Parse ``Connection: a, b, c`` into a lowercase token set so we - # can drop any header the sender explicitly marked as hop-by-hop - # on this connection. This is separate from the static set - # above — extension headers can be anything. - connection_header = next( - (value for name, value in headers.items() if name.lower() == "connection"), - "", - ) - connection_tokens: set[str] = { - token.strip().lower() for token in connection_header.split(",") if token.strip() - } + connection_tokens = _parse_connection_tokens(headers) cleaned: dict[str, str] = {} for name, value in headers.items(): @@ -253,6 +254,40 @@ def clean_request_headers(headers: dict[str, str]) -> dict[str, str]: return cleaned +def clean_response_headers( + headers: "Any", +) -> list[tuple[str, str]]: + """Like :func:`clean_request_headers` but preserves multi-valued + headers (e.g. ``Set-Cookie``). Accepts any mapping-like object + whose ``.items()`` yields ``(name, value)`` pairs — including + aiohttp's ``CIMultiDictProxy`` which can have duplicate keys. + + Returns a list of ``(name, value)`` tuples suitable for passing + to ``web.StreamResponse(headers=...)`` via ``CIMultiDict``. + """ + connection_tokens: set[str] = set() + for name, value in headers.items(): + if name.lower() == "connection": + connection_tokens = { + t.strip().lower() for t in value.split(",") if t.strip() + } + break + + cleaned: list[tuple[str, str]] = [] + for name, value in headers.items(): + lower_name = name.lower() + if lower_name in _HOP_BY_HOP_HEADERS or lower_name in connection_tokens: + continue + if lower_name == "anthropic-beta": + stripped = strip_forbidden_anthropic_beta_header(value) + if stripped is None: + continue + cleaned.append((name, stripped)) + continue + cleaned.append((name, value)) + return cleaned + + # --------------------------------------------------------------------------- # The proxy server # --------------------------------------------------------------------------- @@ -312,8 +347,16 @@ class OpenRouterCompatProxy: """ if self._runner is not None: return # already started + # Use sock_connect + sock_read instead of total so long-lived + # SSE / streaming responses aren't killed after request_timeout. + # total=None means no cumulative limit; sock_read is the per-chunk + # idle timeout (time between data arriving on the socket). client = aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=self._request_timeout) + timeout=aiohttp.ClientTimeout( + total=None, + sock_connect=self._request_timeout, + sock_read=self._request_timeout, + ) ) app = web.Application() # Catch every method + path so we can also forward GETs @@ -358,7 +401,11 @@ class OpenRouterCompatProxy: # endpoint is anyway discoverable from the config the operator # already has access to. The detailed upstream is exposed via # the ``target_base_url`` property for callers that need it. - logger.info("OpenRouter compat proxy listening on 127.0.0.1:%d", self._port) + logger.info( + "OpenRouter compat proxy listening on %s:%d", + self._bind_host, + self._port, + ) async def stop(self) -> None: """Stop accepting connections and release the port.""" @@ -433,10 +480,13 @@ class OpenRouterCompatProxy: return web.Response(status=502, text="upstream error") # Stream the response back unchanged (apart from hop-by-hop - # header filtering). + # header filtering). Use clean_response_headers to preserve + # multi-valued headers like Set-Cookie that dict() would drop. + from multidict import CIMultiDict + downstream = web.StreamResponse( status=upstream_response.status, - headers=clean_request_headers(dict(upstream_response.headers)), + headers=CIMultiDict(clean_response_headers(upstream_response.headers)), ) await downstream.prepare(request) # Track whether the stream terminated cleanly. A mid-stream