Compare commits

...

21 Commits

Author SHA1 Message Date
Tim O'Farrell 8f2768c387 Merge branch 'main' into remove-socketio-communications 2026-04-28 17:57:37 -06:00
Tim O'Farrell 85b959d883 chore: remove unused items from openhands.server package (#14202)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-04-28 17:57:23 -06:00
openhands dc7e7d6394 Fix tests: test Redis errors instead of None client
The _get_redis_client() now always returns a Redis client (singleton pattern),
so tests should verify error handling for RedisError exceptions rather than
testing for None return value.

Renamed tests:
- test_acquire_user_creation_lock_no_redis -> test_acquire_user_creation_lock_redis_error
- test_release_user_creation_lock_no_redis -> test_release_user_creation_lock_redis_error

Co-authored-by: openhands <openhands@all-hands.dev>
2026-04-28 23:36:07 +00:00
openhands 42894d0c37 Address review feedback: remove dead code, improve exception handling, add thread safety
- Remove unreachable 'if not redis' checks in automation_event_service.py
  (get_redis_client_async() always returns a Redis instance)
- Use specific redis_exceptions.RedisError instead of broad Exception in user_store.py
- Add thread-safe double-checked locking pattern to Redis singleton in redis.py
- Update comment in shared.py to mention both sync and async Redis versions

Co-authored-by: openhands <openhands@all-hands.dev>
2026-04-28 23:14:07 +00:00
openhands 73db8a45ce Merge remote-tracking branch 'origin/main' into remove-socketio-communications 2026-04-28 23:09:59 +00:00
Tim O'Farrell 1ee548b909 Remove unused items from openhands.core package (#14201)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-04-28 17:05:57 -06:00
tofarr 8a5fada01d Yet more lint fixes 2026-04-28 16:39:53 -06:00
openhands fd65f86af9 Fix enterprise tests: replace sio mocks with get_redis_client_async mocks
Tests were mocking the removed socketio 'sio' object to access Redis via
sio.manager.redis. Updated to mock get_redis_client_async directly, which
is the new way Redis clients are obtained after the socketio removal.

- test_automation_event_service.py: 36 failing tests fixed
- test_rate_limit_utils.py: all tests fixed

Co-authored-by: openhands <openhands@all-hands.dev>
2026-04-28 22:32:34 +00:00
tofarr cfd2a6d817 Test fixes 2026-04-28 16:18:18 -06:00
tofarr 2d118c2bfd Lint fixes 2026-04-28 15:49:20 -06:00
tofarr b836c82ce4 Test fixes 2026-04-28 15:45:00 -06:00
tofarr 68e4afbf3b Lint fixes 2026-04-28 15:21:21 -06:00
tofarr 97607345e2 Merge branch 'main' into remove-socketio-communications 2026-04-28 15:15:23 -06:00
openhands 4b8152c6ab Reuse Redis clients via lazy singleton pattern; rename create_* to get_*
- Convert create_redis_client() and create_redis_client_async() to
  get_redis_client() and get_redis_client_async() with lazy initialization,
  so each returns a shared singleton instead of creating a new client per call.
- Update all callers across enterprise and openhands packages.
- Simplify user_store.py _get_redis_client and lock methods by removing
  unnecessary try/except around the client getter (it never raises).
- Fix import ordering in listen.py (move 'import os' to top).
- Remove unused 'import os' from shared.py (caught by ruff).

Co-authored-by: openhands <openhands@all-hands.dev>
2026-04-28 21:10:56 +00:00
Tim O'Farrell 21aa52ce3b Move openhands.server.user_auth to openhands.app_server.user_auth (#14199)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-04-28 15:00:50 -06:00
Tim O'Farrell f9b629166c Merge branch 'main' into remove-socketio-communications 2026-04-28 14:53:23 -06:00
Tim O'Farrell 1e023ce56b Remove unused legacy V0 server modules (#14198)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-04-28 14:47:26 -06:00
openhands 7f5517cf3c Remove socketio from app server communications
- Remove socketio from legacy server (listen.py, shared.py, listen_socket.py)
- Update enterprise code to use async redis client from storage.redis
- Add create_redis_client_async() to redis.py for async operations
- Refactor all code using sio.manager.redis to use async redis client

SocketIO was only wrapping FastAPI without any actual socket events,
handlers, or emit calls - it was unnecessary dead code.
2026-04-28 20:45:10 +00:00
Tim O'Farrell 5b500d640a refactor: move openhands.integrations to openhands.app_server.integrations (#14195)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-04-28 14:21:14 -06:00
Tim O'Farrell c824b2dda5 refactor: move FileStore to openhands.app_server.file_store (#14178)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-04-28 13:10:31 -06:00
Rohit Malhotra c9b6f54e76 fix: correct GLOBAL_SKILLS_DIR path for skills settings page (#14194)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-04-28 19:08:04 +00:00
278 changed files with 1420 additions and 2726 deletions
+2 -2
View File
@@ -27,7 +27,7 @@ Before pushing any changes, you MUST ensure that any lint errors or simple test
* If you've made changes to the backend, you should run `pre-commit run --config ./dev_config/python/.pre-commit-config.yaml` (this will run on staged files).
* If you've made changes to the frontend, you should run `cd frontend && npm run lint:fix && npm run build ; cd ..`
* If you've made changes to the VSCode extension, you should run `cd openhands/integrations/vscode && npm run lint:fix && npm run compile ; cd ../../..`
* If you've made changes to the VSCode extension, you should run `cd openhands/app_server/integrations/vscode && npm run lint:fix && npm run compile ; cd ../../..`
The pre-commit hooks MUST pass successfully before pushing any changes to the repository. This is a mandatory requirement to maintain code quality and consistency.
@@ -150,7 +150,7 @@ Frontend:
VSCode Extension:
- Located in the `openhands/integrations/vscode` directory
- Located in the `openhands/app_server/integrations/vscode` directory
- Setup: Run `npm install` in the extension directory
- Linting:
- Run linting with fixes: `npm run lint:fix`
@@ -1,9 +1,11 @@
from pydantic import SecretStr
from server.auth.token_manager import TokenManager
from openhands.app_server.integrations.bitbucket.bitbucket_service import (
BitBucketService,
)
from openhands.app_server.integrations.service_types import ProviderType
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.bitbucket.bitbucket_service import BitBucketService
from openhands.integrations.service_types import ProviderType
class SaaSBitBucketService(BitBucketService):
@@ -1,11 +1,11 @@
from pydantic import SecretStr
from server.auth.token_manager import TokenManager
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.bitbucket_data_center.bitbucket_dc_service import (
from openhands.app_server.integrations.bitbucket_data_center.bitbucket_dc_service import (
BitbucketDCService,
)
from openhands.integrations.service_types import ProviderType
from openhands.app_server.integrations.service_types import ProviderType
from openhands.core.logger import openhands_logger as logger
class SaaSBitbucketDCService(BitbucketDCService):
@@ -19,12 +19,12 @@ from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
from storage.openhands_pr import OpenhandsPR
from storage.openhands_pr_store import OpenhandsPRStore
from openhands.app_server.conversation_paths import get_conversation_dir
from openhands.app_server.file_store import get_file_store
from openhands.app_server.integrations.github.github_service import GithubServiceImpl
from openhands.app_server.integrations.service_types import ProviderType
from openhands.core.config import load_openhands_config
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.github.github_service import GithubServiceImpl
from openhands.integrations.service_types import ProviderType
from openhands.storage import get_file_store
from openhands.storage.locations import get_conversation_dir
config = load_openhands_config()
file_store = get_file_store(config.file_store, config.file_store_path)
@@ -112,7 +112,7 @@ class GitHubDataCollector:
suffix = path.format(repo_id, number)
if conversation_id:
return f'{get_conversation_dir(conversation_id)}{suffix}'
return f'{get_conversation_dir(conversation_id)}/{suffix}'
return suffix
@@ -31,10 +31,10 @@ from server.auth.auth_error import ExpiredError
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
from server.auth.token_manager import TokenManager
from openhands.app_server.integrations.provider import ProviderToken, ProviderType
from openhands.app_server.integrations.service_types import AuthenticationError
from openhands.app_server.secrets.secrets_models import Secrets
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderToken, ProviderType
from openhands.integrations.service_types import AuthenticationError
from openhands.server.types import (
LLMAuthenticationError,
MissingSettingsError,
@@ -4,9 +4,9 @@ from integrations.store_repo_utils import store_repositories_in_db
from pydantic import SecretStr
from server.auth.token_manager import TokenManager
from openhands.app_server.integrations.github.github_service import GitHubService
from openhands.app_server.integrations.service_types import ProviderType, Repository
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.github.github_service import GitHubService
from openhands.integrations.service_types import ProviderType, Repository
from openhands.server.types import AppMode
@@ -34,14 +34,14 @@ from openhands.app_server.app_conversation.app_conversation_models import (
ConversationTrigger,
)
from openhands.app_server.config import get_app_conversation_service
from openhands.app_server.integrations.github.github_service import GithubServiceImpl
from openhands.app_server.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
from openhands.app_server.integrations.service_types import Comment
from openhands.app_server.services.injector import InjectorState
from openhands.app_server.user.specifiy_user_context import USER_CONTEXT_ATTR
from openhands.app_server.user_auth.user_auth import UserAuth
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.github.github_service import GithubServiceImpl
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
from openhands.integrations.service_types import Comment
from openhands.sdk import TextContent
from openhands.server.user_auth.user_auth import UserAuth
from openhands.utils.async_utils import call_sync_from_async
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
@@ -25,10 +25,10 @@ from jinja2 import Environment, FileSystemLoader
from pydantic import SecretStr
from server.auth.token_manager import TokenManager
from openhands.app_server.integrations.gitlab.gitlab_service import GitLabServiceImpl
from openhands.app_server.integrations.provider import ProviderToken, ProviderType
from openhands.app_server.secrets.secrets_models import Secrets
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
from openhands.integrations.provider import ProviderToken, ProviderType
from openhands.server.types import (
LLMAuthenticationError,
MissingSettingsError,
@@ -7,14 +7,14 @@ from server.auth.token_manager import TokenManager
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
from storage.gitlab_webhook_store import GitlabWebhookStore
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.gitlab.gitlab_service import GitLabService
from openhands.integrations.service_types import (
from openhands.app_server.integrations.gitlab.gitlab_service import GitLabService
from openhands.app_server.integrations.service_types import (
ProviderType,
RateLimitError,
Repository,
RequestMethod,
)
from openhands.core.logger import openhands_logger as logger
from openhands.server.types import AppMode
@@ -22,14 +22,14 @@ from openhands.app_server.app_conversation.app_conversation_models import (
ConversationTrigger,
)
from openhands.app_server.config import get_app_conversation_service
from openhands.app_server.integrations.gitlab.gitlab_service import GitLabServiceImpl
from openhands.app_server.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
from openhands.app_server.integrations.service_types import Comment
from openhands.app_server.services.injector import InjectorState
from openhands.app_server.user.specifiy_user_context import USER_CONTEXT_ATTR
from openhands.app_server.user_auth.user_auth import UserAuth
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
from openhands.integrations.service_types import Comment
from openhands.sdk import TextContent
from openhands.server.user_auth.user_auth import UserAuth
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
CONFIDENTIAL_NOTE = 'confidential_note'
+1 -1
View File
@@ -42,13 +42,13 @@ from storage.jira_integration_store import JiraIntegrationStore
from storage.jira_user import JiraUser
from storage.jira_workspace import JiraWorkspace
from openhands.app_server.user_auth.user_auth import UserAuth
from openhands.core.logger import openhands_logger as logger
from openhands.server.types import (
LLMAuthenticationError,
MissingSettingsError,
SessionExpiredError,
)
from openhands.server.user_auth.user_auth import UserAuth
from openhands.utils.http_session import httpx_verify_option
JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
+1 -1
View File
@@ -7,7 +7,7 @@ from jinja2 import Environment
from storage.jira_user import JiraUser
from storage.jira_workspace import JiraWorkspace
from openhands.server.user_auth.user_auth import UserAuth
from openhands.app_server.user_auth.user_auth import UserAuth
if TYPE_CHECKING:
from integrations.jira.jira_payload import JiraWebhookPayload
+2 -2
View File
@@ -38,12 +38,12 @@ from openhands.app_server.app_conversation.app_conversation_models import (
ConversationTrigger,
)
from openhands.app_server.config import get_app_conversation_service
from openhands.app_server.integrations.provider import ProviderHandler, ProviderType
from openhands.app_server.services.injector import InjectorState
from openhands.app_server.user.specifiy_user_context import USER_CONTEXT_ATTR
from openhands.app_server.user_auth.user_auth import UserAuth
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderHandler, ProviderType
from openhands.sdk import TextContent
from openhands.server.user_auth.user_auth import UserAuth
from openhands.utils.http_session import httpx_verify_option
JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
@@ -29,16 +29,16 @@ from storage.jira_dc_integration_store import JiraDcIntegrationStore
from storage.jira_dc_user import JiraDcUser
from storage.jira_dc_workspace import JiraDcWorkspace
from openhands.app_server.integrations.provider import ProviderHandler
from openhands.app_server.integrations.service_types import Repository
from openhands.app_server.user_auth.user_auth import UserAuth
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderHandler
from openhands.integrations.service_types import Repository
from openhands.server.shared import server_config
from openhands.server.types import (
LLMAuthenticationError,
MissingSettingsError,
SessionExpiredError,
)
from openhands.server.user_auth.user_auth import UserAuth
from openhands.utils.http_session import httpx_verify_option
@@ -5,7 +5,7 @@ from jinja2 import Environment
from storage.jira_dc_user import JiraDcUser
from storage.jira_dc_workspace import JiraDcWorkspace
from openhands.server.user_auth.user_auth import UserAuth
from openhands.app_server.user_auth.user_auth import UserAuth
class JiraDcViewInterface(ABC):
@@ -30,12 +30,12 @@ from openhands.app_server.app_conversation.app_conversation_models import (
ConversationTrigger,
)
from openhands.app_server.config import get_app_conversation_service
from openhands.app_server.integrations.provider import ProviderHandler, ProviderType
from openhands.app_server.services.injector import InjectorState
from openhands.app_server.user.specifiy_user_context import USER_CONTEXT_ATTR
from openhands.app_server.user_auth.user_auth import UserAuth
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderHandler, ProviderType
from openhands.sdk import TextContent
from openhands.server.user_auth.user_auth import UserAuth
integration_store = JiraDcIntegrationStore.get_instance()
+6 -3
View File
@@ -1,11 +1,14 @@
from uuid import UUID
from openhands.app_server.integrations.provider import (
PROVIDER_TOKEN_TYPE,
ProviderHandler,
)
from openhands.app_server.integrations.service_types import ProviderType, UserGitInfo
from openhands.app_server.user.user_context import UserContext
from openhands.app_server.user.user_models import UserInfo
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
from openhands.integrations.service_types import ProviderType, UserGitInfo
from openhands.app_server.user_auth.user_auth import UserAuth
from openhands.sdk.secret import SecretSource, StaticSecret
from openhands.server.user_auth.user_auth import UserAuth
class ResolverUserContext(UserContext):
@@ -28,22 +28,23 @@ from slack_sdk.oauth import AuthorizeUrlGenerator
from slack_sdk.web.async_client import AsyncWebClient
from sqlalchemy import select
from storage.database import a_session_maker
from storage.redis import get_redis_client_async
from storage.slack_user import SlackUser
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderHandler
from openhands.integrations.service_types import (
from openhands.app_server.integrations.provider import ProviderHandler
from openhands.app_server.integrations.service_types import (
AuthenticationError,
ProviderTimeoutError,
Repository,
)
from openhands.server.shared import config, server_config, sio
from openhands.app_server.user_auth.user_auth import UserAuth
from openhands.core.logger import openhands_logger as logger
from openhands.server.shared import config, server_config
from openhands.server.types import (
LLMAuthenticationError,
MissingSettingsError,
SessionExpiredError,
)
from openhands.server.user_auth.user_auth import UserAuth
authorize_url_generator = AuthorizeUrlGenerator(
client_id=SLACK_CLIENT_ID,
@@ -114,7 +115,7 @@ class SlackManager(Manager[SlackViewInterface]):
"""
key = f'{SLACK_USER_MSG_KEY_PREFIX}:{message_ts}:{thread_ts}'
try:
redis = sio.manager.redis
redis = get_redis_client_async()
await redis.set(key, user_msg, ex=SLACK_USER_MSG_EXPIRATION)
logger.info(
'slack_stored_user_msg',
@@ -157,7 +158,7 @@ class SlackManager(Manager[SlackViewInterface]):
"""
key = f'{SLACK_USER_MSG_KEY_PREFIX}:{message_ts}:{thread_ts}'
try:
redis = sio.manager.redis
redis = get_redis_client_async()
user_msg = await redis.get(key)
if user_msg:
# Redis returns bytes, decode to string
+1 -1
View File
@@ -5,7 +5,7 @@ from integrations.types import SummaryExtractionTracker
from jinja2 import Environment
from storage.slack_user import SlackUser
from openhands.server.user_auth.user_auth import UserAuth
from openhands.app_server.user_auth.user_auth import UserAuth
@dataclass
+2 -2
View File
@@ -30,13 +30,13 @@ from openhands.app_server.app_conversation.app_conversation_models import (
SendMessageRequest,
)
from openhands.app_server.config import get_app_conversation_service
from openhands.app_server.integrations.provider import ProviderHandler
from openhands.app_server.sandbox.sandbox_models import SandboxStatus
from openhands.app_server.services.injector import InjectorState
from openhands.app_server.user.specifiy_user_context import USER_CONTEXT_ATTR
from openhands.app_server.user_auth.user_auth import UserAuth
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderHandler
from openhands.sdk import TextContent
from openhands.server.user_auth.user_auth import UserAuth
from openhands.utils.async_utils import GENERAL_TIMEOUT
# =================================================
+1 -1
View File
@@ -3,9 +3,9 @@ from storage.stored_repository import StoredRepository
from storage.user_repo_map import UserRepositoryMap
from storage.user_repo_map_store import UserRepositoryMapStore
from openhands.app_server.integrations.service_types import Repository
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.service_types import Repository
async def store_repositories_in_db(repos: list[Repository], user_id: str) -> None:
+2 -2
View File
@@ -9,8 +9,8 @@ from pydantic import BaseModel
if TYPE_CHECKING:
from integrations.models import Message
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.server.user_auth.user_auth import UserAuth
from openhands.app_server.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.app_server.user_auth.user_auth import UserAuth
class GitLabResourceType(Enum):
+2 -2
View File
@@ -6,7 +6,7 @@ import re
from jinja2 import Environment, FileSystemLoader
from server.constants import WEB_HOST
from openhands.integrations.service_types import Repository
from openhands.app_server.integrations.service_types import Repository
# ---- DO NOT REMOVE ----
# WARNING: Langfuse depends on the WEB_HOST environment variable being set to track events.
@@ -65,7 +65,7 @@ def get_user_not_found_message(username: str | None = None) -> str:
OPENHANDS_RESOLVER_TEMPLATES_DIR = (
os.getenv('OPENHANDS_RESOLVER_TEMPLATES_DIR')
or 'openhands/integrations/templates/resolver/'
or 'openhands/app_server/integrations/templates/resolver/'
)
_jinja_env = Environment(loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR))
+1 -1
View File
@@ -7,8 +7,8 @@ from pydantic import SecretStr
from server.auth.saas_user_auth import SaasUserAuth
from server.auth.token_manager import TokenManager
from openhands.app_server.user_auth.user_auth import UserAuth
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth.user_auth import UserAuth
def is_budget_exceeded_error(error_message: str) -> bool:
+2 -3
View File
@@ -8,7 +8,6 @@ load_dotenv()
if not os.getenv('OPENHANDS_CONFIG_CLS'):
os.environ['OPENHANDS_CONFIG_CLS'] = 'server.config.SaaSServerConfig'
import socketio # noqa: E402
from fastapi import Request, status # noqa: E402
from fastapi.middleware.cors import CORSMiddleware # noqa: E402
from fastapi.responses import JSONResponse # noqa: E402
@@ -61,7 +60,6 @@ from server.verified_models.verified_model_router import ( # noqa: E402
)
from openhands.server.app import app as base_app # noqa: E402
from openhands.server.listen_socket import sio # noqa: E402
from openhands.server.middleware import ( # noqa: E402
CacheControlMiddleware,
)
@@ -178,4 +176,5 @@ async def expired_exception_handler(request: Request, exc: ExpiredError):
return JSONResponse({'error': ExpiredError.__name__}, status.HTTP_401_UNAUTHORIZED)
app = socketio.ASGIApp(sio, other_asgi_app=base_app)
# Note: socketio is no longer used for communication. The base FastAPI app is used directly.
app = base_app
+1 -1
View File
@@ -40,8 +40,8 @@ from storage.org_member_store import OrgMemberStore
from storage.role import Role
from storage.role_store import RoleStore
from openhands.app_server.user_auth import get_user_auth, get_user_id
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_auth, get_user_id
class Permission(str, Enum):
+1 -1
View File
@@ -1,6 +1,6 @@
import os
from openhands.integrations.gitlab.constants import GITLAB_HOST
from openhands.app_server.integrations.gitlab.constants import GITLAB_HOST
GITHUB_APP_CLIENT_ID = os.getenv('GITHUB_APP_CLIENT_ID', '').strip()
GITHUB_APP_CLIENT_SECRET = os.getenv('GITHUB_APP_CLIENT_SECRET', '').strip()
+1 -1
View File
@@ -2,8 +2,8 @@ from integrations.github.github_service import SaaSGitHubService
from pydantic import SecretStr
from server.auth.auth_utils import user_verifier
from openhands.app_server.integrations.github.github_types import GitHubUser
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.github.github_types import GitHubUser
def is_user_allowed(user_login: str):
+2 -2
View File
@@ -3,8 +3,8 @@ import asyncio
from pydantic import SecretStr
from sqlalchemy import select
from openhands.app_server.integrations.service_types import ProviderType
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.service_types import ProviderType
from openhands.server.types import AppMode
@@ -55,7 +55,7 @@ def schedule_gitlab_repo_sync(
# Lazy import to avoid circular dependency:
# middleware -> gitlab_sync -> integrations.gitlab.gitlab_service
# -> openhands.integrations.gitlab.gitlab_service -> get_impl
# -> openhands.app_server.integrations.gitlab.gitlab_service -> get_impl
# -> integrations.gitlab.gitlab_service (circular)
from integrations.gitlab.gitlab_service import SaaSGitLabService
+5 -5
View File
@@ -35,15 +35,15 @@ from storage.user_authorization_store import UserAuthorizationStore
from storage.user_store import UserStore
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
from openhands.app_server.secrets.secrets_models import Secrets
from openhands.app_server.settings.settings_models import Settings
from openhands.app_server.settings.settings_store import SettingsStore
from openhands.integrations.provider import (
from openhands.app_server.integrations.provider import (
PROVIDER_TOKEN_TYPE,
ProviderToken,
ProviderType,
)
from openhands.server.user_auth.user_auth import AuthType, UserAuth
from openhands.app_server.secrets.secrets_models import Secrets
from openhands.app_server.settings.settings_models import Settings
from openhands.app_server.settings.settings_store import SettingsStore
from openhands.app_server.user_auth.user_auth import AuthType, UserAuth
token_manager = TokenManager()
+1 -1
View File
@@ -51,7 +51,7 @@ from storage.github_app_installation import GithubAppInstallation
from storage.offline_token_store import OfflineTokenStore
from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt
from openhands.integrations.service_types import ProviderType
from openhands.app_server.integrations.service_types import ProviderType
from openhands.server.types import SessionExpiredError
from openhands.utils.http_session import httpx_verify_option
+1 -1
View File
@@ -22,8 +22,8 @@ from server.auth.constants import (
)
from server.constants import DEPLOYMENT_MODE
from openhands.app_server.integrations.service_types import ProviderType
from openhands.core.config.utils import load_openhands_config
from openhands.integrations.service_types import ProviderType
from openhands.server.config.server_config import ServerConfig
from openhands.server.types import AppMode
+1 -1
View File
@@ -4,8 +4,8 @@ Email domain validation utilities for enterprise endpoints.
from fastapi import Depends, HTTPException, Request, status
from openhands.app_server.user_auth import get_user_auth, get_user_id
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_auth, get_user_id
async def get_admin_user_id(
+1 -1
View File
@@ -15,9 +15,9 @@ from server.auth.saas_user_auth import SaasUserAuth, token_manager
from server.routes.auth import set_response_cookie
from server.utils.url_utils import get_cookie_domain, get_cookie_samesite
from openhands.app_server.user_auth.user_auth import AuthType, UserAuth, get_user_auth
from openhands.core.logger import openhands_logger as logger
from openhands.server.shared import config
from openhands.server.user_auth.user_auth import AuthType, UserAuth, get_user_auth
class SetAuthCookieMiddleware:
+1 -1
View File
@@ -2,8 +2,8 @@
from pydantic import BaseModel
from openhands.app_server.integrations.service_types import ProviderType
from openhands.app_server.user.user_models import UserInfo
from openhands.integrations.service_types import ProviderType
class SaasUserInfo(UserInfo):
+2 -2
View File
@@ -12,9 +12,9 @@ from storage.org_member_store import OrgMemberStore
from storage.org_service import OrgService
from storage.user_store import UserStore
from openhands.app_server.user_auth import get_user_auth, get_user_id
from openhands.app_server.user_auth.user_auth import AuthType
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_auth, get_user_id
from openhands.server.user_auth.user_auth import AuthType
# Helper functions for BYOR API key management
+21 -5
View File
@@ -3,6 +3,7 @@ import json
import uuid
import warnings
from datetime import datetime, timezone
from types import MappingProxyType
from typing import Annotated, Optional, cast
from urllib.parse import quote, urlencode
from uuid import UUID as parse_uuid
@@ -46,13 +47,16 @@ from storage.database import a_session_maker
from storage.user import User
from storage.user_store import UserStore
from openhands.app_server.integrations.provider import (
PROVIDER_TOKEN_TYPE,
ProviderHandler,
ProviderToken,
)
from openhands.app_server.integrations.service_types import ProviderType, TokenResponse
from openhands.app_server.user_auth import get_access_token
from openhands.app_server.user_auth.user_auth import get_user_auth
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderHandler
from openhands.integrations.service_types import ProviderType, TokenResponse
from openhands.server.services.conversation_service import create_provider_tokens_object
from openhands.server.shared import config
from openhands.server.user_auth import get_access_token
from openhands.server.user_auth.user_auth import get_user_auth
with warnings.catch_warnings():
warnings.simplefilter('ignore')
@@ -63,6 +67,18 @@ oauth_router = APIRouter(prefix='/oauth')
token_manager = TokenManager()
def create_provider_tokens_object(
providers_set: list[ProviderType],
) -> PROVIDER_TOKEN_TYPE:
"""Create provider tokens object for the given providers."""
provider_information: dict[ProviderType, ProviderToken] = {}
for provider in providers_set:
provider_information[provider] = ProviderToken(token=None, user_id=None)
return MappingProxyType(provider_information)
def set_response_cookie(
request: Request,
response: Response,
+1 -1
View File
@@ -21,7 +21,7 @@ from storage.subscription_access import SubscriptionAccess
from storage.user_store import UserStore
from openhands.app_server.config import get_global_config
from openhands.server.user_auth import get_user_id
from openhands.app_server.user_auth import get_user_id
stripe.api_key = STRIPE_API_KEY
billing_router = APIRouter(prefix='/api/billing', tags=['Billing'])
+2 -2
View File
@@ -13,9 +13,9 @@ from server.utils.rate_limit_utils import check_rate_limit_by_user_id
from server.utils.url_utils import get_web_url
from storage.user_store import UserStore
from openhands.app_server.user_auth import get_user_id
from openhands.app_server.user_auth.user_auth import get_user_auth
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
from openhands.server.user_auth.user_auth import get_user_auth
# Email validation regex pattern
EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
@@ -15,8 +15,8 @@ from server.auth.constants import (
from server.auth.token_manager import TokenManager
from server.services.automation_event_service import AutomationEventService
from openhands.app_server.integrations.provider import ProviderType
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderType
# Environment variable to disable GitHub webhooks
GITHUB_WEBHOOKS_ENABLED = os.environ.get('GITHUB_WEBHOOKS_ENABLED', '1') in (
@@ -18,11 +18,11 @@ from pydantic import BaseModel
from server.auth.token_manager import TokenManager
from storage.gitlab_webhook import GitlabWebhook
from storage.gitlab_webhook_store import GitlabWebhookStore
from storage.redis import get_redis_client_async
from openhands.app_server.integrations.gitlab.gitlab_service import GitLabServiceImpl
from openhands.app_server.user_auth import get_user_id
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
from openhands.server.shared import sio
from openhands.server.user_auth import get_user_id
gitlab_integration_router = APIRouter(prefix='/integration')
webhook_store = GitlabWebhookStore()
@@ -103,7 +103,7 @@ async def gitlab_events(
dedup_hash = hashlib.sha256(dedup_json.encode()).hexdigest()
dedup_key = f'gitlab_msg: {dedup_hash}'
redis = sio.manager.redis
redis = get_redis_client_async()
created = await redis.set(dedup_key, 1, nx=True, ex=60)
if not created:
logger.info('gitlab_is_duplicate')
+3 -3
View File
@@ -18,10 +18,10 @@ from server.auth.constants import JIRA_CLIENT_ID, JIRA_CLIENT_SECRET
from server.auth.saas_user_auth import SaasUserAuth
from server.auth.token_manager import TokenManager
from storage.jira_workspace import JiraWorkspace
from storage.redis import create_redis_client
from storage.redis import get_redis_client
from openhands.app_server.user_auth.user_auth import get_user_auth
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth.user_auth import get_user_auth
# Environment variable to disable Jira webhooks
JIRA_WEBHOOKS_ENABLED = os.environ.get('JIRA_WEBHOOKS_ENABLED', '0') in (
@@ -123,7 +123,7 @@ class JiraValidateWorkspaceResponse(BaseModel):
jira_integration_router = APIRouter(prefix='/integration/jira')
token_manager = TokenManager()
jira_manager = JiraManager(token_manager)
redis_client = create_redis_client()
redis_client = get_redis_client()
async def verify_jira_signature(body: bytes, signature: str, payload: dict):
@@ -26,10 +26,10 @@ from server.auth.constants import (
from server.auth.saas_user_auth import SaasUserAuth
from server.auth.token_manager import TokenManager
from server.constants import WEB_HOST
from storage.redis import create_redis_client
from storage.redis import get_redis_client
from openhands.app_server.user_auth.user_auth import get_user_auth
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth.user_auth import get_user_auth
# Environment variable to disable Jira DC webhooks
JIRA_DC_WEBHOOKS_ENABLED = os.environ.get('JIRA_DC_WEBHOOKS_ENABLED', '0') in (
@@ -129,7 +129,7 @@ class JiraDcValidateWorkspaceResponse(BaseModel):
jira_dc_integration_router = APIRouter(prefix='/integration/jira-dc')
token_manager = TokenManager()
jira_dc_manager = JiraDcManager(token_manager)
redis_client = create_redis_client()
redis_client = get_redis_client()
async def _handle_workspace_link_creation(
@@ -35,12 +35,16 @@ from slack_sdk.signature import SignatureVerifier
from slack_sdk.web.async_client import AsyncWebClient
from sqlalchemy import delete
from storage.database import a_session_maker
from storage.redis import get_redis_client_async
from storage.slack_team_store import SlackTeamStore
from storage.slack_user import SlackUser
from storage.user_store import UserStore
from openhands.integrations.service_types import ProviderTimeoutError, ProviderType
from openhands.server.shared import config, sio
from openhands.app_server.integrations.service_types import (
ProviderTimeoutError,
ProviderType,
)
from openhands.server.shared import config
signature_verifier = SignatureVerifier(signing_secret=SLACK_SIGNING_SECRET)
slack_router = APIRouter(prefix='/slack')
@@ -324,7 +328,7 @@ async def on_event(request: Request, background_tasks: BackgroundTasks):
team_id = payload['team_id']
# Sometimes slack sends duplicates, so we need to make sure this is not a duplicate.
redis = sio.manager.redis
redis = get_redis_client_async()
key = f'slack_msg:{client_msg_id}'
created = await redis.set(key, 1, nx=True, ex=60)
if not created:
+1 -1
View File
@@ -10,8 +10,8 @@ from server.utils.url_utils import get_web_url
from storage.api_key_store import ApiKeyStore
from storage.device_code_store import DeviceCodeStore
from openhands.app_server.user_auth import get_user_id
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
# ---------------------------------------------------------------------------
# Constants
+1 -1
View File
@@ -22,8 +22,8 @@ from server.utils.rate_limit_utils import check_rate_limit_by_user_id
from storage.org_store import OrgStore
from storage.role_store import RoleStore
from openhands.app_server.user_auth import get_user_id
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
# Router for invitation operations on an organization (requires org_id)
invitation_router = APIRouter(prefix='/api/organizations/{org_id}/members')
+1 -1
View File
@@ -50,8 +50,8 @@ from storage.org_service import OrgService
from storage.org_store import OrgStore
from storage.user_store import UserStore
from openhands.app_server.user_auth import get_user_id
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
# Initialize API router
org_router = APIRouter(prefix='/api/organizations', tags=['Orgs'])
+2 -2
View File
@@ -1,7 +1,7 @@
from fastapi import APIRouter, HTTPException, status
from sqlalchemy.sql import text
from storage.database import a_session_maker
from storage.redis import create_redis_client
from storage.redis import get_redis_client
from openhands.core.logger import openhands_logger as logger
@@ -23,7 +23,7 @@ async def is_ready():
# Check Redis connection
try:
redis_client = create_redis_client()
redis_client = get_redis_client()
redis_client.ping()
except Exception as e:
logger.error(f'Redis check failed: {str(e)}')
+2 -2
View File
@@ -17,12 +17,12 @@ from openhands.app_server.config import (
depends_user_context,
resolve_provider_llm_base_url,
)
from openhands.app_server.integrations.provider import ProviderHandler
from openhands.app_server.integrations.service_types import ProviderType
from openhands.app_server.sandbox.session_auth import validate_session_key_ownership
from openhands.app_server.user.auth_user_context import AuthUserContext
from openhands.app_server.user.user_context import UserContext
from openhands.app_server.utils.dependencies import get_dependencies
from openhands.integrations.provider import ProviderHandler
from openhands.integrations.service_types import ProviderType
_logger = logging.getLogger(__name__)
@@ -34,10 +34,10 @@ from server.auth.constants import (
AUTOMATION_WEBHOOK_SECRET,
)
from server.auth.token_manager import TokenManager
from storage.redis import get_redis_client_async
from openhands.app_server.integrations.provider import ProviderType
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderType
from openhands.server.shared import sio
# Cache TTL constants
ORG_CLAIM_CACHE_TTL_SECONDS = 3600 # 1 hour for org claims (rarely change)
@@ -382,16 +382,7 @@ class AutomationEventService:
Monitor logs for 'Redis unavailable' warnings to detect degradation.
"""
try:
redis = getattr(sio.manager, 'redis', None)
if not redis:
# Log at warning level - this is a significant degradation that
# will cause DB load. Monitor these logs for alerting.
logger.warning(
'[AutomationEventService] Redis unavailable for cache read, '
'falling back to direct DB queries (this will increase DB load)'
)
return None
redis = get_redis_client_async()
cached = await redis.get(cache_key)
if cached is None:
return None
@@ -415,11 +406,7 @@ class AutomationEventService:
Fails silently if Redis is unavailable (graceful degradation).
"""
try:
redis = getattr(sio.manager, 'redis', None)
if not redis:
# Silent failure - read path already logs the warning
return
redis = get_redis_client_async()
await redis.setex(cache_key, ttl_seconds, value)
except Exception as e:
# Log at warning level for visibility
@@ -1,7 +1,7 @@
from datetime import datetime
# Simplified imports to avoid dependency chain issues
# from openhands.integrations.service_types import ProviderType
# from openhands.app_server.integrations.service_types import ProviderType
# from openhands.sdk.llm import MetricsSnapshot
# from openhands.app_server.app_conversation.app_conversation_models import ConversationTrigger
# For now, use Any to avoid import issues
@@ -30,8 +30,8 @@ from openhands.agent_server.utils import utc_now
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
StoredConversationMetadata,
)
from openhands.app_server.integrations.provider import ProviderType
from openhands.app_server.services.injector import InjectorState
from openhands.integrations.provider import ProviderType
from openhands.sdk.llm import MetricsSnapshot, TokenUsage
logger = logging.getLogger(__name__)
+2 -2
View File
@@ -1,7 +1,7 @@
from fastapi import HTTPException, Request, status
from storage.redis import get_redis_client_async
from openhands.core.logger import openhands_logger as logger
from openhands.server.shared import sio
# Rate limiting constants
RATE_LIMIT_USER_SECONDS = 120 # 2 minutes per user_id
@@ -32,7 +32,7 @@ async def check_rate_limit_by_user_id(
HTTPException: If rate limit is exceeded (429 status code)
"""
try:
redis = sio.manager.redis
redis = get_redis_client_async()
if not redis:
# If Redis is unavailable, log warning and allow request (fail open)
logger.warning('Redis unavailable for rate limiting, allowing request')
+1 -1
View File
@@ -10,8 +10,8 @@ from sqlalchemy.exc import OperationalError
from storage.auth_tokens import AuthTokens
from storage.database import a_session_maker
from openhands.app_server.integrations.service_types import ProviderType
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.service_types import ProviderType
# Time buffer (in seconds) before actual expiration to consider token expired
# This ensures tokens are refreshed before they actually expire. The
+1 -1
View File
@@ -6,8 +6,8 @@ from sqlalchemy import and_, desc, select
from storage.database import a_session_maker
from storage.openhands_pr import OpenhandsPR
from openhands.app_server.integrations.service_types import ProviderType
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.service_types import ProviderType
class OpenhandsPRStore:
@@ -13,8 +13,8 @@ from sqlalchemy import and_, delete, select, update
from storage.database import a_session_maker
from storage.proactive_convos import ProactiveConversation
from openhands.app_server.integrations.service_types import ProviderType
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.service_types import ProviderType
@dataclass
+62 -9
View File
@@ -1,23 +1,76 @@
import os
import threading
import redis
from redis import Redis
from redis import asyncio as aioredis
from redis import exceptions as redis_exceptions
# Redis configuration
REDIS_HOST = os.environ.get('REDIS_HOST', 'localhost')
REDIS_PORT = int(os.environ.get('REDIS_PORT', '6379'))
REDIS_PASSWORD = os.environ.get('REDIS_PASSWORD', '')
REDIS_DB = int(os.environ.get('REDIS_DB', '0'))
REDIS_SOCKET_TIMEOUT = 2
_redis_client: Redis | None = None
_redis_client_async: aioredis.Redis | None = None
_redis_lock = threading.Lock()
def create_redis_client():
return redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB,
socket_timeout=2,
)
def _get_redis_kwargs():
"""Return common kwargs for Redis client creation."""
return {
'host': REDIS_HOST,
'port': REDIS_PORT,
'password': REDIS_PASSWORD,
'db': REDIS_DB,
'socket_timeout': REDIS_SOCKET_TIMEOUT,
}
def get_redis_client() -> Redis:
"""Get a shared synchronous Redis client, lazily initialized.
Thread-safe with double-checked locking pattern.
Returns:
A Redis client for synchronous operations.
"""
global _redis_client
if _redis_client is None:
with _redis_lock:
if _redis_client is None:
_redis_client = Redis(**_get_redis_kwargs())
return _redis_client
def get_redis_client_async() -> aioredis.Redis:
"""Get a shared asynchronous Redis client, lazily initialized.
Note: This function is synchronous but returns an async client.
Thread-safe initialization is handled via a threading lock since
asyncio.Lock cannot be used in a sync context.
Returns:
An aioredis client for asynchronous operations.
"""
global _redis_client_async
if _redis_client_async is None:
with _redis_lock:
if _redis_client_async is None:
_redis_client_async = aioredis.Redis(**_get_redis_kwargs())
return _redis_client_async
def get_redis_authed_url():
return f'redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}'
__all__ = [
'Redis',
'aioredis',
'get_redis_client',
'get_redis_client_async',
'get_redis_authed_url',
'redis_exceptions',
]
+23 -19
View File
@@ -109,10 +109,10 @@ class UserStore:
@staticmethod
def _get_redis_client():
"""Get the Redis client from the Socket.IO manager."""
from openhands.server.shared import sio
"""Get the shared async Redis client from enterprise storage."""
from storage.redis import get_redis_client_async
return getattr(sio.manager, 'redis', None)
return get_redis_client_async()
@staticmethod
async def _acquire_user_creation_lock(user_id: str) -> bool:
@@ -121,19 +121,21 @@ class UserStore:
Returns True if the lock was acquired or if Redis is unavailable (fallback to no locking).
Returns False if another process holds the lock.
"""
from storage.redis import redis_exceptions
redis_client = UserStore._get_redis_client()
if redis_client is None:
try:
user_key = f'{_REDIS_USER_CREATION_KEY_PREFIX}{user_id}'
lock_acquired = await redis_client.set(
user_key, 1, nx=True, ex=_REDIS_CREATE_TIMEOUT_SECONDS
)
return bool(lock_acquired)
except redis_exceptions.RedisError:
logger.warning(
'user_store:_acquire_user_creation_lock:no_redis_client',
'user_store:_acquire_user_creation_lock:redis_error',
extra={'user_id': user_id},
)
return True # Proceed without locking if Redis is unavailable
user_key = f'{_REDIS_USER_CREATION_KEY_PREFIX}{user_id}'
lock_acquired = await redis_client.set(
user_key, 1, nx=True, ex=_REDIS_CREATE_TIMEOUT_SECONDS
)
return bool(lock_acquired)
return True # Proceed without locking on error
@staticmethod
async def _release_user_creation_lock(user_id: str) -> bool:
@@ -142,17 +144,19 @@ class UserStore:
Returns True if the lock was released or if Redis is unavailable.
Returns False if the lock could not be released.
"""
from storage.redis import redis_exceptions
redis_client = UserStore._get_redis_client()
if redis_client is None:
try:
user_key = f'{_REDIS_USER_CREATION_KEY_PREFIX}{user_id}'
deleted = await redis_client.delete(user_key)
return bool(deleted)
except redis_exceptions.RedisError:
logger.warning(
'user_store:_release_user_creation_lock:no_redis_client',
'user_store:_release_user_creation_lock:redis_error',
extra={'user_id': user_id},
)
return True # Nothing to release if Redis is unavailable
user_key = f'{_REDIS_USER_CREATION_KEY_PREFIX}{user_id}'
deleted = await redis_client.delete(user_key)
return bool(deleted)
return True # Proceed without locking on error
@staticmethod
async def migrate_user(
+1 -1
View File
@@ -15,8 +15,8 @@ from storage.database import a_session_maker
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
from storage.gitlab_webhook_store import GitlabWebhookStore
from openhands.app_server.integrations.gitlab.gitlab_service import GitLabServiceImpl
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
if TYPE_CHECKING:
from integrations.gitlab.gitlab_service import SaaSGitLabService
@@ -34,7 +34,7 @@ class TestSaaSBitbucketDCServiceInit:
def test_refresh_flag_is_true(self):
# self.refresh = True is required so the base class BitbucketDCService
# retries the request with a refreshed token on 401 responses.
# See openhands/integrations/bitbucket_data_center/service/base.py,
# See openhands/app_server/integrations/bitbucket_data_center/service/base.py,
# which checks `if self.refresh` before attempting the retry.
service = SaaSBitbucketDCService()
assert service.refresh is True
@@ -24,7 +24,10 @@ def jinja_env() -> Environment:
repo_root = Path(__file__).resolve().parents[5]
return Environment(
loader=FileSystemLoader(
str(repo_root / 'openhands/integrations/templates/resolver/github')
str(
repo_root
/ 'openhands/app_server/integrations/templates/resolver/github'
)
)
)
@@ -18,8 +18,8 @@ from storage.jira_conversation import JiraConversation
from storage.jira_user import JiraUser
from storage.jira_workspace import JiraWorkspace
from openhands.integrations.service_types import ProviderType, Repository
from openhands.server.user_auth.user_auth import UserAuth
from openhands.app_server.integrations.service_types import ProviderType, Repository
from openhands.app_server.user_auth.user_auth import UserAuth
@pytest.fixture
@@ -16,8 +16,8 @@ from storage.jira_dc_conversation import JiraDcConversation
from storage.jira_dc_user import JiraDcUser
from storage.jira_dc_workspace import JiraDcWorkspace
from openhands.integrations.service_types import ProviderType, Repository
from openhands.server.user_auth.user_auth import UserAuth
from openhands.app_server.integrations.service_types import ProviderType, Repository
from openhands.app_server.user_auth.user_auth import UserAuth
@pytest.fixture
@@ -17,7 +17,7 @@ from integrations.jira_dc.jira_dc_view import (
)
from integrations.models import Message, SourceType
from openhands.integrations.service_types import ProviderType, Repository
from openhands.app_server.integrations.service_types import ProviderType, Repository
from openhands.server.types import (
LLMAuthenticationError,
MissingSettingsError,
@@ -17,7 +17,7 @@ from storage.slack_conversation import SlackConversation
from storage.slack_user import SlackUser
from openhands.app_server.sandbox.sandbox_models import SandboxStatus
from openhands.server.user_auth.user_auth import UserAuth
from openhands.app_server.user_auth.user_auth import UserAuth
# ---------------------------------------------------------------------------
# Fixtures
@@ -10,11 +10,11 @@ import pytest
from pydantic import SecretStr
from enterprise.integrations.resolver_context import ResolverUserContext
from openhands.app_server.secrets.secrets_models import Secrets
# Import the real classes we want to test
from openhands.integrations.provider import CustomSecret, ProviderToken
from openhands.integrations.service_types import ProviderType
from openhands.app_server.integrations.provider import CustomSecret, ProviderToken
from openhands.app_server.integrations.service_types import ProviderType
from openhands.app_server.secrets.secrets_models import Secrets
# Import the SDK types we need for testing
from openhands.sdk.secret import SecretSource, StaticSecret
@@ -344,7 +344,7 @@ async def test_get_provider_handler_creates_handler_with_correct_params(
handler = await resolver_context._get_provider_handler()
# Assert
from openhands.integrations.provider import ProviderHandler
from openhands.app_server.integrations.provider import ProviderHandler
assert isinstance(handler, ProviderHandler)
assert handler.provider_tokens == provider_tokens
@@ -19,7 +19,7 @@ from server.routes.api_keys import (
)
from storage.lite_llm_manager import LiteLlmManager
from openhands.server.user_auth.user_auth import AuthType
from openhands.app_server.user_auth.user_auth import AuthType
class TestVerifyByorKeyInLitellm:
@@ -22,7 +22,7 @@ from server.routes.orgs import (
from sqlalchemy.exc import IntegrityError
from storage.org_git_claim import OrgGitClaim
from openhands.server.user_auth import get_user_id
from openhands.app_server.user_auth import get_user_id
TEST_USER_ID = str(uuid.uuid4())
@@ -44,8 +44,8 @@ from server.routes.orgs import (
)
from storage.org import Org
from openhands.app_server.user_auth import get_user_id
from openhands.sdk.settings import AgentSettings, ConversationSettings
from openhands.server.user_auth import get_user_id
# Test user ID constant (must be a valid UUID string)
TEST_USER_ID = str(uuid.uuid4())
@@ -16,7 +16,7 @@ from server.routes.user_app_settings_models import (
UserNotFoundError,
)
from openhands.server.user_auth import get_user_id
from openhands.app_server.user_auth import get_user_id
TEST_USER_ID = str(uuid.uuid4())
@@ -144,7 +144,7 @@ class TestGetOrgInfoFromContext:
from server.routes.users_v1 import _get_org_info_from_context
from openhands.app_server.user.auth_user_context import AuthUserContext
from openhands.server.user_auth.user_auth import UserAuth
from openhands.app_server.user_auth.user_auth import UserAuth
# Create AuthUserContext with a non-SaasUserAuth
mock_user_auth = MagicMock(spec=UserAuth)
@@ -15,7 +15,9 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from openhands.integrations.service_types import ProviderType
from openhands.app_server.integrations.service_types import ProviderType
REDIS_PATCH = 'server.services.automation_event_service.get_redis_client_async'
# Default patches for constants
CONSTANT_PATCHES = {
@@ -91,10 +93,8 @@ def github_user_payload():
def create_service(mock_token_manager):
"""Helper to create a service with mocked sio and constants."""
with patch('server.services.automation_event_service.sio'), patch.dict(
'os.environ', {}, clear=False
):
"""Helper to create a service with mocked constants."""
with patch.dict('os.environ', {}, clear=False):
for key, value in CONSTANT_PATCHES.items():
patch(key, value).start()
@@ -123,9 +123,7 @@ class TestResolveGitOrg:
'server.services.automation_event_service.resolve_org_for_repo',
new_callable=AsyncMock,
return_value=mock_org_git_claim.org_id,
), patch('server.services.automation_event_service.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
), patch(REDIS_PATCH, return_value=mock_redis):
service = create_service(mock_token_manager)
result = await service._resolve_git_org(ProviderType.GITHUB, 'test-org')
@@ -147,11 +145,7 @@ class TestResolveGitOrg:
with patch(
'server.services.automation_event_service.resolve_org_for_repo',
new_callable=AsyncMock,
) as mock_resolver, patch(
'server.services.automation_event_service.sio'
) as mock_sio:
mock_sio.manager.redis = mock_redis
) as mock_resolver, patch(REDIS_PATCH, return_value=mock_redis):
service = create_service(mock_token_manager)
result = await service._resolve_git_org(ProviderType.GITHUB, 'test-org')
@@ -174,9 +168,7 @@ class TestResolveGitOrg:
'server.services.automation_event_service.resolve_org_for_repo',
new_callable=AsyncMock,
return_value=None,
), patch('server.services.automation_event_service.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
), patch(REDIS_PATCH, return_value=mock_redis):
service = create_service(mock_token_manager)
result = await service._resolve_git_org(
ProviderType.GITHUB, 'unclaimed-org'
@@ -202,11 +194,7 @@ class TestResolveGitOrg:
with patch(
'server.services.automation_event_service.resolve_org_for_repo',
new_callable=AsyncMock,
) as mock_resolver, patch(
'server.services.automation_event_service.sio'
) as mock_sio:
mock_sio.manager.redis = mock_redis
) as mock_resolver, patch(REDIS_PATCH, return_value=mock_redis):
service = create_service(mock_token_manager)
result = await service._resolve_git_org(
ProviderType.GITHUB, 'unclaimed-org'
@@ -232,9 +220,7 @@ class TestResolveGitOrg:
'server.services.automation_event_service.resolve_org_for_repo',
new_callable=AsyncMock,
return_value=mock_org_git_claim.org_id,
), patch('server.services.automation_event_service.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
), patch(REDIS_PATCH, return_value=mock_redis):
service = create_service(mock_token_manager)
# Call for GitHub
@@ -264,9 +250,7 @@ class TestResolvePersonalOrg:
mock_redis.get = AsyncMock(return_value=None) # Cache miss
mock_redis.setex = AsyncMock()
with patch('server.services.automation_event_service.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
with patch(REDIS_PATCH, return_value=mock_redis):
service = create_service(mock_token_manager)
result = await service._resolve_personal_org(ProviderType.GITHUB, 12345)
@@ -284,9 +268,7 @@ class TestResolvePersonalOrg:
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value=keycloak_id.encode())
with patch('server.services.automation_event_service.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
with patch(REDIS_PATCH, return_value=mock_redis):
service = create_service(mock_token_manager)
result = await service._resolve_personal_org(ProviderType.GITHUB, 12345)
@@ -324,9 +306,7 @@ class TestResolvePersonalOrg:
mock_redis.get = AsyncMock(return_value=None)
mock_redis.setex = AsyncMock()
with patch('server.services.automation_event_service.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
with patch(REDIS_PATCH, return_value=mock_redis):
service = create_service(mock_token_manager)
# Call for GitHub
@@ -359,15 +339,11 @@ class TestForwardEvent:
'server.services.automation_event_service.resolve_org_for_repo',
new_callable=AsyncMock,
return_value=mock_org_git_claim.org_id,
), patch(
'server.services.automation_event_service.sio'
) as mock_sio, patch.object(
), patch(REDIS_PATCH, return_value=mock_redis), patch.object(
AutomationEventService,
'_send_to_automation_service',
new_callable=AsyncMock,
) as mock_send:
mock_sio.manager.redis = mock_redis
service = AutomationEventService(mock_token_manager)
await service.forward_event(
provider=ProviderType.GITHUB,
@@ -411,15 +387,11 @@ class TestForwardEvent:
'server.services.automation_event_service.resolve_org_for_repo',
new_callable=AsyncMock,
return_value=None, # No org claim for personal repo
), patch(
'server.services.automation_event_service.sio'
) as mock_sio, patch.object(
), patch(REDIS_PATCH, return_value=mock_redis), patch.object(
AutomationEventService,
'_send_to_automation_service',
new_callable=AsyncMock,
) as mock_send:
mock_sio.manager.redis = mock_redis
service = AutomationEventService(mock_token_manager)
await service.forward_event(
provider=ProviderType.GITHUB,
@@ -450,7 +422,7 @@ class TestForwardEvent:
'sender': {'id': 12345, 'login': 'testuser'},
}
with patch('server.services.automation_event_service.sio'), patch(
with patch(
'server.services.automation_event_service.logger'
) as mock_logger, patch.object(
AutomationEventService,
@@ -487,15 +459,13 @@ class TestForwardEvent:
'server.services.automation_event_service.resolve_org_for_repo',
new_callable=AsyncMock,
return_value=None,
), patch('server.services.automation_event_service.sio') as mock_sio, patch(
), patch(REDIS_PATCH, return_value=mock_redis), patch(
'server.services.automation_event_service.logger'
) as mock_logger, patch.object(
AutomationEventService,
'_send_to_automation_service',
new_callable=AsyncMock,
) as mock_send:
mock_sio.manager.redis = mock_redis
service = AutomationEventService(mock_token_manager)
await service.forward_event(
provider=ProviderType.GITHUB,
@@ -608,7 +578,7 @@ class TestSendToAutomationService:
with patch(
'server.services.automation_event_service.AUTOMATION_SERVICE_URL',
'https://automation.example.com',
), patch('server.services.automation_event_service.sio'), patch(
), patch(
'server.services.automation_event_service.aiohttp.ClientSession',
return_value=mock_session_context,
):
@@ -653,7 +623,7 @@ class TestSendToAutomationService:
with patch(
'server.services.automation_event_service.AUTOMATION_SERVICE_URL',
'https://automation.example.com',
), patch('server.services.automation_event_service.sio'), patch(
), patch(
'server.services.automation_event_service.aiohttp.ClientSession',
return_value=mock_session_context,
):
@@ -678,9 +648,7 @@ class TestSendToAutomationService:
with patch(
'server.services.automation_event_service.AUTOMATION_SERVICE_URL', None
), patch('server.services.automation_event_service.sio'), patch(
'server.services.automation_event_service.logger'
) as mock_logger:
), patch('server.services.automation_event_service.logger') as mock_logger:
service = create_service(mock_token_manager)
await service._send_to_automation_service(
ProviderType.GITHUB, org_id, payload
@@ -702,7 +670,7 @@ class TestSignPayload:
with patch(
'server.services.automation_event_service.AUTOMATION_WEBHOOK_SECRET',
'test-shared-secret',
), patch('server.services.automation_event_service.sio'):
):
service = create_service(mock_token_manager)
payload_bytes = b'{"test": "data"}'
@@ -734,7 +702,7 @@ class TestSignPayload:
with patch(
'server.services.automation_event_service.AUTOMATION_WEBHOOK_SECRET',
shared_secret,
), patch('server.services.automation_event_service.sio'):
):
service = create_service(mock_token_manager)
signature = service._sign_payload(payload_bytes)
@@ -754,9 +722,7 @@ class TestCacheHelpers:
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value=b'cached-value')
with patch('server.services.automation_event_service.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
with patch(REDIS_PATCH, return_value=mock_redis):
service = create_service(mock_token_manager)
result = await service._get_cached_value('test-key')
@@ -772,9 +738,7 @@ class TestCacheHelpers:
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value=None)
with patch('server.services.automation_event_service.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
with patch(REDIS_PATCH, return_value=mock_redis):
service = create_service(mock_token_manager)
result = await service._get_cached_value('test-key')
@@ -787,9 +751,7 @@ class TestCacheHelpers:
WHEN: _get_cached_value is called
THEN: None is returned (graceful degradation)
"""
with patch('server.services.automation_event_service.sio') as mock_sio:
mock_sio.manager.redis = None
with patch(REDIS_PATCH, return_value=None):
service = create_service(mock_token_manager)
result = await service._get_cached_value('test-key')
@@ -805,9 +767,7 @@ class TestCacheHelpers:
mock_redis = AsyncMock()
mock_redis.setex = AsyncMock()
with patch('server.services.automation_event_service.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
with patch(REDIS_PATCH, return_value=mock_redis):
service = create_service(mock_token_manager)
await service._set_cached_value('test-key', 'test-value', 3600)
@@ -820,9 +780,7 @@ class TestCacheHelpers:
WHEN: _set_cached_value is called
THEN: No error is raised (silent failure)
"""
with patch('server.services.automation_event_service.sio') as mock_sio:
mock_sio.manager.redis = None
with patch(REDIS_PATCH, return_value=None):
service = create_service(mock_token_manager)
# Should not raise
await service._set_cached_value('test-key', 'test-value', 3600)
@@ -8,6 +8,8 @@ from server.utils.rate_limit_utils import (
check_rate_limit_by_user_id,
)
REDIS_PATCH = 'server.utils.rate_limit_utils.get_redis_client_async'
@pytest.fixture
def mock_request():
@@ -34,11 +36,9 @@ async def test_rate_limit_by_user_id_first_request_succeeds(mock_request, mock_r
key_prefix = 'email_resend'
with (
patch('server.utils.rate_limit_utils.sio') as mock_sio,
patch(REDIS_PATCH, return_value=mock_redis),
patch('server.utils.rate_limit_utils.logger') as mock_logger,
):
mock_sio.manager.redis = mock_redis
# Act
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=user_id
@@ -63,11 +63,9 @@ async def test_rate_limit_by_user_id_second_request_within_window_fails(
mock_redis.set = AsyncMock(return_value=False) # Key already exists
with (
patch('server.utils.rate_limit_utils.sio') as mock_sio,
patch(REDIS_PATCH, return_value=mock_redis),
patch('server.utils.rate_limit_utils.logger') as mock_logger,
):
mock_sio.manager.redis = mock_redis
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await check_rate_limit_by_user_id(
@@ -87,11 +85,9 @@ async def test_rate_limit_by_ip_when_user_id_is_none(mock_request, mock_redis):
key_prefix = 'email_resend'
with (
patch('server.utils.rate_limit_utils.sio') as mock_sio,
patch(REDIS_PATCH, return_value=mock_redis),
patch('server.utils.rate_limit_utils.logger') as mock_logger,
):
mock_sio.manager.redis = mock_redis
# Act
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=None
@@ -116,11 +112,7 @@ async def test_rate_limit_by_ip_second_request_within_window_fails(
key_prefix = 'email_resend'
mock_redis.set = AsyncMock(return_value=False) # Key already exists
with (
patch('server.utils.rate_limit_utils.sio') as mock_sio,
):
mock_sio.manager.redis = mock_redis
with patch(REDIS_PATCH, return_value=mock_redis):
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await check_rate_limit_by_user_id(
@@ -139,11 +131,9 @@ async def test_rate_limit_redis_unavailable_fails_open(mock_request):
user_id = 'test_user_id'
with (
patch('server.utils.rate_limit_utils.sio') as mock_sio,
patch(REDIS_PATCH, return_value=None),
patch('server.utils.rate_limit_utils.logger') as mock_logger,
):
mock_sio.manager.redis = None # Redis unavailable
# Act
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=user_id
@@ -164,11 +154,9 @@ async def test_rate_limit_redis_exception_fails_open(mock_request, mock_redis):
mock_redis.set = AsyncMock(side_effect=Exception('Redis connection error'))
with (
patch('server.utils.rate_limit_utils.sio') as mock_sio,
patch(REDIS_PATCH, return_value=mock_redis),
patch('server.utils.rate_limit_utils.logger') as mock_logger,
):
mock_sio.manager.redis = mock_redis
# Act
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=user_id
@@ -186,9 +174,7 @@ async def test_rate_limit_custom_key_prefix(mock_request, mock_redis):
user_id = 'test_user_id'
key_prefix = 'password_reset'
with patch('server.utils.rate_limit_utils.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
with patch(REDIS_PATCH, return_value=mock_redis):
# Act
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=user_id
@@ -209,9 +195,7 @@ async def test_rate_limit_custom_rate_limit_seconds(mock_request, mock_redis):
custom_user_seconds = 60
custom_ip_seconds = 180
with patch('server.utils.rate_limit_utils.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
with patch(REDIS_PATCH, return_value=mock_redis):
# Act
await check_rate_limit_by_user_id(
request=mock_request,
@@ -234,9 +218,7 @@ async def test_rate_limit_ip_with_unknown_client(mock_request, mock_redis):
key_prefix = 'email_resend'
mock_request.client = None # No client information
with patch('server.utils.rate_limit_utils.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
with patch(REDIS_PATCH, return_value=mock_redis):
# Act
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=None
@@ -258,9 +240,7 @@ async def test_rate_limit_different_users_have_separate_limits(
user_id_1 = 'user_1'
user_id_2 = 'user_2'
with patch('server.utils.rate_limit_utils.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
with patch(REDIS_PATCH, return_value=mock_redis):
# Act
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=user_id_1
@@ -15,7 +15,7 @@ from storage.auth_token_store import (
from storage.auth_tokens import AuthTokens
from storage.base import Base
from openhands.integrations.service_types import ProviderType
from openhands.app_server.integrations.service_types import ProviderType
@pytest.fixture
@@ -24,8 +24,8 @@ from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationInfo,
ConversationTrigger,
)
from openhands.app_server.integrations.service_types import ProviderType
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
from openhands.integrations.service_types import ProviderType
# Test UUIDs
USER1_ID = UUID('a1111111-1111-1111-1111-111111111111')
@@ -13,7 +13,7 @@ from server.auth.auth_error import (
from server.auth.saas_user_auth import SaasUserAuth
from server.middleware import SetAuthCookieMiddleware
from openhands.server.user_auth.user_auth import AuthType
from openhands.app_server.user_auth.user_auth import AuthType
@pytest.fixture
+2 -2
View File
@@ -19,7 +19,7 @@ from server.routes.auth import (
set_response_cookie,
)
from openhands.integrations.service_types import ProviderType
from openhands.app_server.integrations.service_types import ProviderType
def create_mock_user_authorizer(success: bool = True, error_detail: str | None = None):
@@ -799,7 +799,7 @@ async def test_logout_without_refresh_token():
with patch('server.routes.auth.token_manager') as mock_token_manager:
with patch(
'openhands.server.user_auth.default_user_auth.DefaultUserAuth.get_instance'
'openhands.app_server.user_auth.default_user_auth.DefaultUserAuth.get_instance'
) as mock_get_instance:
mock_get_instance.side_effect = AuthError()
result = await logout(mock_request)
+12 -12
View File
@@ -15,9 +15,9 @@ from server.routes.integration.gitlab import gitlab_events
@pytest.mark.asyncio
@patch('server.routes.integration.gitlab.verify_gitlab_signature')
@patch('server.routes.integration.gitlab.gitlab_manager')
@patch('server.routes.integration.gitlab.sio')
@patch('server.routes.integration.gitlab.get_redis_client_async')
async def test_gitlab_events_deduplication_with_object_id(
mock_sio, mock_gitlab_manager, mock_verify_signature
mock_get_redis_client_async, mock_gitlab_manager, mock_verify_signature
):
"""Test that duplicate GitLab events are deduplicated using object_attributes.id."""
# Setup mocks
@@ -26,7 +26,7 @@ async def test_gitlab_events_deduplication_with_object_id(
# Mock Redis
mock_redis = AsyncMock()
mock_sio.manager.redis = mock_redis
mock_get_redis_client_async.return_value = mock_redis
# First request - Redis returns True (key was set)
mock_redis.set.return_value = True
@@ -90,9 +90,9 @@ async def test_gitlab_events_deduplication_with_object_id(
@pytest.mark.asyncio
@patch('server.routes.integration.gitlab.verify_gitlab_signature')
@patch('server.routes.integration.gitlab.gitlab_manager')
@patch('server.routes.integration.gitlab.sio')
@patch('server.routes.integration.gitlab.get_redis_client_async')
async def test_gitlab_events_deduplication_without_object_id(
mock_sio, mock_gitlab_manager, mock_verify_signature
mock_get_redis_client_async, mock_gitlab_manager, mock_verify_signature
):
"""Test that GitLab events without object_attributes.id are deduplicated using hash of payload."""
# Setup mocks
@@ -101,7 +101,7 @@ async def test_gitlab_events_deduplication_without_object_id(
# Mock Redis
mock_redis = AsyncMock()
mock_sio.manager.redis = mock_redis
mock_get_redis_client_async.return_value = mock_redis
# First request - Redis returns True (key was set)
mock_redis.set.return_value = True
@@ -170,9 +170,9 @@ async def test_gitlab_events_deduplication_without_object_id(
@pytest.mark.asyncio
@patch('server.routes.integration.gitlab.verify_gitlab_signature')
@patch('server.routes.integration.gitlab.gitlab_manager')
@patch('server.routes.integration.gitlab.sio')
@patch('server.routes.integration.gitlab.get_redis_client_async')
async def test_gitlab_events_different_payloads_not_deduplicated(
mock_sio, mock_gitlab_manager, mock_verify_signature
mock_get_redis_client_async, mock_gitlab_manager, mock_verify_signature
):
"""Test that different GitLab events are not deduplicated."""
# Setup mocks
@@ -181,7 +181,7 @@ async def test_gitlab_events_different_payloads_not_deduplicated(
# Mock Redis
mock_redis = AsyncMock()
mock_sio.manager.redis = mock_redis
mock_get_redis_client_async.return_value = mock_redis
mock_redis.set.return_value = True # Always return True for this test
# First payload with ID 123
@@ -240,9 +240,9 @@ async def test_gitlab_events_different_payloads_not_deduplicated(
@pytest.mark.asyncio
@patch('server.routes.integration.gitlab.verify_gitlab_signature')
@patch('server.routes.integration.gitlab.gitlab_manager')
@patch('server.routes.integration.gitlab.sio')
@patch('server.routes.integration.gitlab.get_redis_client_async')
async def test_gitlab_events_multiple_identical_payloads_deduplicated(
mock_sio, mock_gitlab_manager, mock_verify_signature
mock_get_redis_client_async, mock_gitlab_manager, mock_verify_signature
):
"""Test that multiple identical GitLab events are properly deduplicated."""
# Setup mocks
@@ -251,7 +251,7 @@ async def test_gitlab_events_multiple_identical_payloads_deduplicated(
# Mock Redis
mock_redis = AsyncMock()
mock_sio.manager.redis = mock_redis
mock_get_redis_client_async.return_value = mock_redis
# Create a payload with object_attributes.id
payload = {
@@ -68,7 +68,7 @@ class TestAcceptInvitationPostEndpoint:
def auth_app(self):
"""Create a FastAPI app with dependency overrides for authenticated tests."""
from openhands.server.user_auth import get_user_id
from openhands.app_server.user_auth import get_user_id
app = FastAPI()
app.include_router(accept_router)
@@ -200,7 +200,7 @@ class TestCreateInvitationBatchEndpoint:
@pytest.fixture
def batch_app(self):
"""Create a FastAPI app with dependency overrides for batch tests."""
from openhands.server.user_auth import get_user_id
from openhands.app_server.user_auth import get_user_id
app = FastAPI()
app.include_router(invitation_router)
@@ -8,9 +8,9 @@ from pydantic import SecretStr
from storage.saas_secrets_store import SaasSecretsStore
from storage.stored_custom_secrets import StoredCustomSecrets
from openhands.app_server.integrations.provider import CustomSecret
from openhands.app_server.secrets.secrets_models import Secrets
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.integrations.provider import CustomSecret
@pytest.fixture
+1 -1
View File
@@ -22,8 +22,8 @@ from server.auth.saas_user_auth import (
from storage.api_key_store import ApiKeyValidationResult
from storage.user_authorization import UserAuthorizationType
from openhands.app_server.integrations.provider import ProviderToken, ProviderType
from openhands.app_server.secrets.secrets_models import Secrets
from openhands.integrations.provider import ProviderToken, ProviderType
@pytest.fixture
@@ -21,9 +21,9 @@ from openhands.app_server.app_conversation.app_conversation_models import (
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
SQLAppConversationInfoService,
)
from openhands.app_server.integrations.provider import ProviderType
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
from openhands.app_server.utils.sql_utils import Base
from openhands.integrations.provider import ProviderType
from openhands.sdk.llm import MetricsSnapshot, TokenUsage
+41 -32
View File
@@ -11,12 +11,12 @@ from integrations.slack.slack_manager import (
from integrations.slack.slack_view import SlackNewConversationView
from storage.slack_user import SlackUser
from openhands.integrations.service_types import (
from openhands.app_server.integrations.service_types import (
ProviderTimeoutError,
ProviderType,
Repository,
)
from openhands.server.user_auth.user_auth import UserAuth
from openhands.app_server.user_auth.user_auth import UserAuth
@pytest.fixture
@@ -89,21 +89,21 @@ def test_infer_repo_from_message(message, expected):
class TestRepoVerificationHandling:
"""Test repo verification handling for Slack integration."""
@patch('integrations.slack.slack_manager.sio')
@patch('integrations.slack.slack_manager.get_redis_client_async')
@patch('integrations.slack.slack_manager.ProviderHandler')
@patch.object(SlackManager, 'send_message', new_callable=AsyncMock)
async def test_timeout_during_verification_shows_selector(
self,
mock_send_message,
mock_provider_handler_class,
mock_sio,
mock_get_redis_client_async,
slack_manager,
slack_new_conversation_view,
):
"""Test that when repo verification times out, selector is shown."""
# Setup Redis mock
mock_redis = AsyncMock()
mock_sio.manager.redis = mock_redis
mock_get_redis_client_async.return_value = mock_redis
# Setup: Modify message to include exactly one repo reference to trigger verification
slack_new_conversation_view.user_msg = 'Help me with OpenHands/OpenHands repo'
@@ -132,12 +132,12 @@ class TestRepoVerificationHandling:
assert isinstance(selector_message, dict)
assert selector_message.get('text') == 'Choose a Repository:'
@patch('integrations.slack.slack_manager.sio')
@patch('integrations.slack.slack_manager.get_redis_client_async')
@patch.object(SlackManager, 'send_message', new_callable=AsyncMock)
async def test_no_repo_mentioned_shows_button_and_dropdown(
self,
mock_send_message,
mock_sio,
mock_get_redis_client_async,
slack_manager,
slack_new_conversation_view,
):
@@ -149,7 +149,7 @@ class TestRepoVerificationHandling:
"""
# Setup Redis mock
mock_redis = AsyncMock()
mock_sio.manager.redis = mock_redis
mock_get_redis_client_async.return_value = mock_redis
# Setup: user message without any repo mention
slack_new_conversation_view.user_msg = 'Hello, can you help me?'
@@ -189,10 +189,10 @@ class TestRepoVerificationHandling:
assert elements[1].get('action_id').startswith('repository_select:')
@pytest.mark.asyncio
@patch('integrations.slack.slack_manager.sio')
@patch('integrations.slack.slack_manager.get_redis_client_async')
async def test_no_repository_button_click_processes_correctly(
self,
mock_sio,
mock_get_redis_client_async,
slack_manager,
):
"""Test that clicking 'No Repository' button correctly processes the interaction.
@@ -202,7 +202,7 @@ class TestRepoVerificationHandling:
"""
# Setup: Mock Redis to return a stored user message
mock_redis = AsyncMock()
mock_sio.manager.redis = mock_redis
mock_get_redis_client_async.return_value = mock_redis
stored_msg = json.dumps({'text': 'Hello, help me with code', 'user': 'U123'})
mock_redis.get = AsyncMock(return_value=stored_msg)
@@ -236,14 +236,14 @@ class TestRepoVerificationHandling:
assert call_args.message['message_ts'] == '1234567890.123456'
assert call_args.message['thread_ts'] is None
@patch('integrations.slack.slack_manager.sio')
@patch('integrations.slack.slack_manager.get_redis_client_async')
@patch('integrations.slack.slack_manager.ProviderHandler')
@patch.object(SlackManager, 'send_message', new_callable=AsyncMock)
async def test_verified_repo_starts_job(
self,
mock_send_message,
mock_provider_handler_class,
mock_sio,
mock_get_redis_client_async,
slack_manager,
slack_new_conversation_view,
):
@@ -251,7 +251,7 @@ class TestRepoVerificationHandling:
# Setup Redis mock
mock_redis = AsyncMock()
mock_sio.manager.redis = mock_redis
mock_get_redis_client_async.return_value = mock_redis
# Setup: Modify message to include exactly one repo reference
slack_new_conversation_view.user_msg = 'Help me with OpenHands/OpenHands repo'
@@ -532,13 +532,18 @@ class TestUserMsgStorage:
],
ids=['with_thread', 'without_thread', 'different_timestamps'],
)
@patch('integrations.slack.slack_manager.sio')
@patch('integrations.slack.slack_manager.get_redis_client_async')
async def test_store_user_msg_for_form(
self, mock_sio, slack_manager, message_ts, thread_ts, user_msg
self,
mock_get_redis_client_async,
slack_manager,
message_ts,
thread_ts,
user_msg,
):
"""Test storing user message in Redis with various timestamp combinations."""
mock_redis = AsyncMock()
mock_sio.manager.redis = mock_redis
mock_get_redis_client_async.return_value = mock_redis
# Should not raise an exception on success
await slack_manager._store_user_msg_for_form(message_ts, thread_ts, user_msg)
@@ -557,16 +562,16 @@ class TestUserMsgStorage:
],
ids=['connection_error', 'timeout_error', 'generic_exception'],
)
@patch('integrations.slack.slack_manager.sio')
@patch('integrations.slack.slack_manager.get_redis_client_async')
async def test_store_user_msg_for_form_redis_failure(
self, mock_sio, slack_manager, exception_type, exception_msg
self, mock_get_redis_client_async, slack_manager, exception_type, exception_msg
):
"""Test that Redis failures during store raise SlackError."""
from integrations.slack.slack_errors import SlackError, SlackErrorCode
mock_redis = AsyncMock()
mock_redis.set.side_effect = exception_type(exception_msg)
mock_sio.manager.redis = mock_redis
mock_get_redis_client_async.return_value = mock_redis
message_ts = '1234567890.123456'
thread_ts = '1234567890.111111'
@@ -591,14 +596,18 @@ class TestUserMsgStorage:
],
ids=['bytes_response', 'string_response'],
)
@patch('integrations.slack.slack_manager.sio')
@patch('integrations.slack.slack_manager.get_redis_client_async')
async def test_retrieve_user_msg_for_form(
self, mock_sio, slack_manager, redis_return_value, expected_result
self,
mock_get_redis_client_async,
slack_manager,
redis_return_value,
expected_result,
):
"""Test retrieving user message from Redis with various response types."""
mock_redis = AsyncMock()
mock_redis.get.return_value = redis_return_value
mock_sio.manager.redis = mock_redis
mock_get_redis_client_async.return_value = mock_redis
message_ts = '1234567890.123456'
thread_ts = '1234567890.111111'
@@ -609,16 +618,16 @@ class TestUserMsgStorage:
mock_redis.get.assert_called_once_with(expected_key)
assert result == expected_result
@patch('integrations.slack.slack_manager.sio')
@patch('integrations.slack.slack_manager.get_redis_client_async')
async def test_retrieve_user_msg_for_form_key_not_found(
self, mock_sio, slack_manager
self, mock_get_redis_client_async, slack_manager
):
"""Test that missing key raises SlackError with SESSION_EXPIRED."""
from integrations.slack.slack_errors import SlackError, SlackErrorCode
mock_redis = AsyncMock()
mock_redis.get.return_value = None
mock_sio.manager.redis = mock_redis
mock_get_redis_client_async.return_value = mock_redis
message_ts = '1234567890.123456'
thread_ts = '1234567890.111111'
@@ -637,16 +646,16 @@ class TestUserMsgStorage:
],
ids=['connection_error', 'timeout_error'],
)
@patch('integrations.slack.slack_manager.sio')
@patch('integrations.slack.slack_manager.get_redis_client_async')
async def test_retrieve_user_msg_for_form_redis_failure(
self, mock_sio, slack_manager, exception_type, exception_msg
self, mock_get_redis_client_async, slack_manager, exception_type, exception_msg
):
"""Test that Redis failures during retrieve raise SlackError."""
from integrations.slack.slack_errors import SlackError, SlackErrorCode
mock_redis = AsyncMock()
mock_redis.get.side_effect = exception_type(exception_msg)
mock_sio.manager.redis = mock_redis
mock_get_redis_client_async.return_value = mock_redis
message_ts = '1234567890.123456'
thread_ts = '1234567890.111111'
@@ -661,18 +670,18 @@ class TestUserMsgStorage:
class TestIsJobRequestedWithUserMsgStorage:
"""Test that is_job_requested properly stores user message for form flow."""
@patch('integrations.slack.slack_manager.sio')
@patch('integrations.slack.slack_manager.get_redis_client_async')
@patch.object(SlackManager, 'send_message', new_callable=AsyncMock)
async def test_stores_user_msg_when_showing_repo_selector(
self,
mock_send_message,
mock_sio,
mock_get_redis_client_async,
slack_manager,
slack_new_conversation_view,
):
"""Test that user_msg is stored in Redis when repo selector is shown."""
mock_redis = AsyncMock()
mock_sio.manager.redis = mock_redis
mock_get_redis_client_async.return_value = mock_redis
# Setup: user message without any repo mention (no repo inferred)
slack_new_conversation_view.user_msg = 'Hello, can you help me?'
@@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from server.auth.token_manager import TokenManager, create_encryption_utility
from openhands.integrations.service_types import ProviderType
from openhands.app_server.integrations.service_types import ProviderType
@pytest.fixture
@@ -12,9 +12,9 @@ import pytest
from fastapi import HTTPException, status
from pydantic import SecretStr
from openhands.app_server.integrations.provider import ProviderToken
from openhands.app_server.integrations.service_types import ProviderType
from openhands.app_server.user.user_context import UserContext
from openhands.integrations.provider import ProviderToken
from openhands.integrations.service_types import ProviderType
def _make_user_context(provider_tokens, user_id: str = 'user-1') -> UserContext:
@@ -97,7 +97,7 @@ async def test_returns_organizations_for_supported_provider(
)
with patch(
'openhands.integrations.provider.ProviderHandler.get_service'
'openhands.app_server.integrations.provider.ProviderHandler.get_service'
) as mock_get_service:
mock_service = mock_get_service.return_value
setattr(mock_service, service_method, AsyncMock(return_value=service_return))
+16 -6
View File
@@ -1020,9 +1020,14 @@ def test_create_user_settings_from_entities_with_org_fallback():
@pytest.mark.asyncio
async def test_acquire_user_creation_lock_no_redis():
"""Test that _acquire_user_creation_lock returns True when Redis is unavailable."""
with patch.object(UserStore, '_get_redis_client', return_value=None):
async def test_acquire_user_creation_lock_redis_error():
"""Test that _acquire_user_creation_lock returns True when Redis has an error."""
from redis import exceptions as redis_exceptions
mock_redis = AsyncMock()
mock_redis.set.side_effect = redis_exceptions.RedisError('Connection refused')
with patch.object(UserStore, '_get_redis_client', return_value=mock_redis):
result = await UserStore._acquire_user_creation_lock('test-user-id')
assert result is True
@@ -1054,9 +1059,14 @@ async def test_acquire_user_creation_lock_not_acquired():
@pytest.mark.asyncio
async def test_release_user_creation_lock_no_redis():
"""Test that _release_user_creation_lock returns True when Redis is unavailable."""
with patch.object(UserStore, '_get_redis_client', return_value=None):
async def test_release_user_creation_lock_redis_error():
"""Test that _release_user_creation_lock returns True when Redis has an error."""
from redis import exceptions as redis_exceptions
mock_redis = AsyncMock()
mock_redis.delete.side_effect = redis_exceptions.RedisError('Connection refused')
with patch.object(UserStore, '_get_redis_client', return_value=mock_redis):
result = await UserStore._release_user_creation_lock('test-user-id')
assert result is True
@@ -15,11 +15,11 @@ from openhands.agent_server.utils import OpenHandsUUID, utc_now
from openhands.app_server.event_callback.event_callback_models import (
EventCallbackProcessor,
)
from openhands.app_server.integrations.service_types import ProviderType, SuggestedTask
from openhands.app_server.sandbox.sandbox_models import SandboxStatus
# Import from new location and re-export for backward compatibility
from openhands.app_server.settings.settings_models import SandboxGroupingStrategy
from openhands.integrations.service_types import ProviderType, SuggestedTask
from openhands.sdk.conversation import ConversationExecutionStatus
from openhands.sdk.llm import MetricsSnapshot
from openhands.sdk.plugin import PluginSource
@@ -70,6 +70,8 @@ from openhands.app_server.event_callback.event_callback_service import (
from openhands.app_server.event_callback.set_title_callback_processor import (
SetTitleCallbackProcessor,
)
from openhands.app_server.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
from openhands.app_server.integrations.service_types import SuggestedTask
from openhands.app_server.pending_messages.pending_message_service import (
PendingMessageService,
)
@@ -92,8 +94,6 @@ from openhands.app_server.utils.llm_metadata import (
get_llm_metadata,
should_set_litellm_extra_body,
)
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
from openhands.integrations.service_types import SuggestedTask
from openhands.sdk import Agent, AgentContext, LocalWorkspace
from openhands.sdk.agent.acp_agent import ACPAgent
from openhands.sdk.hooks import HookConfig
@@ -15,10 +15,10 @@ import logging
import httpx
from pydantic import BaseModel
from openhands.app_server.integrations.provider import ProviderType
from openhands.app_server.integrations.service_types import AuthenticationError
from openhands.app_server.sandbox.sandbox_models import SandboxInfo
from openhands.app_server.user.user_context import UserContext
from openhands.integrations.provider import ProviderType
from openhands.integrations.service_types import AuthenticationError
from openhands.sdk.context.skills import KeywordTrigger, Skill, TaskTrigger
_logger = logging.getLogger(__name__)
@@ -48,13 +48,13 @@ from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationSortOrder,
ConversationTrigger,
)
from openhands.app_server.integrations.provider import ProviderType
from openhands.app_server.services.injector import InjectorState
from openhands.app_server.user.user_context import UserContext
from openhands.app_server.utils.sql_utils import (
Base,
create_json_type_decorator,
)
from openhands.integrations.provider import ProviderType
from openhands.sdk import ConversationStats
from openhands.sdk.event import ConversationStateUpdateEvent
from openhands.sdk.llm import MetricsSnapshot, TokenUsage
@@ -0,0 +1,73 @@
"""Conversation path helpers for consistent path construction.
This module provides helper functions for constructing conversation-related
storage paths. Use these helpers instead of hardcoding path patterns to ensure
consistency across the codebase.
"""
from pathlib import Path
from uuid import UUID
# The base directory name for v1 conversation storage
V1_CONVERSATIONS_DIR = 'v1_conversations'
def get_conversation_dir(conversation_id: UUID | str) -> str:
"""Get the conversation directory path segment.
Args:
conversation_id: The conversation ID (UUID or hex string)
Returns:
Path segment like 'v1_conversations/{conversation_id_hex}'
Example:
>>> get_conversation_dir(UUID('12345678-1234-5678-1234-567812345678'))
'v1_conversations/12345678123456781234567812345678'
>>> get_conversation_dir('12345678123456781234567812345678')
'v1_conversations/12345678123456781234567812345678'
"""
if isinstance(conversation_id, UUID):
conversation_id_hex = conversation_id.hex
else:
# Already a hex string
conversation_id_hex = conversation_id
return f'{V1_CONVERSATIONS_DIR}/{conversation_id_hex}'
def get_conversation_path(
conversation_id: UUID | str,
user_id: str | None = None,
prefix: Path | str | None = None,
) -> Path:
"""Get the full conversation path.
Args:
conversation_id: The conversation ID (UUID or hex string)
user_id: Optional user ID to include in path
prefix: Optional prefix path
Returns:
Full path like '{prefix}/{user_id}/v1_conversations/{conversation_id_hex}'
Example:
>>> get_conversation_path(UUID('...'), user_id='user123', prefix=Path('/data'))
Path('/data/user123/v1_conversations/...')
"""
if isinstance(conversation_id, UUID):
conversation_id_hex = conversation_id.hex
else:
conversation_id_hex = conversation_id
parts: list[str] = []
if prefix:
parts.append(str(prefix))
if user_id:
parts.append(user_id)
parts.append(V1_CONVERSATIONS_DIR)
parts.append(conversation_id_hex)
return Path(*parts) if parts else Path(V1_CONVERSATIONS_DIR) / conversation_id_hex
@@ -12,6 +12,7 @@ from openhands.app_server.app_conversation.app_conversation_info_service import
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationInfo,
)
from openhands.app_server.conversation_paths import V1_CONVERSATIONS_DIR
from openhands.app_server.event.event_service import EventService
from openhands.app_server.event_callback.event_callback_models import EventKind
from openhands.sdk import Event
@@ -60,7 +61,7 @@ class EventServiceBase(EventService, ABC):
conversation_info = await task
if conversation_info and conversation_info.created_by_user_id:
path /= conversation_info.created_by_user_id
path = path / 'v1_conversations' / conversation_id.hex
path = path / V1_CONVERSATIONS_DIR / conversation_id.hex
return path
async def get_event(self, conversation_id: UUID, event_id: UUID) -> Event | None:
@@ -34,6 +34,7 @@ from openhands.app_server.event_callback.event_callback_models import EventCallb
from openhands.app_server.event_callback.set_title_callback_processor import (
SetTitleCallbackProcessor,
)
from openhands.app_server.integrations.provider import ProviderType
from openhands.app_server.sandbox.sandbox_models import SandboxInfo
from openhands.app_server.services.injector import InjectorState
from openhands.app_server.services.jwt_service import JwtService
@@ -43,14 +44,13 @@ from openhands.app_server.user.specifiy_user_context import (
USER_CONTEXT_ATTR,
SpecifyUserContext,
)
from openhands.integrations.provider import ProviderType
from openhands.app_server.user_auth.default_user_auth import DefaultUserAuth
from openhands.app_server.user_auth.user_auth import (
get_for_user as get_user_auth_for_user,
)
from openhands.sdk import ConversationExecutionStatus, Event
from openhands.sdk.event import ConversationStateUpdateEvent
from openhands.server.types import AppMode
from openhands.server.user_auth.default_user_auth import DefaultUserAuth
from openhands.server.user_auth.user_auth import (
get_for_user as get_user_auth_for_user,
)
router = APIRouter(prefix='/webhooks', tags=['Webhooks'])
event_service_dependency = depends_event_service()
+77
View File
@@ -0,0 +1,77 @@
# OpenHands FileStore Module
The file store module provides different storage backends for file operations in OpenHands. This module implements a common interface (`FileStore`) that allows for interchangeable storage backends.
All FileStore implementations use `DiscriminatedUnionMixin` for automatic serialization/deserialization with a `kind` discriminator field based on the class name.
## Usage
```python
from openhands.app_server.file_store import get_file_store, LocalFileStore
# Using the factory function
store = get_file_store("local", "/tmp/file_store")
# Or instantiate directly
store = LocalFileStore(root="/tmp/file_store")
# Write, read, list, and delete operations
store.write("example.txt", "Hello, world!")
content = store.read("example.txt")
files = store.list("/")
store.delete("example.txt")
```
## Available Storage Backends
### 1. Local File Storage (`LocalFileStore`)
Local file storage saves files to the local filesystem.
**Parameters:**
- `root`: The root directory for file storage (supports `~` expansion)
### 2. In-Memory Storage (`InMemoryFileStore`)
In-memory storage keeps files in memory, useful for testing or temporary storage.
**Parameters:**
- `files`: Optional dictionary of initial files (default: empty)
### 3. Amazon S3 Storage (`S3FileStore`)
S3 storage uses Amazon S3 or compatible services for file storage.
**Parameters:**
- `bucket`: The S3 bucket name (falls back to `AWS_S3_BUCKET` environment variable if empty)
**Environment Variables:**
- `AWS_ACCESS_KEY_ID`: Your AWS access key
- `AWS_SECRET_ACCESS_KEY`: Your AWS secret key
- `AWS_S3_BUCKET`: Default bucket name (used if `bucket` parameter is empty)
- `AWS_S3_ENDPOINT`: Optional custom endpoint for S3-compatible services
- `AWS_S3_SECURE`: Whether to use HTTPS (default: "true")
### 4. Google Cloud Storage (`GoogleCloudFileStore`)
Google Cloud Storage uses GCS buckets for file storage.
**Parameters:**
- `bucket_name`: The GCS bucket name (falls back to `GOOGLE_CLOUD_BUCKET_NAME` environment variable if empty)
**Environment Variables:**
- `GOOGLE_CLOUD_BUCKET_NAME`: Default bucket name (used if `bucket_name` parameter is empty)
- `GOOGLE_APPLICATION_CREDENTIALS`: Path to Google Cloud credentials JSON file
## Configuration
To configure the file store in OpenHands, use the following configuration options:
```toml
[core]
# File store type: "local", "memory", "s3", "google_cloud"
file_store = "local"
# Path/bucket for file store (interpretation depends on file_store type)
file_store_path = "/tmp/file_store"
```
@@ -0,0 +1,21 @@
from openhands.app_server.file_store.files import FileStore
from openhands.app_server.file_store.google_cloud import GoogleCloudFileStore
from openhands.app_server.file_store.local import LocalFileStore
from openhands.app_server.file_store.memory import InMemoryFileStore
from openhands.app_server.file_store.s3 import S3FileStore
def get_file_store(
file_store_type: str,
file_store_path: str | None = None,
) -> FileStore:
if file_store_type == 'local':
if file_store_path is None:
raise ValueError('file_store_path is required for local file store')
return LocalFileStore(root=file_store_path)
elif file_store_type == 's3':
return S3FileStore(bucket_name=file_store_path or '')
elif file_store_type == 'google_cloud':
return GoogleCloudFileStore(bucket_name=file_store_path or '')
else:
return InMemoryFileStore()
+30
View File
@@ -0,0 +1,30 @@
from abc import ABC, abstractmethod
from pydantic import ConfigDict
from openhands.sdk.utils.models import DiscriminatedUnionMixin
class FileStore(DiscriminatedUnionMixin, ABC):
"""Base class for file storage implementations.
Uses DiscriminatedUnionMixin for automatic `kind` field based on class name.
"""
model_config = ConfigDict(extra='forbid', arbitrary_types_allowed=True)
@abstractmethod
def write(self, path: str, contents: str | bytes) -> None:
pass
@abstractmethod
def read(self, path: str) -> str:
pass
@abstractmethod
def list(self, path: str) -> list[str]:
pass
@abstractmethod
def delete(self, path: str) -> None:
pass
@@ -5,21 +5,44 @@ from google.cloud import storage
from google.cloud.storage.blob import Blob
from google.cloud.storage.bucket import Bucket
from google.cloud.storage.client import Client
from pydantic import Field, PrivateAttr
from openhands.storage.files import FileStore
from openhands.app_server.file_store.files import FileStore
class GoogleCloudFileStore(FileStore):
def __init__(self, bucket_name: str | None = None) -> None:
"""Create a new FileStore.
"""Google Cloud Storage file store.
If GOOGLE_APPLICATION_CREDENTIALS is defined in the environment it will be used
for authentication. Otherwise access will be anonymous.
"""
if bucket_name is None:
bucket_name = os.environ['GOOGLE_CLOUD_BUCKET_NAME']
self.storage_client: Client = storage.Client()
self.bucket: Bucket = self.storage_client.bucket(bucket_name)
If GOOGLE_APPLICATION_CREDENTIALS is defined in the environment it will be used
for authentication. Otherwise access will be anonymous.
The storage client and bucket are initialized lazily on first access.
"""
bucket_name: str = Field(default='')
_storage_client: Client | None = PrivateAttr(default=None)
_bucket: Bucket | None = PrivateAttr(default=None)
def _get_bucket_name(self) -> str:
"""Get bucket name, falling back to environment variable if not set."""
if self.bucket_name:
return self.bucket_name
return os.environ['GOOGLE_CLOUD_BUCKET_NAME']
@property
def storage_client(self) -> Client:
"""Get the storage client, initializing lazily on first access."""
if self._storage_client is None:
self._storage_client = storage.Client()
return self._storage_client
@property
def bucket(self) -> Bucket:
"""Get the bucket, initializing lazily on first access."""
if self._bucket is None:
self._bucket = self.storage_client.bucket(self._get_bucket_name())
return self._bucket
def write(self, path: str, contents: str | bytes) -> None:
blob: Blob = self.bucket.blob(path)
@@ -2,18 +2,21 @@ import os
import shutil
import threading
from pydantic import model_validator
from openhands.app_server.file_store.files import FileStore
from openhands.core.logger import openhands_logger as logger
from openhands.storage.files import FileStore
class LocalFileStore(FileStore):
root: str
def __init__(self, root: str):
if root.startswith('~'):
root = os.path.expanduser(root)
self.root = root
@model_validator(mode='after')
def _setup_root(self) -> 'LocalFileStore':
if self.root.startswith('~'):
self.root = os.path.expanduser(self.root)
os.makedirs(self.root, exist_ok=True)
return self
def get_full_path(self, path: str) -> str:
if path.startswith('/'):
@@ -1,16 +1,13 @@
import os
from pydantic import Field
from openhands.app_server.file_store.files import FileStore
from openhands.core.logger import openhands_logger as logger
from openhands.storage.files import FileStore
class InMemoryFileStore(FileStore):
files: dict[str, str]
def __init__(self, files: dict[str, str] | None = None) -> None:
self.files = {}
if files is not None:
self.files = files
files: dict[str, str] = Field(default_factory=dict)
def write(self, path: str, contents: str | bytes) -> None:
if isinstance(contents, bytes):
@@ -3,8 +3,9 @@ from typing import Any, TypedDict
import boto3
import botocore
from pydantic import Field, PrivateAttr
from openhands.storage.files import FileStore
from openhands.app_server.file_store.files import FileStore
class S3ObjectDict(TypedDict):
@@ -20,45 +21,64 @@ class ListObjectsV2OutputDict(TypedDict):
class S3FileStore(FileStore):
def __init__(self, bucket_name: str | None) -> None:
access_key = os.getenv('AWS_ACCESS_KEY_ID')
secret_key = os.getenv('AWS_SECRET_ACCESS_KEY')
secure = os.getenv('AWS_S3_SECURE', 'true').lower() == 'true'
endpoint = self._ensure_url_scheme(secure, os.getenv('AWS_S3_ENDPOINT'))
if bucket_name is None:
bucket_name = os.environ['AWS_S3_BUCKET']
self.bucket: str = bucket_name
self.client: Any = boto3.client(
's3',
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
endpoint_url=endpoint,
use_ssl=secure,
)
"""S3-compatible file store.
The S3 client is initialized lazily on first access.
"""
bucket_name: str = Field(default='')
_client: Any = PrivateAttr(default=None)
_resolved_bucket: str | None = PrivateAttr(default=None)
def _get_bucket_name(self) -> str:
"""Get bucket name, falling back to environment variable if not set."""
if self._resolved_bucket is None:
self._resolved_bucket = self.bucket_name or os.environ['AWS_S3_BUCKET']
return self._resolved_bucket
@property
def client(self) -> Any:
"""Get the S3 client, initializing lazily on first access."""
if self._client is None:
access_key = os.getenv('AWS_ACCESS_KEY_ID')
secret_key = os.getenv('AWS_SECRET_ACCESS_KEY')
secure = os.getenv('AWS_S3_SECURE', 'true').lower() == 'true'
endpoint = self._ensure_url_scheme(secure, os.getenv('AWS_S3_ENDPOINT'))
self._client = boto3.client(
's3',
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
endpoint_url=endpoint,
use_ssl=secure,
)
return self._client
def write(self, path: str, contents: str | bytes) -> None:
try:
as_bytes = (
contents.encode('utf-8') if isinstance(contents, str) else contents
)
self.client.put_object(Bucket=self.bucket, Key=path, Body=as_bytes)
self.client.put_object(
Bucket=self._get_bucket_name(), Key=path, Body=as_bytes
)
except botocore.exceptions.ClientError as e:
if e.response['Error']['Code'] == 'AccessDenied':
raise FileNotFoundError(
f"Error: Access denied to bucket '{self.bucket}'."
f"Error: Access denied to bucket '{self._get_bucket_name()}'."
)
elif e.response['Error']['Code'] == 'NoSuchBucket':
raise FileNotFoundError(
f"Error: The bucket '{self.bucket}' does not exist."
f"Error: The bucket '{self._get_bucket_name()}' does not exist."
)
raise FileNotFoundError(
f"Error: Failed to write to bucket '{self.bucket}' at path {path}: {e}"
f"Error: Failed to write to bucket '{self._get_bucket_name()}' at path {path}: {e}"
)
def read(self, path: str) -> str:
try:
response: GetObjectOutputDict = self.client.get_object(
Bucket=self.bucket, Key=path
Bucket=self._get_bucket_name(), Key=path
)
with response['Body'] as stream:
return str(stream.read().decode('utf-8'))
@@ -66,19 +86,19 @@ class S3FileStore(FileStore):
# Catch all S3-related errors
if e.response['Error']['Code'] == 'NoSuchBucket':
raise FileNotFoundError(
f"Error: The bucket '{self.bucket}' does not exist."
f"Error: The bucket '{self._get_bucket_name()}' does not exist."
)
elif e.response['Error']['Code'] == 'NoSuchKey':
raise FileNotFoundError(
f"Error: The object key '{path}' does not exist in bucket '{self.bucket}'."
f"Error: The object key '{path}' does not exist in bucket '{self._get_bucket_name()}'."
)
else:
raise FileNotFoundError(
f"Error: Failed to read from bucket '{self.bucket}' at path {path}: {e}"
f"Error: Failed to read from bucket '{self._get_bucket_name()}' at path {path}: {e}"
)
except Exception as e:
raise FileNotFoundError(
f"Error: Failed to read from bucket '{self.bucket}' at path {path}: {e}"
f"Error: Failed to read from bucket '{self._get_bucket_name()}' at path {path}: {e}"
)
def list(self, path: str) -> list[str]:
@@ -96,7 +116,7 @@ class S3FileStore(FileStore):
results: set[str] = set()
prefix_len = len(path)
response: ListObjectsV2OutputDict = self.client.list_objects_v2(
Bucket=self.bucket, Prefix=path
Bucket=self._get_bucket_name(), Prefix=path
)
contents = response.get('Contents')
if not contents:
@@ -123,34 +143,36 @@ class S3FileStore(FileStore):
# Try to delete any child resources (Assume the path is a directory)
response = self.client.list_objects_v2(
Bucket=self.bucket, Prefix=f'{path}/'
Bucket=self._get_bucket_name(), Prefix=f'{path}/'
)
for content in response.get('Contents') or []:
self.client.delete_object(Bucket=self.bucket, Key=content['Key'])
self.client.delete_object(
Bucket=self._get_bucket_name(), Key=content['Key']
)
# Next try to delete item as a file
self.client.delete_object(Bucket=self.bucket, Key=path)
self.client.delete_object(Bucket=self._get_bucket_name(), Key=path)
except botocore.exceptions.ClientError as e:
if e.response['Error']['Code'] == 'NoSuchBucket':
raise FileNotFoundError(
f"Error: The bucket '{self.bucket}' does not exist."
f"Error: The bucket '{self._get_bucket_name()}' does not exist."
)
elif e.response['Error']['Code'] == 'AccessDenied':
raise FileNotFoundError(
f"Error: Access denied to bucket '{self.bucket}'."
f"Error: Access denied to bucket '{self._get_bucket_name()}'."
)
elif e.response['Error']['Code'] == 'NoSuchKey':
raise FileNotFoundError(
f"Error: The object key '{path}' does not exist in bucket '{self.bucket}'."
f"Error: The object key '{path}' does not exist in bucket '{self._get_bucket_name()}'."
)
else:
raise FileNotFoundError(
f"Error: Failed to delete key '{path}' from bucket '{self.bucket}': {e}"
f"Error: Failed to delete key '{path}' from bucket '{self._get_bucket_name()}': {e}"
)
except Exception as e:
raise FileNotFoundError(
f"Error: Failed to delete key '{path}' from bucket '{self.bucket}: {e}"
f"Error: Failed to delete key '{path}' from bucket '{self._get_bucket_name()}: {e}"
)
def _ensure_url_scheme(self, secure: bool, url: str | None) -> str | None:

Some files were not shown because too many files have changed in this diff Show More