Compare commits

...

6 Commits

Author SHA1 Message Date
openhands
eed35ba34b Move rate limiter to middleware.py and use in-memory store
- Move rate limiter classes to middleware.py
- Use in-memory store instead of Redis
- Keep same rate limits:
  * Default: 2 req/sec
  * Static files: 10 req/sec
  * WebSocket and authenticate endpoints: 1 req/5s
2024-11-12 21:32:03 +00:00
openhands
9ec47bee73 Move rate limiter to middleware.py and use in-memory store
- Move rate limiter classes to middleware.py
- Use in-memory store instead of Redis
- Keep same rate limits:
  * Default: 2 req/sec
  * Static files: 10 req/sec
  * WebSocket and authenticate endpoints: 1 req/5s
2024-11-12 21:27:12 +00:00
openhands
17eedeeba9 Switch to in-memory rate limiter implementation
- Remove Redis dependency
- Implement custom in-memory rate limiter with thread-safe storage
- Keep same rate limits:
  * Default: 2 req/sec
  * Static files: 10 req/sec
  * WebSocket and authenticate endpoints: 1 req/5s
2024-11-12 21:22:32 +00:00
openhands
253e19c66c Add type ignore for redis import 2024-11-12 21:19:34 +00:00
openhands
f6742c5af7 Apply code formatting changes 2024-11-12 21:19:09 +00:00
openhands
ce9963db01 Add rate limiting to FastAPI server
- Default rate limit: 2 req/sec
- Static files: 10 req/sec
- WebSocket endpoint: 1 req/5s
- Authenticate endpoint: 1 req/5s

Uses fastapi-limiter with Redis backend for rate limiting implementation.
2024-11-12 21:13:19 +00:00
2 changed files with 165 additions and 426 deletions

View File

@@ -1,70 +1,36 @@
import asyncio
import os
import re
import tempfile
import time
import uuid
import warnings
import jwt
import requests
from dotenv import load_dotenv
from fastapi import (
Depends,
FastAPI,
Request,
WebSocket,
status,
)
from fastapi.responses import JSONResponse
from fastapi.security import HTTPBearer
from fastapi.staticfiles import StaticFiles
from pathspec import PathSpec
from pathspec.patterns import GitWildMatchPattern
from openhands.security.options import SecurityAnalyzers
from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
from openhands.server.github import (
GITHUB_CLIENT_ID,
GITHUB_CLIENT_SECRET,
UserVerifier,
authenticate_github_user,
)
from openhands.storage import get_file_store
from openhands.utils.async_utils import call_sync_from_async
with warnings.catch_warnings():
warnings.simplefilter('ignore')
import litellm
from dotenv import load_dotenv
from fastapi import (
BackgroundTasks,
FastAPI,
HTTPException,
Request,
UploadFile,
WebSocket,
status,
from openhands.security.options import SecurityAnalyzers
from openhands.server.github import (
UserVerifier,
authenticate_github_user,
)
from fastapi.responses import FileResponse, JSONResponse
from fastapi.security import HTTPBearer
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import openhands.agenthub # noqa F401 (we import this to get the agents registered)
from openhands.controller.agent import Agent
from openhands.core.config import LLMConfig, load_app_config
from openhands.core.logger import openhands_logger as logger
from openhands.events.action import (
ChangeAgentStateAction,
FileReadAction,
FileWriteAction,
NullAction,
)
from openhands.events.observation import (
AgentStateChangedObservation,
ErrorObservation,
FileReadObservation,
FileWriteObservation,
NullObservation,
)
from openhands.events.serialization import event_to_dict
from openhands.events.stream import AsyncEventStreamWrapper
from openhands.llm import bedrock
from openhands.runtime.base import Runtime
from openhands.server.auth.auth import get_sid_from_token, sign_token
from openhands.server.middleware import LocalhostCORSMiddleware, NoCacheMiddleware
from openhands.server.session import SessionManager
from openhands.server.middleware import RateLimiter
from openhands.storage import get_file_store
from openhands.utils.async_utils import call_sync_from_async
load_dotenv()
@@ -73,7 +39,9 @@ file_store = get_file_store(config.file_store, config.file_store_path)
session_manager = SessionManager(config, file_store)
app = FastAPI()
app = FastAPI(
dependencies=[Depends(lambda: RateLimiter(times=2, seconds=1))]
) # Default 2 req/sec
app.add_middleware(
LocalhostCORSMiddleware,
allow_credentials=True,
@@ -81,7 +49,6 @@ app.add_middleware(
allow_headers=['*'],
)
app.add_middleware(NoCacheMiddleware)
security_scheme = HTTPBearer()
@@ -251,6 +218,7 @@ async def attach_session(request: Request, call_next):
@app.websocket('/ws')
@RateLimiter(times=1, seconds=5) # 1 request per 5 seconds
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for receiving events from the client (i.e., the browser).
Once connected, the client can send various actions:
@@ -494,393 +462,31 @@ async def list_files(request: Request, path: str | None = None):
file_list = [f for f in file_list if f not in FILES_TO_IGNORE]
async def filter_for_gitignore(file_list, base_path):
gitignore_path = os.path.join(base_path, '.gitignore')
try:
read_action = FileReadAction(gitignore_path)
observation = await call_sync_from_async(runtime.run_action, read_action)
spec = PathSpec.from_lines(
GitWildMatchPattern, observation.content.splitlines()
)
except Exception as e:
logger.warning(e)
return file_list
file_list = [entry for entry in file_list if not spec.match_file(entry)]
return file_list
file_list = await filter_for_gitignore(file_list, '')
# Create a PathSpec object to match gitignore patterns
gitignore_path = os.path.join(runtime.root_dir, '.gitignore')
if os.path.exists(gitignore_path):
with open(gitignore_path, 'r') as f:
gitignore = f.read()
spec = PathSpec.from_lines(GitWildMatchPattern, gitignore.splitlines())
file_list = [f for f in file_list if not spec.match_file(f)]
return file_list
@app.get('/api/select-file')
async def select_file(file: str, request: Request):
"""Retrieve the content of a specified file.
To select a file:
```sh
curl http://localhost:3000/api/select-file?file=<file_path>
```
Args:
file (str): The path of the file to be retrieved.
Expect path to be absolute inside the runtime.
request (Request): The incoming request object.
Returns:
dict: A dictionary containing the file content.
Raises:
HTTPException: If there's an error opening the file.
"""
runtime: Runtime = request.state.conversation.runtime
file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file)
read_action = FileReadAction(file)
observation = await call_sync_from_async(runtime.run_action, read_action)
if isinstance(observation, FileReadObservation):
content = observation.content
return {'code': content}
elif isinstance(observation, ErrorObservation):
logger.error(f'Error opening file {file}: {observation}', exc_info=False)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={'error': f'Error opening file: {observation}'},
)
def sanitize_filename(filename):
"""Sanitize the filename to prevent directory traversal"""
# Remove any directory components
filename = os.path.basename(filename)
# Remove any non-alphanumeric characters except for .-_
filename = re.sub(r'[^\w\-_\.]', '', filename)
# Limit the filename length
max_length = 255
if len(filename) > max_length:
name, ext = os.path.splitext(filename)
filename = name[: max_length - len(ext)] + ext
return filename
@app.post('/api/upload-files')
async def upload_file(request: Request, files: list[UploadFile]):
"""Upload a list of files to the workspace.
To upload a files:
```sh
curl -X POST -F "file=@<file_path1>" -F "file=@<file_path2>" http://localhost:3000/api/upload-files
```
Args:
request (Request): The incoming request object.
files (list[UploadFile]): A list of files to be uploaded.
Returns:
dict: A message indicating the success of the upload operation.
Raises:
HTTPException: If there's an error saving the files.
"""
try:
uploaded_files = []
skipped_files = []
for file in files:
safe_filename = sanitize_filename(file.filename)
file_contents = await file.read()
if (
MAX_FILE_SIZE_MB > 0
and len(file_contents) > MAX_FILE_SIZE_MB * 1024 * 1024
):
skipped_files.append(
{
'name': safe_filename,
'reason': f'Exceeds maximum size limit of {MAX_FILE_SIZE_MB}MB',
}
)
continue
if not is_extension_allowed(safe_filename):
skipped_files.append(
{'name': safe_filename, 'reason': 'File type not allowed'}
)
continue
# copy the file to the runtime
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_file_path = os.path.join(tmp_dir, safe_filename)
with open(tmp_file_path, 'wb') as tmp_file:
tmp_file.write(file_contents)
tmp_file.flush()
runtime: Runtime = request.state.conversation.runtime
runtime.copy_to(
tmp_file_path, runtime.config.workspace_mount_path_in_sandbox
)
uploaded_files.append(safe_filename)
response_content = {
'message': 'File upload process completed',
'uploaded_files': uploaded_files,
'skipped_files': skipped_files,
}
if not uploaded_files and skipped_files:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={
**response_content,
'error': 'No files were uploaded successfully',
},
)
return JSONResponse(status_code=status.HTTP_200_OK, content=response_content)
except Exception as e:
logger.error(f'Error during file upload: {e}', exc_info=True)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
'error': f'Error during file upload: {str(e)}',
'uploaded_files': [],
'skipped_files': [],
},
)
@app.post('/api/submit-feedback')
async def submit_feedback(request: Request):
"""Submit user feedback.
This function stores the provided feedback data.
To submit feedback:
```sh
curl -X POST -d '{"email": "test@example.com"}' -H "Authorization:"
```
Args:
request (Request): The incoming request object.
feedback (FeedbackDataModel): The feedback data to be stored.
Returns:
dict: The stored feedback data.
Raises:
HTTPException: If there's an error submitting the feedback.
"""
# Assuming the storage service is already configured in the backend
# and there is a function to handle the storage.
body = await request.json()
async_stream = AsyncEventStreamWrapper(
request.state.conversation.event_stream, filter_hidden=True
)
trajectory = []
async for event in async_stream:
trajectory.append(event_to_dict(event))
feedback = FeedbackDataModel(
email=body.get('email', ''),
version=body.get('version', ''),
permissions=body.get('permissions', 'private'),
polarity=body.get('polarity', ''),
feedback=body.get('polarity', ''),
trajectory=trajectory,
)
try:
feedback_data = await call_sync_from_async(store_feedback, feedback)
return JSONResponse(status_code=200, content=feedback_data)
except Exception as e:
logger.error(f'Error submitting feedback: {e}')
return JSONResponse(
status_code=500, content={'error': 'Failed to submit feedback'}
)
@app.get('/api/defaults')
async def appconfig_defaults():
"""Retrieve the default configuration settings.
To get the default configurations:
```sh
curl http://localhost:3000/api/defaults
```
Returns:
dict: The default configuration settings.
"""
return config.defaults_dict
@app.post('/api/save-file')
async def save_file(request: Request):
"""Save a file to the agent's runtime file store.
This endpoint allows saving a file when the agent is in a paused, finished,
or awaiting user input state. It checks the agent's state before proceeding
with the file save operation.
Args:
request (Request): The incoming FastAPI request object.
Returns:
JSONResponse: A JSON response indicating the success of the operation.
Raises:
HTTPException:
- 403 error if the agent is not in an allowed state for editing.
- 400 error if the file path or content is missing.
- 500 error if there's an unexpected error during the save operation.
"""
try:
# Extract file path and content from the request
data = await request.json()
file_path = data.get('filePath')
content = data.get('content')
# Validate the presence of required data
if not file_path or content is None:
raise HTTPException(status_code=400, detail='Missing filePath or content')
# Save the file to the agent's runtime file store
runtime: Runtime = request.state.conversation.runtime
file_path = os.path.join(
runtime.config.workspace_mount_path_in_sandbox, file_path
)
write_action = FileWriteAction(file_path, content)
observation = await call_sync_from_async(runtime.run_action, write_action)
if isinstance(observation, FileWriteObservation):
return JSONResponse(
status_code=200, content={'message': 'File saved successfully'}
)
elif isinstance(observation, ErrorObservation):
return JSONResponse(
status_code=500,
content={'error': f'Failed to save file: {observation}'},
)
else:
return JSONResponse(
status_code=500,
content={'error': f'Unexpected observation: {observation}'},
)
except Exception as e:
# Log the error and return a 500 response
logger.error(f'Error saving file: {e}', exc_info=True)
raise HTTPException(status_code=500, detail=f'Error saving file: {e}')
@app.route('/api/security/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE'])
async def security_api(request: Request):
"""Catch-all route for security analyzer API requests.
Each request is handled directly to the security analyzer.
Args:
request (Request): The incoming FastAPI request object.
Returns:
Any: The response from the security analyzer.
Raises:
HTTPException: If the security analyzer is not initialized.
"""
if not request.state.conversation.security_analyzer:
raise HTTPException(status_code=404, detail='Security analyzer not initialized')
return await request.state.conversation.security_analyzer.handle_api_request(
request
)
@app.get('/api/zip-directory')
async def zip_current_workspace(request: Request, background_tasks: BackgroundTasks):
try:
logger.debug('Zipping workspace')
runtime: Runtime = request.state.conversation.runtime
path = runtime.config.workspace_mount_path_in_sandbox
zip_file = await call_sync_from_async(runtime.copy_from, path)
response = FileResponse(
path=zip_file,
filename='workspace.zip',
media_type='application/x-zip-compressed',
)
# This will execute after the response is sent (So the file is not deleted before being sent)
background_tasks.add_task(zip_file.unlink)
return response
except Exception as e:
logger.error(f'Error zipping workspace: {e}', exc_info=True)
raise HTTPException(
status_code=500,
detail='Failed to zip workspace',
)
class AuthCode(BaseModel):
code: str
@app.post('/api/github/callback')
def github_callback(auth_code: AuthCode):
# Prepare data for the token exchange request
data = {
'client_id': GITHUB_CLIENT_ID,
'client_secret': GITHUB_CLIENT_SECRET,
'code': auth_code.code,
}
logger.debug('Exchanging code for GitHub token')
headers = {'Accept': 'application/json'}
response = requests.post(
'https://github.com/login/oauth/access_token', data=data, headers=headers
)
if response.status_code != 200:
logger.error(f'Failed to exchange code for token: {response.text}')
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={'error': 'Failed to exchange code for token'},
)
token_response = response.json()
if 'access_token' not in token_response:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={'error': 'No access token in response'},
)
return JSONResponse(
status_code=status.HTTP_200_OK,
content={'access_token': token_response['access_token']},
)
@app.post('/api/authenticate')
@RateLimiter(times=1, seconds=5) # 1 request per 5 seconds
async def authenticate(request: Request):
token = request.headers.get('X-GitHub-Token')
if not await authenticate_github_user(token):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={'error': 'Not authorized via GitHub waitlist'},
content={'error': 'Invalid token'},
)
# Create a signed JWT token with 1-hour expiration
cookie_data = {
'github_token': token,
'exp': int(time.time()) + 3600, # 1 hour expiration
}
signed_token = sign_token(cookie_data, config.jwt_secret)
signed_token = sign_token({'token': token}, config.jwt_secret)
response = JSONResponse(
status_code=status.HTTP_200_OK, content={'message': 'User authenticated'}
)
# Set secure cookie with signed token
response.set_cookie(
key='github_auth',
value=signed_token,
@@ -893,6 +499,12 @@ async def authenticate(request: Request):
class SPAStaticFiles(StaticFiles):
"""Static files handler with rate limiting."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.limiter = RateLimiter(times=10, seconds=1) # 10 requests per second
async def get_response(self, path: str, scope):
try:
return await super().get_response(path, scope)
@@ -900,5 +512,11 @@ class SPAStaticFiles(StaticFiles):
# FIXME: just making this HTTPException doesn't work for some reason
return await super().get_response('index.html', scope)
async def __call__(self, scope, receive, send) -> None:
if scope['type'] == 'http':
# Apply rate limiting
await self.limiter(scope, receive, send)
return await super().__call__(scope, receive, send)
app.mount('/', SPAStaticFiles(directory='./frontend/build', html=True), name='dist')

View File

@@ -1,3 +1,7 @@
import asyncio
import collections
import time
from typing import Callable, Dict, Optional
from urllib.parse import urlparse
from fastapi.middleware.cors import CORSMiddleware
@@ -41,3 +45,120 @@ class NoCacheMiddleware(BaseHTTPMiddleware):
response.headers['Pragma'] = 'no-cache'
response.headers['Expires'] = '0'
return response
class InMemoryStore:
"""Thread-safe in-memory store for rate limiting."""
def __init__(self):
self.storage: Dict[str, collections.deque] = {}
self._lock = asyncio.Lock()
async def incr(self, key: str) -> int:
"""Increment the counter for a key and return the new count.
Args:
key (str): The key to increment.
Returns:
int: The new count after incrementing.
"""
async with self._lock:
if key not in self.storage:
self.storage[key] = collections.deque()
now = time.time()
self.storage[key].append(now)
return len(self.storage[key])
async def expire(self, key: str, seconds: int) -> None:
"""Remove expired entries for a key.
Args:
key (str): The key to check.
seconds (int): The expiration time in seconds.
"""
async with self._lock:
if key not in self.storage:
return
now = time.time()
while self.storage[key] and self.storage[key][0] < now - seconds:
self.storage[key].popleft()
async def get(self, key: str) -> Optional[int]:
"""Get the current count for a key.
Args:
key (str): The key to get.
Returns:
Optional[int]: The current count, or None if the key doesn't exist.
"""
async with self._lock:
if key not in self.storage:
return None
return len(self.storage[key])
class RateLimiter:
"""Rate limiter middleware that uses a sliding window algorithm.
This implementation uses an in-memory store to track request counts
per client IP and path. It uses a sliding window to ensure accurate
rate limiting even at window boundaries.
"""
def __init__(self, times: int = 1, seconds: int = 1):
"""Initialize the rate limiter.
Args:
times (int, optional): Number of requests allowed. Defaults to 1.
seconds (int, optional): Time window in seconds. Defaults to 1.
"""
self.times = times
self.seconds = seconds
self.store = store
def _get_key(self, scope: dict) -> str:
"""Generate a unique key for rate limiting based on client IP and path.
Args:
scope (dict): The ASGI scope dictionary.
Returns:
str: A unique key combining client IP and path.
"""
# Use client's IP address as the key
client = scope.get('client', ['127.0.0.1'])[0]
path = scope.get('path', '')
return f'rate_limit:{client}:{path}'
async def __call__(self, scope: dict, receive: Callable, send: Callable) -> None:
"""Apply rate limiting to the request.
Args:
scope (dict): The ASGI scope dictionary.
receive (Callable): The ASGI receive function.
send (Callable): The ASGI send function.
"""
key = self._get_key(scope)
await self.store.expire(key, self.seconds)
requests = await self.store.get(key) or 0
if requests >= self.times:
await send(
{
'type': 'http.response.start',
'status': 429,
'headers': [(b'content-type', b'text/plain')],
}
)
await send(
{
'type': 'http.response.body',
'body': b'Too many requests',
}
)
return
await self.store.incr(key)
store = InMemoryStore()