Refactor Authentication (#8040)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: rohitvinodmalhotra@gmail.com <rohitvinodmalhotra@gmail.com>
This commit is contained in:
tofarr
2025-04-24 18:49:41 -06:00
committed by GitHub
parent 9b1aaa53fe
commit c5245a622d
19 changed files with 931 additions and 851 deletions

View File

@@ -1,34 +0,0 @@
from fastapi import Request
from pydantic import SecretStr
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
def get_provider_tokens(request: Request) -> PROVIDER_TOKEN_TYPE | None:
"""Get GitHub token from request state. For backward compatibility."""
return getattr(request.state, 'provider_tokens', None)
def get_access_token(request: Request) -> SecretStr | None:
return getattr(request.state, 'access_token', None)
def get_user_id(request: Request) -> str | None:
return getattr(request.state, 'user_id', None)
def get_github_token(request: Request) -> SecretStr | None:
provider_tokens = get_provider_tokens(request)
if provider_tokens and ProviderType.GITHUB in provider_tokens:
return provider_tokens[ProviderType.GITHUB].token
return None
def get_github_user_id(request: Request) -> str | None:
provider_tokens = get_provider_tokens(request)
if provider_tokens and ProviderType.GITHUB in provider_tokens:
return provider_tokens[ProviderType.GITHUB].user_id
return None

View File

@@ -20,6 +20,7 @@ class ServerConfig(ServerConfigInterface):
)
conversation_manager_class: str = 'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager'
monitoring_listener_class: str = 'openhands.server.monitoring.MonitoringListener'
user_auth_class: str = 'openhands.server.user_auth.default_user_auth.DefaultUserAuth'
def verify_config(self):
if self.config_cls:

View File

@@ -9,7 +9,6 @@ from openhands.server.middleware import (
CacheControlMiddleware,
InMemoryRateLimiter,
LocalhostCORSMiddleware,
ProviderTokenMiddleware,
RateLimitMiddleware,
)
from openhands.server.static import SPAStaticFiles
@@ -32,6 +31,5 @@ base_app.add_middleware(
rate_limiter=InMemoryRateLimiter(requests=10, seconds=1),
)
base_app.middleware('http')(AttachConversationMiddleware(base_app))
base_app.middleware('http')(ProviderTokenMiddleware(base_app))
app = socketio.ASGIApp(sio, other_asgi_app=base_app)

View File

@@ -12,8 +12,8 @@ from starlette.requests import Request as StarletteRequest
from starlette.types import ASGIApp
from openhands.server import shared
from openhands.server.auth import get_user_id
from openhands.server.types import SessionMiddlewareInterface
from openhands.server.user_auth import get_user_id
class LocalhostCORSMiddleware(CORSMiddleware):
@@ -147,9 +147,10 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
"""
Attach the user's session based on the provided authentication token.
"""
user_id = await get_user_id(request)
request.state.conversation = (
await shared.conversation_manager.attach_to_conversation(
request.state.sid, get_user_id(request)
request.state.sid, user_id
)
)
if not request.state.conversation:
@@ -183,27 +184,3 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
await self._detach_session(request)
return response
class ProviderTokenMiddleware(SessionMiddlewareInterface):
def __init__(self, app):
self.app = app
async def __call__(self, request: Request, call_next: Callable):
settings_store = await shared.SettingsStoreImpl.get_instance(
shared.config, get_user_id(request)
)
settings = await settings_store.load()
# TODO: To avoid checks like this we should re-add the abilty to have completely different middleware in SAAS as in OSS
if getattr(request.state, 'provider_tokens', None) is None:
if (
settings
and settings.secrets_store
and settings.secrets_store.provider_tokens
):
request.state.provider_tokens = settings.secrets_store.provider_tokens
else:
request.state.provider_tokens = None
return await call_next(request)

View File

@@ -2,6 +2,7 @@ import os
from fastapi import (
APIRouter,
Depends,
HTTPException,
Request,
status,
@@ -21,7 +22,6 @@ from openhands.events.observation import (
FileReadObservation,
)
from openhands.runtime.base import Runtime
from openhands.server.auth import get_github_user_id, get_user_id
from openhands.server.data_models.conversation_info import ConversationInfo
from openhands.server.file_config import (
FILES_TO_IGNORE,
@@ -31,6 +31,8 @@ from openhands.server.shared import (
config,
conversation_manager,
)
from openhands.server.user_auth import get_user_id
from openhands.server.utils import get_conversation_store
from openhands.storage.conversation.conversation_store import ConversationStore
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
from openhands.storage.data_models.conversation_status import ConversationStatus
@@ -187,10 +189,15 @@ def zip_current_workspace(request: Request):
@app.get('/git/changes')
async def git_changes(request: Request, conversation_id: str):
async def git_changes(
request: Request,
conversation_id: str,
user_id: str = Depends(get_user_id),
):
runtime: Runtime = request.state.conversation.runtime
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request), get_github_user_id(request)
config,
user_id,
)
cwd = await get_cwd(
@@ -223,11 +230,13 @@ async def git_changes(request: Request, conversation_id: str):
@app.get('/git/diff')
async def git_diff(request: Request, path: str, conversation_id: str):
async def git_diff(
request: Request,
path: str,
conversation_id: str,
conversation_store = Depends(get_conversation_store),
):
runtime: Runtime = request.state.conversation.runtime
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request), get_github_user_id(request)
)
cwd = await get_cwd(
conversation_store,

View File

@@ -15,8 +15,12 @@ from openhands.integrations.service_types import (
UnknownException,
User,
)
from openhands.server.auth import get_access_token, get_provider_tokens, get_user_id
from openhands.server.shared import server_config
from openhands.server.user_auth import (
get_access_token,
get_provider_tokens,
get_user_id,
)
app = APIRouter(prefix='/api/user')

View File

@@ -1,7 +1,7 @@
import uuid
from datetime import datetime, timezone
from fastapi import APIRouter, Body, Request, status
from fastapi import APIRouter, Body, Depends, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel
@@ -15,11 +15,6 @@ from openhands.integrations.provider import (
)
from openhands.integrations.service_types import Repository
from openhands.runtime import get_runtime_cls
from openhands.server.auth import (
get_github_user_id,
get_provider_tokens,
get_user_id,
)
from openhands.server.data_models.conversation_info import ConversationInfo
from openhands.server.data_models.conversation_info_result_set import (
ConversationInfoResultSet,
@@ -33,6 +28,12 @@ from openhands.server.shared import (
file_store,
)
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
from openhands.server.user_auth import (
get_provider_tokens,
get_user_id,
)
from openhands.server.utils import get_conversation_store
from openhands.storage.conversation.conversation_store import ConversationStore
from openhands.storage.data_models.conversation_metadata import (
ConversationMetadata,
ConversationTrigger,
@@ -95,7 +96,7 @@ async def _create_new_conversation(
session_init_args['selected_branch'] = selected_branch
conversation_init_data = ConversationInitData(**session_init_args)
logger.info('Loading conversation store')
conversation_store = await ConversationStoreImpl.get_instance(config, user_id, None)
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
logger.info('Conversation store loaded')
conversation_id = uuid.uuid4().hex
@@ -152,14 +153,17 @@ async def _create_new_conversation(
@app.post('/conversations')
async def new_conversation(request: Request, data: InitSessionRequest):
async def new_conversation(
data: InitSessionRequest,
user_id: str = Depends(get_user_id),
provider_tokens: PROVIDER_TOKEN_TYPE = Depends(get_provider_tokens),
):
"""Initialize a new session or join an existing one.
After successful initialization, the client should connect to the WebSocket
using the returned conversation ID.
"""
logger.info('Initializing new conversation')
provider_tokens = get_provider_tokens(request)
selected_repository = data.selected_repository
selected_branch = data.selected_branch
initial_user_msg = data.initial_user_msg
@@ -169,7 +173,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
try:
# Create conversation with initial message
conversation_id = await _create_new_conversation(
get_user_id(request),
user_id,
provider_tokens,
selected_repository,
selected_branch,
@@ -204,13 +208,11 @@ async def new_conversation(request: Request, data: InitSessionRequest):
@app.get('/conversations')
async def search_conversations(
request: Request,
page_id: str | None = None,
limit: int = 20,
user_id: str | None = Depends(get_user_id),
conversation_store: ConversationStore = Depends(get_conversation_store),
) -> ConversationInfoResultSet:
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request), get_github_user_id(request)
)
conversation_metadata_result_set = await conversation_store.search(page_id, limit)
# Filter out conversations older than max_age
@@ -228,7 +230,7 @@ async def search_conversations(
conversation.conversation_id for conversation in filtered_results
)
running_conversations = await conversation_manager.get_running_agent_loops(
get_user_id(request), set(conversation_ids)
user_id, set(conversation_ids)
)
result = ConversationInfoResultSet(
results=await wait_all(
@@ -245,11 +247,9 @@ async def search_conversations(
@app.get('/conversations/{conversation_id}')
async def get_conversation(
conversation_id: str, request: Request
conversation_id: str,
conversation_store: ConversationStore = Depends(get_conversation_store),
) -> ConversationInfo | None:
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request), get_github_user_id(request)
)
try:
metadata = await conversation_store.get_metadata(conversation_id)
is_running = await conversation_manager.is_agent_loop_running(conversation_id)
@@ -340,11 +340,12 @@ async def auto_generate_title(conversation_id: str, user_id: str | None) -> str:
@app.patch('/conversations/{conversation_id}')
async def update_conversation(
request: Request, conversation_id: str, title: str = Body(embed=True)
conversation_id: str,
title: str = Body(embed=True),
user_id: str | None = Depends(get_user_id),
) -> bool:
user_id = get_user_id(request)
conversation_store = await ConversationStoreImpl.get_instance(
config, user_id, get_github_user_id(request)
config, user_id
)
metadata = await conversation_store.get_metadata(conversation_id)
if not metadata:
@@ -366,10 +367,10 @@ async def update_conversation(
@app.delete('/conversations/{conversation_id}')
async def delete_conversation(
conversation_id: str,
request: Request,
user_id: str | None = Depends(get_user_id),
) -> bool:
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request), get_github_user_id(request)
config, user_id
)
try:
await conversation_store.get_metadata(conversation_id)

View File

@@ -1,11 +1,15 @@
from fastapi import APIRouter, Request, status
from fastapi import APIRouter, Depends, Request, status
from fastapi.responses import JSONResponse
from pydantic import SecretStr
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore
from openhands.integrations.provider import (
PROVIDER_TOKEN_TYPE,
ProviderToken,
ProviderType,
SecretStore,
)
from openhands.integrations.utils import validate_provider_token
from openhands.server.auth import get_provider_tokens, get_user_id
from openhands.server.settings import (
GETSettingsCustomSecrets,
GETSettingsModel,
@@ -15,16 +19,24 @@ from openhands.server.settings import (
)
from openhands.server.shared import SettingsStoreImpl, config, server_config
from openhands.server.types import AppMode
from openhands.server.user_auth import (
get_provider_tokens,
get_user_id,
get_user_settings,
get_user_settings_store,
)
from openhands.storage.settings.settings_store import SettingsStore
app = APIRouter(prefix='/api')
@app.get('/settings', response_model=GETSettingsModel)
async def load_settings(request: Request) -> GETSettingsModel | JSONResponse:
async def load_settings(
user_id: str | None = Depends(get_user_id),
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
settings: Settings | None = Depends(get_user_settings),
) -> GETSettingsModel | JSONResponse:
try:
user_id = get_user_id(request)
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
settings: Settings = await settings_store.load()
if not settings:
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
@@ -36,7 +48,6 @@ async def load_settings(request: Request) -> GETSettingsModel | JSONResponse:
if bool(user_id):
provider_tokens_set[ProviderType.GITHUB.value] = True
provider_tokens = get_provider_tokens(request)
if provider_tokens:
all_provider_types = [provider.value for provider in ProviderType]
provider_tokens_types = [provider.value for provider in provider_tokens]
@@ -63,12 +74,9 @@ async def load_settings(request: Request) -> GETSettingsModel | JSONResponse:
@app.get('/secrets', response_model=GETSettingsCustomSecrets)
async def load_custom_secrets_names(
request: Request,
settings: Settings | None = Depends(get_user_settings),
) -> GETSettingsCustomSecrets | JSONResponse:
try:
user_id = get_user_id(request)
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
settings = await settings_store.load()
if not settings:
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
@@ -93,13 +101,11 @@ async def load_custom_secrets_names(
@app.post('/secrets', response_model=dict[str, str])
async def add_custom_secret(
request: Request, incoming_secrets: POSTSettingsCustomSecrets
incoming_secrets: POSTSettingsCustomSecrets,
settings_store: SettingsStore = Depends(get_user_settings_store),
) -> JSONResponse:
try:
settings_store = await SettingsStoreImpl.get_instance(
config, get_user_id(request)
)
existing_settings: Settings = await settings_store.load()
existing_settings = await settings_store.load()
if existing_settings:
for (
secret_name,
@@ -121,7 +127,6 @@ async def add_custom_secret(
update={'secrets_store': updated_secret_store}
)
updated_settings = convert_to_settings(updated_settings)
await settings_store.store(updated_settings)
return JSONResponse(
@@ -137,11 +142,11 @@ async def add_custom_secret(
@app.delete('/secrets/{secret_id}')
async def delete_custom_secret(request: Request, secret_id: str) -> JSONResponse:
async def delete_custom_secret(
secret_id: str,
settings_store: SettingsStore = Depends(get_user_settings_store),
) -> JSONResponse:
try:
settings_store = await SettingsStoreImpl.get_instance(
config, get_user_id(request)
)
existing_settings: Settings | None = await settings_store.load()
custom_secrets = {}
if existing_settings:
@@ -162,7 +167,6 @@ async def delete_custom_secret(request: Request, secret_id: str) -> JSONResponse
update={'secrets_store': updated_secret_store}
)
updated_settings = convert_to_settings(updated_settings)
await settings_store.store(updated_settings)
return JSONResponse(
@@ -178,12 +182,10 @@ async def delete_custom_secret(request: Request, secret_id: str) -> JSONResponse
@app.post('/unset-settings-tokens', response_model=dict[str, str])
async def unset_settings_tokens(request: Request) -> JSONResponse:
async def unset_settings_tokens(
settings_store: SettingsStore = Depends(get_user_settings_store),
) -> JSONResponse:
try:
settings_store = await SettingsStoreImpl.get_instance(
config, get_user_id(request)
)
existing_settings = await settings_store.load()
if existing_settings:
settings = existing_settings.model_copy(
@@ -205,7 +207,7 @@ async def unset_settings_tokens(request: Request) -> JSONResponse:
@app.post('/reset-settings', response_model=dict[str, str])
async def reset_settings(request: Request) -> JSONResponse:
async def reset_settings() -> JSONResponse:
"""
Resets user settings. (Deprecated)
"""
@@ -218,7 +220,7 @@ async def reset_settings(request: Request) -> JSONResponse:
)
async def check_provider_tokens(request: Request, settings: POSTSettingsModel) -> str:
async def check_provider_tokens(settings: POSTSettingsModel) -> str:
if settings.provider_tokens:
# Remove extraneous token types
provider_types = [provider.value for provider in ProviderType]
@@ -238,8 +240,9 @@ async def check_provider_tokens(request: Request, settings: POSTSettingsModel) -
return ''
async def store_provider_tokens(request: Request, settings: POSTSettingsModel):
settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request))
async def store_provider_tokens(
settings: POSTSettingsModel, settings_store: SettingsStore
):
existing_settings = await settings_store.load()
if existing_settings:
if settings.provider_tokens:
@@ -273,9 +276,8 @@ async def store_provider_tokens(request: Request, settings: POSTSettingsModel):
async def store_llm_settings(
request: Request, settings: POSTSettingsModel
settings: POSTSettingsModel, settings_store: SettingsStore
) -> POSTSettingsModel:
settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request))
existing_settings = await settings_store.load()
# Convert to Settings model and merge with existing settings
@@ -293,11 +295,11 @@ async def store_llm_settings(
@app.post('/settings', response_model=dict[str, str])
async def store_settings(
request: Request,
settings: POSTSettingsModel,
settings_store: SettingsStore = Depends(get_user_settings_store),
) -> JSONResponse:
# Check provider tokens are valid
provider_err_msg = await check_provider_tokens(request, settings)
provider_err_msg = await check_provider_tokens(settings)
if provider_err_msg:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -305,14 +307,11 @@ async def store_settings(
)
try:
settings_store = await SettingsStoreImpl.get_instance(
config, get_user_id(request)
)
existing_settings = await settings_store.load()
# Convert to Settings model and merge with existing settings
if existing_settings:
settings = await store_llm_settings(request, settings)
settings = await store_llm_settings(settings, settings_store)
# Keep existing analytics consent if not provided
if settings.user_consents_to_analytics is None:
@@ -320,7 +319,7 @@ async def store_settings(
existing_settings.user_consents_to_analytics
)
settings = await store_provider_tokens(request, settings)
settings = await store_provider_tokens(settings, settings_store)
# Update sandbox config with new settings
if settings.remote_runtime_resource_factor is not None:

View File

@@ -94,7 +94,10 @@ class Settings(BaseModel):
return {
'provider_tokens': secrets.provider_tokens_serializer(
secrets.provider_tokens, info
)
),
'custom_secrets': secrets.custom_secrets_serializer(
secrets.custom_secrets, info
),
}
@staticmethod

View File

@@ -0,0 +1,48 @@
from fastapi import Request
from pydantic import SecretStr
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.integrations.service_types import ProviderType
from openhands.server.settings import Settings
from openhands.server.user_auth.user_auth import get_user_auth
from openhands.storage.settings.settings_store import SettingsStore
async def get_provider_tokens(request: Request) -> PROVIDER_TOKEN_TYPE | None:
user_auth = await get_user_auth(request)
provider_tokens = await user_auth.get_provider_tokens()
return provider_tokens
async def get_access_token(request: Request) -> SecretStr | None:
user_auth = await get_user_auth(request)
access_token = await user_auth.get_access_token()
return access_token
async def get_user_id(request: Request) -> str | None:
user_auth = await get_user_auth(request)
user_id = await user_auth.get_user_id()
return user_id
async def get_github_user_id(request: Request) -> str | None:
provider_tokens = await get_provider_tokens(request)
if not provider_tokens:
return None
github_provider = provider_tokens.get(ProviderType.GITHUB)
if github_provider:
return github_provider.user_id
return None
async def get_user_settings(request: Request) -> Settings | None:
user_auth = await get_user_auth(request)
user_settings = await user_auth.get_user_settings()
return user_settings
async def get_user_settings_store(request: Request) -> SettingsStore | None:
user_auth = await get_user_auth(request)
user_settings_store = await user_auth.get_user_settings_store()
return user_settings_store

View File

@@ -0,0 +1,57 @@
from dataclasses import dataclass
from fastapi import Request
from pydantic import SecretStr
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.server import shared
from openhands.server.settings import Settings
from openhands.server.user_auth.user_auth import UserAuth
from openhands.storage.settings.settings_store import SettingsStore
@dataclass
class DefaultUserAuth(UserAuth):
"""Default user authentication mechanism"""
_settings: Settings | None = None
_settings_store: SettingsStore | None = None
async def get_user_id(self) -> str | None:
"""The default implementation does not support multi tenancy, so user_id is always None"""
return None
async def get_access_token(self) -> SecretStr | None:
"""The default implementation does not support multi tenancy, so access_token is always None"""
return None
async def get_user_settings_store(self):
settings_store = self._settings_store
if settings_store:
return settings_store
user_id = await self.get_user_id()
settings_store = await shared.SettingsStoreImpl.get_instance(
shared.config, user_id
)
self._settings_store = settings_store
return settings_store
async def get_user_settings(self) -> Settings | None:
settings = self._settings
if settings:
return settings
settings_store = await self.get_user_settings_store()
settings = await settings_store.load()
self._settings = settings
return settings
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
settings = await self.get_user_settings()
secrets_store = getattr(settings, 'secrets_store', None)
provider_tokens = getattr(secrets_store, 'provider_tokens', None)
return provider_tokens
@classmethod
async def get_instance(cls, request: Request) -> UserAuth:
user_auth = DefaultUserAuth()
return user_auth

View File

@@ -0,0 +1,63 @@
from __future__ import annotations
import os
from abc import ABC, abstractmethod
from fastapi import Request
from pydantic import SecretStr
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.server.settings import Settings
from openhands.server.shared import server_config
from openhands.storage.settings.settings_store import SettingsStore
from openhands.utils.import_utils import get_impl
class UserAuth(ABC):
"""Extensible class encapsulating user Authentication"""
_settings: Settings | None
@abstractmethod
async def get_user_id(self) -> str | None:
"""Get the unique identifier for the current user"""
@abstractmethod
async def get_access_token(self) -> SecretStr | None:
"""Get the access token for the current user"""
@abstractmethod
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
"""Get the provider tokens for the current user."""
@abstractmethod
async def get_user_settings_store(self) -> SettingsStore | None:
"""Get the settings store for the current user."""
async def get_user_settings(self) -> Settings | None:
"""Get the user settings for the current user"""
settings = self._settings
if settings:
return settings
settings_store = await self.get_user_settings_store()
if settings_store is None:
return None
settings = await settings_store.load()
self._settings = settings
return settings
@classmethod
@abstractmethod
async def get_instance(cls, request: Request) -> UserAuth:
"""Get an instance of UserAuth from the request given"""
async def get_user_auth(request: Request) -> UserAuth:
user_auth = getattr(request.state, 'user_auth', None)
if user_auth:
return user_auth
impl_name = server_config.user_auth_class
impl = get_impl(UserAuth, impl_name)
user_auth = await impl.get_instance(request)
request.state.user_auth = user_auth
return user_auth

16
openhands/server/utils.py Normal file
View File

@@ -0,0 +1,16 @@
from fastapi import Request
from openhands.server.shared import ConversationStoreImpl, config
from openhands.server.user_auth import get_user_auth
from openhands.storage.conversation.conversation_store import ConversationStore
async def get_conversation_store(request: Request) -> ConversationStore | None:
conversation_store = getattr(request.state, 'conversation_store', None)
if conversation_store:
return conversation_store
user_auth = await get_user_auth(request)
user_id = await user_auth.get_user_id()
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
request.state.conversation_store = conversation_store
return conversation_store