feat(backend): allow regex on CORS allowed origins (#11336)

## Changes 🏗️

Allow dynamic URLs in the CORS config, to match them via regex. This
helps because currently we have Front-end preview deployments which are
isolated ( _nice they don't pollute or overrride other domains_ ) like:
```
https://autogpt-git-{branch_name}-{commit}-significant-gravitas.vercel.app
```
The Front-end builds and works there, but as soon as you login, any API
requests to endpoints that need auth will fail due to CORS, given our
current CORS config does not support dynamically generated domains.

### Changes

After these changes we can specify dynamic domains to be allowed under
CORS. I also made `localhost` disabled if the API is in production for
safety...

### Before

```yml
cors:
  allowOrigin: "https://dev-builder.agpt.co" # could only specify full URL strings, not dyamic ones
```

### After

```yml
cors:
  allowOrigins:
    - "https://dev-builder.agpt.co"
    - "regex:https://autogpt-git-[a-z0-9-]+\\.vercel\\.app" # dynamic domains supported via regex
```

### Files

- add `build_cors_params` utility to parse literal/regex origins and
block localhost in production (`backend/server/utils/cors.py`)
- apply the helper in both `AgentServer` and `WebsocketServer` so CORS
logic and validations remain consistent
- add reusable `override_config` testing helper and update existing
WebSocket tests to cover the shared CORS behavior
- introduce targeted unit tests for the new CORS helper
(`backend/server/utils/cors_test.py`)

## Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] We will know once we made the origin config changes on infra and
test with this...
This commit is contained in:
Ubbe
2025-11-07 23:28:14 +07:00
committed by GitHub
parent dfed092869
commit e68896a25a
7 changed files with 261 additions and 27 deletions

View File

@@ -43,6 +43,7 @@ from backend.integrations.providers import ProviderName
from backend.monitoring.instrumentation import instrument_fastapi
from backend.server.external.api import external_app
from backend.server.middleware.security import SecurityHeadersMiddleware
from backend.server.utils.cors import build_cors_params
from backend.util import json
from backend.util.cloud_storage import shutdown_cloud_storage_handler
from backend.util.exceptions import (
@@ -303,9 +304,14 @@ async def health():
class AgentServer(backend.util.service.AppProcess):
def run(self):
cors_params = build_cors_params(
settings.config.backend_cors_allow_origins,
settings.config.app_env,
)
server_app = starlette.middleware.cors.CORSMiddleware(
app=app,
allow_origins=settings.config.backend_cors_allow_origins,
**cors_params,
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers

View File

@@ -1,7 +1,8 @@
"""Helper functions for improved test assertions and error handling."""
import json
from typing import Any, Dict, Optional
from contextlib import contextmanager
from typing import Any, Dict, Iterator, Optional
def assert_response_status(
@@ -107,3 +108,24 @@ def assert_mock_called_with_partial(mock_obj: Any, **expected_kwargs: Any) -> No
assert (
actual_kwargs[key] == expected_value
), f"Mock called with {key}={actual_kwargs[key]}, expected {expected_value}"
@contextmanager
def override_config(settings: Any, attribute: str, value: Any) -> Iterator[None]:
"""Temporarily override a config attribute for testing.
Warning: Directly mutates settings.config. If config is reloaded or cached
elsewhere during the test, side effects may leak. Use with caution in
parallel tests or when config is accessed globally.
Args:
settings: The settings object containing .config
attribute: The config attribute name to override
value: The temporary value to set
"""
original = getattr(settings.config, attribute)
setattr(settings.config, attribute, value)
try:
yield
finally:
setattr(settings.config, attribute, original)

View File

@@ -0,0 +1,67 @@
from __future__ import annotations
import re
from typing import List, Sequence, TypedDict
from backend.util.settings import AppEnvironment
class CorsParams(TypedDict):
allow_origins: List[str]
allow_origin_regex: str | None
def build_cors_params(origins: Sequence[str], app_env: AppEnvironment) -> CorsParams:
allow_origins: List[str] = []
regex_patterns: List[str] = []
if app_env == AppEnvironment.PRODUCTION:
for origin in origins:
if origin.startswith("regex:"):
pattern = origin[len("regex:") :]
pattern_lower = pattern.lower()
if "localhost" in pattern_lower or "127.0.0.1" in pattern_lower:
raise ValueError(
f"Production environment cannot allow localhost origins via regex: {pattern}"
)
try:
compiled = re.compile(pattern)
test_urls = [
"http://localhost:3000",
"http://127.0.0.1:3000",
"https://localhost:8000",
"https://127.0.0.1:8000",
]
for test_url in test_urls:
if compiled.search(test_url):
raise ValueError(
f"Production regex pattern matches localhost/127.0.0.1: {pattern}"
)
except re.error:
pass
continue
lowered = origin.lower()
if "localhost" in lowered or "127.0.0.1" in lowered:
raise ValueError(
"Production environment cannot allow localhost origins"
)
for origin in origins:
if origin.startswith("regex:"):
regex_patterns.append(origin[len("regex:") :])
else:
allow_origins.append(origin)
allow_origin_regex = None
if regex_patterns:
if len(regex_patterns) == 1:
allow_origin_regex = f"^(?:{regex_patterns[0]})$"
else:
combined_pattern = "|".join(f"(?:{pattern})" for pattern in regex_patterns)
allow_origin_regex = f"^(?:{combined_pattern})$"
return {
"allow_origins": allow_origins,
"allow_origin_regex": allow_origin_regex,
}

View File

@@ -0,0 +1,62 @@
import pytest
from backend.server.utils.cors import build_cors_params
from backend.util.settings import AppEnvironment
def test_build_cors_params_splits_regex_patterns() -> None:
origins = [
"https://app.example.com",
"regex:https://.*\\.example\\.com",
]
result = build_cors_params(origins, AppEnvironment.LOCAL)
assert result["allow_origins"] == ["https://app.example.com"]
assert result["allow_origin_regex"] == "^(?:https://.*\\.example\\.com)$"
def test_build_cors_params_combines_multiple_regex_patterns() -> None:
origins = [
"regex:https://alpha.example.com",
"regex:https://beta.example.com",
]
result = build_cors_params(origins, AppEnvironment.DEVELOPMENT)
assert result["allow_origins"] == []
assert result["allow_origin_regex"] == (
"^(?:(?:https://alpha.example.com)|(?:https://beta.example.com))$"
)
def test_build_cors_params_blocks_localhost_literal_in_production() -> None:
with pytest.raises(ValueError):
build_cors_params(["http://localhost:3000"], AppEnvironment.PRODUCTION)
def test_build_cors_params_blocks_localhost_regex_in_production() -> None:
with pytest.raises(ValueError):
build_cors_params(["regex:https://.*localhost.*"], AppEnvironment.PRODUCTION)
def test_build_cors_params_blocks_case_insensitive_localhost_regex() -> None:
with pytest.raises(ValueError, match="localhost origins via regex"):
build_cors_params(["regex:https://(?i)LOCALHOST.*"], AppEnvironment.PRODUCTION)
def test_build_cors_params_blocks_regex_matching_localhost_at_runtime() -> None:
with pytest.raises(ValueError, match="matches localhost"):
build_cors_params(["regex:https?://.*:3000"], AppEnvironment.PRODUCTION)
def test_build_cors_params_allows_vercel_preview_regex() -> None:
result = build_cors_params(
["regex:https://autogpt-git-[a-z0-9-]+\\.vercel\\.app"],
AppEnvironment.PRODUCTION,
)
assert result["allow_origins"] == []
assert result["allow_origin_regex"] == (
"^(?:https://autogpt-git-[a-z0-9-]+\\.vercel\\.app)$"
)

View File

@@ -22,6 +22,7 @@ from backend.server.model import (
WSSubscribeGraphExecutionRequest,
WSSubscribeGraphExecutionsRequest,
)
from backend.server.utils.cors import build_cors_params
from backend.util.retry import continuous_retry
from backend.util.service import AppProcess
from backend.util.settings import AppEnvironment, Config, Settings
@@ -315,9 +316,13 @@ async def health():
class WebsocketServer(AppProcess):
def run(self):
logger.info(f"CORS allow origins: {settings.config.backend_cors_allow_origins}")
cors_params = build_cors_params(
settings.config.backend_cors_allow_origins,
settings.config.app_env,
)
server_app = CORSMiddleware(
app=app,
allow_origins=settings.config.backend_cors_allow_origins,
**cors_params,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],

View File

@@ -8,11 +8,13 @@ from pytest_snapshot.plugin import Snapshot
from backend.data.user import DEFAULT_USER_ID
from backend.server.conn_manager import ConnectionManager
from backend.server.test_helpers import override_config
from backend.server.ws_api import AppEnvironment, WebsocketServer, WSMessage, WSMethod
from backend.server.ws_api import app as websocket_app
from backend.server.ws_api import (
WSMessage,
WSMethod,
handle_subscribe,
handle_unsubscribe,
settings,
websocket_router,
)
@@ -29,6 +31,47 @@ def mock_manager() -> AsyncMock:
return AsyncMock(spec=ConnectionManager)
def test_websocket_server_uses_cors_helper(mocker) -> None:
cors_params = {
"allow_origins": ["https://app.example.com"],
"allow_origin_regex": None,
}
mocker.patch("backend.server.ws_api.uvicorn.run")
cors_middleware = mocker.patch(
"backend.server.ws_api.CORSMiddleware", return_value=object()
)
build_cors = mocker.patch(
"backend.server.ws_api.build_cors_params", return_value=cors_params
)
with override_config(
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
), override_config(settings, "app_env", AppEnvironment.LOCAL):
WebsocketServer().run()
build_cors.assert_called_once_with(
cors_params["allow_origins"], AppEnvironment.LOCAL
)
cors_middleware.assert_called_once_with(
app=websocket_app,
allow_origins=cors_params["allow_origins"],
allow_origin_regex=cors_params["allow_origin_regex"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def test_websocket_server_blocks_localhost_in_production(mocker) -> None:
mocker.patch("backend.server.ws_api.uvicorn.run")
with override_config(
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
), override_config(settings, "app_env", AppEnvironment.PRODUCTION):
with pytest.raises(ValueError):
WebsocketServer().run()
@pytest.mark.asyncio
async def test_websocket_router_subscribe(
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot, mocker

View File

@@ -1,5 +1,6 @@
import json
import os
import re
from enum import Enum
from typing import Any, Dict, Generic, List, Set, Tuple, Type, TypeVar
@@ -427,34 +428,62 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="Maximum message size limit for communication with the message bus",
)
backend_cors_allow_origins: List[str] = Field(default=["http://localhost:3000"])
backend_cors_allow_origins: List[str] = Field(
default=["http://localhost:3000"],
description="Allowed Origins for CORS. Supports exact URLs (http/https) or entries prefixed with "
'"regex:" to match via regular expression.',
)
@field_validator("backend_cors_allow_origins")
@classmethod
def validate_cors_allow_origins(cls, v: List[str]) -> List[str]:
out = []
port = None
has_localhost = False
has_127_0_0_1 = False
for url in v:
url = url.strip()
if url.startswith(("http://", "https://")):
if "localhost" in url:
port = url.split(":")[2]
has_localhost = True
if "127.0.0.1" in url:
port = url.split(":")[2]
has_127_0_0_1 = True
out.append(url)
else:
raise ValueError(f"Invalid URL: {url}")
validated: List[str] = []
localhost_ports: set[str] = set()
ip127_ports: set[str] = set()
if has_127_0_0_1 and not has_localhost:
out.append(f"http://localhost:{port}")
if has_localhost and not has_127_0_0_1:
out.append(f"http://127.0.0.1:{port}")
for raw_origin in v:
origin = raw_origin.strip()
if origin.startswith("regex:"):
pattern = origin[len("regex:") :]
if not pattern:
raise ValueError("Invalid regex pattern: pattern cannot be empty")
try:
re.compile(pattern)
except re.error as exc:
raise ValueError(
f"Invalid regex pattern '{pattern}': {exc}"
) from exc
validated.append(origin)
continue
return out
if origin.startswith(("http://", "https://")):
if "localhost" in origin:
try:
port = origin.split(":")[2]
localhost_ports.add(port)
except IndexError as exc:
raise ValueError(
"localhost origins must include an explicit port, e.g. http://localhost:3000"
) from exc
if "127.0.0.1" in origin:
try:
port = origin.split(":")[2]
ip127_ports.add(port)
except IndexError as exc:
raise ValueError(
"127.0.0.1 origins must include an explicit port, e.g. http://127.0.0.1:3000"
) from exc
validated.append(origin)
continue
raise ValueError(f"Invalid URL or regex origin: {origin}")
for port in ip127_ports - localhost_ports:
validated.append(f"http://localhost:{port}")
for port in localhost_ports - ip127_ports:
validated.append(f"http://127.0.0.1:{port}")
return validated
@classmethod
def settings_customise_sources(