mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-10 07:18:10 -05:00
Refactor Authentication (#8040)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: rohitvinodmalhotra@gmail.com <rohitvinodmalhotra@gmail.com>
This commit is contained in:
@@ -1,34 +0,0 @@
|
||||
from fastapi import Request
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
|
||||
|
||||
|
||||
def get_provider_tokens(request: Request) -> PROVIDER_TOKEN_TYPE | None:
|
||||
"""Get GitHub token from request state. For backward compatibility."""
|
||||
return getattr(request.state, 'provider_tokens', None)
|
||||
|
||||
|
||||
def get_access_token(request: Request) -> SecretStr | None:
|
||||
return getattr(request.state, 'access_token', None)
|
||||
|
||||
|
||||
def get_user_id(request: Request) -> str | None:
|
||||
return getattr(request.state, 'user_id', None)
|
||||
|
||||
|
||||
def get_github_token(request: Request) -> SecretStr | None:
|
||||
provider_tokens = get_provider_tokens(request)
|
||||
|
||||
if provider_tokens and ProviderType.GITHUB in provider_tokens:
|
||||
return provider_tokens[ProviderType.GITHUB].token
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_github_user_id(request: Request) -> str | None:
|
||||
provider_tokens = get_provider_tokens(request)
|
||||
if provider_tokens and ProviderType.GITHUB in provider_tokens:
|
||||
return provider_tokens[ProviderType.GITHUB].user_id
|
||||
|
||||
return None
|
||||
@@ -20,6 +20,7 @@ class ServerConfig(ServerConfigInterface):
|
||||
)
|
||||
conversation_manager_class: str = 'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager'
|
||||
monitoring_listener_class: str = 'openhands.server.monitoring.MonitoringListener'
|
||||
user_auth_class: str = 'openhands.server.user_auth.default_user_auth.DefaultUserAuth'
|
||||
|
||||
def verify_config(self):
|
||||
if self.config_cls:
|
||||
|
||||
@@ -9,7 +9,6 @@ from openhands.server.middleware import (
|
||||
CacheControlMiddleware,
|
||||
InMemoryRateLimiter,
|
||||
LocalhostCORSMiddleware,
|
||||
ProviderTokenMiddleware,
|
||||
RateLimitMiddleware,
|
||||
)
|
||||
from openhands.server.static import SPAStaticFiles
|
||||
@@ -32,6 +31,5 @@ base_app.add_middleware(
|
||||
rate_limiter=InMemoryRateLimiter(requests=10, seconds=1),
|
||||
)
|
||||
base_app.middleware('http')(AttachConversationMiddleware(base_app))
|
||||
base_app.middleware('http')(ProviderTokenMiddleware(base_app))
|
||||
|
||||
app = socketio.ASGIApp(sio, other_asgi_app=base_app)
|
||||
|
||||
@@ -12,8 +12,8 @@ from starlette.requests import Request as StarletteRequest
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from openhands.server import shared
|
||||
from openhands.server.auth import get_user_id
|
||||
from openhands.server.types import SessionMiddlewareInterface
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
|
||||
class LocalhostCORSMiddleware(CORSMiddleware):
|
||||
@@ -147,9 +147,10 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
|
||||
"""
|
||||
Attach the user's session based on the provided authentication token.
|
||||
"""
|
||||
user_id = await get_user_id(request)
|
||||
request.state.conversation = (
|
||||
await shared.conversation_manager.attach_to_conversation(
|
||||
request.state.sid, get_user_id(request)
|
||||
request.state.sid, user_id
|
||||
)
|
||||
)
|
||||
if not request.state.conversation:
|
||||
@@ -183,27 +184,3 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
|
||||
await self._detach_session(request)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class ProviderTokenMiddleware(SessionMiddlewareInterface):
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable):
|
||||
settings_store = await shared.SettingsStoreImpl.get_instance(
|
||||
shared.config, get_user_id(request)
|
||||
)
|
||||
settings = await settings_store.load()
|
||||
|
||||
# TODO: To avoid checks like this we should re-add the abilty to have completely different middleware in SAAS as in OSS
|
||||
if getattr(request.state, 'provider_tokens', None) is None:
|
||||
if (
|
||||
settings
|
||||
and settings.secrets_store
|
||||
and settings.secrets_store.provider_tokens
|
||||
):
|
||||
request.state.provider_tokens = settings.secrets_store.provider_tokens
|
||||
else:
|
||||
request.state.provider_tokens = None
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Request,
|
||||
status,
|
||||
@@ -21,7 +22,6 @@ from openhands.events.observation import (
|
||||
FileReadObservation,
|
||||
)
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.server.auth import get_github_user_id, get_user_id
|
||||
from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.file_config import (
|
||||
FILES_TO_IGNORE,
|
||||
@@ -31,6 +31,8 @@ from openhands.server.shared import (
|
||||
config,
|
||||
conversation_manager,
|
||||
)
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.server.utils import get_conversation_store
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
@@ -187,10 +189,15 @@ def zip_current_workspace(request: Request):
|
||||
|
||||
|
||||
@app.get('/git/changes')
|
||||
async def git_changes(request: Request, conversation_id: str):
|
||||
async def git_changes(
|
||||
request: Request,
|
||||
conversation_id: str,
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request), get_github_user_id(request)
|
||||
config,
|
||||
user_id,
|
||||
)
|
||||
|
||||
cwd = await get_cwd(
|
||||
@@ -223,11 +230,13 @@ async def git_changes(request: Request, conversation_id: str):
|
||||
|
||||
|
||||
@app.get('/git/diff')
|
||||
async def git_diff(request: Request, path: str, conversation_id: str):
|
||||
async def git_diff(
|
||||
request: Request,
|
||||
path: str,
|
||||
conversation_id: str,
|
||||
conversation_store = Depends(get_conversation_store),
|
||||
):
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request), get_github_user_id(request)
|
||||
)
|
||||
|
||||
cwd = await get_cwd(
|
||||
conversation_store,
|
||||
|
||||
@@ -15,8 +15,12 @@ from openhands.integrations.service_types import (
|
||||
UnknownException,
|
||||
User,
|
||||
)
|
||||
from openhands.server.auth import get_access_token, get_provider_tokens, get_user_id
|
||||
from openhands.server.shared import server_config
|
||||
from openhands.server.user_auth import (
|
||||
get_access_token,
|
||||
get_provider_tokens,
|
||||
get_user_id,
|
||||
)
|
||||
|
||||
app = APIRouter(prefix='/api/user')
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Body, Request, status
|
||||
from fastapi import APIRouter, Body, Depends, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -15,11 +15,6 @@ from openhands.integrations.provider import (
|
||||
)
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.server.auth import (
|
||||
get_github_user_id,
|
||||
get_provider_tokens,
|
||||
get_user_id,
|
||||
)
|
||||
from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.data_models.conversation_info_result_set import (
|
||||
ConversationInfoResultSet,
|
||||
@@ -33,6 +28,12 @@ from openhands.server.shared import (
|
||||
file_store,
|
||||
)
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.server.user_auth import (
|
||||
get_provider_tokens,
|
||||
get_user_id,
|
||||
)
|
||||
from openhands.server.utils import get_conversation_store
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
ConversationTrigger,
|
||||
@@ -95,7 +96,7 @@ async def _create_new_conversation(
|
||||
session_init_args['selected_branch'] = selected_branch
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
logger.info('Loading conversation store')
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id, None)
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
logger.info('Conversation store loaded')
|
||||
|
||||
conversation_id = uuid.uuid4().hex
|
||||
@@ -152,14 +153,17 @@ async def _create_new_conversation(
|
||||
|
||||
|
||||
@app.post('/conversations')
|
||||
async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
async def new_conversation(
|
||||
data: InitSessionRequest,
|
||||
user_id: str = Depends(get_user_id),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE = Depends(get_provider_tokens),
|
||||
):
|
||||
"""Initialize a new session or join an existing one.
|
||||
|
||||
After successful initialization, the client should connect to the WebSocket
|
||||
using the returned conversation ID.
|
||||
"""
|
||||
logger.info('Initializing new conversation')
|
||||
provider_tokens = get_provider_tokens(request)
|
||||
selected_repository = data.selected_repository
|
||||
selected_branch = data.selected_branch
|
||||
initial_user_msg = data.initial_user_msg
|
||||
@@ -169,7 +173,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
try:
|
||||
# Create conversation with initial message
|
||||
conversation_id = await _create_new_conversation(
|
||||
get_user_id(request),
|
||||
user_id,
|
||||
provider_tokens,
|
||||
selected_repository,
|
||||
selected_branch,
|
||||
@@ -204,13 +208,11 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
|
||||
@app.get('/conversations')
|
||||
async def search_conversations(
|
||||
request: Request,
|
||||
page_id: str | None = None,
|
||||
limit: int = 20,
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
conversation_store: ConversationStore = Depends(get_conversation_store),
|
||||
) -> ConversationInfoResultSet:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request), get_github_user_id(request)
|
||||
)
|
||||
conversation_metadata_result_set = await conversation_store.search(page_id, limit)
|
||||
|
||||
# Filter out conversations older than max_age
|
||||
@@ -228,7 +230,7 @@ async def search_conversations(
|
||||
conversation.conversation_id for conversation in filtered_results
|
||||
)
|
||||
running_conversations = await conversation_manager.get_running_agent_loops(
|
||||
get_user_id(request), set(conversation_ids)
|
||||
user_id, set(conversation_ids)
|
||||
)
|
||||
result = ConversationInfoResultSet(
|
||||
results=await wait_all(
|
||||
@@ -245,11 +247,9 @@ async def search_conversations(
|
||||
|
||||
@app.get('/conversations/{conversation_id}')
|
||||
async def get_conversation(
|
||||
conversation_id: str, request: Request
|
||||
conversation_id: str,
|
||||
conversation_store: ConversationStore = Depends(get_conversation_store),
|
||||
) -> ConversationInfo | None:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request), get_github_user_id(request)
|
||||
)
|
||||
try:
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
is_running = await conversation_manager.is_agent_loop_running(conversation_id)
|
||||
@@ -340,11 +340,12 @@ async def auto_generate_title(conversation_id: str, user_id: str | None) -> str:
|
||||
|
||||
@app.patch('/conversations/{conversation_id}')
|
||||
async def update_conversation(
|
||||
request: Request, conversation_id: str, title: str = Body(embed=True)
|
||||
conversation_id: str,
|
||||
title: str = Body(embed=True),
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
) -> bool:
|
||||
user_id = get_user_id(request)
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, user_id, get_github_user_id(request)
|
||||
config, user_id
|
||||
)
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
if not metadata:
|
||||
@@ -366,10 +367,10 @@ async def update_conversation(
|
||||
@app.delete('/conversations/{conversation_id}')
|
||||
async def delete_conversation(
|
||||
conversation_id: str,
|
||||
request: Request,
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
) -> bool:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request), get_github_user_id(request)
|
||||
config, user_id
|
||||
)
|
||||
try:
|
||||
await conversation_store.get_metadata(conversation_id)
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
from fastapi import APIRouter, Request, status
|
||||
from fastapi import APIRouter, Depends, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore
|
||||
from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
ProviderToken,
|
||||
ProviderType,
|
||||
SecretStore,
|
||||
)
|
||||
from openhands.integrations.utils import validate_provider_token
|
||||
from openhands.server.auth import get_provider_tokens, get_user_id
|
||||
from openhands.server.settings import (
|
||||
GETSettingsCustomSecrets,
|
||||
GETSettingsModel,
|
||||
@@ -15,16 +19,24 @@ from openhands.server.settings import (
|
||||
)
|
||||
from openhands.server.shared import SettingsStoreImpl, config, server_config
|
||||
from openhands.server.types import AppMode
|
||||
from openhands.server.user_auth import (
|
||||
get_provider_tokens,
|
||||
get_user_id,
|
||||
get_user_settings,
|
||||
get_user_settings_store,
|
||||
)
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
|
||||
app = APIRouter(prefix='/api')
|
||||
|
||||
|
||||
@app.get('/settings', response_model=GETSettingsModel)
|
||||
async def load_settings(request: Request) -> GETSettingsModel | JSONResponse:
|
||||
async def load_settings(
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
settings: Settings | None = Depends(get_user_settings),
|
||||
) -> GETSettingsModel | JSONResponse:
|
||||
try:
|
||||
user_id = get_user_id(request)
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
settings: Settings = await settings_store.load()
|
||||
if not settings:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
@@ -36,7 +48,6 @@ async def load_settings(request: Request) -> GETSettingsModel | JSONResponse:
|
||||
if bool(user_id):
|
||||
provider_tokens_set[ProviderType.GITHUB.value] = True
|
||||
|
||||
provider_tokens = get_provider_tokens(request)
|
||||
if provider_tokens:
|
||||
all_provider_types = [provider.value for provider in ProviderType]
|
||||
provider_tokens_types = [provider.value for provider in provider_tokens]
|
||||
@@ -63,12 +74,9 @@ async def load_settings(request: Request) -> GETSettingsModel | JSONResponse:
|
||||
|
||||
@app.get('/secrets', response_model=GETSettingsCustomSecrets)
|
||||
async def load_custom_secrets_names(
|
||||
request: Request,
|
||||
settings: Settings | None = Depends(get_user_settings),
|
||||
) -> GETSettingsCustomSecrets | JSONResponse:
|
||||
try:
|
||||
user_id = get_user_id(request)
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
settings = await settings_store.load()
|
||||
if not settings:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
@@ -93,13 +101,11 @@ async def load_custom_secrets_names(
|
||||
|
||||
@app.post('/secrets', response_model=dict[str, str])
|
||||
async def add_custom_secret(
|
||||
request: Request, incoming_secrets: POSTSettingsCustomSecrets
|
||||
incoming_secrets: POSTSettingsCustomSecrets,
|
||||
settings_store: SettingsStore = Depends(get_user_settings_store),
|
||||
) -> JSONResponse:
|
||||
try:
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
existing_settings: Settings = await settings_store.load()
|
||||
existing_settings = await settings_store.load()
|
||||
if existing_settings:
|
||||
for (
|
||||
secret_name,
|
||||
@@ -121,7 +127,6 @@ async def add_custom_secret(
|
||||
update={'secrets_store': updated_secret_store}
|
||||
)
|
||||
|
||||
updated_settings = convert_to_settings(updated_settings)
|
||||
await settings_store.store(updated_settings)
|
||||
|
||||
return JSONResponse(
|
||||
@@ -137,11 +142,11 @@ async def add_custom_secret(
|
||||
|
||||
|
||||
@app.delete('/secrets/{secret_id}')
|
||||
async def delete_custom_secret(request: Request, secret_id: str) -> JSONResponse:
|
||||
async def delete_custom_secret(
|
||||
secret_id: str,
|
||||
settings_store: SettingsStore = Depends(get_user_settings_store),
|
||||
) -> JSONResponse:
|
||||
try:
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
existing_settings: Settings | None = await settings_store.load()
|
||||
custom_secrets = {}
|
||||
if existing_settings:
|
||||
@@ -162,7 +167,6 @@ async def delete_custom_secret(request: Request, secret_id: str) -> JSONResponse
|
||||
update={'secrets_store': updated_secret_store}
|
||||
)
|
||||
|
||||
updated_settings = convert_to_settings(updated_settings)
|
||||
await settings_store.store(updated_settings)
|
||||
|
||||
return JSONResponse(
|
||||
@@ -178,12 +182,10 @@ async def delete_custom_secret(request: Request, secret_id: str) -> JSONResponse
|
||||
|
||||
|
||||
@app.post('/unset-settings-tokens', response_model=dict[str, str])
|
||||
async def unset_settings_tokens(request: Request) -> JSONResponse:
|
||||
async def unset_settings_tokens(
|
||||
settings_store: SettingsStore = Depends(get_user_settings_store),
|
||||
) -> JSONResponse:
|
||||
try:
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
|
||||
existing_settings = await settings_store.load()
|
||||
if existing_settings:
|
||||
settings = existing_settings.model_copy(
|
||||
@@ -205,7 +207,7 @@ async def unset_settings_tokens(request: Request) -> JSONResponse:
|
||||
|
||||
|
||||
@app.post('/reset-settings', response_model=dict[str, str])
|
||||
async def reset_settings(request: Request) -> JSONResponse:
|
||||
async def reset_settings() -> JSONResponse:
|
||||
"""
|
||||
Resets user settings. (Deprecated)
|
||||
"""
|
||||
@@ -218,7 +220,7 @@ async def reset_settings(request: Request) -> JSONResponse:
|
||||
)
|
||||
|
||||
|
||||
async def check_provider_tokens(request: Request, settings: POSTSettingsModel) -> str:
|
||||
async def check_provider_tokens(settings: POSTSettingsModel) -> str:
|
||||
if settings.provider_tokens:
|
||||
# Remove extraneous token types
|
||||
provider_types = [provider.value for provider in ProviderType]
|
||||
@@ -238,8 +240,9 @@ async def check_provider_tokens(request: Request, settings: POSTSettingsModel) -
|
||||
return ''
|
||||
|
||||
|
||||
async def store_provider_tokens(request: Request, settings: POSTSettingsModel):
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request))
|
||||
async def store_provider_tokens(
|
||||
settings: POSTSettingsModel, settings_store: SettingsStore
|
||||
):
|
||||
existing_settings = await settings_store.load()
|
||||
if existing_settings:
|
||||
if settings.provider_tokens:
|
||||
@@ -273,9 +276,8 @@ async def store_provider_tokens(request: Request, settings: POSTSettingsModel):
|
||||
|
||||
|
||||
async def store_llm_settings(
|
||||
request: Request, settings: POSTSettingsModel
|
||||
settings: POSTSettingsModel, settings_store: SettingsStore
|
||||
) -> POSTSettingsModel:
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request))
|
||||
existing_settings = await settings_store.load()
|
||||
|
||||
# Convert to Settings model and merge with existing settings
|
||||
@@ -293,11 +295,11 @@ async def store_llm_settings(
|
||||
|
||||
@app.post('/settings', response_model=dict[str, str])
|
||||
async def store_settings(
|
||||
request: Request,
|
||||
settings: POSTSettingsModel,
|
||||
settings_store: SettingsStore = Depends(get_user_settings_store),
|
||||
) -> JSONResponse:
|
||||
# Check provider tokens are valid
|
||||
provider_err_msg = await check_provider_tokens(request, settings)
|
||||
provider_err_msg = await check_provider_tokens(settings)
|
||||
if provider_err_msg:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -305,14 +307,11 @@ async def store_settings(
|
||||
)
|
||||
|
||||
try:
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
existing_settings = await settings_store.load()
|
||||
|
||||
# Convert to Settings model and merge with existing settings
|
||||
if existing_settings:
|
||||
settings = await store_llm_settings(request, settings)
|
||||
settings = await store_llm_settings(settings, settings_store)
|
||||
|
||||
# Keep existing analytics consent if not provided
|
||||
if settings.user_consents_to_analytics is None:
|
||||
@@ -320,7 +319,7 @@ async def store_settings(
|
||||
existing_settings.user_consents_to_analytics
|
||||
)
|
||||
|
||||
settings = await store_provider_tokens(request, settings)
|
||||
settings = await store_provider_tokens(settings, settings_store)
|
||||
|
||||
# Update sandbox config with new settings
|
||||
if settings.remote_runtime_resource_factor is not None:
|
||||
|
||||
@@ -94,7 +94,10 @@ class Settings(BaseModel):
|
||||
return {
|
||||
'provider_tokens': secrets.provider_tokens_serializer(
|
||||
secrets.provider_tokens, info
|
||||
)
|
||||
),
|
||||
'custom_secrets': secrets.custom_secrets_serializer(
|
||||
secrets.custom_secrets, info
|
||||
),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
||||
48
openhands/server/user_auth/__init__.py
Normal file
48
openhands/server/user_auth/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from fastapi import Request
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.server.user_auth.user_auth import get_user_auth
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
|
||||
|
||||
async def get_provider_tokens(request: Request) -> PROVIDER_TOKEN_TYPE | None:
|
||||
user_auth = await get_user_auth(request)
|
||||
provider_tokens = await user_auth.get_provider_tokens()
|
||||
return provider_tokens
|
||||
|
||||
|
||||
async def get_access_token(request: Request) -> SecretStr | None:
|
||||
user_auth = await get_user_auth(request)
|
||||
access_token = await user_auth.get_access_token()
|
||||
return access_token
|
||||
|
||||
|
||||
async def get_user_id(request: Request) -> str | None:
|
||||
user_auth = await get_user_auth(request)
|
||||
user_id = await user_auth.get_user_id()
|
||||
return user_id
|
||||
|
||||
|
||||
async def get_github_user_id(request: Request) -> str | None:
|
||||
provider_tokens = await get_provider_tokens(request)
|
||||
if not provider_tokens:
|
||||
return None
|
||||
github_provider = provider_tokens.get(ProviderType.GITHUB)
|
||||
if github_provider:
|
||||
return github_provider.user_id
|
||||
return None
|
||||
|
||||
|
||||
async def get_user_settings(request: Request) -> Settings | None:
|
||||
user_auth = await get_user_auth(request)
|
||||
user_settings = await user_auth.get_user_settings()
|
||||
return user_settings
|
||||
|
||||
|
||||
async def get_user_settings_store(request: Request) -> SettingsStore | None:
|
||||
user_auth = await get_user_auth(request)
|
||||
user_settings_store = await user_auth.get_user_settings_store()
|
||||
return user_settings_store
|
||||
57
openhands/server/user_auth/default_user_auth.py
Normal file
57
openhands/server/user_auth/default_user_auth.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.server import shared
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
|
||||
|
||||
@dataclass
|
||||
class DefaultUserAuth(UserAuth):
|
||||
"""Default user authentication mechanism"""
|
||||
|
||||
_settings: Settings | None = None
|
||||
_settings_store: SettingsStore | None = None
|
||||
|
||||
async def get_user_id(self) -> str | None:
|
||||
"""The default implementation does not support multi tenancy, so user_id is always None"""
|
||||
return None
|
||||
|
||||
async def get_access_token(self) -> SecretStr | None:
|
||||
"""The default implementation does not support multi tenancy, so access_token is always None"""
|
||||
return None
|
||||
|
||||
async def get_user_settings_store(self):
|
||||
settings_store = self._settings_store
|
||||
if settings_store:
|
||||
return settings_store
|
||||
user_id = await self.get_user_id()
|
||||
settings_store = await shared.SettingsStoreImpl.get_instance(
|
||||
shared.config, user_id
|
||||
)
|
||||
self._settings_store = settings_store
|
||||
return settings_store
|
||||
|
||||
async def get_user_settings(self) -> Settings | None:
|
||||
settings = self._settings
|
||||
if settings:
|
||||
return settings
|
||||
settings_store = await self.get_user_settings_store()
|
||||
settings = await settings_store.load()
|
||||
self._settings = settings
|
||||
return settings
|
||||
|
||||
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
|
||||
settings = await self.get_user_settings()
|
||||
secrets_store = getattr(settings, 'secrets_store', None)
|
||||
provider_tokens = getattr(secrets_store, 'provider_tokens', None)
|
||||
return provider_tokens
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls, request: Request) -> UserAuth:
|
||||
user_auth = DefaultUserAuth()
|
||||
return user_auth
|
||||
63
openhands/server/user_auth/user_auth.py
Normal file
63
openhands/server/user_auth/user_auth.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.server.shared import server_config
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
class UserAuth(ABC):
|
||||
"""Extensible class encapsulating user Authentication"""
|
||||
|
||||
_settings: Settings | None
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_id(self) -> str | None:
|
||||
"""Get the unique identifier for the current user"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_access_token(self) -> SecretStr | None:
|
||||
"""Get the access token for the current user"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
|
||||
"""Get the provider tokens for the current user."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_settings_store(self) -> SettingsStore | None:
|
||||
"""Get the settings store for the current user."""
|
||||
|
||||
async def get_user_settings(self) -> Settings | None:
|
||||
"""Get the user settings for the current user"""
|
||||
settings = self._settings
|
||||
if settings:
|
||||
return settings
|
||||
settings_store = await self.get_user_settings_store()
|
||||
if settings_store is None:
|
||||
return None
|
||||
settings = await settings_store.load()
|
||||
self._settings = settings
|
||||
return settings
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def get_instance(cls, request: Request) -> UserAuth:
|
||||
"""Get an instance of UserAuth from the request given"""
|
||||
|
||||
|
||||
async def get_user_auth(request: Request) -> UserAuth:
|
||||
user_auth = getattr(request.state, 'user_auth', None)
|
||||
if user_auth:
|
||||
return user_auth
|
||||
impl_name = server_config.user_auth_class
|
||||
impl = get_impl(UserAuth, impl_name)
|
||||
user_auth = await impl.get_instance(request)
|
||||
request.state.user_auth = user_auth
|
||||
return user_auth
|
||||
16
openhands/server/utils.py
Normal file
16
openhands/server/utils.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from fastapi import Request
|
||||
|
||||
from openhands.server.shared import ConversationStoreImpl, config
|
||||
from openhands.server.user_auth import get_user_auth
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
|
||||
|
||||
async def get_conversation_store(request: Request) -> ConversationStore | None:
|
||||
conversation_store = getattr(request.state, 'conversation_store', None)
|
||||
if conversation_store:
|
||||
return conversation_store
|
||||
user_auth = await get_user_auth(request)
|
||||
user_id = await user_auth.get_user_id()
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
request.state.conversation_store = conversation_store
|
||||
return conversation_store
|
||||
Reference in New Issue
Block a user