Refactor: move middleware definition (#6552)

This commit is contained in:
Robert Brennan
2025-01-30 15:32:26 -05:00
committed by GitHub
parent 5dd4810f58
commit 27fdae6ecc
13 changed files with 130 additions and 145 deletions

View File

@@ -21,7 +21,7 @@ from openhands.server.routes.public import app as public_api_router
from openhands.server.routes.security import app as security_api_router
from openhands.server.routes.settings import app as settings_router
from openhands.server.routes.trajectory import app as trajectory_router
from openhands.server.shared import conversation_manager, openhands_config
from openhands.server.shared import conversation_manager
@asynccontextmanager
@@ -36,7 +36,6 @@ app = FastAPI(
version=__version__,
lifespan=_lifespan,
)
openhands_config.attach_middleware(app)
@app.get('/health')

View File

@@ -1,79 +0,0 @@
import os
from fastapi import FastAPI, HTTPException
from openhands.core.logger import openhands_logger as logger
from openhands.server.middleware import (
AttachConversationMiddleware,
CacheControlMiddleware,
GitHubTokenMiddleware,
InMemoryRateLimiter,
LocalhostCORSMiddleware,
RateLimitMiddleware,
)
from openhands.server.types import AppMode, OpenhandsConfigInterface
from openhands.storage.settings.settings_store import SettingsStore
from openhands.utils.import_utils import get_impl
class OpenhandsConfig(OpenhandsConfigInterface):
config_cls = os.environ.get('OPENHANDS_CONFIG_CLS', None)
app_mode = AppMode.OSS
posthog_client_key = 'phc_3ESMmY9SgqEAGBB6sMGK5ayYHkeUuknH2vP6FmWH9RA'
github_client_id = os.environ.get('GITHUB_APP_CLIENT_ID', '')
settings_store_class: str = (
'openhands.storage.settings.file_settings_store.FileSettingsStore'
)
conversation_store_class: str = (
'openhands.storage.conversation.file_conversation_store.FileConversationStore'
)
conversation_manager_class: str = 'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager'
def verify_config(self):
if self.config_cls:
raise ValueError('Unexpected config path provided')
def verify_github_repo_list(self, installation_id: int | None):
if self.app_mode == AppMode.OSS and installation_id:
raise HTTPException(
status_code=400,
detail='Unexpected installation ID',
)
def get_config(self):
config = {
'APP_MODE': self.app_mode,
'GITHUB_CLIENT_ID': self.github_client_id,
'POSTHOG_CLIENT_KEY': self.posthog_client_key,
}
return config
def attach_middleware(self, api: FastAPI) -> None:
SettingsStoreImpl = get_impl(SettingsStore, self.settings_store_class) # type: ignore
api.add_middleware(
LocalhostCORSMiddleware,
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
api.add_middleware(CacheControlMiddleware)
api.add_middleware(
RateLimitMiddleware,
rate_limiter=InMemoryRateLimiter(requests=10, seconds=1),
)
api.middleware('http')(AttachConversationMiddleware(api))
api.middleware('http')(GitHubTokenMiddleware(api, SettingsStoreImpl)) # type: ignore
def load_openhands_config():
config_cls = os.environ.get('OPENHANDS_CONFIG_CLS', None)
logger.info(f'Using config class {config_cls}')
openhands_config_cls = get_impl(OpenhandsConfig, config_cls)
openhands_config = openhands_config_cls()
openhands_config.verify_config()
return openhands_config

View File

@@ -0,0 +1,43 @@
import os
from openhands.core.logger import openhands_logger as logger
from openhands.server.types import AppMode, ServerConfigInterface
from openhands.utils.import_utils import get_impl
class ServerConfig(ServerConfigInterface):
config_cls = os.environ.get('OPENHANDS_CONFIG_CLS', None)
app_mode = AppMode.OSS
posthog_client_key = 'phc_3ESMmY9SgqEAGBB6sMGK5ayYHkeUuknH2vP6FmWH9RA'
github_client_id = os.environ.get('GITHUB_APP_CLIENT_ID', '')
settings_store_class: str = (
'openhands.storage.settings.file_settings_store.FileSettingsStore'
)
conversation_store_class: str = (
'openhands.storage.conversation.file_conversation_store.FileConversationStore'
)
conversation_manager_class: str = 'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager'
def verify_config(self):
if self.config_cls:
raise ValueError('Unexpected config path provided')
def get_config(self):
config = {
'APP_MODE': self.app_mode,
'GITHUB_CLIENT_ID': self.github_client_id,
'POSTHOG_CLIENT_KEY': self.posthog_client_key,
}
return config
def load_server_config():
config_cls = os.environ.get('OPENHANDS_CONFIG_CLS', None)
logger.info(f'Using config class {config_cls}')
server_config_cls = get_impl(ServerConfig, config_cls)
server_config = server_config_cls()
server_config.verify_config()
return server_config

View File

@@ -2,10 +2,34 @@ import socketio
from openhands.server.app import app as base_app
from openhands.server.listen_socket import sio
from openhands.server.middleware import (
AttachConversationMiddleware,
CacheControlMiddleware,
GitHubTokenMiddleware,
InMemoryRateLimiter,
LocalhostCORSMiddleware,
RateLimitMiddleware,
)
from openhands.server.shared import SettingsStoreImpl
from openhands.server.static import SPAStaticFiles
base_app.mount(
'/', SPAStaticFiles(directory='./frontend/build', html=True), name='dist'
)
base_app.add_middleware(
LocalhostCORSMiddleware,
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
base_app.add_middleware(CacheControlMiddleware)
base_app.add_middleware(
RateLimitMiddleware,
rate_limiter=InMemoryRateLimiter(requests=10, seconds=1),
)
base_app.middleware('http')(AttachConversationMiddleware(base_app))
base_app.middleware('http')(GitHubTokenMiddleware(base_app, SettingsStoreImpl)) # type: ignore
app = socketio.ASGIApp(sio, other_asgi_app=base_app)

View File

@@ -14,8 +14,14 @@ from openhands.events.observation import (
from openhands.events.observation.agent import AgentStateChangedObservation
from openhands.events.serialization import event_to_dict
from openhands.events.stream import AsyncEventStreamWrapper
from openhands.server.routes.settings import ConversationStoreImpl, SettingsStoreImpl
from openhands.server.shared import config, conversation_manager, openhands_config, sio
from openhands.server.shared import (
ConversationStoreImpl,
SettingsStoreImpl,
config,
conversation_manager,
server_config,
sio,
)
from openhands.server.types import AppMode
@@ -30,7 +36,7 @@ async def connect(connection_id: str, environ, auth):
raise ConnectionRefusedError('No conversation_id in query params')
user_id = None
if openhands_config.app_mode != AppMode.OSS:
if server_config.app_mode != AppMode.OSS:
cookies_str = environ.get('HTTP_COOKIE', '')
cookies = dict(cookie.split('=', 1) for cookie in cookies_str.split('; '))
signed_token = cookies.get('github_auth', '')

View File

@@ -14,7 +14,6 @@ from starlette.types import ASGIApp
from openhands.server import shared
from openhands.server.auth import get_user_id
from openhands.server.types import SessionMiddlewareInterface
from openhands.storage.settings.settings_store import SettingsStore
class LocalhostCORSMiddleware(CORSMiddleware):
@@ -185,12 +184,11 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
class GitHubTokenMiddleware(SessionMiddlewareInterface):
def __init__(self, app, settings_store: SettingsStore):
def __init__(self, app):
self.app = app
self.settings_store_impl = settings_store
async def __call__(self, request: Request, call_next: Callable):
settings_store = await self.settings_store_impl.get_instance(
settings_store = await shared.SettingsStoreImpl.get_instance(
shared.config, get_user_id(request)
)
settings = await settings_store.load()

View File

@@ -4,7 +4,6 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import JSONResponse
from openhands.server.auth import get_github_token
from openhands.server.shared import openhands_config
from openhands.utils.async_utils import call_sync_from_async
app = APIRouter(prefix='/api/github')
@@ -29,14 +28,10 @@ async def get_github_repositories(
installation_id: int | None = None,
github_token: str = Depends(require_github_token),
):
openhands_config.verify_github_repo_list(installation_id)
# Add query parameters
params: dict[str, str] = {
'page': str(page),
'per_page': str(per_page),
}
# Construct the GitHub API URL
if installation_id:
github_api_url = (
f'https://api.github.com/user/installations/{installation_id}/repositories'
@@ -45,10 +40,8 @@ async def get_github_repositories(
github_api_url = 'https://api.github.com/user/repos'
params['sort'] = sort
# Set the authorization header with the GitHub token
headers = generate_github_headers(github_token)
# Fetch repositories from GitHub
try:
async with httpx.AsyncClient() as client:
response = await client.get(github_api_url, headers=headers, params=params)
@@ -93,7 +86,9 @@ async def get_github_installation_ids(
headers = generate_github_headers(github_token)
try:
async with httpx.AsyncClient() as client:
response = await client.get('https://api.github.com/user/installations', headers=headers)
response = await client.get(
'https://api.github.com/user/installations', headers=headers
)
response.raise_for_status()
data = response.json()
ids = [installation['id'] for installation in data['installations']]

View File

@@ -11,9 +11,13 @@ from openhands.events.action.message import MessageAction
from openhands.events.stream import EventStreamSubscriber
from openhands.runtime import get_runtime_cls
from openhands.server.auth import get_github_token, get_user_id
from openhands.server.routes.settings import ConversationStoreImpl, SettingsStoreImpl
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.server.shared import config, conversation_manager
from openhands.server.shared import (
ConversationStoreImpl,
SettingsStoreImpl,
config,
conversation_manager,
)
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
from openhands.storage.data_models.conversation_info import ConversationInfo
from openhands.storage.data_models.conversation_info_result_set import (

View File

@@ -16,7 +16,7 @@ from openhands.controller.agent import Agent
from openhands.core.config import LLMConfig
from openhands.core.logger import openhands_logger as logger
from openhands.llm import bedrock
from openhands.server.shared import config, openhands_config
from openhands.server.shared import config, server_config
app = APIRouter(prefix='/api/options')
@@ -112,4 +112,4 @@ async def get_config():
Get current config
"""
return openhands_config.get_config()
return server_config.get_config()

View File

@@ -5,20 +5,11 @@ from openhands.core.logger import openhands_logger as logger
from openhands.server.auth import get_user_id
from openhands.server.services.github_service import GitHubService
from openhands.server.settings import Settings, SettingsWithTokenMeta
from openhands.server.shared import config, openhands_config
from openhands.storage.conversation.conversation_store import ConversationStore
from openhands.storage.settings.settings_store import SettingsStore
from openhands.server.shared import SettingsStoreImpl, config
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.import_utils import get_impl
app = APIRouter(prefix='/api')
SettingsStoreImpl = get_impl(SettingsStore, openhands_config.settings_store_class) # type: ignore
ConversationStoreImpl = get_impl(
ConversationStore, # type: ignore
openhands_config.conversation_store_class,
)
@app.get('/settings')
async def load_settings(request: Request) -> SettingsWithTokenMeta | None:

View File

@@ -4,17 +4,19 @@ import socketio
from dotenv import load_dotenv
from openhands.core.config import load_app_config
from openhands.server.config.openhands_config import load_openhands_config
from openhands.server.config.server_config import load_server_config
from openhands.server.conversation_manager.conversation_manager import (
ConversationManager,
)
from openhands.storage import get_file_store
from openhands.storage.conversation.conversation_store import ConversationStore
from openhands.storage.settings.settings_store import SettingsStore
from openhands.utils.import_utils import get_impl
load_dotenv()
config = load_app_config()
openhands_config = load_openhands_config()
server_config = load_server_config()
file_store = get_file_store(config.file_store, config.file_store_path)
client_manager = None
@@ -32,6 +34,13 @@ sio = socketio.AsyncServer(
ConversationManagerImpl = get_impl(
ConversationManager, # type: ignore
openhands_config.conversation_manager_class,
server_config.conversation_manager_class,
)
conversation_manager = ConversationManagerImpl.get_instance(sio, config, file_store)
SettingsStoreImpl = get_impl(SettingsStore, server_config.settings_store_class) # type: ignore
ConversationStoreImpl = get_impl(
ConversationStore, # type: ignore
server_config.conversation_store_class,
)

View File

@@ -2,8 +2,6 @@ from abc import ABC, abstractmethod
from enum import Enum
from typing import ClassVar, Protocol
from fastapi import FastAPI
class AppMode(Enum):
OSS = 'oss'
@@ -16,7 +14,7 @@ class SessionMiddlewareInterface(Protocol):
pass
class OpenhandsConfigInterface(ABC):
class ServerConfigInterface(ABC):
CONFIG_PATH: ClassVar[str | None]
APP_MODE: ClassVar[AppMode]
POSTHOG_CLIENT_KEY: ClassVar[str]
@@ -28,21 +26,11 @@ class OpenhandsConfigInterface(ABC):
"""Verify configuration settings."""
raise NotImplementedError
@abstractmethod
async def verify_github_repo_list(self, installation_id: int | None) -> None:
"""Verify that repo list is being called via user's profile or Github App installations."""
raise NotImplementedError
@abstractmethod
async def get_config(self) -> dict[str, str]:
"""Configure attributes for frontend"""
raise NotImplementedError
@abstractmethod
def attach_middleware(self, api: FastAPI) -> None:
"""Attach required middleware for the current environment"""
raise NotImplementedError
class MissingSettingsError(ValueError):
"""Raised when settings are missing or not found."""

View File

@@ -9,25 +9,6 @@ from openhands.server.app import app
from openhands.server.settings import Settings
@pytest.fixture
def test_client():
# Mock the middleware that adds github_token
class MockMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope['type'] == 'http':
scope['state'] = {'github_token': 'test-token'}
await self.app(scope, receive, send)
# Replace the middleware
app.middleware_stack = None # Clear existing middleware
app.add_middleware(MockMiddleware)
return TestClient(app)
@pytest.fixture
def mock_settings_store():
with patch('openhands.server.routes.settings.SettingsStoreImpl') as mock:
@@ -38,6 +19,27 @@ def mock_settings_store():
yield store_instance
@pytest.fixture
def test_client(mock_settings_store):
# Mock the middleware that adds github_token
class MockMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
settings = mock_settings_store.load.return_value
token = settings.github_token if settings else None
if scope['type'] == 'http':
scope['state'] = {'github_token': token}
await self.app(scope, receive, send)
# Replace the middleware
app.middleware_stack = None # Clear existing middleware
app.add_middleware(MockMiddleware)
return TestClient(app)
@pytest.fixture
def mock_github_service():
with patch('openhands.server.routes.settings.GitHubService') as mock:
@@ -185,6 +187,10 @@ async def test_settings_unset_github_token(
# Mock settings store to return our settings for the GET request
mock_settings_store.load.return_value = Settings(**settings_data)
response = test_client.get('/api/settings')
assert response.status_code == 200
assert response.json()['github_token_is_set'] is True
settings_data['unset_github_token'] = True
# Make the POST request to store settings
@@ -194,6 +200,7 @@ async def test_settings_unset_github_token(
# Verify the settings were stored with the github_token unset
stored_settings = mock_settings_store.store.call_args[0][0]
assert stored_settings.github_token is None
mock_settings_store.load.return_value = Settings(**stored_settings.dict())
# Make a GET request to retrieve settings
response = test_client.get('/api/settings')