Compare commits

...

1 Commits

Author SHA1 Message Date
Swifty
50f39e55f7 added ability to disable cors at the backend 2025-09-11 09:59:47 +02:00
3 changed files with 29 additions and 15 deletions

View File

@@ -252,13 +252,19 @@ async def health():
class AgentServer(backend.util.service.AppProcess): class AgentServer(backend.util.service.AppProcess):
def run(self): def run(self):
server_app = starlette.middleware.cors.CORSMiddleware(
app=app, if settings.config.enable_cors_all_origins:
allow_origins=settings.config.backend_cors_allow_origins, server_app = starlette.middleware.cors.CORSMiddleware(
allow_credentials=True, app=app,
allow_methods=["*"], # Allows all methods allow_origins=settings.config.backend_cors_allow_origins,
allow_headers=["*"], # Allows all headers allow_credentials=True,
) allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
else:
logger.info("CORS is disabled")
server_app = app
uvicorn.run( uvicorn.run(
server_app, server_app,
host=backend.util.settings.Config().agent_api_host, host=backend.util.settings.Config().agent_api_host,

View File

@@ -295,14 +295,17 @@ async def health():
class WebsocketServer(AppProcess): class WebsocketServer(AppProcess):
def run(self): def run(self):
logger.info(f"CORS allow origins: {settings.config.backend_cors_allow_origins}") if settings.config.enable_cors_all_origins:
server_app = CORSMiddleware( server_app = CORSMiddleware(
app=app, app=app,
allow_origins=settings.config.backend_cors_allow_origins, allow_origins=settings.config.backend_cors_allow_origins,
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
else:
logger.info("CORS is disabled")
server_app = app
uvicorn.run( uvicorn.run(
server_app, server_app,

View File

@@ -368,6 +368,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="Maximum message size limit for communication with the message bus", description="Maximum message size limit for communication with the message bus",
) )
enable_cors_all_origins: bool = Field(
default=True,
description="Whether to enable all CORS origins",
)
backend_cors_allow_origins: List[str] = Field(default=["http://localhost:3000"]) backend_cors_allow_origins: List[str] = Field(default=["http://localhost:3000"])
@field_validator("backend_cors_allow_origins") @field_validator("backend_cors_allow_origins")