Compare commits

...

2 Commits

Author SHA1 Message Date
openhands
f5f2846ef5 Move file-related endpoints to api_files.py for better modularity 2024-11-09 16:37:51 +00:00
openhands
004d10b116 Refactor listen.py into smaller modules for better maintainability 2024-11-09 16:33:01 +00:00
5 changed files with 202 additions and 458 deletions

View File

@@ -0,0 +1,49 @@
from fastapi import APIRouter, Request, HTTPException, status
from fastapi.responses import JSONResponse
from openhands.runtime.base import Runtime
from openhands.utils.async_utils import call_sync_from_async
import os
router = APIRouter()
FILES_TO_IGNORE = [
'.git/',
'.DS_Store',
'node_modules/',
'__pycache__/',
]
@router.get('/list-files')
async def list_files(request: Request, path: str | None = None):
"""List files in the specified path.
This function retrieves a list of files from the agent's runtime file store,
excluding certain system and hidden files/directories.
To list files:
```sh
curl http://localhost:3000/api/list-files
```
Args:
request (Request): The incoming request object.
path (str, optional): The path to list files from. Defaults to None.
Returns:
list: A list of file names in the specified path.
Raises:
HTTPException: If there's an error listing the files.
"""
if not request.state.conversation.runtime:
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={'error': 'Runtime not yet initialized'},
)
runtime: Runtime = request.state.conversation.runtime
file_list = await call_sync_from_async(runtime.list_files, path)
if path:
file_list = [os.path.join(path, f) for f in file_list]
return [f for f in file_list if not any(ignored in f for ignored in FILES_TO_IGNORE)]

View File

@@ -0,0 +1,56 @@
from fastapi import APIRouter
from openhands.llm import bedrock
from openhands.core.config import LLMConfig
from openhands.controller.agent import Agent
from openhands.security.options import SecurityAnalyzers
from openhands.core.logger import openhands_logger as logger
import requests
router = APIRouter()
@router.get('/models')
async def get_litellm_models(config) -> list[str]:
litellm_model_list = litellm.model_list + list(litellm.model_cost.keys())
litellm_model_list_without_bedrock = bedrock.remove_error_modelId(
litellm_model_list
)
llm_config: LLMConfig = config.get_llm_config()
bedrock_model_list = []
if (
llm_config.aws_region_name
and llm_config.aws_access_key_id
and llm_config.aws_secret_access_key
):
bedrock_model_list = bedrock.list_foundation_models(
llm_config.aws_region_name,
llm_config.aws_access_key_id,
llm_config.aws_secret_access_key,
)
model_list = litellm_model_list_without_bedrock + bedrock_model_list
for llm_config in config.llms.values():
ollama_base_url = llm_config.ollama_base_url
if llm_config.model.startswith('ollama'):
if not ollama_base_url:
ollama_base_url = llm_config.base_url
if ollama_base_url:
ollama_url = ollama_base_url.strip('/') + '/api/tags'
try:
ollama_models_list = requests.get(ollama_url, timeout=3).json()[
'models'
]
for model in ollama_models_list:
model_list.append('ollama/' + model['name'])
break
except requests.exceptions.RequestException as e:
logger.error(f'Error getting OLLAMA models: {e}', exc_info=True)
return list(sorted(set(model_list)))
@router.get('/agents')
async def get_agents():
agents = sorted(Agent.list_agents())
return agents
@router.get('/security-analyzers')
async def get_security_analyzers():
return sorted(SecurityAnalyzers.keys())

View File

@@ -0,0 +1,16 @@
import os
import warnings
from dotenv import load_dotenv
from openhands.core.config import load_app_config
from openhands.storage import get_file_store
from openhands.server.session import SessionManager
load_dotenv()
config = load_app_config()
file_store = get_file_store(config.file_store, config.file_store_path)
session_manager = SessionManager(config, file_store)
with warnings.catch_warnings():
warnings.simplefilter('ignore')
import litellm

View File

@@ -1,473 +1,26 @@
import asyncio
import os
import re
import tempfile
import time
import uuid
import warnings
from contextlib import asynccontextmanager
import jwt
import requests
from pathspec import PathSpec
from pathspec.patterns import GitWildMatchPattern
from openhands.security.options import SecurityAnalyzers
from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
from openhands.server.github import (
GITHUB_CLIENT_ID,
GITHUB_CLIENT_SECRET,
UserVerifier,
authenticate_github_user,
)
from openhands.storage import get_file_store
from openhands.utils.async_utils import call_sync_from_async
with warnings.catch_warnings():
warnings.simplefilter('ignore')
import litellm
from dotenv import load_dotenv
from fastapi import (
BackgroundTasks,
FastAPI,
HTTPException,
Request,
UploadFile,
WebSocket,
status,
)
from fastapi.responses import FileResponse, JSONResponse
from fastapi import FastAPI
from fastapi.security import HTTPBearer
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from openhands.server.config import config, session_manager
from openhands.server.middleware import LocalhostCORSMiddleware, NoCacheMiddleware, attach_session
from openhands.server.websocket import websocket_endpoint
from openhands.server.api_options import router as api_options_router
from openhands.server.api_files import router as api_files_router
import openhands.agenthub # noqa F401 (we import this to get the agents registered)
from openhands.controller.agent import Agent
from openhands.core.config import LLMConfig, load_app_config
from openhands.core.logger import openhands_logger as logger
from openhands.events.action import (
ChangeAgentStateAction,
FileReadAction,
FileWriteAction,
NullAction,
)
from openhands.events.observation import (
AgentStateChangedObservation,
ErrorObservation,
FileReadObservation,
FileWriteObservation,
NullObservation,
)
from openhands.events.serialization import event_to_dict
from openhands.events.stream import AsyncEventStreamWrapper
from openhands.llm import bedrock
from openhands.runtime.base import Runtime
from openhands.server.auth.auth import get_sid_from_token, sign_token
from openhands.server.middleware import LocalhostCORSMiddleware, NoCacheMiddleware
from openhands.server.session import SessionManager
load_dotenv()
config = load_app_config()
file_store = get_file_store(config.file_store, config.file_store_path)
session_manager = SessionManager(config, file_store)
@asynccontextmanager
async def lifespan(app: FastAPI):
global session_manager
async with session_manager:
yield
app = FastAPI(lifespan=lifespan)
app = FastAPI()
app.add_middleware(
LocalhostCORSMiddleware,
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
app.add_middleware(NoCacheMiddleware)
app.middleware('http')(attach_session)
app.websocket('/ws')(websocket_endpoint)
app.include_router(api_options_router, prefix='/api/options')
app.include_router(api_files_router, prefix='/api')
security_scheme = HTTPBearer()
def load_file_upload_config() -> tuple[int, bool, list[str]]:
"""Load file upload configuration from the config object.
This function retrieves the file upload settings from the global config object.
It handles the following settings:
- Maximum file size for uploads
- Whether to restrict file types
- List of allowed file extensions
It also performs sanity checks on the values to ensure they are valid and safe.
Returns:
tuple: A tuple containing:
- max_file_size_mb (int): Maximum file size in MB. 0 means no limit.
- restrict_file_types (bool): Whether file type restrictions are enabled.
- allowed_extensions (set): Set of allowed file extensions.
"""
# Retrieve values from config
max_file_size_mb = config.file_uploads_max_file_size_mb
restrict_file_types = config.file_uploads_restrict_file_types
allowed_extensions = config.file_uploads_allowed_extensions
# Sanity check for max_file_size_mb
if not isinstance(max_file_size_mb, int) or max_file_size_mb < 0:
logger.warning(
f'Invalid max_file_size_mb: {max_file_size_mb}. Setting to 0 (no limit).'
)
max_file_size_mb = 0
# Sanity check for allowed_extensions
if not isinstance(allowed_extensions, (list, set)) or not allowed_extensions:
logger.warning(
f'Invalid allowed_extensions: {allowed_extensions}. Setting to [".*"].'
)
allowed_extensions = ['.*']
else:
# Ensure all extensions start with a dot and are lowercase
allowed_extensions = [
ext.lower() if ext.startswith('.') else f'.{ext.lower()}'
for ext in allowed_extensions
]
# If restrictions are disabled, allow all
if not restrict_file_types:
allowed_extensions = ['.*']
logger.debug(
f'File upload config: max_size={max_file_size_mb}MB, '
f'restrict_types={restrict_file_types}, '
f'allowed_extensions={allowed_extensions}'
)
return max_file_size_mb, restrict_file_types, allowed_extensions
# Load configuration
MAX_FILE_SIZE_MB, RESTRICT_FILE_TYPES, ALLOWED_EXTENSIONS = load_file_upload_config()
def is_extension_allowed(filename):
"""Check if the file extension is allowed based on the current configuration.
This function supports wildcards and files without extensions.
The check is case-insensitive for extensions.
Args:
filename (str): The name of the file to check.
Returns:
bool: True if the file extension is allowed, False otherwise.
"""
if not RESTRICT_FILE_TYPES:
return True
file_ext = os.path.splitext(filename)[1].lower() # Convert to lowercase
return (
'.*' in ALLOWED_EXTENSIONS
or file_ext in (ext.lower() for ext in ALLOWED_EXTENSIONS)
or (file_ext == '' and '.' in ALLOWED_EXTENSIONS)
)
@app.middleware('http')
async def attach_session(request: Request, call_next):
"""Middleware to attach session information to the request.
This middleware checks for the Authorization header, validates the token,
and attaches the corresponding session to the request state.
Args:
request (Request): The incoming request object.
call_next (Callable): The next middleware or route handler in the chain.
Returns:
Response: The response from the next middleware or route handler.
"""
non_authed_paths = [
'/api/options/',
'/api/github/callback',
'/api/authenticate',
]
if any(
request.url.path.startswith(path) for path in non_authed_paths
) or not request.url.path.startswith('/api/'):
response = await call_next(request)
return response
# Bypass authentication for OPTIONS requests (preflight)
if request.method == 'OPTIONS':
response = await call_next(request)
return response
user_verifier = UserVerifier()
if user_verifier.is_active():
signed_token = request.cookies.get('github_auth')
if not signed_token:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={'error': 'Not authenticated'},
)
try:
jwt.decode(signed_token, config.jwt_secret, algorithms=['HS256'])
except Exception as e:
logger.warning(f'Invalid token: {e}')
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={'error': 'Invalid token'},
)
if not request.headers.get('Authorization'):
logger.warning('Missing Authorization header')
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={'error': 'Missing Authorization header'},
)
auth_token = request.headers.get('Authorization')
if 'Bearer' in auth_token:
auth_token = auth_token.split('Bearer')[1].strip()
request.state.sid = get_sid_from_token(auth_token, config.jwt_secret)
if request.state.sid == '':
logger.warning('Invalid token')
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={'error': 'Invalid token'},
)
request.state.conversation = await session_manager.attach_to_conversation(
request.state.sid
)
if request.state.conversation is None:
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={'error': 'Session not found'},
)
try:
response = await call_next(request)
finally:
await session_manager.detach_from_conversation(request.state.conversation)
return response
@app.websocket('/ws')
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for receiving events from the client (i.e., the browser).
Once connected, the client can send various actions:
- Initialize the agent:
session management, and event streaming.
```json
{"action": "initialize", "args": {"LLM_MODEL": "ollama/llama3", "AGENT": "CodeActAgent", "LANGUAGE": "en", "LLM_API_KEY": "ollama"}}
Args:
```
websocket (WebSocket): The WebSocket connection object.
- Start a new development task:
```json
{"action": "start", "args": {"task": "write a bash script that prints hello"}}
```
- Send a message:
```json
{"action": "message", "args": {"content": "Hello, how are you?", "images_urls": ["base64_url1", "base64_url2"]}}
```
- Write contents to a file:
```json
{"action": "write", "args": {"path": "./greetings.txt", "content": "Hello, OpenHands?"}}
```
- Read the contents of a file:
```json
{"action": "read", "args": {"path": "./greetings.txt"}}
```
- Run a command:
```json
{"action": "run", "args": {"command": "ls -l", "thought": "", "confirmation_state": "confirmed"}}
```
- Run an IPython command:
```json
{"action": "run_ipython", "args": {"command": "print('Hello, IPython!')"}}
```
- Open a web page:
```json
{"action": "browse", "args": {"url": "https://arxiv.org/html/2402.01030v2"}}
```
- Add a task to the root_task:
```json
{"action": "add_task", "args": {"task": "Implement feature X"}}
```
- Update a task in the root_task:
```json
{"action": "modify_task", "args": {"id": "0", "state": "in_progress", "thought": ""}}
```
- Change the agent's state:
```json
{"action": "change_agent_state", "args": {"state": "paused"}}
```
- Finish the task:
```json
{"action": "finish", "args": {}}
```
"""
# Get protocols from Sec-WebSocket-Protocol header
protocols = websocket.headers.get('sec-websocket-protocol', '').split(', ')
# The first protocol should be our real protocol (e.g. 'openhands')
# The second protocol should contain our auth token
if len(protocols) < 3:
logger.error('Expected 3 websocket protocols, got %d', len(protocols))
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
real_protocol = protocols[0]
jwt_token = protocols[1] if protocols[1] != 'NO_JWT' else ''
github_token = protocols[2] if protocols[2] != 'NO_GITHUB' else ''
if not await authenticate_github_user(github_token):
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
await asyncio.wait_for(websocket.accept(subprotocol=real_protocol), 10)
if jwt_token:
sid = get_sid_from_token(jwt_token, config.jwt_secret)
if sid == '':
await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
await websocket.close()
return
else:
sid = str(uuid.uuid4())
jwt_token = sign_token({'sid': sid}, config.jwt_secret)
logger.info(f'New session: {sid}')
session = session_manager.add_or_restart_session(sid, websocket)
await websocket.send_json({'token': jwt_token, 'status': 'ok'})
latest_event_id = -1
if websocket.query_params.get('latest_event_id'):
latest_event_id = int(websocket.query_params.get('latest_event_id'))
async_stream = AsyncEventStreamWrapper(
session.agent_session.event_stream, latest_event_id + 1
)
async for event in async_stream:
if isinstance(
event,
(
NullAction,
NullObservation,
ChangeAgentStateAction,
AgentStateChangedObservation,
),
):
continue
await websocket.send_json(event_to_dict(event))
await session.loop_recv()
@app.get('/api/options/models')
async def get_litellm_models() -> list[str]:
"""
Get all models supported by LiteLLM.
This function combines models from litellm and Bedrock, removing any
error-prone Bedrock models.
To get the models:
```sh
curl http://localhost:3000/api/litellm-models
```
Returns:
list: A sorted list of unique model names.
"""
litellm_model_list = litellm.model_list + list(litellm.model_cost.keys())
litellm_model_list_without_bedrock = bedrock.remove_error_modelId(
litellm_model_list
)
# TODO: for bedrock, this is using the default config
llm_config: LLMConfig = config.get_llm_config()
bedrock_model_list = []
if (
llm_config.aws_region_name
and llm_config.aws_access_key_id
and llm_config.aws_secret_access_key
):
bedrock_model_list = bedrock.list_foundation_models(
llm_config.aws_region_name,
llm_config.aws_access_key_id,
llm_config.aws_secret_access_key,
)
model_list = litellm_model_list_without_bedrock + bedrock_model_list
for llm_config in config.llms.values():
ollama_base_url = llm_config.ollama_base_url
if llm_config.model.startswith('ollama'):
if not ollama_base_url:
ollama_base_url = llm_config.base_url
if ollama_base_url:
ollama_url = ollama_base_url.strip('/') + '/api/tags'
try:
ollama_models_list = requests.get(ollama_url, timeout=3).json()[
'models'
]
for model in ollama_models_list:
model_list.append('ollama/' + model['name'])
break
except requests.exceptions.RequestException as e:
logger.error(f'Error getting OLLAMA models: {e}', exc_info=True)
return list(sorted(set(model_list)))
@app.get('/api/options/agents')
async def get_agents():
"""Get all agents supported by LiteLLM.
To get the agents:
```sh
curl http://localhost:3000/api/agents
```
Returns:
list: A sorted list of agent names.
"""
agents = sorted(Agent.list_agents())
return agents
@app.get('/api/options/security-analyzers')
async def get_security_analyzers():
"""Get all supported security analyzers.
To get the security analyzers:
```sh
curl http://localhost:3000/api/security-analyzers
```
Returns:
list: A sorted list of security analyzer names.
"""
return sorted(SecurityAnalyzers.keys())
FILES_TO_IGNORE = [
'.git/',
'.DS_Store',
'node_modules/',
'__pycache__/',
]
@app.get('/api/list-files')
async def list_files(request: Request, path: str | None = None):
"""List files in the specified path.

View File

@@ -0,0 +1,70 @@
import asyncio
import uuid
from fastapi import WebSocket, status
from openhands.server.auth.auth import get_sid_from_token, sign_token
from openhands.server.github import authenticate_github_user
from openhands.events.stream import AsyncEventStreamWrapper
from openhands.events.serialization import event_to_dict
from openhands.events.action import (
ChangeAgentStateAction,
NullAction,
)
from openhands.events.observation import (
AgentStateChangedObservation,
NullObservation,
)
from openhands.core.logger import openhands_logger as logger
async def websocket_endpoint(websocket: WebSocket, session_manager, config):
protocols = websocket.headers.get('sec-websocket-protocol', '').split(', ')
if len(protocols) < 3:
logger.error('Expected 3 websocket protocols, got %d', len(protocols))
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
real_protocol = protocols[0]
jwt_token = protocols[1] if protocols[1] != 'NO_JWT' else ''
github_token = protocols[2] if protocols[2] != 'NO_GITHUB' else ''
if not await authenticate_github_user(github_token):
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
await asyncio.wait_for(websocket.accept(subprotocol=real_protocol), 10)
if jwt_token:
sid = get_sid_from_token(jwt_token, config.jwt_secret)
if sid == '':
await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
await websocket.close()
return
else:
sid = str(uuid.uuid4())
jwt_token = sign_token({'sid': sid}, config.jwt_secret)
logger.info(f'New session: {sid}')
session = session_manager.add_or_restart_session(sid, websocket)
await websocket.send_json({'token': jwt_token, 'status': 'ok'})
latest_event_id = -1
if websocket.query_params.get('latest_event_id'):
latest_event_id = int(websocket.query_params.get('latest_event_id'))
async_stream = AsyncEventStreamWrapper(
session.agent_session.event_stream, latest_event_id + 1
)
async for event in async_stream:
if isinstance(
event,
(
NullAction,
NullObservation,
ChangeAgentStateAction,
AgentStateChangedObservation,
),
):
continue
await websocket.send_json(event_to_dict(event))
await session.loop_recv()