mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
[Refactor]: Add LLMRegistry for llm services (#9589)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Graham Neubig <neubig@gmail.com> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
|
||||
import socketio
|
||||
|
||||
from openhands.core.config import OpenHandsConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
||||
@@ -136,6 +137,16 @@ class ConversationManager(ABC):
|
||||
) -> list[AgentLoopInfo]:
|
||||
"""Get the AgentLoopInfo for conversations."""
|
||||
|
||||
@abstractmethod
|
||||
async def request_llm_completion(
|
||||
self,
|
||||
sid: str,
|
||||
service_id: str,
|
||||
llm_config: LLMConfig,
|
||||
messages: list[dict[str, str]],
|
||||
) -> str:
|
||||
"""Request extraneous llm completions for a conversation"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_instance(
|
||||
|
||||
@@ -15,13 +15,13 @@ from docker.models.containers import Container
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import OpenHandsConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.nested_event_store import NestedEventStore
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.experiments.experiment_manager import ExperimentManagerImpl
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.runtime.impl.docker.docker_runtime import DockerRuntime
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
@@ -42,6 +42,7 @@ from openhands.storage.files import FileStore
|
||||
from openhands.storage.locations import get_conversation_dir
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.import_utils import get_impl
|
||||
from openhands.utils.utils import create_registry_and_convo_stats
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -275,6 +276,16 @@ class DockerNestedConversationManager(ConversationManager):
|
||||
# Not supported - clients should connect directly to the nested server!
|
||||
raise ValueError('unsupported_operation')
|
||||
|
||||
async def request_llm_completion(
|
||||
self,
|
||||
sid: str,
|
||||
service_id: str,
|
||||
llm_config: LLMConfig,
|
||||
messages: list[dict[str, str]],
|
||||
) -> str:
|
||||
# Not supported - clients should connect directly to the nested server!
|
||||
raise ValueError('unsupported_operation')
|
||||
|
||||
async def send_event_to_conversation(self, sid, data):
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
@@ -471,27 +482,27 @@ class DockerNestedConversationManager(ConversationManager):
|
||||
# This session is created here only because it is the easiest way to get a runtime, which
|
||||
# is the easiest way to create the needed docker container
|
||||
|
||||
# Run experiment manager variant test before creating session
|
||||
config: OpenHandsConfig = ExperimentManagerImpl.run_config_variant_test(
|
||||
user_id, sid, self.config
|
||||
)
|
||||
|
||||
llm_registry, convo_stats, config = create_registry_and_convo_stats(
|
||||
config, sid, user_id, settings
|
||||
)
|
||||
|
||||
session = Session(
|
||||
sid=sid,
|
||||
llm_registry=llm_registry,
|
||||
convo_stats=convo_stats,
|
||||
file_store=self.file_store,
|
||||
config=config,
|
||||
sio=self.sio,
|
||||
user_id=user_id,
|
||||
)
|
||||
llm_registry.retry_listner = session._notify_on_llm_retry
|
||||
agent_cls = settings.agent or config.default_agent
|
||||
agent_name = agent_cls if agent_cls is not None else 'agent'
|
||||
llm = LLM(
|
||||
config=config.get_llm_config_from_agent(agent_name),
|
||||
retry_listener=session._notify_on_llm_retry,
|
||||
)
|
||||
llm = session._create_llm(agent_cls)
|
||||
agent_config = config.get_agent_config(agent_cls)
|
||||
agent = Agent.get_cls(agent_cls)(llm, agent_config)
|
||||
agent = Agent.get_cls(agent_cls)(agent_config, llm_registry)
|
||||
|
||||
config = config.model_copy(deep=True)
|
||||
env_vars = config.sandbox.runtime_startup_env_vars
|
||||
@@ -543,6 +554,7 @@ class DockerNestedConversationManager(ConversationManager):
|
||||
headless_mode=False,
|
||||
attach_to_existing=False,
|
||||
main_module='openhands.server',
|
||||
llm_registry=llm_registry,
|
||||
)
|
||||
|
||||
# Hack - disable setting initial env.
|
||||
|
||||
@@ -6,12 +6,14 @@ from typing import Callable, Iterable
|
||||
|
||||
import socketio
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.exceptions import AgentRuntimeUnavailableError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.stream import EventStreamSubscriber, session_exists
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.constants import ROOM_KEY
|
||||
@@ -37,6 +39,7 @@ from openhands.utils.conversation_summary import (
|
||||
)
|
||||
from openhands.utils.import_utils import get_impl
|
||||
from openhands.utils.shutdown_listener import should_continue
|
||||
from openhands.utils.utils import create_registry_and_convo_stats
|
||||
|
||||
from .conversation_manager import ConversationManager
|
||||
|
||||
@@ -332,12 +335,15 @@ class StandaloneConversationManager(ConversationManager):
|
||||
)
|
||||
await self.close_session(oldest_conversation_id)
|
||||
|
||||
config = self.config.model_copy(deep=True)
|
||||
|
||||
llm_registry, convo_stats, config = create_registry_and_convo_stats(
|
||||
self.config, sid, user_id, settings
|
||||
)
|
||||
session = Session(
|
||||
sid=sid,
|
||||
file_store=self.file_store,
|
||||
config=config,
|
||||
llm_registry=llm_registry,
|
||||
convo_stats=convo_stats,
|
||||
sio=self.sio,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -349,7 +355,9 @@ class StandaloneConversationManager(ConversationManager):
|
||||
try:
|
||||
session.agent_session.event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER,
|
||||
self._create_conversation_update_callback(user_id, sid, settings),
|
||||
self._create_conversation_update_callback(
|
||||
user_id, sid, settings, session.llm_registry
|
||||
),
|
||||
UPDATED_AT_CALLBACK_ID,
|
||||
)
|
||||
except ValueError:
|
||||
@@ -369,6 +377,21 @@ class StandaloneConversationManager(ConversationManager):
|
||||
raise RuntimeError(f'no_conversation:{sid}')
|
||||
await session.dispatch(data)
|
||||
|
||||
async def request_llm_completion(
|
||||
self,
|
||||
sid: str,
|
||||
service_id: str,
|
||||
llm_config: LLMConfig,
|
||||
messages: list[dict[str, str]],
|
||||
):
|
||||
session = self._local_agent_loops_by_sid.get(sid)
|
||||
if not session:
|
||||
raise RuntimeError(f'no_conversation:{sid}')
|
||||
llm_registry = session.llm_registry
|
||||
return llm_registry.request_extraneous_completion(
|
||||
service_id, llm_config, messages
|
||||
)
|
||||
|
||||
async def disconnect_from_session(self, connection_id: str):
|
||||
sid = self._local_connection_id_to_session_id.pop(connection_id, None)
|
||||
logger.info(
|
||||
@@ -450,6 +473,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
user_id: str | None,
|
||||
conversation_id: str,
|
||||
settings: Settings,
|
||||
llm_registry: LLMRegistry,
|
||||
) -> Callable:
|
||||
def callback(event, *args, **kwargs):
|
||||
call_async_from_sync(
|
||||
@@ -458,6 +482,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
user_id,
|
||||
conversation_id,
|
||||
settings,
|
||||
llm_registry,
|
||||
event,
|
||||
)
|
||||
|
||||
@@ -468,6 +493,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
settings: Settings,
|
||||
llm_registry: LLMRegistry,
|
||||
event=None,
|
||||
):
|
||||
conversation_store = await self._get_conversation_store(user_id)
|
||||
@@ -495,7 +521,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
conversation.title == default_title
|
||||
): # attempt to autogenerate if default title is in use
|
||||
title = await auto_generate_title(
|
||||
conversation_id, user_id, self.file_store, settings
|
||||
conversation_id, user_id, self.file_store, settings, llm_registry
|
||||
)
|
||||
if title and not title.isspace():
|
||||
conversation.title = title
|
||||
|
||||
0
openhands/server/conversation_manager/utils.py
Normal file
0
openhands/server/conversation_manager/utils.py
Normal file
@@ -33,7 +33,6 @@ from openhands.integrations.service_types import (
|
||||
ProviderType,
|
||||
SuggestedTask,
|
||||
)
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.runtime.runtime_status import RuntimeStatus
|
||||
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
||||
@@ -47,6 +46,7 @@ from openhands.server.services.conversation_service import (
|
||||
setup_init_conversation_settings,
|
||||
)
|
||||
from openhands.server.shared import (
|
||||
ConversationManagerImpl,
|
||||
ConversationStoreImpl,
|
||||
config,
|
||||
conversation_manager,
|
||||
@@ -364,7 +364,7 @@ async def get_prompt(
|
||||
)
|
||||
|
||||
prompt_template = generate_prompt_template(stringified_events)
|
||||
prompt = generate_prompt(llm_config, prompt_template)
|
||||
prompt = generate_prompt(llm_config, prompt_template, conversation_id)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
@@ -380,8 +380,9 @@ def generate_prompt_template(events: str) -> str:
|
||||
return template.render(events=events)
|
||||
|
||||
|
||||
def generate_prompt(llm_config: LLMConfig, prompt_template: str) -> str:
|
||||
llm = LLM(llm_config)
|
||||
def generate_prompt(
|
||||
llm_config: LLMConfig, prompt_template: str, conversation_id: str
|
||||
) -> str:
|
||||
messages = [
|
||||
{
|
||||
'role': 'system',
|
||||
@@ -393,8 +394,9 @@ def generate_prompt(llm_config: LLMConfig, prompt_template: str) -> str:
|
||||
},
|
||||
]
|
||||
|
||||
response = llm.completion(messages=messages)
|
||||
raw_prompt = response['choices'][0]['message']['content'].strip()
|
||||
raw_prompt = ConversationManagerImpl.request_llm_completion(
|
||||
'remember_prompt', conversation_id, llm_config, messages
|
||||
)
|
||||
prompt = re.search(r'<update_prompt>(.*?)</update_prompt>', raw_prompt, re.DOTALL)
|
||||
|
||||
if prompt:
|
||||
|
||||
@@ -31,20 +31,60 @@ from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.utils.conversation_summary import get_default_conversation_title
|
||||
|
||||
|
||||
async def create_new_conversation(
|
||||
async def initialize_conversation(
|
||||
user_id: str | None,
|
||||
conversation_id: str | None,
|
||||
selected_repository: str | None,
|
||||
selected_branch: str | None,
|
||||
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
|
||||
git_provider: ProviderType | None = None,
|
||||
) -> ConversationMetadata | None:
|
||||
if conversation_id is None:
|
||||
conversation_id = uuid.uuid4().hex
|
||||
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
|
||||
if not await conversation_store.exists(conversation_id):
|
||||
logger.info(
|
||||
f'New conversation ID: {conversation_id}',
|
||||
extra={'user_id': user_id, 'session_id': conversation_id},
|
||||
)
|
||||
|
||||
conversation_title = get_default_conversation_title(conversation_id)
|
||||
|
||||
logger.info(f'Saving metadata for conversation {conversation_id}')
|
||||
convo_metadata = ConversationMetadata(
|
||||
trigger=conversation_trigger,
|
||||
conversation_id=conversation_id,
|
||||
title=conversation_title,
|
||||
user_id=user_id,
|
||||
selected_repository=selected_repository,
|
||||
selected_branch=selected_branch,
|
||||
git_provider=git_provider,
|
||||
)
|
||||
|
||||
await conversation_store.save_metadata(convo_metadata)
|
||||
return convo_metadata
|
||||
|
||||
try:
|
||||
convo_metadata = await conversation_store.get_metadata(conversation_id)
|
||||
return convo_metadata
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def start_conversation(
|
||||
user_id: str | None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None,
|
||||
custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA | None,
|
||||
selected_repository: str | None,
|
||||
selected_branch: str | None,
|
||||
initial_user_msg: str | None,
|
||||
image_urls: list[str] | None,
|
||||
replay_json: str | None,
|
||||
conversation_instructions: str | None = None,
|
||||
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
|
||||
attach_conversation_id: bool = False,
|
||||
git_provider: ProviderType | None = None,
|
||||
conversation_id: str | None = None,
|
||||
conversation_id: str,
|
||||
convo_metadata: ConversationMetadata,
|
||||
conversation_instructions: str | None,
|
||||
mcp_config: MCPConfig | None = None,
|
||||
) -> AgentLoopInfo:
|
||||
logger.info(
|
||||
@@ -52,7 +92,7 @@ async def create_new_conversation(
|
||||
extra={
|
||||
'signal': 'create_conversation',
|
||||
'user_id': user_id,
|
||||
'trigger': conversation_trigger.value,
|
||||
'trigger': convo_metadata.trigger,
|
||||
},
|
||||
)
|
||||
logger.info('Loading settings')
|
||||
@@ -79,53 +119,25 @@ async def create_new_conversation(
|
||||
raise MissingSettingsError('Settings not found')
|
||||
|
||||
session_init_args['git_provider_tokens'] = git_provider_tokens
|
||||
session_init_args['selected_repository'] = selected_repository
|
||||
session_init_args['selected_repository'] = convo_metadata.selected_repository
|
||||
session_init_args['custom_secrets'] = custom_secrets
|
||||
session_init_args['selected_branch'] = selected_branch
|
||||
session_init_args['git_provider'] = git_provider
|
||||
session_init_args['selected_branch'] = convo_metadata.selected_branch
|
||||
session_init_args['git_provider'] = convo_metadata.git_provider
|
||||
session_init_args['conversation_instructions'] = conversation_instructions
|
||||
if mcp_config:
|
||||
session_init_args['mcp_config'] = mcp_config
|
||||
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
|
||||
logger.info('Loading conversation store')
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
logger.info('ServerConversation store loaded')
|
||||
|
||||
# For nested runtimes, we allow a single conversation id, passed in on container creation
|
||||
if conversation_id is None:
|
||||
conversation_id = uuid.uuid4().hex
|
||||
|
||||
if not await conversation_store.exists(conversation_id):
|
||||
logger.info(
|
||||
f'New conversation ID: {conversation_id}',
|
||||
extra={'user_id': user_id, 'session_id': conversation_id},
|
||||
)
|
||||
|
||||
conversation_init_data = ExperimentManagerImpl.run_conversation_variant_test(
|
||||
user_id, conversation_id, conversation_init_data
|
||||
)
|
||||
conversation_title = get_default_conversation_title(conversation_id)
|
||||
|
||||
logger.info(f'Saving metadata for conversation {conversation_id}')
|
||||
await conversation_store.save_metadata(
|
||||
ConversationMetadata(
|
||||
trigger=conversation_trigger,
|
||||
conversation_id=conversation_id,
|
||||
title=conversation_title,
|
||||
user_id=user_id,
|
||||
selected_repository=selected_repository,
|
||||
selected_branch=selected_branch,
|
||||
git_provider=git_provider,
|
||||
llm_model=conversation_init_data.llm_model,
|
||||
)
|
||||
)
|
||||
conversation_init_data = ExperimentManagerImpl.run_conversation_variant_test(
|
||||
user_id, conversation_id, conversation_init_data
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'Starting agent loop for conversation {conversation_id}',
|
||||
extra={'user_id': user_id, 'session_id': conversation_id},
|
||||
)
|
||||
|
||||
initial_message_action = None
|
||||
if initial_user_msg or image_urls:
|
||||
initial_message_action = MessageAction(
|
||||
@@ -133,9 +145,6 @@ async def create_new_conversation(
|
||||
image_urls=image_urls or [],
|
||||
)
|
||||
|
||||
if attach_conversation_id:
|
||||
logger.warning('Attaching conversation ID is deprecated, skipping process')
|
||||
|
||||
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||
conversation_id,
|
||||
conversation_init_data,
|
||||
@@ -147,6 +156,47 @@ async def create_new_conversation(
|
||||
return agent_loop_info
|
||||
|
||||
|
||||
async def create_new_conversation(
|
||||
user_id: str | None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None,
|
||||
custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA | None,
|
||||
selected_repository: str | None,
|
||||
selected_branch: str | None,
|
||||
initial_user_msg: str | None,
|
||||
image_urls: list[str] | None,
|
||||
replay_json: str | None,
|
||||
conversation_instructions: str | None = None,
|
||||
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
|
||||
git_provider: ProviderType | None = None,
|
||||
conversation_id: str | None = None,
|
||||
mcp_config: MCPConfig | None = None,
|
||||
) -> AgentLoopInfo:
|
||||
conversation_metadata = await initialize_conversation(
|
||||
user_id,
|
||||
conversation_id,
|
||||
selected_repository,
|
||||
selected_branch,
|
||||
conversation_trigger,
|
||||
git_provider,
|
||||
)
|
||||
|
||||
if not conversation_metadata:
|
||||
raise Exception('Failed to initialize conversation')
|
||||
|
||||
return await start_conversation(
|
||||
user_id,
|
||||
git_provider_tokens,
|
||||
custom_secrets,
|
||||
initial_user_msg,
|
||||
image_urls,
|
||||
replay_json,
|
||||
conversation_metadata.conversation_id,
|
||||
conversation_metadata,
|
||||
conversation_instructions,
|
||||
mcp_config,
|
||||
)
|
||||
|
||||
|
||||
def create_provider_tokens_object(
|
||||
providers_set: list[ProviderType],
|
||||
) -> PROVIDER_TOKEN_TYPE:
|
||||
|
||||
77
openhands/server/services/conversation_stats.py
Normal file
77
openhands/server/services/conversation_stats.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import base64
|
||||
import pickle
|
||||
from threading import Lock
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.llm.llm_registry import RegistryEvent
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.storage.locations import get_conversation_stats_filename
|
||||
|
||||
|
||||
class ConversationStats:
|
||||
def __init__(
|
||||
self,
|
||||
file_store: FileStore | None,
|
||||
conversation_id: str,
|
||||
user_id: str | None,
|
||||
):
|
||||
self.metrics_path = get_conversation_stats_filename(conversation_id, user_id)
|
||||
self.file_store = file_store
|
||||
self.conversation_id = conversation_id
|
||||
self.user_id = user_id
|
||||
|
||||
self._save_lock = Lock()
|
||||
|
||||
self.service_to_metrics: dict[str, Metrics] = {}
|
||||
self.restored_metrics: dict[str, Metrics] = {}
|
||||
|
||||
# Always attempt to restore registry if it exists
|
||||
self.maybe_restore_metrics()
|
||||
|
||||
def save_metrics(self):
|
||||
if not self.file_store:
|
||||
return
|
||||
|
||||
with self._save_lock:
|
||||
pickled = pickle.dumps(self.service_to_metrics)
|
||||
serialized_metrics = base64.b64encode(pickled).decode('utf-8')
|
||||
self.file_store.write(self.metrics_path, serialized_metrics)
|
||||
|
||||
def maybe_restore_metrics(self):
|
||||
if not self.file_store or not self.conversation_id:
|
||||
return
|
||||
|
||||
try:
|
||||
encoded = self.file_store.read(self.metrics_path)
|
||||
pickled = base64.b64decode(encoded)
|
||||
self.restored_metrics = pickle.loads(pickled)
|
||||
logger.info(f'restored metrics: {self.conversation_id}')
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
def get_combined_metrics(self) -> Metrics:
|
||||
total_metrics = Metrics()
|
||||
for metrics in self.service_to_metrics.values():
|
||||
total_metrics.merge(metrics)
|
||||
|
||||
logger.info(f'metrics by all services: {self.service_to_metrics}')
|
||||
logger.info(f'combined metrics\n\n{total_metrics}')
|
||||
return total_metrics
|
||||
|
||||
def get_metrics_for_service(self, service_id: str) -> Metrics:
|
||||
if service_id not in self.service_to_metrics:
|
||||
raise Exception(f'LLM service does not exist {service_id}')
|
||||
|
||||
return self.service_to_metrics[service_id]
|
||||
|
||||
def register_llm(self, event: RegistryEvent):
|
||||
# Listen for llm creations and track their metrics
|
||||
llm = event.llm
|
||||
service_id = event.service_id
|
||||
|
||||
if service_id in self.restored_metrics:
|
||||
llm.metrics = self.restored_metrics[service_id].copy()
|
||||
del self.restored_metrics[service_id]
|
||||
|
||||
self.service_to_metrics[service_id] = llm.metrics
|
||||
@@ -21,6 +21,7 @@ from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
ProviderHandler,
|
||||
)
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.mcp import add_mcp_tools_to_agent
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.microagent.microagent import BaseMicroagent
|
||||
@@ -29,6 +30,7 @@ from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime
|
||||
from openhands.runtime.runtime_status import RuntimeStatus
|
||||
from openhands.security import SecurityAnalyzer, options
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.utils.async_utils import EXECUTOR, call_sync_from_async
|
||||
@@ -48,6 +50,7 @@ class AgentSession:
|
||||
sid: str
|
||||
user_id: str | None
|
||||
event_stream: EventStream
|
||||
llm_registry: LLMRegistry
|
||||
file_store: FileStore
|
||||
controller: AgentController | None = None
|
||||
runtime: Runtime | None = None
|
||||
@@ -63,6 +66,8 @@ class AgentSession:
|
||||
self,
|
||||
sid: str,
|
||||
file_store: FileStore,
|
||||
llm_registry: LLMRegistry,
|
||||
convo_stats: ConversationStats,
|
||||
status_callback: Callable | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
@@ -80,6 +85,8 @@ class AgentSession:
|
||||
self.logger = OpenHandsLoggerAdapter(
|
||||
extra={'session_id': sid, 'user_id': user_id}
|
||||
)
|
||||
self.llm_registry = llm_registry
|
||||
self.convo_stats = convo_stats
|
||||
|
||||
async def start(
|
||||
self,
|
||||
@@ -340,6 +347,7 @@ class AgentSession:
|
||||
self.runtime = runtime_cls(
|
||||
config=config,
|
||||
event_stream=self.event_stream,
|
||||
llm_registry=self.llm_registry,
|
||||
sid=self.sid,
|
||||
plugins=agent.sandbox_plugins,
|
||||
status_callback=self._status_callback,
|
||||
@@ -360,6 +368,7 @@ class AgentSession:
|
||||
self.runtime = runtime_cls(
|
||||
config=config,
|
||||
event_stream=self.event_stream,
|
||||
llm_registry=self.llm_registry,
|
||||
sid=self.sid,
|
||||
plugins=agent.sandbox_plugins,
|
||||
status_callback=self._status_callback,
|
||||
@@ -441,6 +450,7 @@ class AgentSession:
|
||||
user_id=self.user_id,
|
||||
file_store=self.file_store,
|
||||
event_stream=self.event_stream,
|
||||
convo_stats=self.convo_stats,
|
||||
agent=agent,
|
||||
iteration_delta=int(max_iterations),
|
||||
budget_per_task_delta=max_budget_per_task,
|
||||
@@ -490,6 +500,15 @@ class AgentSession:
|
||||
)
|
||||
return memory
|
||||
|
||||
def get_state(self) -> AgentState | None:
|
||||
controller = self.controller
|
||||
if controller:
|
||||
return controller.state.agent_state
|
||||
if time.time() > self._started_at + WAIT_TIME_BEFORE_CLOSE:
|
||||
# If 5 minutes have elapsed and we still don't have a controller, something has gone wrong
|
||||
return AgentState.ERROR
|
||||
return None
|
||||
|
||||
def _maybe_restore_state(self) -> State | None:
|
||||
"""Helper method to handle state restore logic."""
|
||||
restored_state = None
|
||||
@@ -510,14 +529,5 @@ class AgentSession:
|
||||
self.logger.debug('No events found, no state to restore')
|
||||
return restored_state
|
||||
|
||||
def get_state(self) -> AgentState | None:
|
||||
controller = self.controller
|
||||
if controller:
|
||||
return controller.state.agent_state
|
||||
if time.time() > self._started_at + WAIT_TIME_BEFORE_CLOSE:
|
||||
# If 5 minutes have elapsed and we still don't have a controller, something has gone wrong
|
||||
return AgentState.ERROR
|
||||
return None
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
|
||||
from openhands.core.config import OpenHandsConfig
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.security import SecurityAnalyzer, options
|
||||
@@ -45,6 +46,7 @@ class ServerConversation:
|
||||
else:
|
||||
runtime_cls = get_runtime_cls(self.config.runtime)
|
||||
runtime = runtime_cls(
|
||||
llm_registry=LLMRegistry(self.config),
|
||||
config=config,
|
||||
event_stream=self.event_stream,
|
||||
sid=self.sid,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from logging import LoggerAdapter
|
||||
|
||||
import socketio
|
||||
@@ -28,9 +27,10 @@ from openhands.events.observation.agent import RecallObservation
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.serialization import event_from_dict, event_to_dict
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime.runtime_status import RuntimeStatus
|
||||
from openhands.server.constants import ROOM_KEY
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.server.session.agent_session import AgentSession
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
@@ -45,6 +45,7 @@ class Session:
|
||||
agent_session: AgentSession
|
||||
loop: asyncio.AbstractEventLoop
|
||||
config: OpenHandsConfig
|
||||
llm_registry: LLMRegistry
|
||||
file_store: FileStore
|
||||
user_id: str | None
|
||||
logger: LoggerAdapter
|
||||
@@ -53,6 +54,8 @@ class Session:
|
||||
self,
|
||||
sid: str,
|
||||
config: OpenHandsConfig,
|
||||
llm_registry: LLMRegistry,
|
||||
convo_stats: ConversationStats,
|
||||
file_store: FileStore,
|
||||
sio: socketio.AsyncServer | None,
|
||||
user_id: str | None = None,
|
||||
@@ -62,17 +65,21 @@ class Session:
|
||||
self.last_active_ts = int(time.time())
|
||||
self.file_store = file_store
|
||||
self.logger = OpenHandsLoggerAdapter(extra={'session_id': sid})
|
||||
self.llm_registry = llm_registry
|
||||
self.convo_stats = convo_stats
|
||||
self.agent_session = AgentSession(
|
||||
sid,
|
||||
file_store,
|
||||
llm_registry=self.llm_registry,
|
||||
convo_stats=convo_stats,
|
||||
status_callback=self.queue_status_message,
|
||||
user_id=user_id,
|
||||
)
|
||||
self.agent_session.event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER, self.on_event, self.sid
|
||||
)
|
||||
# Copying this means that when we update variables they are not applied to the shared global configuration!
|
||||
self.config = deepcopy(config)
|
||||
self.config = config
|
||||
|
||||
# Lazy import to avoid circular dependency
|
||||
from openhands.experiments.experiment_manager import ExperimentManagerImpl
|
||||
|
||||
@@ -140,13 +147,6 @@ class Session:
|
||||
else self.config.max_budget_per_task
|
||||
)
|
||||
|
||||
# This is a shallow copy of the default LLM config, so changes here will
|
||||
# persist if we retrieve the default LLM config again when constructing
|
||||
# the agent
|
||||
default_llm_config = self.config.get_llm_config()
|
||||
default_llm_config.model = settings.llm_model or ''
|
||||
default_llm_config.api_key = settings.llm_api_key
|
||||
default_llm_config.base_url = settings.llm_base_url
|
||||
self.config.search_api_key = settings.search_api_key
|
||||
if settings.sandbox_api_key:
|
||||
self.config.sandbox.api_key = settings.sandbox_api_key.get_secret_value()
|
||||
@@ -181,10 +181,9 @@ class Session:
|
||||
)
|
||||
|
||||
# TODO: override other LLM config & agent config groups (#2075)
|
||||
|
||||
llm = self._create_llm(agent_cls)
|
||||
agent_config = self.config.get_agent_config(agent_cls)
|
||||
|
||||
agent_name = agent_cls if agent_cls is not None else 'agent'
|
||||
llm_config = self.config.get_llm_config_from_agent(agent_name)
|
||||
if settings.enable_default_condenser:
|
||||
# Default condenser chains three condensers together:
|
||||
# 1. a conversation window condenser that handles explicit
|
||||
@@ -200,7 +199,7 @@ class Session:
|
||||
ConversationWindowCondenserConfig(),
|
||||
BrowserOutputCondenserConfig(attention_window=2),
|
||||
LLMSummarizingCondenserConfig(
|
||||
llm_config=llm.config, keep_first=4, max_size=120
|
||||
llm_config=llm_config, keep_first=4, max_size=120
|
||||
),
|
||||
]
|
||||
)
|
||||
@@ -208,12 +207,14 @@ class Session:
|
||||
self.logger.info(
|
||||
f'Enabling pipeline condenser with:'
|
||||
f' browser_output_masking(attention_window=2), '
|
||||
f' llm(model="{llm.config.model}", '
|
||||
f' base_url="{llm.config.base_url}", '
|
||||
f' llm(model="{llm_config.model}", '
|
||||
f' base_url="{llm_config.base_url}", '
|
||||
f' keep_first=4, max_size=80)'
|
||||
)
|
||||
agent_config.condenser = default_condenser_config
|
||||
agent = Agent.get_cls(agent_cls)(llm, agent_config)
|
||||
agent = Agent.get_cls(agent_cls)(agent_config, self.llm_registry)
|
||||
|
||||
self.llm_registry.retry_listner = self._notify_on_llm_retry
|
||||
|
||||
git_provider_tokens = None
|
||||
selected_repository = None
|
||||
@@ -269,14 +270,6 @@ class Session:
|
||||
)
|
||||
return
|
||||
|
||||
def _create_llm(self, agent_cls: str | None) -> LLM:
|
||||
"""Initialize LLM, extracted for testing."""
|
||||
agent_name = agent_cls if agent_cls is not None else 'agent'
|
||||
return LLM(
|
||||
config=self.config.get_llm_config_from_agent(agent_name),
|
||||
retry_listener=self._notify_on_llm_retry,
|
||||
)
|
||||
|
||||
def _notify_on_llm_retry(self, retries: int, max: int) -> None:
|
||||
self.queue_status_message(
|
||||
'info', RuntimeStatus.LLM_RETRY, f'Retrying LLM request, {retries} / {max}'
|
||||
|
||||
Reference in New Issue
Block a user