mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-09 22:35:54 -05:00
Compare commits
5 Commits
fix/code-r
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
179df7a726 | ||
|
|
6467f6734f | ||
|
|
5a30d11416 | ||
|
|
1f4105e8f9 | ||
|
|
caf9ff34e6 |
@@ -45,10 +45,7 @@ async def create_chat_session(
|
||||
successfulAgentRuns=SafeJson({}),
|
||||
successfulAgentSchedules=SafeJson({}),
|
||||
)
|
||||
return await PrismaChatSession.prisma().create(
|
||||
data=data,
|
||||
include={"Messages": True},
|
||||
)
|
||||
return await PrismaChatSession.prisma().create(data=data)
|
||||
|
||||
|
||||
async def update_chat_session(
|
||||
|
||||
@@ -266,12 +266,38 @@ async def stream_chat_post(
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
stream_start_time = time.perf_counter()
|
||||
|
||||
# Base log metadata (task_id added after creation)
|
||||
log_meta = {"component": "ChatStream", "session_id": session_id}
|
||||
if user_id:
|
||||
log_meta["user_id"] = user_id
|
||||
|
||||
logger.info(
|
||||
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
|
||||
f"user={user_id}, message_len={len(request.message)}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
logger.info(
|
||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"duration_ms": (time.perf_counter() - stream_start_time) * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
task_id = str(uuid_module.uuid4())
|
||||
operation_id = str(uuid_module.uuid4())
|
||||
log_meta["task_id"] = task_id
|
||||
|
||||
task_create_start = time.perf_counter()
|
||||
await stream_registry.create_task(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
@@ -280,14 +306,46 @@ async def stream_chat_post(
|
||||
tool_name="chat",
|
||||
operation_id=operation_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Background task that runs the AI generation independently of SSE connection
|
||||
async def run_ai_generation():
|
||||
import time as time_module
|
||||
|
||||
gen_start_time = time_module.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
first_chunk_time, ttfc = None, None
|
||||
chunk_count = 0
|
||||
try:
|
||||
# Emit a start event with task_id for reconnection
|
||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
||||
logger.info(
|
||||
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time)*1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
|
||||
* 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[TIMING] Calling stream_chat_completion",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
@@ -296,54 +354,202 @@ async def stream_chat_post(
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
):
|
||||
chunk_count += 1
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = time_module.perf_counter()
|
||||
ttfc = first_chunk_time - gen_start_time
|
||||
logger.info(
|
||||
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunk_type": type(chunk).__name__,
|
||||
"time_to_first_chunk_ms": ttfc * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
# Write to Redis (subscribers will receive via XREAD)
|
||||
await stream_registry.publish_chunk(task_id, chunk)
|
||||
|
||||
# Mark task as completed
|
||||
gen_end_time = time_module.perf_counter()
|
||||
total_time = (gen_end_time - gen_start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
|
||||
f"task={task_id}, session={session_id}, "
|
||||
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"time_to_first_chunk_ms": (
|
||||
ttfc * 1000 if ttfc is not None else None
|
||||
),
|
||||
"n_chunks": chunk_count,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
await stream_registry.mark_task_completed(task_id, "completed")
|
||||
except Exception as e:
|
||||
elapsed = time_module.perf_counter() - gen_start_time
|
||||
logger.error(
|
||||
f"Error in background AI generation for session {session_id}: {e}"
|
||||
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": elapsed * 1000,
|
||||
"error": str(e),
|
||||
}
|
||||
},
|
||||
)
|
||||
await stream_registry.mark_task_completed(task_id, "failed")
|
||||
|
||||
# Start the AI generation in a background task
|
||||
bg_task = asyncio.create_task(run_ai_generation())
|
||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||
)
|
||||
|
||||
# SSE endpoint that subscribes to the task's stream
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import time as time_module
|
||||
|
||||
event_gen_start = time_module.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
|
||||
f"user={user_id}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
subscriber_queue = None
|
||||
first_chunk_yielded = False
|
||||
chunks_yielded = 0
|
||||
try:
|
||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||
subscribe_start = time_module.perf_counter()
|
||||
logger.info(
|
||||
"[TIMING] Calling subscribe_to_task",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
last_message_id="0-0", # Get all messages from the beginning
|
||||
)
|
||||
subscribe_time = (time_module.perf_counter() - subscribe_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] subscribe_to_task completed in {subscribe_time:.1f}ms, "
|
||||
f"queue_ok={subscriber_queue is not None}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"duration_ms": subscribe_time,
|
||||
"queue_obtained": subscriber_queue is not None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
logger.info(
|
||||
"[TIMING] subscriber_queue is None, yielding finish",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
# Read from the subscriber queue and yield to SSE
|
||||
logger.info(
|
||||
"[TIMING] Starting to read from subscriber_queue",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
queue_wait_start = time_module.perf_counter()
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||
queue_wait_time = (
|
||||
time_module.perf_counter() - queue_wait_start
|
||||
) * 1000
|
||||
chunks_yielded += 1
|
||||
|
||||
if not first_chunk_yielded:
|
||||
first_chunk_yielded = True
|
||||
elapsed = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
|
||||
f"type={type(chunk).__name__}, "
|
||||
f"wait={queue_wait_time:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunk_type": type(chunk).__name__,
|
||||
"elapsed_ms": elapsed * 1000,
|
||||
"queue_wait_ms": queue_wait_time,
|
||||
}
|
||||
},
|
||||
)
|
||||
elif chunks_yielded % 50 == 0:
|
||||
logger.info(
|
||||
f"[TIMING] Chunk #{chunks_yielded}, "
|
||||
f"type={type(chunk).__name__}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunk_number": chunks_yielded,
|
||||
"chunk_type": type(chunk).__name__,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
total_time = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
|
||||
f"n_chunks={chunks_yielded}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunks_yielded": chunks_yielded,
|
||||
"total_time_ms": total_time * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat to keep connection alive
|
||||
logger.info(
|
||||
f"[TIMING] Heartbeat timeout, chunks_so_far={chunks_yielded}",
|
||||
extra={
|
||||
"json_fields": {**log_meta, "chunks_so_far": chunks_yielded}
|
||||
},
|
||||
)
|
||||
yield StreamHeartbeat().to_sse()
|
||||
|
||||
except GeneratorExit:
|
||||
logger.info(
|
||||
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunks_yielded": chunks_yielded,
|
||||
"reason": "client_disconnect",
|
||||
}
|
||||
},
|
||||
)
|
||||
pass # Client disconnected - background task continues
|
||||
except Exception as e:
|
||||
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
||||
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
||||
logger.error(
|
||||
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
|
||||
extra={
|
||||
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
||||
},
|
||||
)
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||
if subscriber_queue is not None:
|
||||
@@ -357,6 +563,18 @@ async def stream_chat_post(
|
||||
exc_info=True,
|
||||
)
|
||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||
total_time = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
||||
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time * 1000,
|
||||
"chunks_yielded": chunks_yielded,
|
||||
}
|
||||
},
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
@@ -425,7 +643,7 @@ async def stream_chat_get(
|
||||
"Chat stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_count": chunk_count,
|
||||
"n_chunks": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid as uuid_module
|
||||
from asyncio import CancelledError
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import openai
|
||||
|
||||
from backend.util.prompt import compress_context
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
@@ -374,21 +371,45 @@ async def stream_chat_completion(
|
||||
ValueError: If max_context_messages is exceeded
|
||||
|
||||
"""
|
||||
completion_start = time.monotonic()
|
||||
|
||||
# Build log metadata for structured logging
|
||||
log_meta = {"component": "ChatService", "session_id": session_id}
|
||||
if user_id:
|
||||
log_meta["user_id"] = user_id
|
||||
|
||||
logger.info(
|
||||
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
|
||||
f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, "
|
||||
f"message_len={len(message) if message else 0}, is_user={is_user_message}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"message_len": len(message) if message else 0,
|
||||
"is_user_message": is_user_message,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Only fetch from Redis if session not provided (initial call)
|
||||
if session is None:
|
||||
fetch_start = time.monotonic()
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
fetch_time = (time.monotonic() - fetch_start) * 1000
|
||||
logger.info(
|
||||
f"Fetched session from Redis: {session.session_id if session else 'None'}, "
|
||||
f"message_count={len(session.messages) if session else 0}"
|
||||
f"[TIMING] get_chat_session took {fetch_time:.1f}ms, "
|
||||
f"n_messages={len(session.messages) if session else 0}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"duration_ms": fetch_time,
|
||||
"n_messages": len(session.messages) if session else 0,
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Using provided session object: {session.session_id}, "
|
||||
f"message_count={len(session.messages)}"
|
||||
f"[TIMING] Using provided session, messages={len(session.messages)}",
|
||||
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
||||
)
|
||||
|
||||
if not session:
|
||||
@@ -409,17 +430,25 @@ async def stream_chat_completion(
|
||||
|
||||
# Track user message in PostHog
|
||||
if is_user_message:
|
||||
posthog_start = time.monotonic()
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(message),
|
||||
)
|
||||
posthog_time = (time.monotonic() - posthog_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] track_user_message took {posthog_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": posthog_time}},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Upserting session: {session.session_id} with user id {session.user_id}, "
|
||||
f"message_count={len(session.messages)}"
|
||||
)
|
||||
upsert_start = time.monotonic()
|
||||
session = await upsert_chat_session(session)
|
||||
upsert_time = (time.monotonic() - upsert_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": upsert_time}},
|
||||
)
|
||||
assert session, "Session not found"
|
||||
|
||||
# Generate title for new sessions on first user message (non-blocking)
|
||||
@@ -457,7 +486,13 @@ async def stream_chat_completion(
|
||||
asyncio.create_task(_update_title())
|
||||
|
||||
# Build system prompt with business understanding
|
||||
prompt_start = time.monotonic()
|
||||
system_prompt, understanding = await _build_system_prompt(user_id)
|
||||
prompt_time = (time.monotonic() - prompt_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": prompt_time}},
|
||||
)
|
||||
|
||||
# Initialize variables for streaming
|
||||
assistant_response = ChatMessage(
|
||||
@@ -480,13 +515,24 @@ async def stream_chat_completion(
|
||||
should_retry = False
|
||||
|
||||
# Generate unique IDs for AI SDK protocol
|
||||
import uuid as uuid_module
|
||||
|
||||
message_id = str(uuid_module.uuid4())
|
||||
text_block_id = str(uuid_module.uuid4())
|
||||
|
||||
# Yield message start
|
||||
setup_time = (time.monotonic() - completion_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||
)
|
||||
yield StreamStart(messageId=message_id)
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"[TIMING] Calling _stream_chat_chunks",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
async for chunk in _stream_chat_chunks(
|
||||
session=session,
|
||||
tools=tools,
|
||||
@@ -840,6 +886,10 @@ async def _manage_context_window(
|
||||
Returns:
|
||||
CompressResult with compacted messages and metadata
|
||||
"""
|
||||
import openai
|
||||
|
||||
from backend.util.prompt import compress_context
|
||||
|
||||
# Convert messages to dict format
|
||||
messages_dict = []
|
||||
for msg in messages:
|
||||
@@ -890,9 +940,21 @@ async def _stream_chat_chunks(
|
||||
SSE formatted JSON response objects
|
||||
|
||||
"""
|
||||
import time as time_module
|
||||
|
||||
stream_chunks_start = time_module.perf_counter()
|
||||
model = config.model
|
||||
|
||||
logger.info("Starting pure chat stream")
|
||||
# Build log metadata for structured logging
|
||||
log_meta = {"component": "ChatService", "session_id": session.session_id}
|
||||
if session.user_id:
|
||||
log_meta["user_id"] = session.user_id
|
||||
|
||||
logger.info(
|
||||
f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, "
|
||||
f"user={session.user_id}, n_messages={len(session.messages)}",
|
||||
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
||||
)
|
||||
|
||||
messages = session.to_openai_messages()
|
||||
if system_prompt:
|
||||
@@ -903,12 +965,18 @@ async def _stream_chat_chunks(
|
||||
messages = [system_message] + messages
|
||||
|
||||
# Apply context window management
|
||||
context_start = time_module.perf_counter()
|
||||
context_result = await _manage_context_window(
|
||||
messages=messages,
|
||||
model=model,
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
)
|
||||
context_time = (time_module.perf_counter() - context_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] _manage_context_window took {context_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": context_time}},
|
||||
)
|
||||
|
||||
if context_result.error:
|
||||
if "System prompt dropped" in context_result.error:
|
||||
@@ -943,9 +1011,19 @@ async def _stream_chat_chunks(
|
||||
|
||||
while retry_count <= MAX_RETRIES:
|
||||
try:
|
||||
elapsed = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||
retry_info = (
|
||||
f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
|
||||
)
|
||||
logger.info(
|
||||
f"Creating OpenAI chat completion stream..."
|
||||
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
|
||||
f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": elapsed,
|
||||
"retry_count": retry_count,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Build extra_body for OpenRouter tracing and PostHog analytics
|
||||
@@ -962,6 +1040,7 @@ async def _stream_chat_chunks(
|
||||
:128
|
||||
] # OpenRouter limit
|
||||
|
||||
api_call_start = time_module.perf_counter()
|
||||
stream = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||
@@ -971,6 +1050,11 @@ async def _stream_chat_chunks(
|
||||
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
|
||||
extra_body=extra_body,
|
||||
)
|
||||
api_init_time = (time_module.perf_counter() - api_call_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": api_init_time}},
|
||||
)
|
||||
|
||||
# Variables to accumulate tool calls
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
@@ -981,10 +1065,13 @@ async def _stream_chat_chunks(
|
||||
|
||||
# Track if we've started the text block
|
||||
text_started = False
|
||||
first_content_chunk = True
|
||||
chunk_count = 0
|
||||
|
||||
# Process the stream
|
||||
chunk: ChatCompletionChunk
|
||||
async for chunk in stream:
|
||||
chunk_count += 1
|
||||
if chunk.usage:
|
||||
yield StreamUsage(
|
||||
promptTokens=chunk.usage.prompt_tokens,
|
||||
@@ -1007,6 +1094,23 @@ async def _stream_chat_chunks(
|
||||
if not text_started and text_block_id:
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
text_started = True
|
||||
# Log timing for first content chunk
|
||||
if first_content_chunk:
|
||||
first_content_chunk = False
|
||||
ttfc = (
|
||||
time_module.perf_counter() - api_call_start
|
||||
) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms "
|
||||
f"(since API call), n_chunks={chunk_count}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"time_to_first_chunk_ms": ttfc,
|
||||
"n_chunks": chunk_count,
|
||||
}
|
||||
},
|
||||
)
|
||||
# Stream the text delta
|
||||
text_response = StreamTextDelta(
|
||||
id=text_block_id or "",
|
||||
@@ -1063,7 +1167,21 @@ async def _stream_chat_chunks(
|
||||
toolName=tool_calls[idx]["function"]["name"],
|
||||
)
|
||||
emitted_start_for_idx.add(idx)
|
||||
logger.info(f"Stream complete. Finish reason: {finish_reason}")
|
||||
stream_duration = time_module.perf_counter() - api_call_start
|
||||
logger.info(
|
||||
f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, "
|
||||
f"duration={stream_duration:.2f}s, "
|
||||
f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"stream_duration_ms": stream_duration * 1000,
|
||||
"finish_reason": finish_reason,
|
||||
"n_chunks": chunk_count,
|
||||
"n_tool_calls": len(tool_calls),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Yield all accumulated tool calls after the stream is complete
|
||||
# This ensures all tool call arguments have been fully received
|
||||
@@ -1083,6 +1201,12 @@ async def _stream_chat_chunks(
|
||||
# Re-raise to trigger retry logic in the parent function
|
||||
raise
|
||||
|
||||
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
|
||||
f"session={session.session_id}, user={session.user_id}",
|
||||
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
except Exception as e:
|
||||
@@ -1150,6 +1274,8 @@ async def _yield_tool_call(
|
||||
KeyError: If expected tool call fields are missing
|
||||
TypeError: If tool call structure is invalid
|
||||
"""
|
||||
import uuid as uuid_module
|
||||
|
||||
tool_name = tool_calls[yield_idx]["function"]["name"]
|
||||
tool_call_id = tool_calls[yield_idx]["id"]
|
||||
|
||||
@@ -1770,6 +1896,8 @@ async def _generate_llm_continuation_with_streaming(
|
||||
after a tool result is saved. Chunks are published to the stream registry
|
||||
so reconnecting clients can receive them.
|
||||
"""
|
||||
import uuid as uuid_module
|
||||
|
||||
try:
|
||||
# Load fresh session from DB (bypass cache to get the updated tool result)
|
||||
await invalidate_session_cache(session_id)
|
||||
@@ -1805,6 +1933,10 @@ async def _generate_llm_continuation_with_streaming(
|
||||
extra_body["session_id"] = session_id[:128]
|
||||
|
||||
# Make streaming LLM call (no tools - just text response)
|
||||
from typing import cast
|
||||
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
# Generate unique IDs for AI SDK protocol
|
||||
message_id = str(uuid_module.uuid4())
|
||||
text_block_id = str(uuid_module.uuid4())
|
||||
|
||||
@@ -104,6 +104,24 @@ async def create_task(
|
||||
Returns:
|
||||
The created ActiveTask instance (metadata only)
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Build log metadata for structured logging
|
||||
log_meta = {
|
||||
"component": "StreamRegistry",
|
||||
"task_id": task_id,
|
||||
"session_id": session_id,
|
||||
}
|
||||
if user_id:
|
||||
log_meta["user_id"] = user_id
|
||||
|
||||
logger.info(
|
||||
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
|
||||
task = ActiveTask(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
@@ -114,10 +132,18 @@ async def create_task(
|
||||
)
|
||||
|
||||
# Store metadata in Redis
|
||||
redis_start = time.perf_counter()
|
||||
redis = await get_redis_async()
|
||||
redis_time = (time.perf_counter() - redis_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] get_redis_async took {redis_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
|
||||
)
|
||||
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
op_key = _get_operation_mapping_key(operation_id)
|
||||
|
||||
hset_start = time.perf_counter()
|
||||
await redis.hset( # type: ignore[misc]
|
||||
meta_key,
|
||||
mapping={
|
||||
@@ -131,12 +157,22 @@ async def create_task(
|
||||
"created_at": task.created_at.isoformat(),
|
||||
},
|
||||
)
|
||||
hset_time = (time.perf_counter() - hset_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] redis.hset took {hset_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
|
||||
)
|
||||
|
||||
await redis.expire(meta_key, config.stream_ttl)
|
||||
|
||||
# Create operation_id -> task_id mapping for webhook lookups
|
||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||
|
||||
logger.debug(f"Created task {task_id} for session {session_id}")
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
|
||||
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
@@ -156,26 +192,60 @@ async def publish_chunk(
|
||||
Returns:
|
||||
The Redis Stream message ID
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.perf_counter()
|
||||
chunk_type = type(chunk).__name__
|
||||
chunk_json = chunk.model_dump_json()
|
||||
message_id = "0-0"
|
||||
|
||||
# Build log metadata
|
||||
log_meta = {
|
||||
"component": "StreamRegistry",
|
||||
"task_id": task_id,
|
||||
"chunk_type": chunk_type,
|
||||
}
|
||||
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Write to Redis Stream for persistence and real-time delivery
|
||||
xadd_start = time.perf_counter()
|
||||
raw_id = await redis.xadd(
|
||||
stream_key,
|
||||
{"data": chunk_json},
|
||||
maxlen=config.stream_max_length,
|
||||
)
|
||||
xadd_time = (time.perf_counter() - xadd_start) * 1000
|
||||
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
||||
|
||||
# Set TTL on stream to match task metadata TTL
|
||||
await redis.expire(stream_key, config.stream_ttl)
|
||||
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
# Only log timing for significant chunks or slow operations
|
||||
if (
|
||||
chunk_type
|
||||
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
|
||||
or total_time > 50
|
||||
):
|
||||
logger.info(
|
||||
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"xadd_time_ms": xadd_time,
|
||||
"message_id": message_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
logger.error(
|
||||
f"Failed to publish chunk for task {task_id}: {e}",
|
||||
f"[TIMING] Failed to publish chunk {chunk_type} after {elapsed:.1f}ms: {e}",
|
||||
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@@ -200,24 +270,61 @@ async def subscribe_to_task(
|
||||
An asyncio Queue that will receive stream chunks, or None if task not found
|
||||
or user doesn't have access
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Build log metadata
|
||||
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
||||
if user_id:
|
||||
log_meta["user_id"] = user_id
|
||||
|
||||
logger.info(
|
||||
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
|
||||
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
|
||||
)
|
||||
|
||||
redis_start = time.perf_counter()
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
hgetall_time = (time.perf_counter() - redis_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
|
||||
)
|
||||
|
||||
if not meta:
|
||||
logger.debug(f"Task {task_id} not found in Redis")
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": elapsed,
|
||||
"reason": "task_not_found",
|
||||
}
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||
task_status = meta.get("status", "")
|
||||
task_user_id = meta.get("user_id", "") or None
|
||||
log_meta["session_id"] = meta.get("session_id", "")
|
||||
|
||||
# Validate ownership - if task has an owner, requester must match
|
||||
if task_user_id:
|
||||
if user_id != task_user_id:
|
||||
logger.warning(
|
||||
f"User {user_id} denied access to task {task_id} "
|
||||
f"owned by {task_user_id}"
|
||||
f"[TIMING] Access denied: user {user_id} tried to access task owned by {task_user_id}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"task_owner": task_user_id,
|
||||
"reason": "access_denied",
|
||||
}
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -225,7 +332,19 @@ async def subscribe_to_task(
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Step 1: Replay messages from Redis Stream
|
||||
xread_start = time.perf_counter()
|
||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"duration_ms": xread_time,
|
||||
"task_status": task_status,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
replayed_count = 0
|
||||
replay_last_id = last_message_id
|
||||
@@ -244,19 +363,48 @@ async def subscribe_to_task(
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replay message: {e}")
|
||||
|
||||
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
|
||||
logger.info(
|
||||
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"n_messages_replayed": replayed_count,
|
||||
"replay_last_id": replay_last_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Step 2: If task is still running, start stream listener for live updates
|
||||
if task_status == "running":
|
||||
logger.info(
|
||||
"[TIMING] Task still running, starting _stream_listener",
|
||||
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
||||
)
|
||||
listener_task = asyncio.create_task(
|
||||
_stream_listener(task_id, subscriber_queue, replay_last_id)
|
||||
_stream_listener(task_id, subscriber_queue, replay_last_id, log_meta)
|
||||
)
|
||||
# Track listener task for cleanup on unsubscribe
|
||||
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
||||
else:
|
||||
# Task is completed/failed - add finish marker
|
||||
logger.info(
|
||||
f"[TIMING] Task already {task_status}, adding StreamFinish",
|
||||
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
||||
)
|
||||
await subscriber_queue.put(StreamFinish())
|
||||
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
|
||||
f"n_messages_replayed={replayed_count}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"n_messages_replayed": replayed_count,
|
||||
}
|
||||
},
|
||||
)
|
||||
return subscriber_queue
|
||||
|
||||
|
||||
@@ -264,6 +412,7 @@ async def _stream_listener(
|
||||
task_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
last_replayed_id: str,
|
||||
log_meta: dict | None = None,
|
||||
) -> None:
|
||||
"""Listen to Redis Stream for new messages using blocking XREAD.
|
||||
|
||||
@@ -274,10 +423,27 @@ async def _stream_listener(
|
||||
task_id: Task ID to listen for
|
||||
subscriber_queue: Queue to deliver messages to
|
||||
last_replayed_id: Last message ID from replay (continue from here)
|
||||
log_meta: Structured logging metadata
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Use provided log_meta or build minimal one
|
||||
if log_meta is None:
|
||||
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
||||
|
||||
logger.info(
|
||||
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
|
||||
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
|
||||
)
|
||||
|
||||
queue_id = id(subscriber_queue)
|
||||
# Track the last successfully delivered message ID for recovery hints
|
||||
last_delivered_id = last_replayed_id
|
||||
messages_delivered = 0
|
||||
first_message_time = None
|
||||
xread_count = 0
|
||||
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
@@ -287,9 +453,39 @@ async def _stream_listener(
|
||||
while True:
|
||||
# Block for up to 30 seconds waiting for new messages
|
||||
# This allows periodic checking if task is still running
|
||||
xread_start = time.perf_counter()
|
||||
xread_count += 1
|
||||
messages = await redis.xread(
|
||||
{stream_key: current_id}, block=30000, count=100
|
||||
)
|
||||
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||
|
||||
if messages:
|
||||
msg_count = sum(len(msgs) for _, msgs in messages)
|
||||
logger.info(
|
||||
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"xread_count": xread_count,
|
||||
"n_messages": msg_count,
|
||||
"duration_ms": xread_time,
|
||||
}
|
||||
},
|
||||
)
|
||||
elif xread_time > 1000:
|
||||
# Only log timeouts (30s blocking)
|
||||
logger.info(
|
||||
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"xread_count": xread_count,
|
||||
"duration_ms": xread_time,
|
||||
"reason": "timeout",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if not messages:
|
||||
# Timeout - check if task is still running
|
||||
@@ -326,10 +522,30 @@ async def _stream_listener(
|
||||
)
|
||||
# Update last delivered ID on successful delivery
|
||||
last_delivered_id = current_id
|
||||
messages_delivered += 1
|
||||
if first_message_time is None:
|
||||
first_message_time = time.perf_counter()
|
||||
elapsed = (first_message_time - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": elapsed,
|
||||
"chunk_type": type(chunk).__name__,
|
||||
}
|
||||
},
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"Subscriber queue full for task {task_id}, "
|
||||
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
|
||||
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"timeout_s": QUEUE_PUT_TIMEOUT,
|
||||
"reason": "queue_full",
|
||||
}
|
||||
},
|
||||
)
|
||||
# Send overflow error with recovery info
|
||||
try:
|
||||
@@ -351,15 +567,44 @@ async def _stream_listener(
|
||||
|
||||
# Stop listening on finish
|
||||
if isinstance(chunk, StreamFinish):
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"messages_delivered": messages_delivered,
|
||||
}
|
||||
},
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing stream message: {e}")
|
||||
logger.warning(
|
||||
f"Error processing stream message: {e}",
|
||||
extra={"json_fields": {**log_meta, "error": str(e)}},
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"Stream listener cancelled for task {task_id}")
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": elapsed,
|
||||
"messages_delivered": messages_delivered,
|
||||
"reason": "cancelled",
|
||||
}
|
||||
},
|
||||
)
|
||||
raise # Re-raise to propagate cancellation
|
||||
except Exception as e:
|
||||
logger.error(f"Stream listener error for task {task_id}: {e}")
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
logger.error(
|
||||
f"[TIMING] _stream_listener ERROR after {elapsed:.1f}ms: {e}",
|
||||
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
|
||||
)
|
||||
# On error, send finish to unblock subscriber
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
@@ -368,10 +613,24 @@ async def _stream_listener(
|
||||
)
|
||||
except (asyncio.TimeoutError, asyncio.QueueFull):
|
||||
logger.warning(
|
||||
f"Could not deliver finish event for task {task_id} after error"
|
||||
"Could not deliver finish event after error",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
finally:
|
||||
# Clean up listener task mapping on exit
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
|
||||
f"delivered={messages_delivered}, xread_count={xread_count}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"messages_delivered": messages_delivered,
|
||||
"xread_count": xread_count,
|
||||
}
|
||||
},
|
||||
)
|
||||
_listener_tasks.pop(queue_id, None)
|
||||
|
||||
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
@@ -30,26 +28,6 @@ from .models import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreateAgentInput(BaseModel):
|
||||
"""Input parameters for the create_agent tool."""
|
||||
|
||||
description: str = ""
|
||||
context: str = ""
|
||||
save: bool = True
|
||||
# Internal async processing params (passed by long-running tool handler)
|
||||
_operation_id: str | None = None
|
||||
_task_id: str | None = None
|
||||
|
||||
@field_validator("description", "context", mode="before")
|
||||
@classmethod
|
||||
def strip_strings(cls, v: Any) -> str:
|
||||
"""Strip whitespace from string fields."""
|
||||
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||
|
||||
class Config:
|
||||
extra = "allow" # Allow _operation_id, _task_id from kwargs
|
||||
|
||||
|
||||
class CreateAgentTool(BaseTool):
|
||||
"""Tool for creating agents from natural language descriptions."""
|
||||
|
||||
@@ -107,7 +85,7 @@ class CreateAgentTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the create_agent tool.
|
||||
|
||||
@@ -116,14 +94,16 @@ class CreateAgentTool(BaseTool):
|
||||
2. Generate agent JSON (external service handles fixing and validation)
|
||||
3. Preview or save based on the save parameter
|
||||
"""
|
||||
params = CreateAgentInput(**kwargs)
|
||||
description = kwargs.get("description", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
|
||||
if not params.description:
|
||||
if not description:
|
||||
return ErrorResponse(
|
||||
message="Please provide a description of what the agent should do.",
|
||||
error="Missing description parameter",
|
||||
@@ -135,7 +115,7 @@ class CreateAgentTool(BaseTool):
|
||||
try:
|
||||
library_agents = await get_all_relevant_agents_for_generation(
|
||||
user_id=user_id,
|
||||
search_query=params.description,
|
||||
search_query=description,
|
||||
include_marketplace=True,
|
||||
)
|
||||
logger.debug(
|
||||
@@ -146,7 +126,7 @@ class CreateAgentTool(BaseTool):
|
||||
|
||||
try:
|
||||
decomposition_result = await decompose_goal(
|
||||
params.description, params.context, library_agents
|
||||
description, context, library_agents
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
@@ -162,7 +142,7 @@ class CreateAgentTool(BaseTool):
|
||||
return ErrorResponse(
|
||||
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
|
||||
error="decomposition_failed",
|
||||
details={"description": params.description[:100]},
|
||||
details={"description": description[:100]},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -178,7 +158,7 @@ class CreateAgentTool(BaseTool):
|
||||
message=user_message,
|
||||
error=f"decomposition_failed:{error_type}",
|
||||
details={
|
||||
"description": params.description[:100],
|
||||
"description": description[:100],
|
||||
"service_error": error_msg,
|
||||
"error_type": error_type,
|
||||
},
|
||||
@@ -264,7 +244,7 @@ class CreateAgentTool(BaseTool):
|
||||
return ErrorResponse(
|
||||
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
|
||||
error="generation_failed",
|
||||
details={"description": params.description[:100]},
|
||||
details={"description": description[:100]},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -286,7 +266,7 @@ class CreateAgentTool(BaseTool):
|
||||
message=user_message,
|
||||
error=f"generation_failed:{error_type}",
|
||||
details={
|
||||
"description": params.description[:100],
|
||||
"description": description[:100],
|
||||
"service_error": error_msg,
|
||||
"error_type": error_type,
|
||||
},
|
||||
@@ -311,7 +291,7 @@ class CreateAgentTool(BaseTool):
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
link_count = len(agent_json.get("links", []))
|
||||
|
||||
if not params.save:
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||
@@ -29,23 +27,6 @@ from .models import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomizeAgentInput(BaseModel):
|
||||
"""Input parameters for the customize_agent tool."""
|
||||
|
||||
agent_id: str = ""
|
||||
modifications: str = ""
|
||||
context: str = ""
|
||||
save: bool = True
|
||||
|
||||
@field_validator("agent_id", "modifications", "context", mode="before")
|
||||
@classmethod
|
||||
def strip_strings(cls, v: Any) -> str:
|
||||
"""Strip whitespace from string fields."""
|
||||
if isinstance(v, str):
|
||||
return v.strip()
|
||||
return v if v is not None else ""
|
||||
|
||||
|
||||
class CustomizeAgentTool(BaseTool):
|
||||
"""Tool for customizing marketplace/template agents using natural language."""
|
||||
|
||||
@@ -111,7 +92,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the customize_agent tool.
|
||||
|
||||
@@ -121,17 +102,20 @@ class CustomizeAgentTool(BaseTool):
|
||||
3. Call customize_template with the modification request
|
||||
4. Preview or save based on the save parameter
|
||||
"""
|
||||
params = CustomizeAgentInput(**kwargs)
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
modifications = kwargs.get("modifications", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not params.agent_id:
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
|
||||
error="missing_agent_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not params.modifications:
|
||||
if not modifications:
|
||||
return ErrorResponse(
|
||||
message="Please describe how you want to customize this agent.",
|
||||
error="missing_modifications",
|
||||
@@ -139,11 +123,11 @@ class CustomizeAgentTool(BaseTool):
|
||||
)
|
||||
|
||||
# Parse agent_id in format "creator/slug"
|
||||
parts = [p.strip() for p in params.agent_id.split("/")]
|
||||
parts = [p.strip() for p in agent_id.split("/")]
|
||||
if len(parts) != 2 or not parts[0] or not parts[1]:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Invalid agent ID format: '{params.agent_id}'. "
|
||||
f"Invalid agent ID format: '{agent_id}'. "
|
||||
"Expected format is 'creator/agent-name' "
|
||||
"(e.g., 'autogpt/newsletter-writer')."
|
||||
),
|
||||
@@ -161,14 +145,14 @@ class CustomizeAgentTool(BaseTool):
|
||||
except AgentNotFoundError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Could not find marketplace agent '{params.agent_id}'. "
|
||||
f"Could not find marketplace agent '{agent_id}'. "
|
||||
"Please check the agent ID and try again."
|
||||
),
|
||||
error="agent_not_found",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching marketplace agent {params.agent_id}: {e}")
|
||||
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to fetch the marketplace agent. Please try again.",
|
||||
error="fetch_error",
|
||||
@@ -178,7 +162,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
if not agent_details.store_listing_version_id:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"The agent '{params.agent_id}' does not have an available version. "
|
||||
f"The agent '{agent_id}' does not have an available version. "
|
||||
"Please try a different agent."
|
||||
),
|
||||
error="no_version_available",
|
||||
@@ -190,7 +174,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
graph = await store_db.get_agent(agent_details.store_listing_version_id)
|
||||
template_agent = graph_to_json(graph)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agent graph for {params.agent_id}: {e}")
|
||||
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to fetch the agent configuration. Please try again.",
|
||||
error="graph_fetch_error",
|
||||
@@ -201,8 +185,8 @@ class CustomizeAgentTool(BaseTool):
|
||||
try:
|
||||
result = await customize_template(
|
||||
template_agent=template_agent,
|
||||
modification_request=params.modifications,
|
||||
context=params.context,
|
||||
modification_request=modifications,
|
||||
context=context,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
@@ -214,7 +198,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling customize_template for {params.agent_id}: {e}")
|
||||
logger.error(f"Error calling customize_template for {agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Failed to customize the agent due to a service error. "
|
||||
@@ -235,25 +219,55 @@ class CustomizeAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Handle response using match/case for cleaner pattern matching
|
||||
return await self._handle_customization_result(
|
||||
result=result,
|
||||
params=params,
|
||||
agent_details=agent_details,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
# Handle error response
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
error_type = result.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="customize the agent",
|
||||
llm_parse_message=(
|
||||
"The AI had trouble customizing the agent. "
|
||||
"Please try again or simplify your request."
|
||||
),
|
||||
validation_message=(
|
||||
"The customized agent failed validation. "
|
||||
"Please try rephrasing your request."
|
||||
),
|
||||
error_details=error_msg,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"customization_failed:{error_type}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _handle_customization_result(
|
||||
self,
|
||||
result: dict[str, Any],
|
||||
params: CustomizeAgentInput,
|
||||
agent_details: Any,
|
||||
user_id: str | None,
|
||||
session_id: str | None,
|
||||
) -> ToolResponseBase:
|
||||
"""Handle the result from customize_template using pattern matching."""
|
||||
# Ensure result is a dict
|
||||
# Handle clarifying questions
|
||||
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
|
||||
questions = result.get("questions") or []
|
||||
if not isinstance(questions, list):
|
||||
logger.error(
|
||||
f"Unexpected clarifying questions format: {type(questions)}"
|
||||
)
|
||||
questions = []
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information to customize this agent. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
if isinstance(q, dict)
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Result should be the customized agent JSON
|
||||
if not isinstance(result, dict):
|
||||
logger.error(f"Unexpected customize_template response type: {type(result)}")
|
||||
return ErrorResponse(
|
||||
@@ -262,77 +276,8 @@ class CustomizeAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
result_type = result.get("type")
|
||||
customized_agent = result
|
||||
|
||||
match result_type:
|
||||
case "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
error_type = result.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="customize the agent",
|
||||
llm_parse_message=(
|
||||
"The AI had trouble customizing the agent. "
|
||||
"Please try again or simplify your request."
|
||||
),
|
||||
validation_message=(
|
||||
"The customized agent failed validation. "
|
||||
"Please try rephrasing your request."
|
||||
),
|
||||
error_details=error_msg,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"customization_failed:{error_type}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
case "clarifying_questions":
|
||||
questions_data = result.get("questions") or []
|
||||
if not isinstance(questions_data, list):
|
||||
logger.error(
|
||||
f"Unexpected clarifying questions format: {type(questions_data)}"
|
||||
)
|
||||
questions_data = []
|
||||
|
||||
questions = [
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", "") if isinstance(q, dict) else "",
|
||||
keyword=q.get("keyword", "") if isinstance(q, dict) else "",
|
||||
example=q.get("example") if isinstance(q, dict) else None,
|
||||
)
|
||||
for q in questions_data
|
||||
if isinstance(q, dict)
|
||||
]
|
||||
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information to customize this agent. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=questions,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
case _:
|
||||
# Default case: result is the customized agent JSON
|
||||
return await self._save_or_preview_agent(
|
||||
customized_agent=result,
|
||||
params=params,
|
||||
agent_details=agent_details,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _save_or_preview_agent(
|
||||
self,
|
||||
customized_agent: dict[str, Any],
|
||||
params: CustomizeAgentInput,
|
||||
agent_details: Any,
|
||||
user_id: str | None,
|
||||
session_id: str | None,
|
||||
) -> ToolResponseBase:
|
||||
"""Save or preview the customized agent based on params.save."""
|
||||
agent_name = customized_agent.get(
|
||||
"name", f"Customized {agent_details.agent_name}"
|
||||
)
|
||||
@@ -342,7 +287,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
node_count = len(nodes) if isinstance(nodes, list) else 0
|
||||
link_count = len(links) if isinstance(links, list) else 0
|
||||
|
||||
if not params.save:
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've customized the agent '{agent_details.agent_name}'. "
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
@@ -29,20 +27,6 @@ from .models import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EditAgentInput(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
agent_id: str = ""
|
||||
changes: str = ""
|
||||
context: str = ""
|
||||
save: bool = True
|
||||
|
||||
@field_validator("agent_id", "changes", "context", mode="before")
|
||||
@classmethod
|
||||
def strip_strings(cls, v: Any) -> str:
|
||||
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||
|
||||
|
||||
class EditAgentTool(BaseTool):
|
||||
"""Tool for editing existing agents using natural language."""
|
||||
|
||||
@@ -106,7 +90,7 @@ class EditAgentTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the edit_agent tool.
|
||||
|
||||
@@ -115,32 +99,35 @@ class EditAgentTool(BaseTool):
|
||||
2. Generate updated agent (external service handles fixing and validation)
|
||||
3. Preview or save based on the save parameter
|
||||
"""
|
||||
params = EditAgentInput(**kwargs)
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
changes = kwargs.get("changes", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
|
||||
if not params.agent_id:
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the agent ID to edit.",
|
||||
error="Missing agent_id parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not params.changes:
|
||||
if not changes:
|
||||
return ErrorResponse(
|
||||
message="Please describe what changes you want to make.",
|
||||
error="Missing changes parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
current_agent = await get_agent_as_json(params.agent_id, user_id)
|
||||
current_agent = await get_agent_as_json(agent_id, user_id)
|
||||
|
||||
if current_agent is None:
|
||||
return ErrorResponse(
|
||||
message=f"Could not find agent '{params.agent_id}' in your library.",
|
||||
message=f"Could not find agent with ID '{agent_id}' in your library.",
|
||||
error="agent_not_found",
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -151,7 +138,7 @@ class EditAgentTool(BaseTool):
|
||||
graph_id = current_agent.get("id")
|
||||
library_agents = await get_all_relevant_agents_for_generation(
|
||||
user_id=user_id,
|
||||
search_query=params.changes,
|
||||
search_query=changes,
|
||||
exclude_graph_id=graph_id,
|
||||
include_marketplace=True,
|
||||
)
|
||||
@@ -161,11 +148,9 @@ class EditAgentTool(BaseTool):
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents: {e}")
|
||||
|
||||
update_request = params.changes
|
||||
if params.context:
|
||||
update_request = (
|
||||
f"{params.changes}\n\nAdditional context:\n{params.context}"
|
||||
)
|
||||
update_request = changes
|
||||
if context:
|
||||
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
||||
|
||||
try:
|
||||
result = await generate_agent_patch(
|
||||
@@ -189,7 +174,7 @@ class EditAgentTool(BaseTool):
|
||||
return ErrorResponse(
|
||||
message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.",
|
||||
error="update_generation_failed",
|
||||
details={"agent_id": params.agent_id, "changes": params.changes[:100]},
|
||||
details={"agent_id": agent_id, "changes": changes[:100]},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -221,8 +206,8 @@ class EditAgentTool(BaseTool):
|
||||
message=user_message,
|
||||
error=f"update_generation_failed:{error_type}",
|
||||
details={
|
||||
"agent_id": params.agent_id,
|
||||
"changes": params.changes[:100],
|
||||
"agent_id": agent_id,
|
||||
"changes": changes[:100],
|
||||
"service_error": error_msg,
|
||||
"error_type": error_type,
|
||||
},
|
||||
@@ -254,7 +239,7 @@ class EditAgentTool(BaseTool):
|
||||
node_count = len(updated_agent.get("nodes", []))
|
||||
link_count = len(updated_agent.get("links", []))
|
||||
|
||||
if not params.save:
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've updated the agent. "
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
@@ -11,18 +9,6 @@ from .base import BaseTool
|
||||
from .models import ToolResponseBase
|
||||
|
||||
|
||||
class FindAgentInput(BaseModel):
|
||||
"""Input parameters for the find_agent tool."""
|
||||
|
||||
query: str = ""
|
||||
|
||||
@field_validator("query", mode="before")
|
||||
@classmethod
|
||||
def strip_string(cls, v: Any) -> str:
|
||||
"""Strip whitespace from query."""
|
||||
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||
|
||||
|
||||
class FindAgentTool(BaseTool):
|
||||
"""Tool for discovering agents from the marketplace."""
|
||||
|
||||
@@ -50,11 +36,10 @@ class FindAgentTool(BaseTool):
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs: Any
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
params = FindAgentInput(**kwargs)
|
||||
return await search_agents(
|
||||
query=params.query,
|
||||
query=kwargs.get("query", "").strip(),
|
||||
source="marketplace",
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
|
||||
@@ -2,7 +2,6 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
|
||||
@@ -41,18 +40,6 @@ COPILOT_EXCLUDED_BLOCK_IDS = {
|
||||
}
|
||||
|
||||
|
||||
class FindBlockInput(BaseModel):
|
||||
"""Input parameters for the find_block tool."""
|
||||
|
||||
query: str = ""
|
||||
|
||||
@field_validator("query", mode="before")
|
||||
@classmethod
|
||||
def strip_string(cls, v: Any) -> str:
|
||||
"""Strip whitespace from query."""
|
||||
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||
|
||||
|
||||
class FindBlockTool(BaseTool):
|
||||
"""Tool for searching available blocks."""
|
||||
|
||||
@@ -94,24 +81,24 @@ class FindBlockTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search for blocks matching the query.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
**kwargs: Tool parameters
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
BlockListResponse: List of matching blocks
|
||||
NoResultsResponse: No blocks found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
params = FindBlockInput(**kwargs)
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id
|
||||
|
||||
if not params.query:
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
session_id=session_id,
|
||||
@@ -120,7 +107,7 @@ class FindBlockTool(BaseTool):
|
||||
try:
|
||||
# Search for blocks using hybrid search
|
||||
results, total = await unified_hybrid_search(
|
||||
query=params.query,
|
||||
query=query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=_OVERFETCH_PAGE_SIZE,
|
||||
@@ -128,7 +115,7 @@ class FindBlockTool(BaseTool):
|
||||
|
||||
if not results:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found for '{params.query}'",
|
||||
message=f"No blocks found for '{query}'",
|
||||
suggestions=[
|
||||
"Try broader keywords like 'email', 'http', 'text', 'ai'",
|
||||
"Check spelling of technical terms",
|
||||
@@ -230,7 +217,7 @@ class FindBlockTool(BaseTool):
|
||||
|
||||
if not blocks:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found for '{params.query}'",
|
||||
message=f"No blocks found for '{query}'",
|
||||
suggestions=[
|
||||
"Try broader keywords like 'email', 'http', 'text', 'ai'",
|
||||
],
|
||||
@@ -239,13 +226,13 @@ class FindBlockTool(BaseTool):
|
||||
|
||||
return BlockListResponse(
|
||||
message=(
|
||||
f"Found {len(blocks)} block(s) matching '{params.query}'. "
|
||||
f"Found {len(blocks)} block(s) matching '{query}'. "
|
||||
"To execute a block, use run_block with the block's 'id' field "
|
||||
"and provide 'input_data' matching the block's input_schema."
|
||||
),
|
||||
blocks=blocks,
|
||||
count=len(blocks),
|
||||
query=params.query,
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
@@ -11,15 +9,6 @@ from .base import BaseTool
|
||||
from .models import ToolResponseBase
|
||||
|
||||
|
||||
class FindLibraryAgentInput(BaseModel):
|
||||
query: str = ""
|
||||
|
||||
@field_validator("query", mode="before")
|
||||
@classmethod
|
||||
def strip_string(cls, v: Any) -> str:
|
||||
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||
|
||||
|
||||
class FindLibraryAgentTool(BaseTool):
|
||||
"""Tool for searching agents in the user's library."""
|
||||
|
||||
@@ -53,11 +42,10 @@ class FindLibraryAgentTool(BaseTool):
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs: Any
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
params = FindLibraryAgentInput(**kwargs)
|
||||
return await search_agents(
|
||||
query=params.query,
|
||||
query=kwargs.get("query", "").strip(),
|
||||
source="library",
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
|
||||
@@ -4,8 +4,6 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
@@ -20,18 +18,6 @@ logger = logging.getLogger(__name__)
|
||||
DOCS_BASE_URL = "https://docs.agpt.co"
|
||||
|
||||
|
||||
class GetDocPageInput(BaseModel):
|
||||
"""Input parameters for the get_doc_page tool."""
|
||||
|
||||
path: str = ""
|
||||
|
||||
@field_validator("path", mode="before")
|
||||
@classmethod
|
||||
def strip_string(cls, v: Any) -> str:
|
||||
"""Strip whitespace from path."""
|
||||
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||
|
||||
|
||||
class GetDocPageTool(BaseTool):
|
||||
"""Tool for fetching full content of a documentation page."""
|
||||
|
||||
@@ -89,23 +75,23 @@ class GetDocPageTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Fetch full content of a documentation page.
|
||||
|
||||
Args:
|
||||
user_id: User ID (not required for docs)
|
||||
session: Chat session
|
||||
**kwargs: Tool parameters
|
||||
path: Path to the documentation file
|
||||
|
||||
Returns:
|
||||
DocPageResponse: Full document content
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
params = GetDocPageInput(**kwargs)
|
||||
path = kwargs.get("path", "").strip()
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not params.path:
|
||||
if not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide a documentation path.",
|
||||
error="Missing path parameter",
|
||||
@@ -113,7 +99,7 @@ class GetDocPageTool(BaseTool):
|
||||
)
|
||||
|
||||
# Sanitize path to prevent directory traversal
|
||||
if ".." in params.path or params.path.startswith("/"):
|
||||
if ".." in path or path.startswith("/"):
|
||||
return ErrorResponse(
|
||||
message="Invalid documentation path.",
|
||||
error="invalid_path",
|
||||
@@ -121,11 +107,11 @@ class GetDocPageTool(BaseTool):
|
||||
)
|
||||
|
||||
docs_root = self._get_docs_root()
|
||||
full_path = docs_root / params.path
|
||||
full_path = docs_root / path
|
||||
|
||||
if not full_path.exists():
|
||||
return ErrorResponse(
|
||||
message=f"Documentation page not found: {params.path}",
|
||||
message=f"Documentation page not found: {path}",
|
||||
error="not_found",
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -142,19 +128,19 @@ class GetDocPageTool(BaseTool):
|
||||
|
||||
try:
|
||||
content = full_path.read_text(encoding="utf-8")
|
||||
title = self._extract_title(content, params.path)
|
||||
title = self._extract_title(content, path)
|
||||
|
||||
return DocPageResponse(
|
||||
message=f"Retrieved documentation page: {title}",
|
||||
title=title,
|
||||
path=params.path,
|
||||
path=path,
|
||||
content=content,
|
||||
doc_url=self._make_doc_url(params.path),
|
||||
doc_url=self._make_doc_url(path),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read documentation page {params.path}: {e}")
|
||||
logger.error(f"Failed to read documentation page {path}: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read documentation page: {str(e)}",
|
||||
error="read_failed",
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Shared helpers for chat tools."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def get_inputs_from_schema(
|
||||
input_schema: dict[str, Any],
|
||||
exclude_fields: set[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Extract input field info from JSON schema."""
|
||||
if not isinstance(input_schema, dict):
|
||||
return []
|
||||
|
||||
exclude = exclude_fields or set()
|
||||
properties = input_schema.get("properties", {})
|
||||
required = set(input_schema.get("required", []))
|
||||
|
||||
return [
|
||||
{
|
||||
"name": name,
|
||||
"title": schema.get("title", name),
|
||||
"type": schema.get("type", "string"),
|
||||
"description": schema.get("description", ""),
|
||||
"required": name in required,
|
||||
"default": schema.get("default"),
|
||||
}
|
||||
for name, schema in properties.items()
|
||||
if name not in exclude
|
||||
]
|
||||
@@ -24,6 +24,7 @@ from backend.util.timezone_utils import (
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .helpers import get_inputs_from_schema
|
||||
from .models import (
|
||||
AgentDetails,
|
||||
AgentDetailsResponse,
|
||||
@@ -261,7 +262,7 @@ class RunAgentTool(BaseTool):
|
||||
),
|
||||
requirements={
|
||||
"credentials": requirements_creds_list,
|
||||
"inputs": self._get_inputs_list(graph.input_schema),
|
||||
"inputs": get_inputs_from_schema(graph.input_schema),
|
||||
"execution_modes": self._get_execution_modes(graph),
|
||||
},
|
||||
),
|
||||
@@ -369,22 +370,6 @@ class RunAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Extract inputs list from schema."""
|
||||
inputs_list = []
|
||||
if isinstance(input_schema, dict) and "properties" in input_schema:
|
||||
for field_name, field_schema in input_schema["properties"].items():
|
||||
inputs_list.append(
|
||||
{
|
||||
"name": field_name,
|
||||
"title": field_schema.get("title", field_name),
|
||||
"type": field_schema.get("type", "string"),
|
||||
"description": field_schema.get("description", ""),
|
||||
"required": field_name in input_schema.get("required", []),
|
||||
}
|
||||
)
|
||||
return inputs_list
|
||||
|
||||
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
||||
"""Get available execution modes for the graph."""
|
||||
trigger_info = graph.trigger_setup_info
|
||||
@@ -398,7 +383,7 @@ class RunAgentTool(BaseTool):
|
||||
suffix: str,
|
||||
) -> str:
|
||||
"""Build a message describing available inputs for an agent."""
|
||||
inputs_list = self._get_inputs_list(graph.input_schema)
|
||||
inputs_list = get_inputs_from_schema(graph.input_schema)
|
||||
required_names = [i["name"] for i in inputs_list if i["required"]]
|
||||
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
@@ -13,14 +12,15 @@ from backend.api.features.chat.tools.find_block import (
|
||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||
)
|
||||
from backend.data.block import get_block
|
||||
from backend.data.block import AnyBlockSchema, get_block
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
from .base import BaseTool
|
||||
from .helpers import get_inputs_from_schema
|
||||
from .models import (
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
@@ -29,30 +29,14 @@ from .models import (
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
from .utils import build_missing_credentials_from_field_info
|
||||
from .utils import (
|
||||
build_missing_credentials_from_field_info,
|
||||
match_credentials_to_requirements,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunBlockInput(BaseModel):
|
||||
"""Input parameters for the run_block tool."""
|
||||
|
||||
block_id: str = ""
|
||||
input_data: dict[str, Any] = {}
|
||||
|
||||
@field_validator("block_id", mode="before")
|
||||
@classmethod
|
||||
def strip_block_id(cls, v: Any) -> str:
|
||||
"""Strip whitespace from block_id."""
|
||||
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||
|
||||
@field_validator("input_data", mode="before")
|
||||
@classmethod
|
||||
def ensure_dict(cls, v: Any) -> dict[str, Any]:
|
||||
"""Ensure input_data is a dict."""
|
||||
return v if isinstance(v, dict) else {}
|
||||
|
||||
|
||||
class RunBlockTool(BaseTool):
|
||||
"""Tool for executing a block and returning its outputs."""
|
||||
|
||||
@@ -97,118 +81,41 @@ class RunBlockTool(BaseTool):
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _check_block_credentials(
|
||||
self,
|
||||
user_id: str,
|
||||
block: Any,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Check if user has required credentials for a block.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
block: Block to check credentials for
|
||||
input_data: Input data for the block (used to determine provider via discriminator)
|
||||
|
||||
Returns:
|
||||
tuple[matched_credentials, missing_credentials]
|
||||
"""
|
||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
||||
missing_credentials: list[CredentialsMetaInput] = []
|
||||
input_data = input_data or {}
|
||||
|
||||
# Get credential field info from block's input schema
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
|
||||
if not credentials_fields_info:
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
# Get user's available credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
effective_field_info = field_info
|
||||
if field_info.discriminator and field_info.discriminator_mapping:
|
||||
# Get discriminator from input, falling back to schema default
|
||||
discriminator_value = input_data.get(field_info.discriminator)
|
||||
if discriminator_value is None:
|
||||
field = block.input_schema.model_fields.get(
|
||||
field_info.discriminator
|
||||
)
|
||||
if field and field.default is not PydanticUndefined:
|
||||
discriminator_value = field.default
|
||||
|
||||
if (
|
||||
discriminator_value
|
||||
and discriminator_value in field_info.discriminator_mapping
|
||||
):
|
||||
effective_field_info = field_info.discriminate(discriminator_value)
|
||||
logger.debug(
|
||||
f"Discriminated provider for {field_name}: "
|
||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||
)
|
||||
|
||||
matching_cred = next(
|
||||
(
|
||||
cred
|
||||
for cred in available_creds
|
||||
if cred.provider in effective_field_info.provider
|
||||
and cred.type in effective_field_info.supported_types
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if matching_cred:
|
||||
matched_credentials[field_name] = CredentialsMetaInput(
|
||||
id=matching_cred.id,
|
||||
provider=matching_cred.provider, # type: ignore
|
||||
type=matching_cred.type,
|
||||
title=matching_cred.title,
|
||||
)
|
||||
else:
|
||||
# Create a placeholder for the missing credential
|
||||
provider = next(iter(effective_field_info.provider), "unknown")
|
||||
cred_type = next(iter(effective_field_info.supported_types), "api_key")
|
||||
missing_credentials.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type, # type: ignore
|
||||
title=field_name.replace("_", " ").title(),
|
||||
)
|
||||
)
|
||||
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute a block with the given input data.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
**kwargs: Tool parameters
|
||||
block_id: Block UUID to execute
|
||||
input_data: Input values for the block
|
||||
|
||||
Returns:
|
||||
BlockOutputResponse: Block execution outputs
|
||||
SetupRequirementsResponse: Missing credentials
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
params = RunBlockInput(**kwargs)
|
||||
block_id = kwargs.get("block_id", "").strip()
|
||||
input_data = kwargs.get("input_data", {})
|
||||
session_id = session.session_id
|
||||
|
||||
if not params.block_id:
|
||||
if not block_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide a block_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not isinstance(input_data, dict):
|
||||
return ErrorResponse(
|
||||
message="input_data must be an object",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
@@ -216,15 +123,15 @@ class RunBlockTool(BaseTool):
|
||||
)
|
||||
|
||||
# Get the block
|
||||
block = get_block(params.block_id)
|
||||
block = get_block(block_id)
|
||||
if not block:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{params.block_id}' not found",
|
||||
message=f"Block '{block_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
if block.disabled:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{params.block_id}' is disabled",
|
||||
message=f"Block '{block_id}' is disabled",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -244,8 +151,8 @@ class RunBlockTool(BaseTool):
|
||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||
user_id, block, params.input_data
|
||||
matched_credentials, missing_credentials = (
|
||||
await self._resolve_block_credentials(user_id, block, input_data)
|
||||
)
|
||||
|
||||
if missing_credentials:
|
||||
@@ -263,7 +170,7 @@ class RunBlockTool(BaseTool):
|
||||
),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=params.block_id,
|
||||
agent_id=block_id,
|
||||
agent_name=block.name,
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
@@ -292,7 +199,7 @@ class RunBlockTool(BaseTool):
|
||||
# - node_exec_id = unique per block execution
|
||||
synthetic_graph_id = f"copilot-session-{session.session_id}"
|
||||
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
|
||||
synthetic_node_id = f"copilot-node-{params.block_id}"
|
||||
synthetic_node_id = f"copilot-node-{block_id}"
|
||||
synthetic_node_exec_id = (
|
||||
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
|
||||
)
|
||||
@@ -327,8 +234,8 @@ class RunBlockTool(BaseTool):
|
||||
|
||||
for field_name, cred_meta in matched_credentials.items():
|
||||
# Inject metadata into input_data (for validation)
|
||||
if field_name not in params.input_data:
|
||||
params.input_data[field_name] = cred_meta.model_dump()
|
||||
if field_name not in input_data:
|
||||
input_data[field_name] = cred_meta.model_dump()
|
||||
|
||||
# Fetch actual credentials and pass as kwargs (for execution)
|
||||
actual_credentials = await creds_manager.get(
|
||||
@@ -345,14 +252,14 @@ class RunBlockTool(BaseTool):
|
||||
# Execute the block and collect outputs
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
params.input_data,
|
||||
input_data,
|
||||
**exec_kwargs,
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=params.block_id,
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
success=True,
|
||||
@@ -374,29 +281,75 @@ class RunBlockTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
||||
async def _resolve_block_credentials(
|
||||
self,
|
||||
user_id: str,
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Resolve credentials for a block by matching user's available credentials.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
block: Block to resolve credentials for
|
||||
input_data: Input data for the block (used to determine provider via discriminator)
|
||||
|
||||
Returns:
|
||||
tuple of (matched_credentials, missing_credentials) - matched credentials
|
||||
are used for block execution, missing ones indicate setup requirements.
|
||||
"""
|
||||
input_data = input_data or {}
|
||||
requirements = self._resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
if not requirements:
|
||||
return {}, []
|
||||
|
||||
return await match_credentials_to_requirements(user_id, requirements)
|
||||
|
||||
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
|
||||
"""Extract non-credential inputs from block schema."""
|
||||
inputs_list = []
|
||||
schema = block.input_schema.jsonschema()
|
||||
properties = schema.get("properties", {})
|
||||
required_fields = set(schema.get("required", []))
|
||||
|
||||
# Get credential field names to exclude
|
||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
# Skip credential fields
|
||||
if field_name in credentials_fields:
|
||||
continue
|
||||
def _resolve_discriminated_credentials(
|
||||
self,
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any],
|
||||
) -> dict[str, CredentialsFieldInfo]:
|
||||
"""Resolve credential requirements, applying discriminator logic where needed."""
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
if not credentials_fields_info:
|
||||
return {}
|
||||
|
||||
inputs_list.append(
|
||||
{
|
||||
"name": field_name,
|
||||
"title": field_schema.get("title", field_name),
|
||||
"type": field_schema.get("type", "string"),
|
||||
"description": field_schema.get("description", ""),
|
||||
"required": field_name in required_fields,
|
||||
}
|
||||
)
|
||||
resolved: dict[str, CredentialsFieldInfo] = {}
|
||||
|
||||
return inputs_list
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
effective_field_info = field_info
|
||||
|
||||
if field_info.discriminator and field_info.discriminator_mapping:
|
||||
discriminator_value = input_data.get(field_info.discriminator)
|
||||
if discriminator_value is None:
|
||||
field = block.input_schema.model_fields.get(
|
||||
field_info.discriminator
|
||||
)
|
||||
if field and field.default is not PydanticUndefined:
|
||||
discriminator_value = field.default
|
||||
|
||||
if (
|
||||
discriminator_value
|
||||
and discriminator_value in field_info.discriminator_mapping
|
||||
):
|
||||
effective_field_info = field_info.discriminate(discriminator_value)
|
||||
# For host-scoped credentials, add the discriminator value
|
||||
# (e.g., URL) so _credential_is_for_host can match it
|
||||
effective_field_info.discriminator_values.add(discriminator_value)
|
||||
logger.debug(
|
||||
f"Discriminated provider for {field_name}: "
|
||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||
)
|
||||
|
||||
resolved[field_name] = effective_field_info
|
||||
|
||||
return resolved
|
||||
|
||||
@@ -4,7 +4,6 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
@@ -29,18 +28,6 @@ MAX_RESULTS = 5
|
||||
SNIPPET_LENGTH = 200
|
||||
|
||||
|
||||
class SearchDocsInput(BaseModel):
|
||||
"""Input parameters for the search_docs tool."""
|
||||
|
||||
query: str = ""
|
||||
|
||||
@field_validator("query", mode="before")
|
||||
@classmethod
|
||||
def strip_string(cls, v: Any) -> str:
|
||||
"""Strip whitespace from query."""
|
||||
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||
|
||||
|
||||
class SearchDocsTool(BaseTool):
|
||||
"""Tool for searching AutoGPT platform documentation."""
|
||||
|
||||
@@ -104,24 +91,24 @@ class SearchDocsTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search documentation and return relevant sections.
|
||||
|
||||
Args:
|
||||
user_id: User ID (not required for docs)
|
||||
session: Chat session
|
||||
**kwargs: Tool parameters
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
DocSearchResultsResponse: List of matching documentation sections
|
||||
NoResultsResponse: No results found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
params = SearchDocsInput(**kwargs)
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not params.query:
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query.",
|
||||
error="Missing query parameter",
|
||||
@@ -131,7 +118,7 @@ class SearchDocsTool(BaseTool):
|
||||
try:
|
||||
# Search using hybrid search for DOCUMENTATION content type only
|
||||
results, total = await unified_hybrid_search(
|
||||
query=params.query,
|
||||
query=query,
|
||||
content_types=[ContentType.DOCUMENTATION],
|
||||
page=1,
|
||||
page_size=MAX_RESULTS * 2, # Fetch extra for deduplication
|
||||
@@ -140,7 +127,7 @@ class SearchDocsTool(BaseTool):
|
||||
|
||||
if not results:
|
||||
return NoResultsResponse(
|
||||
message=f"No documentation found for '{params.query}'.",
|
||||
message=f"No documentation found for '{query}'.",
|
||||
suggestions=[
|
||||
"Try different keywords",
|
||||
"Use more general terms",
|
||||
@@ -175,7 +162,7 @@ class SearchDocsTool(BaseTool):
|
||||
|
||||
if not deduplicated:
|
||||
return NoResultsResponse(
|
||||
message=f"No documentation found for '{params.query}'.",
|
||||
message=f"No documentation found for '{query}'.",
|
||||
suggestions=[
|
||||
"Try different keywords",
|
||||
"Use more general terms",
|
||||
@@ -208,7 +195,7 @@ class SearchDocsTool(BaseTool):
|
||||
message=f"Found {len(doc_results)} relevant documentation sections.",
|
||||
results=doc_results,
|
||||
count=len(doc_results),
|
||||
query=params.query,
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from backend.api.features.library import model as library_model
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import (
|
||||
Credentials,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsMetaInput,
|
||||
HostScopedCredentials,
|
||||
@@ -223,6 +224,99 @@ async def get_or_create_library_agent(
|
||||
return library_agents[0]
|
||||
|
||||
|
||||
async def match_credentials_to_requirements(
|
||||
user_id: str,
|
||||
requirements: dict[str, CredentialsFieldInfo],
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Match user's credentials against a dictionary of credential requirements.
|
||||
|
||||
This is the core matching logic shared by both graph and block credential matching.
|
||||
"""
|
||||
matched: dict[str, CredentialsMetaInput] = {}
|
||||
missing: list[CredentialsMetaInput] = []
|
||||
|
||||
if not requirements:
|
||||
return matched, missing
|
||||
|
||||
available_creds = await get_user_credentials(user_id)
|
||||
|
||||
for field_name, field_info in requirements.items():
|
||||
matching_cred = find_matching_credential(available_creds, field_info)
|
||||
|
||||
if matching_cred:
|
||||
try:
|
||||
matched[field_name] = create_credential_meta_from_match(matching_cred)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create CredentialsMetaInput for field '{field_name}': "
|
||||
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
||||
f"credential_id={matching_cred.id}",
|
||||
exc_info=True,
|
||||
)
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||
missing.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type, # type: ignore
|
||||
title=f"{field_name} (validation failed: {e})",
|
||||
)
|
||||
)
|
||||
else:
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||
missing.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type, # type: ignore
|
||||
title=field_name.replace("_", " ").title(),
|
||||
)
|
||||
)
|
||||
|
||||
return matched, missing
|
||||
|
||||
|
||||
async def get_user_credentials(user_id: str) -> list[Credentials]:
|
||||
"""Get all available credentials for a user."""
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
return await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
|
||||
def find_matching_credential(
|
||||
available_creds: list[Credentials],
|
||||
field_info: CredentialsFieldInfo,
|
||||
) -> Credentials | None:
|
||||
"""Find a credential that matches the required provider, type, scopes, and host."""
|
||||
for cred in available_creds:
|
||||
if cred.provider not in field_info.provider:
|
||||
continue
|
||||
if cred.type not in field_info.supported_types:
|
||||
continue
|
||||
if cred.type == "oauth2" and not _credential_has_required_scopes(
|
||||
cred, field_info
|
||||
):
|
||||
continue
|
||||
if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info):
|
||||
continue
|
||||
return cred
|
||||
return None
|
||||
|
||||
|
||||
def create_credential_meta_from_match(
|
||||
matching_cred: Credentials,
|
||||
) -> CredentialsMetaInput:
|
||||
"""Create a CredentialsMetaInput from a matched credential."""
|
||||
return CredentialsMetaInput(
|
||||
id=matching_cred.id,
|
||||
provider=matching_cred.provider, # type: ignore
|
||||
type=matching_cred.type,
|
||||
title=matching_cred.title,
|
||||
)
|
||||
|
||||
|
||||
async def match_user_credentials_to_graph(
|
||||
user_id: str,
|
||||
graph: GraphModel,
|
||||
@@ -331,8 +425,6 @@ def _credential_has_required_scopes(
|
||||
# If no scopes are required, any credential matches
|
||||
if not requirements.required_scopes:
|
||||
return True
|
||||
|
||||
# Check that credential scopes are a superset of required scopes
|
||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
@@ -78,65 +78,6 @@ class WorkspaceDeleteResponse(ToolResponseBase):
|
||||
success: bool
|
||||
|
||||
|
||||
# Input models for workspace tools
|
||||
class ListWorkspaceFilesInput(BaseModel):
|
||||
"""Input parameters for list_workspace_files tool."""
|
||||
|
||||
path_prefix: str | None = None
|
||||
limit: int = 50
|
||||
include_all_sessions: bool = False
|
||||
|
||||
@field_validator("path_prefix", mode="before")
|
||||
@classmethod
|
||||
def strip_path(cls, v: Any) -> str | None:
|
||||
return v.strip() if isinstance(v, str) else None
|
||||
|
||||
|
||||
class ReadWorkspaceFileInput(BaseModel):
|
||||
"""Input parameters for read_workspace_file tool."""
|
||||
|
||||
file_id: str | None = None
|
||||
path: str | None = None
|
||||
force_download_url: bool = False
|
||||
|
||||
@field_validator("file_id", "path", mode="before")
|
||||
@classmethod
|
||||
def strip_strings(cls, v: Any) -> str | None:
|
||||
return v.strip() if isinstance(v, str) else None
|
||||
|
||||
|
||||
class WriteWorkspaceFileInput(BaseModel):
|
||||
"""Input parameters for write_workspace_file tool."""
|
||||
|
||||
filename: str = ""
|
||||
content_base64: str = ""
|
||||
path: str | None = None
|
||||
mime_type: str | None = None
|
||||
overwrite: bool = False
|
||||
|
||||
@field_validator("filename", "content_base64", mode="before")
|
||||
@classmethod
|
||||
def strip_required(cls, v: Any) -> str:
|
||||
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||
|
||||
@field_validator("path", "mime_type", mode="before")
|
||||
@classmethod
|
||||
def strip_optional(cls, v: Any) -> str | None:
|
||||
return v.strip() if isinstance(v, str) else None
|
||||
|
||||
|
||||
class DeleteWorkspaceFileInput(BaseModel):
|
||||
"""Input parameters for delete_workspace_file tool."""
|
||||
|
||||
file_id: str | None = None
|
||||
path: str | None = None
|
||||
|
||||
@field_validator("file_id", "path", mode="before")
|
||||
@classmethod
|
||||
def strip_strings(cls, v: Any) -> str | None:
|
||||
return v.strip() if isinstance(v, str) else None
|
||||
|
||||
|
||||
class ListWorkspaceFilesTool(BaseTool):
|
||||
"""Tool for listing files in user's workspace."""
|
||||
|
||||
@@ -190,9 +131,8 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
params = ListWorkspaceFilesInput(**kwargs)
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
@@ -201,7 +141,9 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
limit = min(params.limit, 100)
|
||||
path_prefix: Optional[str] = kwargs.get("path_prefix")
|
||||
limit = min(kwargs.get("limit", 50), 100)
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
@@ -209,13 +151,13 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
files = await manager.list_files(
|
||||
path=params.path_prefix,
|
||||
path=path_prefix,
|
||||
limit=limit,
|
||||
include_all_sessions=params.include_all_sessions,
|
||||
include_all_sessions=include_all_sessions,
|
||||
)
|
||||
total = await manager.get_file_count(
|
||||
path=params.path_prefix,
|
||||
include_all_sessions=params.include_all_sessions,
|
||||
path=path_prefix,
|
||||
include_all_sessions=include_all_sessions,
|
||||
)
|
||||
|
||||
file_infos = [
|
||||
@@ -229,9 +171,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
for f in files
|
||||
]
|
||||
|
||||
scope_msg = (
|
||||
"all sessions" if params.include_all_sessions else "current session"
|
||||
)
|
||||
scope_msg = "all sessions" if include_all_sessions else "current session"
|
||||
return WorkspaceFileListResponse(
|
||||
files=file_infos,
|
||||
total_count=total,
|
||||
@@ -319,9 +259,8 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
params = ReadWorkspaceFileInput(**kwargs)
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
@@ -330,7 +269,11 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not params.file_id and not params.path:
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
force_download_url: bool = kwargs.get("force_download_url", False)
|
||||
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path",
|
||||
session_id=session_id,
|
||||
@@ -342,21 +285,21 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
# Get file info
|
||||
if params.file_id:
|
||||
file_info = await manager.get_file_info(params.file_id)
|
||||
if file_id:
|
||||
file_info = await manager.get_file_info(file_id)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {params.file_id}",
|
||||
message=f"File not found: {file_id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = params.file_id
|
||||
target_file_id = file_id
|
||||
else:
|
||||
# path is guaranteed to be non-None here due to the check above
|
||||
assert params.path is not None
|
||||
file_info = await manager.get_file_info_by_path(params.path)
|
||||
assert path is not None
|
||||
file_info = await manager.get_file_info_by_path(path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found at path: {params.path}",
|
||||
message=f"File not found at path: {path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_info.id
|
||||
@@ -366,7 +309,7 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
is_text_file = self._is_text_mime_type(file_info.mimeType)
|
||||
|
||||
# Return inline content for small text files (unless force_download_url)
|
||||
if is_small_file and is_text_file and not params.force_download_url:
|
||||
if is_small_file and is_text_file and not force_download_url:
|
||||
content = await manager.read_file_by_id(target_file_id)
|
||||
content_b64 = base64.b64encode(content).decode("utf-8")
|
||||
|
||||
@@ -486,9 +429,8 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
params = WriteWorkspaceFileInput(**kwargs)
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
@@ -497,13 +439,19 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not params.filename:
|
||||
filename: str = kwargs.get("filename", "")
|
||||
content_b64: str = kwargs.get("content_base64", "")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
mime_type: Optional[str] = kwargs.get("mime_type")
|
||||
overwrite: bool = kwargs.get("overwrite", False)
|
||||
|
||||
if not filename:
|
||||
return ErrorResponse(
|
||||
message="Please provide a filename",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not params.content_base64:
|
||||
if not content_b64:
|
||||
return ErrorResponse(
|
||||
message="Please provide content_base64",
|
||||
session_id=session_id,
|
||||
@@ -511,7 +459,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
|
||||
# Decode content
|
||||
try:
|
||||
content = base64.b64decode(params.content_base64)
|
||||
content = base64.b64decode(content_b64)
|
||||
except Exception:
|
||||
return ErrorResponse(
|
||||
message="Invalid base64-encoded content",
|
||||
@@ -528,7 +476,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
|
||||
try:
|
||||
# Virus scan
|
||||
await scan_content_safe(content, filename=params.filename)
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
@@ -536,10 +484,10 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
|
||||
file_record = await manager.write_file(
|
||||
content=content,
|
||||
filename=params.filename,
|
||||
path=params.path,
|
||||
mime_type=params.mime_type,
|
||||
overwrite=params.overwrite,
|
||||
filename=filename,
|
||||
path=path,
|
||||
mime_type=mime_type,
|
||||
overwrite=overwrite,
|
||||
)
|
||||
|
||||
return WorkspaceWriteResponse(
|
||||
@@ -609,9 +557,8 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
params = DeleteWorkspaceFileInput(**kwargs)
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
@@ -620,7 +567,10 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not params.file_id and not params.path:
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path",
|
||||
session_id=session_id,
|
||||
@@ -633,15 +583,15 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
|
||||
# Determine the file_id to delete
|
||||
target_file_id: str
|
||||
if params.file_id:
|
||||
target_file_id = params.file_id
|
||||
if file_id:
|
||||
target_file_id = file_id
|
||||
else:
|
||||
# path is guaranteed to be non-None here due to the check above
|
||||
assert params.path is not None
|
||||
file_info = await manager.get_file_info_by_path(params.path)
|
||||
assert path is not None
|
||||
file_info = await manager.get_file_info_by_path(path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found at path: {params.path}",
|
||||
message=f"File not found at path: {path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_info.id
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
@@ -225,6 +226,10 @@ class SyncRabbitMQ(RabbitMQBase):
|
||||
class AsyncRabbitMQ(RabbitMQBase):
|
||||
"""Asynchronous RabbitMQ client"""
|
||||
|
||||
def __init__(self, config: RabbitMQConfig):
|
||||
super().__init__(config)
|
||||
self._reconnect_lock: asyncio.Lock | None = None
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return bool(self._connection and not self._connection.is_closed)
|
||||
@@ -235,7 +240,17 @@ class AsyncRabbitMQ(RabbitMQBase):
|
||||
|
||||
@conn_retry("AsyncRabbitMQ", "Acquiring async connection")
|
||||
async def connect(self):
|
||||
if self.is_connected:
|
||||
if self.is_connected and self._channel and not self._channel.is_closed:
|
||||
return
|
||||
|
||||
if (
|
||||
self.is_connected
|
||||
and self._connection
|
||||
and (self._channel is None or self._channel.is_closed)
|
||||
):
|
||||
self._channel = await self._connection.channel()
|
||||
await self._channel.set_qos(prefetch_count=1)
|
||||
await self.declare_infrastructure()
|
||||
return
|
||||
|
||||
self._connection = await aio_pika.connect_robust(
|
||||
@@ -291,24 +306,46 @@ class AsyncRabbitMQ(RabbitMQBase):
|
||||
exchange, routing_key=queue.routing_key or queue.name
|
||||
)
|
||||
|
||||
@func_retry
|
||||
async def publish_message(
|
||||
@property
|
||||
def _lock(self) -> asyncio.Lock:
|
||||
if self._reconnect_lock is None:
|
||||
self._reconnect_lock = asyncio.Lock()
|
||||
return self._reconnect_lock
|
||||
|
||||
async def _ensure_channel(self) -> aio_pika.abc.AbstractChannel:
|
||||
"""Get a valid channel, reconnecting if the current one is stale.
|
||||
|
||||
Uses a lock to prevent concurrent reconnection attempts from racing.
|
||||
"""
|
||||
if self.is_ready:
|
||||
return self._channel # type: ignore # is_ready guarantees non-None
|
||||
|
||||
async with self._lock:
|
||||
# Double-check after acquiring lock
|
||||
if self.is_ready:
|
||||
return self._channel # type: ignore
|
||||
|
||||
self._channel = None
|
||||
await self.connect()
|
||||
|
||||
if self._channel is None:
|
||||
raise RuntimeError("Channel should be established after connect")
|
||||
|
||||
return self._channel
|
||||
|
||||
async def _publish_once(
|
||||
self,
|
||||
routing_key: str,
|
||||
message: str,
|
||||
exchange: Optional[Exchange] = None,
|
||||
persistent: bool = True,
|
||||
) -> None:
|
||||
if not self.is_ready:
|
||||
await self.connect()
|
||||
|
||||
if self._channel is None:
|
||||
raise RuntimeError("Channel should be established after connect")
|
||||
channel = await self._ensure_channel()
|
||||
|
||||
if exchange:
|
||||
exchange_obj = await self._channel.get_exchange(exchange.name)
|
||||
exchange_obj = await channel.get_exchange(exchange.name)
|
||||
else:
|
||||
exchange_obj = self._channel.default_exchange
|
||||
exchange_obj = channel.default_exchange
|
||||
|
||||
await exchange_obj.publish(
|
||||
aio_pika.Message(
|
||||
@@ -322,9 +359,23 @@ class AsyncRabbitMQ(RabbitMQBase):
|
||||
routing_key=routing_key,
|
||||
)
|
||||
|
||||
@func_retry
|
||||
async def publish_message(
|
||||
self,
|
||||
routing_key: str,
|
||||
message: str,
|
||||
exchange: Optional[Exchange] = None,
|
||||
persistent: bool = True,
|
||||
) -> None:
|
||||
try:
|
||||
await self._publish_once(routing_key, message, exchange, persistent)
|
||||
except aio_pika.exceptions.ChannelInvalidStateError:
|
||||
logger.warning(
|
||||
"RabbitMQ channel invalid, forcing reconnect and retrying publish"
|
||||
)
|
||||
async with self._lock:
|
||||
self._channel = None
|
||||
await self._publish_once(routing_key, message, exchange, persistent)
|
||||
|
||||
async def get_channel(self) -> aio_pika.abc.AbstractChannel:
|
||||
if not self.is_ready:
|
||||
await self.connect()
|
||||
if self._channel is None:
|
||||
raise RuntimeError("Channel should be established after connect")
|
||||
return self._channel
|
||||
return await self._ensure_channel()
|
||||
|
||||
99
autogpt_platform/backend/poetry.lock
generated
99
autogpt_platform/backend/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aio-pika"
|
||||
@@ -374,7 +374,7 @@ description = "LTS Port of Python audioop"
|
||||
optional = false
|
||||
python-versions = ">=3.13"
|
||||
groups = ["main"]
|
||||
markers = "python_version >= \"3.13\""
|
||||
markers = "python_version == \"3.13\""
|
||||
files = [
|
||||
{file = "audioop_lts-0.2.2-cp313-abi3-macosx_10_13_universal2.whl", hash = "sha256:fd3d4602dc64914d462924a08c1a9816435a2155d74f325853c1f1ac3b2d9800"},
|
||||
{file = "audioop_lts-0.2.2-cp313-abi3-macosx_10_13_x86_64.whl", hash = "sha256:550c114a8df0aafe9a05442a1162dfc8fec37e9af1d625ae6060fed6e756f303"},
|
||||
@@ -474,7 +474,7 @@ description = "Backport of asyncio.Runner, a context manager that controls event
|
||||
optional = false
|
||||
python-versions = "<3.11,>=3.8"
|
||||
groups = ["main"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"},
|
||||
{file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"},
|
||||
@@ -487,7 +487,7 @@ description = "Backport of CPython tarfile module"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
markers = "python_version <= \"3.11\""
|
||||
markers = "python_version < \"3.12\""
|
||||
files = [
|
||||
{file = "backports.tarfile-1.2.0-py3-none-any.whl", hash = "sha256:77e284d754527b01fb1e6fa8a1afe577858ebe4e9dad8919e34c862cb399bc34"},
|
||||
{file = "backports_tarfile-1.2.0.tar.gz", hash = "sha256:d75e02c268746e1b8144c278978b6e98e85de6ad16f8e4b0844a154557eca991"},
|
||||
@@ -563,6 +563,18 @@ webencodings = "*"
|
||||
[package.extras]
|
||||
css = ["tinycss2 (>=1.1.0,<1.5)"]
|
||||
|
||||
[[package]]
|
||||
name = "bracex"
|
||||
version = "2.6"
|
||||
description = "Bash style brace expander."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "bracex-2.6-py3-none-any.whl", hash = "sha256:0b0049264e7340b3ec782b5cb99beb325f36c3782a32e36e876452fd49a09952"},
|
||||
{file = "bracex-2.6.tar.gz", hash = "sha256:98f1347cd77e22ee8d967a30ad4e310b233f7754dbf31ff3fceb76145ba47dc7"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "browserbase"
|
||||
version = "1.4.0"
|
||||
@@ -659,7 +671,6 @@ description = "Foreign Function Interface for Python calling C code."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
markers = "platform_python_implementation != \"PyPy\" or sys_platform == \"darwin\""
|
||||
files = [
|
||||
{file = "cffi-2.0.0-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:0cf2d91ecc3fcc0625c2c530fe004f82c110405f101548512cce44322fa8ac44"},
|
||||
{file = "cffi-2.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f73b96c41e3b2adedc34a7356e64c8eb96e03a3782b535e043a986276ce12a49"},
|
||||
@@ -1149,6 +1160,18 @@ idna = ["idna (>=3.10)"]
|
||||
trio = ["trio (>=0.30)"]
|
||||
wmi = ["wmi (>=1.5.1) ; platform_system == \"Windows\""]
|
||||
|
||||
[[package]]
|
||||
name = "dockerfile-parse"
|
||||
version = "2.0.1"
|
||||
description = "Python library for Dockerfile manipulation"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "dockerfile-parse-2.0.1.tar.gz", hash = "sha256:3184ccdc513221983e503ac00e1aa504a2aa8f84e5de673c46b0b6eee99ec7bc"},
|
||||
{file = "dockerfile_parse-2.0.1-py2.py3-none-any.whl", hash = "sha256:bdffd126d2eb26acf1066acb54cb2e336682e1d72b974a40894fac76a4df17f6"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "docstring-parser"
|
||||
version = "0.17.0"
|
||||
@@ -1235,40 +1258,43 @@ pgp = ["gpg"]
|
||||
|
||||
[[package]]
|
||||
name = "e2b"
|
||||
version = "1.11.1"
|
||||
version = "2.13.2"
|
||||
description = "E2B SDK that give agents cloud environments"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
python-versions = "<4.0,>=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "e2b-1.11.1-py3-none-any.whl", hash = "sha256:1ecb123873788472731c101939a494ab852cbcce0f913df6f7ecb194ae932130"},
|
||||
{file = "e2b-1.11.1.tar.gz", hash = "sha256:7f7b6f238208d0a23353bb0da01f91a924321b57c61b176506862cbc1493ce8c"},
|
||||
{file = "e2b-2.13.2-py3-none-any.whl", hash = "sha256:d91d5293bc0dd1917c72a6e6b35e86513607be2666a14ae18c57b921e7864de4"},
|
||||
{file = "e2b-2.13.2.tar.gz", hash = "sha256:c0e81a3920091874fdf73c0b8f376b28766212db9f1cea5d8bd56a2e95d2436c"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
attrs = ">=23.2.0"
|
||||
dockerfile-parse = ">=2.0.1,<3.0.0"
|
||||
httpcore = ">=1.0.5,<2.0.0"
|
||||
httpx = ">=0.27.0,<1.0.0"
|
||||
packaging = ">=24.1"
|
||||
protobuf = ">=4.21.0"
|
||||
python-dateutil = ">=2.8.2"
|
||||
rich = ">=14.0.0"
|
||||
typing-extensions = ">=4.1.0"
|
||||
wcmatch = ">=10.1,<11.0"
|
||||
|
||||
[[package]]
|
||||
name = "e2b-code-interpreter"
|
||||
version = "1.5.2"
|
||||
version = "2.4.1"
|
||||
description = "E2B Code Interpreter - Stateful code execution"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "e2b_code_interpreter-1.5.2-py3-none-any.whl", hash = "sha256:5c3188d8f25226b28fef4b255447cc6a4c36afb748bdd5180b45be486d5169f3"},
|
||||
{file = "e2b_code_interpreter-1.5.2.tar.gz", hash = "sha256:3bd6ea70596290e85aaf0a2f19f28bf37a5e73d13086f5e6a0080bb591c5a547"},
|
||||
{file = "e2b_code_interpreter-2.4.1-py3-none-any.whl", hash = "sha256:15d35f025b4a15033e119f2e12e7ac65657ad2b5a013fa9149e74581fbee778a"},
|
||||
{file = "e2b_code_interpreter-2.4.1.tar.gz", hash = "sha256:4b15014ee0d0dfcdc3072e1f409cbb87ca48f48d53d75629b7257e5513b9e7dd"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
attrs = ">=21.3.0"
|
||||
e2b = ">=1.5.4,<2.0.0"
|
||||
e2b = ">=2.7.0,<3.0.0"
|
||||
httpx = ">=0.20.0,<1.0.0"
|
||||
|
||||
[[package]]
|
||||
@@ -1338,7 +1364,7 @@ description = "Backport of PEP 654 (exception groups)"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main", "dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598"},
|
||||
{file = "exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219"},
|
||||
@@ -1820,16 +1846,16 @@ files = [
|
||||
google-auth = ">=2.14.1,<3.0.0"
|
||||
googleapis-common-protos = ">=1.56.2,<2.0.0"
|
||||
grpcio = [
|
||||
{version = ">=1.33.2,<2.0.0", optional = true, markers = "extra == \"grpc\""},
|
||||
{version = ">=1.49.1,<2.0.0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
||||
{version = ">=1.33.2,<2.0.0", optional = true, markers = "extra == \"grpc\""},
|
||||
]
|
||||
grpcio-status = [
|
||||
{version = ">=1.33.2,<2.0.0", optional = true, markers = "extra == \"grpc\""},
|
||||
{version = ">=1.49.1,<2.0.0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
||||
{version = ">=1.33.2,<2.0.0", optional = true, markers = "extra == \"grpc\""},
|
||||
]
|
||||
proto-plus = [
|
||||
{version = ">=1.22.3,<2.0.0"},
|
||||
{version = ">=1.25.0,<2.0.0", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=1.22.3,<2.0.0"},
|
||||
]
|
||||
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
|
||||
requests = ">=2.18.0,<3.0.0"
|
||||
@@ -1940,8 +1966,8 @@ google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0", extras
|
||||
google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0"
|
||||
grpcio = ">=1.33.2,<2.0.0"
|
||||
proto-plus = [
|
||||
{version = ">=1.22.3,<2.0.0"},
|
||||
{version = ">=1.25.0,<2.0.0", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=1.22.3,<2.0.0"},
|
||||
]
|
||||
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
|
||||
|
||||
@@ -2001,9 +2027,9 @@ google-cloud-core = ">=2.0.0,<3.0.0"
|
||||
grpc-google-iam-v1 = ">=0.12.4,<1.0.0"
|
||||
opentelemetry-api = ">=1.9.0"
|
||||
proto-plus = [
|
||||
{version = ">=1.22.0,<2.0.0"},
|
||||
{version = ">=1.22.2,<2.0.0", markers = "python_version >= \"3.11\""},
|
||||
{version = ">=1.25.0,<2.0.0", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=1.22.2,<2.0.0", markers = "python_version >= \"3.11\" and python_version < \"3.13\""},
|
||||
{version = ">=1.22.0,<2.0.0", markers = "python_version < \"3.11\""},
|
||||
]
|
||||
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
|
||||
|
||||
@@ -3802,7 +3828,7 @@ description = "Fundamental package for array computing in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "numpy-2.2.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b412caa66f72040e6d268491a59f2c43bf03eb6c96dd8f0307829feb7fa2b6fb"},
|
||||
{file = "numpy-2.2.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8e41fd67c52b86603a91c1a505ebaef50b3314de0213461c7a6e99c9a3beff90"},
|
||||
@@ -4287,9 +4313,9 @@ files = [
|
||||
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
|
||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
|
||||
]
|
||||
python-dateutil = ">=2.8.2"
|
||||
pytz = ">=2020.1"
|
||||
@@ -4532,8 +4558,8 @@ pinecone-plugin-interface = ">=0.0.7,<0.0.8"
|
||||
python-dateutil = ">=2.5.3"
|
||||
typing-extensions = ">=3.7.4"
|
||||
urllib3 = [
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.8\" and python_version < \"3.12\""},
|
||||
{version = ">=1.26.5", markers = "python_version >= \"3.12\" and python_version < \"4.0\""},
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.8\" and python_version < \"3.12\""},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
@@ -5361,7 +5387,7 @@ description = "C parser in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
markers = "(platform_python_implementation != \"PyPy\" or sys_platform == \"darwin\") and implementation_name != \"PyPy\""
|
||||
markers = "implementation_name != \"PyPy\""
|
||||
files = [
|
||||
{file = "pycparser-3.0-py3-none-any.whl", hash = "sha256:b727414169a36b7d524c1c3e31839a521725078d7b2ff038656844266160a992"},
|
||||
{file = "pycparser-3.0.tar.gz", hash = "sha256:600f49d217304a5902ac3c37e1281c9fe94e4d0489de643a9504c5cdfdfc6b29"},
|
||||
@@ -6130,10 +6156,10 @@ files = [
|
||||
grpcio = ">=1.41.0"
|
||||
httpx = {version = ">=0.20.0", extras = ["http2"]}
|
||||
numpy = [
|
||||
{version = ">=1.21,<2.3.0", markers = "python_version == \"3.10\""},
|
||||
{version = ">=1.21", markers = "python_version == \"3.11\""},
|
||||
{version = ">=2.1.0", markers = "python_version == \"3.13\""},
|
||||
{version = ">=1.21", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.26", markers = "python_version == \"3.12\""},
|
||||
{version = ">=1.21,<2.3.0", markers = "python_version == \"3.10\""},
|
||||
]
|
||||
portalocker = ">=2.7.0,<4.0"
|
||||
protobuf = ">=3.20.0"
|
||||
@@ -7317,7 +7343,7 @@ description = "A lil' TOML parser"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main", "dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "tomli-2.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b5ef256a3fd497d4973c11bf142e9ed78b150d36f5773f1ca6088c230ffc5867"},
|
||||
{file = "tomli-2.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5572e41282d5268eb09a697c89a7bee84fae66511f87533a6f88bd2f7b652da9"},
|
||||
@@ -7841,6 +7867,21 @@ files = [
|
||||
[package.dependencies]
|
||||
anyio = ">=3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "wcmatch"
|
||||
version = "10.1"
|
||||
description = "Wildcard/glob file name matcher."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "wcmatch-10.1-py3-none-any.whl", hash = "sha256:5848ace7dbb0476e5e55ab63c6bbd529745089343427caa5537f230cc01beb8a"},
|
||||
{file = "wcmatch-10.1.tar.gz", hash = "sha256:f11f94208c8c8484a16f4f48638a85d771d9513f4ab3f37595978801cb9465af"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
bracex = ">=2.1.1"
|
||||
|
||||
[[package]]
|
||||
name = "webencodings"
|
||||
version = "0.5.1"
|
||||
@@ -8440,4 +8481,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "14686ee0e2dc446a75d0db145b08dc410dc31c357e25085bb0f9b0174711c4b1"
|
||||
content-hash = "f1f229017d133bab1cb5b787a93f8d6d652c2712e07d4966358e725a57e35e80"
|
||||
|
||||
@@ -19,7 +19,7 @@ bleach = { extras = ["css"], version = "^6.2.0" }
|
||||
click = "^8.2.0"
|
||||
cryptography = "^46.0"
|
||||
discord-py = "^2.5.2"
|
||||
e2b-code-interpreter = "^1.5.2"
|
||||
e2b-code-interpreter = "^2.4.1"
|
||||
elevenlabs = "^1.50.0"
|
||||
fastapi = "^0.128.5"
|
||||
feedparser = "^6.0.11"
|
||||
|
||||
@@ -104,7 +104,31 @@ export function FileInput(props: Props) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const getFileLabelFromValue = (val: string) => {
|
||||
const getFileLabelFromValue = (val: unknown): string => {
|
||||
// Handle object format from external API: { name, type, size, data }
|
||||
if (val && typeof val === "object") {
|
||||
const obj = val as Record<string, unknown>;
|
||||
if (typeof obj.name === "string") {
|
||||
return getFileLabel(
|
||||
obj.name,
|
||||
typeof obj.type === "string" ? obj.type : "",
|
||||
);
|
||||
}
|
||||
if (typeof obj.type === "string") {
|
||||
const mimeParts = obj.type.split("/");
|
||||
if (mimeParts.length > 1) {
|
||||
return `${mimeParts[1].toUpperCase()} file`;
|
||||
}
|
||||
return `${obj.type} file`;
|
||||
}
|
||||
return "File";
|
||||
}
|
||||
|
||||
// Handle string values (data URIs or file paths)
|
||||
if (typeof val !== "string") {
|
||||
return "File";
|
||||
}
|
||||
|
||||
if (val.startsWith("data:")) {
|
||||
const matches = val.match(/^data:([^;]+);/);
|
||||
if (matches?.[1]) {
|
||||
|
||||
Reference in New Issue
Block a user