Compare commits

...

1 Commits

Author SHA1 Message Date
Swifty
50f39e55f7 added ability to disable cors at the backend 2025-09-11 09:59:47 +02:00
3 changed files with 29 additions and 15 deletions

View File

@@ -252,6 +252,8 @@ async def health():
class AgentServer(backend.util.service.AppProcess): class AgentServer(backend.util.service.AppProcess):
def run(self): def run(self):
if settings.config.enable_cors_all_origins:
server_app = starlette.middleware.cors.CORSMiddleware( server_app = starlette.middleware.cors.CORSMiddleware(
app=app, app=app,
allow_origins=settings.config.backend_cors_allow_origins, allow_origins=settings.config.backend_cors_allow_origins,
@@ -259,6 +261,10 @@ class AgentServer(backend.util.service.AppProcess):
allow_methods=["*"], # Allows all methods allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers allow_headers=["*"], # Allows all headers
) )
else:
logger.info("CORS is disabled")
server_app = app
uvicorn.run( uvicorn.run(
server_app, server_app,
host=backend.util.settings.Config().agent_api_host, host=backend.util.settings.Config().agent_api_host,

View File

@@ -295,7 +295,7 @@ async def health():
class WebsocketServer(AppProcess): class WebsocketServer(AppProcess):
def run(self): def run(self):
logger.info(f"CORS allow origins: {settings.config.backend_cors_allow_origins}") if settings.config.enable_cors_all_origins:
server_app = CORSMiddleware( server_app = CORSMiddleware(
app=app, app=app,
allow_origins=settings.config.backend_cors_allow_origins, allow_origins=settings.config.backend_cors_allow_origins,
@@ -303,6 +303,9 @@ class WebsocketServer(AppProcess):
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
else:
logger.info("CORS is disabled")
server_app = app
uvicorn.run( uvicorn.run(
server_app, server_app,

View File

@@ -368,6 +368,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="Maximum message size limit for communication with the message bus", description="Maximum message size limit for communication with the message bus",
) )
enable_cors_all_origins: bool = Field(
default=True,
description="Whether to enable all CORS origins",
)
backend_cors_allow_origins: List[str] = Field(default=["http://localhost:3000"]) backend_cors_allow_origins: List[str] = Field(default=["http://localhost:3000"])
@field_validator("backend_cors_allow_origins") @field_validator("backend_cors_allow_origins")