mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(backend): allow regex on CORS allowed origins (#11336)
## Changes 🏗️ Allow dynamic URLs in the CORS config, to match them via regex. This helps because currently we have Front-end preview deployments which are isolated ( _nice they don't pollute or overrride other domains_ ) like: ``` https://autogpt-git-{branch_name}-{commit}-significant-gravitas.vercel.app ``` The Front-end builds and works there, but as soon as you login, any API requests to endpoints that need auth will fail due to CORS, given our current CORS config does not support dynamically generated domains. ### Changes After these changes we can specify dynamic domains to be allowed under CORS. I also made `localhost` disabled if the API is in production for safety... ### Before ```yml cors: allowOrigin: "https://dev-builder.agpt.co" # could only specify full URL strings, not dyamic ones ``` ### After ```yml cors: allowOrigins: - "https://dev-builder.agpt.co" - "regex:https://autogpt-git-[a-z0-9-]+\\.vercel\\.app" # dynamic domains supported via regex ``` ### Files - add `build_cors_params` utility to parse literal/regex origins and block localhost in production (`backend/server/utils/cors.py`) - apply the helper in both `AgentServer` and `WebsocketServer` so CORS logic and validations remain consistent - add reusable `override_config` testing helper and update existing WebSocket tests to cover the shared CORS behavior - introduce targeted unit tests for the new CORS helper (`backend/server/utils/cors_test.py`) ## Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] We will know once we made the origin config changes on infra and test with this...
This commit is contained in:
@@ -43,6 +43,7 @@ from backend.integrations.providers import ProviderName
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
from backend.server.external.api import external_app
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.server.utils.cors import build_cors_params
|
||||
from backend.util import json
|
||||
from backend.util.cloud_storage import shutdown_cloud_storage_handler
|
||||
from backend.util.exceptions import (
|
||||
@@ -303,9 +304,14 @@ async def health():
|
||||
|
||||
class AgentServer(backend.util.service.AppProcess):
|
||||
def run(self):
|
||||
cors_params = build_cors_params(
|
||||
settings.config.backend_cors_allow_origins,
|
||||
settings.config.app_env,
|
||||
)
|
||||
|
||||
server_app = starlette.middleware.cors.CORSMiddleware(
|
||||
app=app,
|
||||
allow_origins=settings.config.backend_cors_allow_origins,
|
||||
**cors_params,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Helper functions for improved test assertions and error handling."""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Iterator, Optional
|
||||
|
||||
|
||||
def assert_response_status(
|
||||
@@ -107,3 +108,24 @@ def assert_mock_called_with_partial(mock_obj: Any, **expected_kwargs: Any) -> No
|
||||
assert (
|
||||
actual_kwargs[key] == expected_value
|
||||
), f"Mock called with {key}={actual_kwargs[key]}, expected {expected_value}"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def override_config(settings: Any, attribute: str, value: Any) -> Iterator[None]:
|
||||
"""Temporarily override a config attribute for testing.
|
||||
|
||||
Warning: Directly mutates settings.config. If config is reloaded or cached
|
||||
elsewhere during the test, side effects may leak. Use with caution in
|
||||
parallel tests or when config is accessed globally.
|
||||
|
||||
Args:
|
||||
settings: The settings object containing .config
|
||||
attribute: The config attribute name to override
|
||||
value: The temporary value to set
|
||||
"""
|
||||
original = getattr(settings.config, attribute)
|
||||
setattr(settings.config, attribute, value)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
setattr(settings.config, attribute, original)
|
||||
|
||||
67
autogpt_platform/backend/backend/server/utils/cors.py
Normal file
67
autogpt_platform/backend/backend/server/utils/cors.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import List, Sequence, TypedDict
|
||||
|
||||
from backend.util.settings import AppEnvironment
|
||||
|
||||
|
||||
class CorsParams(TypedDict):
|
||||
allow_origins: List[str]
|
||||
allow_origin_regex: str | None
|
||||
|
||||
|
||||
def build_cors_params(origins: Sequence[str], app_env: AppEnvironment) -> CorsParams:
|
||||
allow_origins: List[str] = []
|
||||
regex_patterns: List[str] = []
|
||||
|
||||
if app_env == AppEnvironment.PRODUCTION:
|
||||
for origin in origins:
|
||||
if origin.startswith("regex:"):
|
||||
pattern = origin[len("regex:") :]
|
||||
pattern_lower = pattern.lower()
|
||||
if "localhost" in pattern_lower or "127.0.0.1" in pattern_lower:
|
||||
raise ValueError(
|
||||
f"Production environment cannot allow localhost origins via regex: {pattern}"
|
||||
)
|
||||
try:
|
||||
compiled = re.compile(pattern)
|
||||
test_urls = [
|
||||
"http://localhost:3000",
|
||||
"http://127.0.0.1:3000",
|
||||
"https://localhost:8000",
|
||||
"https://127.0.0.1:8000",
|
||||
]
|
||||
for test_url in test_urls:
|
||||
if compiled.search(test_url):
|
||||
raise ValueError(
|
||||
f"Production regex pattern matches localhost/127.0.0.1: {pattern}"
|
||||
)
|
||||
except re.error:
|
||||
pass
|
||||
continue
|
||||
|
||||
lowered = origin.lower()
|
||||
if "localhost" in lowered or "127.0.0.1" in lowered:
|
||||
raise ValueError(
|
||||
"Production environment cannot allow localhost origins"
|
||||
)
|
||||
|
||||
for origin in origins:
|
||||
if origin.startswith("regex:"):
|
||||
regex_patterns.append(origin[len("regex:") :])
|
||||
else:
|
||||
allow_origins.append(origin)
|
||||
|
||||
allow_origin_regex = None
|
||||
if regex_patterns:
|
||||
if len(regex_patterns) == 1:
|
||||
allow_origin_regex = f"^(?:{regex_patterns[0]})$"
|
||||
else:
|
||||
combined_pattern = "|".join(f"(?:{pattern})" for pattern in regex_patterns)
|
||||
allow_origin_regex = f"^(?:{combined_pattern})$"
|
||||
|
||||
return {
|
||||
"allow_origins": allow_origins,
|
||||
"allow_origin_regex": allow_origin_regex,
|
||||
}
|
||||
62
autogpt_platform/backend/backend/server/utils/cors_test.py
Normal file
62
autogpt_platform/backend/backend/server/utils/cors_test.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
|
||||
from backend.server.utils.cors import build_cors_params
|
||||
from backend.util.settings import AppEnvironment
|
||||
|
||||
|
||||
def test_build_cors_params_splits_regex_patterns() -> None:
|
||||
origins = [
|
||||
"https://app.example.com",
|
||||
"regex:https://.*\\.example\\.com",
|
||||
]
|
||||
|
||||
result = build_cors_params(origins, AppEnvironment.LOCAL)
|
||||
|
||||
assert result["allow_origins"] == ["https://app.example.com"]
|
||||
assert result["allow_origin_regex"] == "^(?:https://.*\\.example\\.com)$"
|
||||
|
||||
|
||||
def test_build_cors_params_combines_multiple_regex_patterns() -> None:
|
||||
origins = [
|
||||
"regex:https://alpha.example.com",
|
||||
"regex:https://beta.example.com",
|
||||
]
|
||||
|
||||
result = build_cors_params(origins, AppEnvironment.DEVELOPMENT)
|
||||
|
||||
assert result["allow_origins"] == []
|
||||
assert result["allow_origin_regex"] == (
|
||||
"^(?:(?:https://alpha.example.com)|(?:https://beta.example.com))$"
|
||||
)
|
||||
|
||||
|
||||
def test_build_cors_params_blocks_localhost_literal_in_production() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
build_cors_params(["http://localhost:3000"], AppEnvironment.PRODUCTION)
|
||||
|
||||
|
||||
def test_build_cors_params_blocks_localhost_regex_in_production() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
build_cors_params(["regex:https://.*localhost.*"], AppEnvironment.PRODUCTION)
|
||||
|
||||
|
||||
def test_build_cors_params_blocks_case_insensitive_localhost_regex() -> None:
|
||||
with pytest.raises(ValueError, match="localhost origins via regex"):
|
||||
build_cors_params(["regex:https://(?i)LOCALHOST.*"], AppEnvironment.PRODUCTION)
|
||||
|
||||
|
||||
def test_build_cors_params_blocks_regex_matching_localhost_at_runtime() -> None:
|
||||
with pytest.raises(ValueError, match="matches localhost"):
|
||||
build_cors_params(["regex:https?://.*:3000"], AppEnvironment.PRODUCTION)
|
||||
|
||||
|
||||
def test_build_cors_params_allows_vercel_preview_regex() -> None:
|
||||
result = build_cors_params(
|
||||
["regex:https://autogpt-git-[a-z0-9-]+\\.vercel\\.app"],
|
||||
AppEnvironment.PRODUCTION,
|
||||
)
|
||||
|
||||
assert result["allow_origins"] == []
|
||||
assert result["allow_origin_regex"] == (
|
||||
"^(?:https://autogpt-git-[a-z0-9-]+\\.vercel\\.app)$"
|
||||
)
|
||||
@@ -22,6 +22,7 @@ from backend.server.model import (
|
||||
WSSubscribeGraphExecutionRequest,
|
||||
WSSubscribeGraphExecutionsRequest,
|
||||
)
|
||||
from backend.server.utils.cors import build_cors_params
|
||||
from backend.util.retry import continuous_retry
|
||||
from backend.util.service import AppProcess
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
@@ -315,9 +316,13 @@ async def health():
|
||||
class WebsocketServer(AppProcess):
|
||||
def run(self):
|
||||
logger.info(f"CORS allow origins: {settings.config.backend_cors_allow_origins}")
|
||||
cors_params = build_cors_params(
|
||||
settings.config.backend_cors_allow_origins,
|
||||
settings.config.app_env,
|
||||
)
|
||||
server_app = CORSMiddleware(
|
||||
app=app,
|
||||
allow_origins=settings.config.backend_cors_allow_origins,
|
||||
**cors_params,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
|
||||
@@ -8,11 +8,13 @@ from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.test_helpers import override_config
|
||||
from backend.server.ws_api import AppEnvironment, WebsocketServer, WSMessage, WSMethod
|
||||
from backend.server.ws_api import app as websocket_app
|
||||
from backend.server.ws_api import (
|
||||
WSMessage,
|
||||
WSMethod,
|
||||
handle_subscribe,
|
||||
handle_unsubscribe,
|
||||
settings,
|
||||
websocket_router,
|
||||
)
|
||||
|
||||
@@ -29,6 +31,47 @@ def mock_manager() -> AsyncMock:
|
||||
return AsyncMock(spec=ConnectionManager)
|
||||
|
||||
|
||||
def test_websocket_server_uses_cors_helper(mocker) -> None:
|
||||
cors_params = {
|
||||
"allow_origins": ["https://app.example.com"],
|
||||
"allow_origin_regex": None,
|
||||
}
|
||||
mocker.patch("backend.server.ws_api.uvicorn.run")
|
||||
cors_middleware = mocker.patch(
|
||||
"backend.server.ws_api.CORSMiddleware", return_value=object()
|
||||
)
|
||||
build_cors = mocker.patch(
|
||||
"backend.server.ws_api.build_cors_params", return_value=cors_params
|
||||
)
|
||||
|
||||
with override_config(
|
||||
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
|
||||
), override_config(settings, "app_env", AppEnvironment.LOCAL):
|
||||
WebsocketServer().run()
|
||||
|
||||
build_cors.assert_called_once_with(
|
||||
cors_params["allow_origins"], AppEnvironment.LOCAL
|
||||
)
|
||||
cors_middleware.assert_called_once_with(
|
||||
app=websocket_app,
|
||||
allow_origins=cors_params["allow_origins"],
|
||||
allow_origin_regex=cors_params["allow_origin_regex"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
def test_websocket_server_blocks_localhost_in_production(mocker) -> None:
|
||||
mocker.patch("backend.server.ws_api.uvicorn.run")
|
||||
|
||||
with override_config(
|
||||
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
|
||||
), override_config(settings, "app_env", AppEnvironment.PRODUCTION):
|
||||
with pytest.raises(ValueError):
|
||||
WebsocketServer().run()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_router_subscribe(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot, mocker
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Generic, List, Set, Tuple, Type, TypeVar
|
||||
|
||||
@@ -427,34 +428,62 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="Maximum message size limit for communication with the message bus",
|
||||
)
|
||||
|
||||
backend_cors_allow_origins: List[str] = Field(default=["http://localhost:3000"])
|
||||
backend_cors_allow_origins: List[str] = Field(
|
||||
default=["http://localhost:3000"],
|
||||
description="Allowed Origins for CORS. Supports exact URLs (http/https) or entries prefixed with "
|
||||
'"regex:" to match via regular expression.',
|
||||
)
|
||||
|
||||
@field_validator("backend_cors_allow_origins")
|
||||
@classmethod
|
||||
def validate_cors_allow_origins(cls, v: List[str]) -> List[str]:
|
||||
out = []
|
||||
port = None
|
||||
has_localhost = False
|
||||
has_127_0_0_1 = False
|
||||
for url in v:
|
||||
url = url.strip()
|
||||
if url.startswith(("http://", "https://")):
|
||||
if "localhost" in url:
|
||||
port = url.split(":")[2]
|
||||
has_localhost = True
|
||||
if "127.0.0.1" in url:
|
||||
port = url.split(":")[2]
|
||||
has_127_0_0_1 = True
|
||||
out.append(url)
|
||||
else:
|
||||
raise ValueError(f"Invalid URL: {url}")
|
||||
validated: List[str] = []
|
||||
localhost_ports: set[str] = set()
|
||||
ip127_ports: set[str] = set()
|
||||
|
||||
if has_127_0_0_1 and not has_localhost:
|
||||
out.append(f"http://localhost:{port}")
|
||||
if has_localhost and not has_127_0_0_1:
|
||||
out.append(f"http://127.0.0.1:{port}")
|
||||
for raw_origin in v:
|
||||
origin = raw_origin.strip()
|
||||
if origin.startswith("regex:"):
|
||||
pattern = origin[len("regex:") :]
|
||||
if not pattern:
|
||||
raise ValueError("Invalid regex pattern: pattern cannot be empty")
|
||||
try:
|
||||
re.compile(pattern)
|
||||
except re.error as exc:
|
||||
raise ValueError(
|
||||
f"Invalid regex pattern '{pattern}': {exc}"
|
||||
) from exc
|
||||
validated.append(origin)
|
||||
continue
|
||||
|
||||
return out
|
||||
if origin.startswith(("http://", "https://")):
|
||||
if "localhost" in origin:
|
||||
try:
|
||||
port = origin.split(":")[2]
|
||||
localhost_ports.add(port)
|
||||
except IndexError as exc:
|
||||
raise ValueError(
|
||||
"localhost origins must include an explicit port, e.g. http://localhost:3000"
|
||||
) from exc
|
||||
if "127.0.0.1" in origin:
|
||||
try:
|
||||
port = origin.split(":")[2]
|
||||
ip127_ports.add(port)
|
||||
except IndexError as exc:
|
||||
raise ValueError(
|
||||
"127.0.0.1 origins must include an explicit port, e.g. http://127.0.0.1:3000"
|
||||
) from exc
|
||||
validated.append(origin)
|
||||
continue
|
||||
|
||||
raise ValueError(f"Invalid URL or regex origin: {origin}")
|
||||
|
||||
for port in ip127_ports - localhost_ports:
|
||||
validated.append(f"http://localhost:{port}")
|
||||
for port in localhost_ports - ip127_ports:
|
||||
validated.append(f"http://127.0.0.1:{port}")
|
||||
|
||||
return validated
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
|
||||
Reference in New Issue
Block a user