mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 23:08:04 -05:00
Refactor CORS middleware and enhance localhost handling (#4624)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -29,12 +29,10 @@ from fastapi import (
|
||||
WebSocket,
|
||||
status,
|
||||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.security import HTTPBearer
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
import openhands.agenthub # noqa F401 (we import this to get the agents registered)
|
||||
from openhands.controller.agent import Agent
|
||||
@@ -57,6 +55,7 @@ from openhands.events.serialization import event_to_dict
|
||||
from openhands.llm import bedrock
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.server.auth import get_sid_from_token, sign_token
|
||||
from openhands.server.middleware import LocalhostCORSMiddleware, NoCacheMiddleware
|
||||
from openhands.server.session import SessionManager
|
||||
|
||||
load_dotenv()
|
||||
@@ -93,30 +92,13 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=['http://localhost:3001', 'http://127.0.0.1:3001'],
|
||||
LocalhostCORSMiddleware,
|
||||
allow_credentials=True,
|
||||
allow_methods=['*'],
|
||||
allow_headers=['*'],
|
||||
)
|
||||
|
||||
|
||||
class NoCacheMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to disable caching for all routes by adding appropriate headers
|
||||
"""
|
||||
|
||||
async def dispatch(self, request, call_next):
|
||||
response = await call_next(request)
|
||||
if not request.url.path.startswith('/assets'):
|
||||
response.headers['Cache-Control'] = (
|
||||
'no-cache, no-store, must-revalidate, max-age=0'
|
||||
)
|
||||
response.headers['Pragma'] = 'no-cache'
|
||||
response.headers['Expires'] = '0'
|
||||
return response
|
||||
|
||||
|
||||
app.add_middleware(NoCacheMiddleware)
|
||||
|
||||
security_scheme = HTTPBearer()
|
||||
|
||||
43
openhands/server/middleware.py
Normal file
43
openhands/server/middleware.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
|
||||
class LocalhostCORSMiddleware(CORSMiddleware):
|
||||
"""
|
||||
Custom CORS middleware that allows any request from localhost/127.0.0.1 domains,
|
||||
while using standard CORS rules for other origins.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp, **kwargs) -> None:
|
||||
super().__init__(app, **kwargs)
|
||||
|
||||
async def is_allowed_origin(self, origin: str) -> bool:
|
||||
if origin:
|
||||
parsed = urlparse(origin)
|
||||
hostname = parsed.hostname or ''
|
||||
|
||||
# Allow any localhost/127.0.0.1 origin regardless of port
|
||||
if hostname in ['localhost', '127.0.0.1']:
|
||||
return True
|
||||
|
||||
# For missing origin or other origins, use the parent class's logic
|
||||
return await super().is_allowed_origin(origin)
|
||||
|
||||
|
||||
class NoCacheMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to disable caching for all routes by adding appropriate headers
|
||||
"""
|
||||
|
||||
async def dispatch(self, request, call_next):
|
||||
response = await call_next(request)
|
||||
if not request.url.path.startswith('/assets'):
|
||||
response.headers['Cache-Control'] = (
|
||||
'no-cache, no-store, must-revalidate, max-age=0'
|
||||
)
|
||||
response.headers['Pragma'] = 'no-cache'
|
||||
response.headers['Expires'] = '0'
|
||||
return response
|
||||
Reference in New Issue
Block a user