diff --git a/openhands/server/listen.py b/openhands/server/listen.py index c3a6385345..fc740e8029 100644 --- a/openhands/server/listen.py +++ b/openhands/server/listen.py @@ -29,12 +29,10 @@ from fastapi import ( WebSocket, status, ) -from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from fastapi.security import HTTPBearer from fastapi.staticfiles import StaticFiles from pydantic import BaseModel -from starlette.middleware.base import BaseHTTPMiddleware import openhands.agenthub # noqa F401 (we import this to get the agents registered) from openhands.controller.agent import Agent @@ -57,6 +55,7 @@ from openhands.events.serialization import event_to_dict from openhands.llm import bedrock from openhands.runtime.base import Runtime from openhands.server.auth import get_sid_from_token, sign_token +from openhands.server.middleware import LocalhostCORSMiddleware, NoCacheMiddleware from openhands.server.session import SessionManager load_dotenv() @@ -93,30 +92,13 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) app.add_middleware( - CORSMiddleware, - allow_origins=['http://localhost:3001', 'http://127.0.0.1:3001'], + LocalhostCORSMiddleware, allow_credentials=True, allow_methods=['*'], allow_headers=['*'], ) -class NoCacheMiddleware(BaseHTTPMiddleware): - """ - Middleware to disable caching for all routes by adding appropriate headers - """ - - async def dispatch(self, request, call_next): - response = await call_next(request) - if not request.url.path.startswith('/assets'): - response.headers['Cache-Control'] = ( - 'no-cache, no-store, must-revalidate, max-age=0' - ) - response.headers['Pragma'] = 'no-cache' - response.headers['Expires'] = '0' - return response - - app.add_middleware(NoCacheMiddleware) security_scheme = HTTPBearer() diff --git a/openhands/server/middleware.py b/openhands/server/middleware.py new file mode 100644 index 0000000000..f09ac0788a --- /dev/null +++ b/openhands/server/middleware.py @@ -0,0 +1,43 @@ +from urllib.parse import urlparse + +from fastapi.middleware.cors import CORSMiddleware +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.types import ASGIApp + + +class LocalhostCORSMiddleware(CORSMiddleware): + """ + Custom CORS middleware that allows any request from localhost/127.0.0.1 domains, + while using standard CORS rules for other origins. + """ + + def __init__(self, app: ASGIApp, **kwargs) -> None: + super().__init__(app, **kwargs) + + async def is_allowed_origin(self, origin: str) -> bool: + if origin: + parsed = urlparse(origin) + hostname = parsed.hostname or '' + + # Allow any localhost/127.0.0.1 origin regardless of port + if hostname in ['localhost', '127.0.0.1']: + return True + + # For missing origin or other origins, use the parent class's logic + return await super().is_allowed_origin(origin) + + +class NoCacheMiddleware(BaseHTTPMiddleware): + """ + Middleware to disable caching for all routes by adding appropriate headers + """ + + async def dispatch(self, request, call_next): + response = await call_next(request) + if not request.url.path.startswith('/assets'): + response.headers['Cache-Control'] = ( + 'no-cache, no-store, must-revalidate, max-age=0' + ) + response.headers['Pragma'] = 'no-cache' + response.headers['Expires'] = '0' + return response