mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
feat(backend): New "update microagent prompt" API (#8357)
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com> Co-authored-by: Engel Nyst <engel.nyst@gmail.com>
This commit is contained in:
@@ -1,11 +1,26 @@
|
||||
import itertools
|
||||
import re
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openhands.events.event_filter import EventFilter
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.events.action import (
|
||||
ChangeAgentStateAction,
|
||||
NullAction,
|
||||
)
|
||||
from openhands.events.observation import (
|
||||
NullObservation,
|
||||
AgentStateChangedObservation,
|
||||
)
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
@@ -16,12 +31,15 @@ from openhands.integrations.service_types import (
|
||||
ProviderType,
|
||||
SuggestedTask,
|
||||
)
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
||||
from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.data_models.conversation_info_result_set import (
|
||||
ConversationInfoResultSet,
|
||||
)
|
||||
from openhands.server.services.conversation_service import create_new_conversation
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.services.conversation_service import create_new_conversation
|
||||
from openhands.server.shared import (
|
||||
@@ -35,10 +53,11 @@ from openhands.server.user_auth import (
|
||||
get_provider_tokens,
|
||||
get_user_id,
|
||||
get_user_secrets,
|
||||
get_user_settings_store,
|
||||
get_user_settings,
|
||||
)
|
||||
from openhands.server.user_auth.user_auth import AuthType
|
||||
from openhands.server.utils import get_conversation_store
|
||||
from openhands.server.utils import get_conversation_store, get_conversation as get_conversation_object
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
@@ -47,6 +66,7 @@ from openhands.storage.data_models.conversation_metadata import (
|
||||
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.async_utils import wait_all
|
||||
from openhands.utils.conversation_summary import get_default_conversation_title
|
||||
|
||||
@@ -271,6 +291,76 @@ async def delete_conversation(
|
||||
return True
|
||||
|
||||
|
||||
@app.get('/conversations/{conversation_id}/remember_prompt')
|
||||
async def get_prompt(
|
||||
event_id: int,
|
||||
user_settings: SettingsStore = Depends(get_user_settings_store),
|
||||
conversation: ServerConversation | None = Depends(get_conversation_object)
|
||||
):
|
||||
if conversation is None:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={'error': 'Conversation not found.'},
|
||||
)
|
||||
|
||||
# get event stream for the conversation
|
||||
event_stream = conversation.event_stream
|
||||
|
||||
# retrieve the relevant events
|
||||
stringified_events = _get_contextual_events(event_stream, event_id)
|
||||
|
||||
# generate a prompt
|
||||
settings = await user_settings.load()
|
||||
if settings is None:
|
||||
# placeholder for error handling
|
||||
raise ValueError('Settings not found')
|
||||
|
||||
llm_config = LLMConfig(
|
||||
model=settings.llm_model,
|
||||
api_key=settings.llm_api_key,
|
||||
base_url=settings.llm_base_url,
|
||||
)
|
||||
|
||||
prompt_template = generate_prompt_template(stringified_events)
|
||||
prompt = generate_prompt(llm_config, prompt_template)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'status': 'success',
|
||||
'prompt': prompt,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def generate_prompt_template(events: str) -> str:
|
||||
env = Environment(loader=FileSystemLoader('openhands/microagent/prompts'))
|
||||
template = env.get_template('generate_remember_prompt.j2')
|
||||
return template.render(events=events)
|
||||
|
||||
|
||||
def generate_prompt(llm_config: LLMConfig, prompt_template: str) -> str:
|
||||
llm = LLM(llm_config)
|
||||
messages = [
|
||||
{
|
||||
'role': 'system',
|
||||
'content': prompt_template,
|
||||
},
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Please generate a prompt for the AI to update the special file based on the events provided.',
|
||||
},
|
||||
]
|
||||
|
||||
response = llm.completion(messages=messages)
|
||||
raw_prompt = response['choices'][0]['message']['content'].strip()
|
||||
prompt = re.search(r'<update_prompt>(.*?)</update_prompt>', raw_prompt, re.DOTALL)
|
||||
|
||||
if prompt:
|
||||
return prompt.group(1).strip()
|
||||
else:
|
||||
raise ValueError('No valid prompt found in the response.')
|
||||
|
||||
|
||||
async def _get_conversation_info(
|
||||
conversation: ConversationMetadata,
|
||||
num_connections: int,
|
||||
@@ -411,3 +501,38 @@ async def stop_conversation(
|
||||
},
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
def _get_contextual_events(event_stream: EventStream, event_id: int) -> str:
|
||||
# find the specified events to learn from
|
||||
# Get X events around the target event
|
||||
context_size = 4
|
||||
|
||||
agent_event_filter = EventFilter(
|
||||
exclude_hidden=True,
|
||||
exclude_types=(NullAction, NullObservation, ChangeAgentStateAction, AgentStateChangedObservation
|
||||
),
|
||||
) # the types of events that can be in an agent's history
|
||||
|
||||
# from event_id - context_size to event_id..
|
||||
context_before = event_stream.search_events(
|
||||
start_id=event_id,
|
||||
filter=agent_event_filter,
|
||||
reverse=True,
|
||||
limit=context_size,
|
||||
)
|
||||
|
||||
# from event_id to event_id + context_size + 1
|
||||
context_after = event_stream.search_events(
|
||||
start_id=event_id + 1,
|
||||
filter=agent_event_filter,
|
||||
limit=context_size + 1,
|
||||
)
|
||||
|
||||
# context_before is in reverse chronological order, so convert to list and reverse it.
|
||||
ordered_context_before = list(context_before)
|
||||
ordered_context_before.reverse()
|
||||
|
||||
all_events = itertools.chain(ordered_context_before, context_after)
|
||||
stringified_events = '\n'.join(str(event) for event in all_events)
|
||||
return stringified_events
|
||||
|
||||
Reference in New Issue
Block a user