mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
Refactor: move middleware definition (#6552)
This commit is contained in:
@@ -21,7 +21,7 @@ from openhands.server.routes.public import app as public_api_router
|
||||
from openhands.server.routes.security import app as security_api_router
|
||||
from openhands.server.routes.settings import app as settings_router
|
||||
from openhands.server.routes.trajectory import app as trajectory_router
|
||||
from openhands.server.shared import conversation_manager, openhands_config
|
||||
from openhands.server.shared import conversation_manager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -36,7 +36,6 @@ app = FastAPI(
|
||||
version=__version__,
|
||||
lifespan=_lifespan,
|
||||
)
|
||||
openhands_config.attach_middleware(app)
|
||||
|
||||
|
||||
@app.get('/health')
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
import os
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.middleware import (
|
||||
AttachConversationMiddleware,
|
||||
CacheControlMiddleware,
|
||||
GitHubTokenMiddleware,
|
||||
InMemoryRateLimiter,
|
||||
LocalhostCORSMiddleware,
|
||||
RateLimitMiddleware,
|
||||
)
|
||||
from openhands.server.types import AppMode, OpenhandsConfigInterface
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
class OpenhandsConfig(OpenhandsConfigInterface):
|
||||
config_cls = os.environ.get('OPENHANDS_CONFIG_CLS', None)
|
||||
app_mode = AppMode.OSS
|
||||
posthog_client_key = 'phc_3ESMmY9SgqEAGBB6sMGK5ayYHkeUuknH2vP6FmWH9RA'
|
||||
github_client_id = os.environ.get('GITHUB_APP_CLIENT_ID', '')
|
||||
settings_store_class: str = (
|
||||
'openhands.storage.settings.file_settings_store.FileSettingsStore'
|
||||
)
|
||||
conversation_store_class: str = (
|
||||
'openhands.storage.conversation.file_conversation_store.FileConversationStore'
|
||||
)
|
||||
conversation_manager_class: str = 'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager'
|
||||
|
||||
def verify_config(self):
|
||||
if self.config_cls:
|
||||
raise ValueError('Unexpected config path provided')
|
||||
|
||||
def verify_github_repo_list(self, installation_id: int | None):
|
||||
if self.app_mode == AppMode.OSS and installation_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail='Unexpected installation ID',
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'APP_MODE': self.app_mode,
|
||||
'GITHUB_CLIENT_ID': self.github_client_id,
|
||||
'POSTHOG_CLIENT_KEY': self.posthog_client_key,
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
def attach_middleware(self, api: FastAPI) -> None:
|
||||
SettingsStoreImpl = get_impl(SettingsStore, self.settings_store_class) # type: ignore
|
||||
|
||||
api.add_middleware(
|
||||
LocalhostCORSMiddleware,
|
||||
allow_credentials=True,
|
||||
allow_methods=['*'],
|
||||
allow_headers=['*'],
|
||||
)
|
||||
|
||||
api.add_middleware(CacheControlMiddleware)
|
||||
api.add_middleware(
|
||||
RateLimitMiddleware,
|
||||
rate_limiter=InMemoryRateLimiter(requests=10, seconds=1),
|
||||
)
|
||||
api.middleware('http')(AttachConversationMiddleware(api))
|
||||
api.middleware('http')(GitHubTokenMiddleware(api, SettingsStoreImpl)) # type: ignore
|
||||
|
||||
|
||||
def load_openhands_config():
|
||||
config_cls = os.environ.get('OPENHANDS_CONFIG_CLS', None)
|
||||
logger.info(f'Using config class {config_cls}')
|
||||
|
||||
openhands_config_cls = get_impl(OpenhandsConfig, config_cls)
|
||||
openhands_config = openhands_config_cls()
|
||||
openhands_config.verify_config()
|
||||
|
||||
return openhands_config
|
||||
43
openhands/server/config/server_config.py
Normal file
43
openhands/server/config/server_config.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import os
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.types import AppMode, ServerConfigInterface
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
class ServerConfig(ServerConfigInterface):
|
||||
config_cls = os.environ.get('OPENHANDS_CONFIG_CLS', None)
|
||||
app_mode = AppMode.OSS
|
||||
posthog_client_key = 'phc_3ESMmY9SgqEAGBB6sMGK5ayYHkeUuknH2vP6FmWH9RA'
|
||||
github_client_id = os.environ.get('GITHUB_APP_CLIENT_ID', '')
|
||||
settings_store_class: str = (
|
||||
'openhands.storage.settings.file_settings_store.FileSettingsStore'
|
||||
)
|
||||
conversation_store_class: str = (
|
||||
'openhands.storage.conversation.file_conversation_store.FileConversationStore'
|
||||
)
|
||||
conversation_manager_class: str = 'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager'
|
||||
|
||||
def verify_config(self):
|
||||
if self.config_cls:
|
||||
raise ValueError('Unexpected config path provided')
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'APP_MODE': self.app_mode,
|
||||
'GITHUB_CLIENT_ID': self.github_client_id,
|
||||
'POSTHOG_CLIENT_KEY': self.posthog_client_key,
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def load_server_config():
|
||||
config_cls = os.environ.get('OPENHANDS_CONFIG_CLS', None)
|
||||
logger.info(f'Using config class {config_cls}')
|
||||
|
||||
server_config_cls = get_impl(ServerConfig, config_cls)
|
||||
server_config = server_config_cls()
|
||||
server_config.verify_config()
|
||||
|
||||
return server_config
|
||||
@@ -2,10 +2,34 @@ import socketio
|
||||
|
||||
from openhands.server.app import app as base_app
|
||||
from openhands.server.listen_socket import sio
|
||||
from openhands.server.middleware import (
|
||||
AttachConversationMiddleware,
|
||||
CacheControlMiddleware,
|
||||
GitHubTokenMiddleware,
|
||||
InMemoryRateLimiter,
|
||||
LocalhostCORSMiddleware,
|
||||
RateLimitMiddleware,
|
||||
)
|
||||
from openhands.server.shared import SettingsStoreImpl
|
||||
from openhands.server.static import SPAStaticFiles
|
||||
|
||||
base_app.mount(
|
||||
'/', SPAStaticFiles(directory='./frontend/build', html=True), name='dist'
|
||||
)
|
||||
|
||||
base_app.add_middleware(
|
||||
LocalhostCORSMiddleware,
|
||||
allow_credentials=True,
|
||||
allow_methods=['*'],
|
||||
allow_headers=['*'],
|
||||
)
|
||||
|
||||
base_app.add_middleware(CacheControlMiddleware)
|
||||
base_app.add_middleware(
|
||||
RateLimitMiddleware,
|
||||
rate_limiter=InMemoryRateLimiter(requests=10, seconds=1),
|
||||
)
|
||||
base_app.middleware('http')(AttachConversationMiddleware(base_app))
|
||||
base_app.middleware('http')(GitHubTokenMiddleware(base_app, SettingsStoreImpl)) # type: ignore
|
||||
|
||||
app = socketio.ASGIApp(sio, other_asgi_app=base_app)
|
||||
|
||||
@@ -14,8 +14,14 @@ from openhands.events.observation import (
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.events.stream import AsyncEventStreamWrapper
|
||||
from openhands.server.routes.settings import ConversationStoreImpl, SettingsStoreImpl
|
||||
from openhands.server.shared import config, conversation_manager, openhands_config, sio
|
||||
from openhands.server.shared import (
|
||||
ConversationStoreImpl,
|
||||
SettingsStoreImpl,
|
||||
config,
|
||||
conversation_manager,
|
||||
server_config,
|
||||
sio,
|
||||
)
|
||||
from openhands.server.types import AppMode
|
||||
|
||||
|
||||
@@ -30,7 +36,7 @@ async def connect(connection_id: str, environ, auth):
|
||||
raise ConnectionRefusedError('No conversation_id in query params')
|
||||
|
||||
user_id = None
|
||||
if openhands_config.app_mode != AppMode.OSS:
|
||||
if server_config.app_mode != AppMode.OSS:
|
||||
cookies_str = environ.get('HTTP_COOKIE', '')
|
||||
cookies = dict(cookie.split('=', 1) for cookie in cookies_str.split('; '))
|
||||
signed_token = cookies.get('github_auth', '')
|
||||
|
||||
@@ -14,7 +14,6 @@ from starlette.types import ASGIApp
|
||||
from openhands.server import shared
|
||||
from openhands.server.auth import get_user_id
|
||||
from openhands.server.types import SessionMiddlewareInterface
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
|
||||
|
||||
class LocalhostCORSMiddleware(CORSMiddleware):
|
||||
@@ -185,12 +184,11 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
|
||||
|
||||
|
||||
class GitHubTokenMiddleware(SessionMiddlewareInterface):
|
||||
def __init__(self, app, settings_store: SettingsStore):
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
self.settings_store_impl = settings_store
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable):
|
||||
settings_store = await self.settings_store_impl.get_instance(
|
||||
settings_store = await shared.SettingsStoreImpl.get_instance(
|
||||
shared.config, get_user_id(request)
|
||||
)
|
||||
settings = await settings_store.load()
|
||||
|
||||
@@ -4,7 +4,6 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from openhands.server.auth import get_github_token
|
||||
from openhands.server.shared import openhands_config
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
app = APIRouter(prefix='/api/github')
|
||||
@@ -29,14 +28,10 @@ async def get_github_repositories(
|
||||
installation_id: int | None = None,
|
||||
github_token: str = Depends(require_github_token),
|
||||
):
|
||||
openhands_config.verify_github_repo_list(installation_id)
|
||||
|
||||
# Add query parameters
|
||||
params: dict[str, str] = {
|
||||
'page': str(page),
|
||||
'per_page': str(per_page),
|
||||
}
|
||||
# Construct the GitHub API URL
|
||||
if installation_id:
|
||||
github_api_url = (
|
||||
f'https://api.github.com/user/installations/{installation_id}/repositories'
|
||||
@@ -45,10 +40,8 @@ async def get_github_repositories(
|
||||
github_api_url = 'https://api.github.com/user/repos'
|
||||
params['sort'] = sort
|
||||
|
||||
# Set the authorization header with the GitHub token
|
||||
headers = generate_github_headers(github_token)
|
||||
|
||||
# Fetch repositories from GitHub
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(github_api_url, headers=headers, params=params)
|
||||
@@ -93,7 +86,9 @@ async def get_github_installation_ids(
|
||||
headers = generate_github_headers(github_token)
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get('https://api.github.com/user/installations', headers=headers)
|
||||
response = await client.get(
|
||||
'https://api.github.com/user/installations', headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
ids = [installation['id'] for installation in data['installations']]
|
||||
|
||||
@@ -11,9 +11,13 @@ from openhands.events.action.message import MessageAction
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.server.auth import get_github_token, get_user_id
|
||||
from openhands.server.routes.settings import ConversationStoreImpl, SettingsStoreImpl
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
from openhands.server.shared import config, conversation_manager
|
||||
from openhands.server.shared import (
|
||||
ConversationStoreImpl,
|
||||
SettingsStoreImpl,
|
||||
config,
|
||||
conversation_manager,
|
||||
)
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.storage.data_models.conversation_info import ConversationInfo
|
||||
from openhands.storage.data_models.conversation_info_result_set import (
|
||||
|
||||
@@ -16,7 +16,7 @@ from openhands.controller.agent import Agent
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.llm import bedrock
|
||||
from openhands.server.shared import config, openhands_config
|
||||
from openhands.server.shared import config, server_config
|
||||
|
||||
app = APIRouter(prefix='/api/options')
|
||||
|
||||
@@ -112,4 +112,4 @@ async def get_config():
|
||||
Get current config
|
||||
"""
|
||||
|
||||
return openhands_config.get_config()
|
||||
return server_config.get_config()
|
||||
|
||||
@@ -5,20 +5,11 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.auth import get_user_id
|
||||
from openhands.server.services.github_service import GitHubService
|
||||
from openhands.server.settings import Settings, SettingsWithTokenMeta
|
||||
from openhands.server.shared import config, openhands_config
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.server.shared import SettingsStoreImpl, config
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
app = APIRouter(prefix='/api')
|
||||
|
||||
SettingsStoreImpl = get_impl(SettingsStore, openhands_config.settings_store_class) # type: ignore
|
||||
ConversationStoreImpl = get_impl(
|
||||
ConversationStore, # type: ignore
|
||||
openhands_config.conversation_store_class,
|
||||
)
|
||||
|
||||
|
||||
@app.get('/settings')
|
||||
async def load_settings(request: Request) -> SettingsWithTokenMeta | None:
|
||||
|
||||
@@ -4,17 +4,19 @@ import socketio
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from openhands.core.config import load_app_config
|
||||
from openhands.server.config.openhands_config import load_openhands_config
|
||||
from openhands.server.config.server_config import load_server_config
|
||||
from openhands.server.conversation_manager.conversation_manager import (
|
||||
ConversationManager,
|
||||
)
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
load_dotenv()
|
||||
|
||||
config = load_app_config()
|
||||
openhands_config = load_openhands_config()
|
||||
server_config = load_server_config()
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
|
||||
client_manager = None
|
||||
@@ -32,6 +34,13 @@ sio = socketio.AsyncServer(
|
||||
|
||||
ConversationManagerImpl = get_impl(
|
||||
ConversationManager, # type: ignore
|
||||
openhands_config.conversation_manager_class,
|
||||
server_config.conversation_manager_class,
|
||||
)
|
||||
conversation_manager = ConversationManagerImpl.get_instance(sio, config, file_store)
|
||||
|
||||
SettingsStoreImpl = get_impl(SettingsStore, server_config.settings_store_class) # type: ignore
|
||||
|
||||
ConversationStoreImpl = get_impl(
|
||||
ConversationStore, # type: ignore
|
||||
server_config.conversation_store_class,
|
||||
)
|
||||
|
||||
@@ -2,8 +2,6 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import ClassVar, Protocol
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
class AppMode(Enum):
|
||||
OSS = 'oss'
|
||||
@@ -16,7 +14,7 @@ class SessionMiddlewareInterface(Protocol):
|
||||
pass
|
||||
|
||||
|
||||
class OpenhandsConfigInterface(ABC):
|
||||
class ServerConfigInterface(ABC):
|
||||
CONFIG_PATH: ClassVar[str | None]
|
||||
APP_MODE: ClassVar[AppMode]
|
||||
POSTHOG_CLIENT_KEY: ClassVar[str]
|
||||
@@ -28,21 +26,11 @@ class OpenhandsConfigInterface(ABC):
|
||||
"""Verify configuration settings."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def verify_github_repo_list(self, installation_id: int | None) -> None:
|
||||
"""Verify that repo list is being called via user's profile or Github App installations."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_config(self) -> dict[str, str]:
|
||||
"""Configure attributes for frontend"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def attach_middleware(self, api: FastAPI) -> None:
|
||||
"""Attach required middleware for the current environment"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MissingSettingsError(ValueError):
|
||||
"""Raised when settings are missing or not found."""
|
||||
|
||||
@@ -9,25 +9,6 @@ from openhands.server.app import app
|
||||
from openhands.server.settings import Settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client():
|
||||
# Mock the middleware that adds github_token
|
||||
class MockMiddleware:
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope['type'] == 'http':
|
||||
scope['state'] = {'github_token': 'test-token'}
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
# Replace the middleware
|
||||
app.middleware_stack = None # Clear existing middleware
|
||||
app.add_middleware(MockMiddleware)
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings_store():
|
||||
with patch('openhands.server.routes.settings.SettingsStoreImpl') as mock:
|
||||
@@ -38,6 +19,27 @@ def mock_settings_store():
|
||||
yield store_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client(mock_settings_store):
|
||||
# Mock the middleware that adds github_token
|
||||
class MockMiddleware:
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
settings = mock_settings_store.load.return_value
|
||||
token = settings.github_token if settings else None
|
||||
if scope['type'] == 'http':
|
||||
scope['state'] = {'github_token': token}
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
# Replace the middleware
|
||||
app.middleware_stack = None # Clear existing middleware
|
||||
app.add_middleware(MockMiddleware)
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_github_service():
|
||||
with patch('openhands.server.routes.settings.GitHubService') as mock:
|
||||
@@ -185,6 +187,10 @@ async def test_settings_unset_github_token(
|
||||
# Mock settings store to return our settings for the GET request
|
||||
mock_settings_store.load.return_value = Settings(**settings_data)
|
||||
|
||||
response = test_client.get('/api/settings')
|
||||
assert response.status_code == 200
|
||||
assert response.json()['github_token_is_set'] is True
|
||||
|
||||
settings_data['unset_github_token'] = True
|
||||
|
||||
# Make the POST request to store settings
|
||||
@@ -194,6 +200,7 @@ async def test_settings_unset_github_token(
|
||||
# Verify the settings were stored with the github_token unset
|
||||
stored_settings = mock_settings_store.store.call_args[0][0]
|
||||
assert stored_settings.github_token is None
|
||||
mock_settings_store.load.return_value = Settings(**stored_settings.dict())
|
||||
|
||||
# Make a GET request to retrieve settings
|
||||
response = test_client.get('/api/settings')
|
||||
|
||||
Reference in New Issue
Block a user