diff --git a/autogpt_platform/backend/backend/api/features/chat/db.py b/autogpt_platform/backend/backend/api/features/chat/db.py index d34b4e5b07..303ea0a698 100644 --- a/autogpt_platform/backend/backend/api/features/chat/db.py +++ b/autogpt_platform/backend/backend/api/features/chat/db.py @@ -45,10 +45,7 @@ async def create_chat_session( successfulAgentRuns=SafeJson({}), successfulAgentSchedules=SafeJson({}), ) - return await PrismaChatSession.prisma().create( - data=data, - include={"Messages": True}, - ) + return await PrismaChatSession.prisma().create(data=data) async def update_chat_session( diff --git a/autogpt_platform/backend/backend/api/features/chat/service.py b/autogpt_platform/backend/backend/api/features/chat/service.py index e4b679b286..e10d9a3fd6 100644 --- a/autogpt_platform/backend/backend/api/features/chat/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/service.py @@ -377,21 +377,45 @@ async def stream_chat_completion( ValueError: If max_context_messages is exceeded """ + completion_start = time.monotonic() + + # Build log metadata for structured logging + log_meta = {"component": "ChatService", "session_id": session_id} + if user_id: + log_meta["user_id"] = user_id + logger.info( - f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}" + f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, " + f"message_len={len(message) if message else 0}, is_user={is_user_message}", + extra={ + "json_fields": { + **log_meta, + "message_len": len(message) if message else 0, + "is_user_message": is_user_message, + } + }, ) # Only fetch from Redis if session not provided (initial call) if session is None: + fetch_start = time.monotonic() session = await get_chat_session(session_id, user_id) + fetch_time = (time.monotonic() - fetch_start) * 1000 logger.info( - f"Fetched session from Redis: {session.session_id if session else 'None'}, " - f"message_count={len(session.messages) if session else 0}" + f"[TIMING] get_chat_session took {fetch_time:.1f}ms, " + f"n_messages={len(session.messages) if session else 0}", + extra={ + "json_fields": { + **log_meta, + "duration_ms": fetch_time, + "n_messages": len(session.messages) if session else 0, + } + }, ) else: logger.info( - f"Using provided session object: {session.session_id}, " - f"message_count={len(session.messages)}" + f"[TIMING] Using provided session, messages={len(session.messages)}", + extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}}, ) if not session: @@ -412,17 +436,25 @@ async def stream_chat_completion( # Track user message in PostHog if is_user_message: + posthog_start = time.monotonic() track_user_message( user_id=user_id, session_id=session_id, message_length=len(message), ) + posthog_time = (time.monotonic() - posthog_start) * 1000 + logger.info( + f"[TIMING] track_user_message took {posthog_time:.1f}ms", + extra={"json_fields": {**log_meta, "duration_ms": posthog_time}}, + ) - logger.info( - f"Upserting session: {session.session_id} with user id {session.user_id}, " - f"message_count={len(session.messages)}" - ) + upsert_start = time.monotonic() session = await upsert_chat_session(session) + upsert_time = (time.monotonic() - upsert_start) * 1000 + logger.info( + f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms", + extra={"json_fields": {**log_meta, "duration_ms": upsert_time}}, + ) assert session, "Session not found" # Generate title for new sessions on first user message (non-blocking) @@ -460,7 +492,13 @@ async def stream_chat_completion( asyncio.create_task(_update_title()) # Build system prompt with business understanding + prompt_start = time.monotonic() system_prompt, understanding = await _build_system_prompt(user_id) + prompt_time = (time.monotonic() - prompt_start) * 1000 + logger.info( + f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms", + extra={"json_fields": {**log_meta, "duration_ms": prompt_time}}, + ) # Initialize variables for streaming assistant_response = ChatMessage( @@ -490,7 +528,11 @@ async def stream_chat_completion( text_block_id = str(uuid_module.uuid4()) # Only yield message start for the initial call, not for continuations. - # This is the single place where StreamStart is emitted (removed from routes.py). + setup_time = (time.monotonic() - completion_start) * 1000 + logger.info( + f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms", + extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}}, + ) if not is_continuation: yield StreamStart(messageId=message_id, taskId=_task_id) @@ -498,6 +540,10 @@ async def stream_chat_completion( yield StreamStartStep() try: + logger.info( + "[TIMING] Calling _stream_chat_chunks", + extra={"json_fields": log_meta}, + ) async for chunk in _stream_chat_chunks( session=session, tools=tools, @@ -916,9 +962,21 @@ async def _stream_chat_chunks( SSE formatted JSON response objects """ + import time as time_module + + stream_chunks_start = time_module.perf_counter() model = config.model - logger.info("Starting pure chat stream") + # Build log metadata for structured logging + log_meta = {"component": "ChatService", "session_id": session.session_id} + if session.user_id: + log_meta["user_id"] = session.user_id + + logger.info( + f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, " + f"user={session.user_id}, n_messages={len(session.messages)}", + extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}}, + ) messages = session.to_openai_messages() if system_prompt: @@ -929,12 +987,18 @@ async def _stream_chat_chunks( messages = [system_message] + messages # Apply context window management + context_start = time_module.perf_counter() context_result = await _manage_context_window( messages=messages, model=model, api_key=config.api_key, base_url=config.base_url, ) + context_time = (time_module.perf_counter() - context_start) * 1000 + logger.info( + f"[TIMING] _manage_context_window took {context_time:.1f}ms", + extra={"json_fields": {**log_meta, "duration_ms": context_time}}, + ) if context_result.error: if "System prompt dropped" in context_result.error: @@ -969,9 +1033,19 @@ async def _stream_chat_chunks( while retry_count <= MAX_RETRIES: try: + elapsed = (time_module.perf_counter() - stream_chunks_start) * 1000 + retry_info = ( + f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else "" + ) logger.info( - f"Creating OpenAI chat completion stream..." - f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}" + f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}", + extra={ + "json_fields": { + **log_meta, + "elapsed_ms": elapsed, + "retry_count": retry_count, + } + }, ) # Build extra_body for OpenRouter tracing and PostHog analytics @@ -988,6 +1062,7 @@ async def _stream_chat_chunks( :128 ] # OpenRouter limit + api_call_start = time_module.perf_counter() stream = await client.chat.completions.create( model=model, messages=cast(list[ChatCompletionMessageParam], messages), @@ -997,6 +1072,11 @@ async def _stream_chat_chunks( stream_options=ChatCompletionStreamOptionsParam(include_usage=True), extra_body=extra_body, ) + api_init_time = (time_module.perf_counter() - api_call_start) * 1000 + logger.info( + f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms", + extra={"json_fields": {**log_meta, "duration_ms": api_init_time}}, + ) # Variables to accumulate tool calls tool_calls: list[dict[str, Any]] = [] @@ -1007,10 +1087,13 @@ async def _stream_chat_chunks( # Track if we've started the text block text_started = False + first_content_chunk = True + chunk_count = 0 # Process the stream chunk: ChatCompletionChunk async for chunk in stream: + chunk_count += 1 if chunk.usage: yield StreamUsage( promptTokens=chunk.usage.prompt_tokens, @@ -1033,6 +1116,23 @@ async def _stream_chat_chunks( if not text_started and text_block_id: yield StreamTextStart(id=text_block_id) text_started = True + # Log timing for first content chunk + if first_content_chunk: + first_content_chunk = False + ttfc = ( + time_module.perf_counter() - api_call_start + ) * 1000 + logger.info( + f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms " + f"(since API call), n_chunks={chunk_count}", + extra={ + "json_fields": { + **log_meta, + "time_to_first_chunk_ms": ttfc, + "n_chunks": chunk_count, + } + }, + ) # Stream the text delta text_response = StreamTextDelta( id=text_block_id or "", @@ -1089,7 +1189,21 @@ async def _stream_chat_chunks( toolName=tool_calls[idx]["function"]["name"], ) emitted_start_for_idx.add(idx) - logger.info(f"Stream complete. Finish reason: {finish_reason}") + stream_duration = time_module.perf_counter() - api_call_start + logger.info( + f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, " + f"duration={stream_duration:.2f}s, " + f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}", + extra={ + "json_fields": { + **log_meta, + "stream_duration_ms": stream_duration * 1000, + "finish_reason": finish_reason, + "n_chunks": chunk_count, + "n_tool_calls": len(tool_calls), + } + }, + ) # Yield all accumulated tool calls after the stream is complete # This ensures all tool call arguments have been fully received @@ -1109,6 +1223,12 @@ async def _stream_chat_chunks( # Re-raise to trigger retry logic in the parent function raise + total_time = (time_module.perf_counter() - stream_chunks_start) * 1000 + logger.info( + f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; " + f"session={session.session_id}, user={session.user_id}", + extra={"json_fields": {**log_meta, "total_time_ms": total_time}}, + ) yield StreamFinish() return except Exception as e: diff --git a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py index 739ccdfe4b..abc34b1fc9 100644 --- a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py +++ b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py @@ -104,6 +104,24 @@ async def create_task( Returns: The created ActiveTask instance (metadata only) """ + import time + + start_time = time.perf_counter() + + # Build log metadata for structured logging + log_meta = { + "component": "StreamRegistry", + "task_id": task_id, + "session_id": session_id, + } + if user_id: + log_meta["user_id"] = user_id + + logger.info( + f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}", + extra={"json_fields": log_meta}, + ) + task = ActiveTask( task_id=task_id, session_id=session_id, @@ -114,10 +132,18 @@ async def create_task( ) # Store metadata in Redis + redis_start = time.perf_counter() redis = await get_redis_async() + redis_time = (time.perf_counter() - redis_start) * 1000 + logger.info( + f"[TIMING] get_redis_async took {redis_time:.1f}ms", + extra={"json_fields": {**log_meta, "duration_ms": redis_time}}, + ) + meta_key = _get_task_meta_key(task_id) op_key = _get_operation_mapping_key(operation_id) + hset_start = time.perf_counter() await redis.hset( # type: ignore[misc] meta_key, mapping={ @@ -131,12 +157,22 @@ async def create_task( "created_at": task.created_at.isoformat(), }, ) + hset_time = (time.perf_counter() - hset_start) * 1000 + logger.info( + f"[TIMING] redis.hset took {hset_time:.1f}ms", + extra={"json_fields": {**log_meta, "duration_ms": hset_time}}, + ) + await redis.expire(meta_key, config.stream_ttl) # Create operation_id -> task_id mapping for webhook lookups await redis.set(op_key, task_id, ex=config.stream_ttl) - logger.debug(f"Created task {task_id} for session {session_id}") + total_time = (time.perf_counter() - start_time) * 1000 + logger.info( + f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}", + extra={"json_fields": {**log_meta, "total_time_ms": total_time}}, + ) return task @@ -156,26 +192,60 @@ async def publish_chunk( Returns: The Redis Stream message ID """ + import time + + start_time = time.perf_counter() + chunk_type = type(chunk).__name__ chunk_json = chunk.model_dump_json() message_id = "0-0" + # Build log metadata + log_meta = { + "component": "StreamRegistry", + "task_id": task_id, + "chunk_type": chunk_type, + } + try: redis = await get_redis_async() stream_key = _get_task_stream_key(task_id) # Write to Redis Stream for persistence and real-time delivery + xadd_start = time.perf_counter() raw_id = await redis.xadd( stream_key, {"data": chunk_json}, maxlen=config.stream_max_length, ) + xadd_time = (time.perf_counter() - xadd_start) * 1000 message_id = raw_id if isinstance(raw_id, str) else raw_id.decode() # Set TTL on stream to match task metadata TTL await redis.expire(stream_key, config.stream_ttl) + + total_time = (time.perf_counter() - start_time) * 1000 + # Only log timing for significant chunks or slow operations + if ( + chunk_type + in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd") + or total_time > 50 + ): + logger.info( + f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)", + extra={ + "json_fields": { + **log_meta, + "total_time_ms": total_time, + "xadd_time_ms": xadd_time, + "message_id": message_id, + } + }, + ) except Exception as e: + elapsed = (time.perf_counter() - start_time) * 1000 logger.error( - f"Failed to publish chunk for task {task_id}: {e}", + f"[TIMING] Failed to publish chunk {chunk_type} after {elapsed:.1f}ms: {e}", + extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}}, exc_info=True, ) @@ -200,24 +270,61 @@ async def subscribe_to_task( An asyncio Queue that will receive stream chunks, or None if task not found or user doesn't have access """ + import time + + start_time = time.perf_counter() + + # Build log metadata + log_meta = {"component": "StreamRegistry", "task_id": task_id} + if user_id: + log_meta["user_id"] = user_id + + logger.info( + f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}", + extra={"json_fields": {**log_meta, "last_message_id": last_message_id}}, + ) + + redis_start = time.perf_counter() redis = await get_redis_async() meta_key = _get_task_meta_key(task_id) meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] + hgetall_time = (time.perf_counter() - redis_start) * 1000 + logger.info( + f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms", + extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}}, + ) if not meta: - logger.debug(f"Task {task_id} not found in Redis") + elapsed = (time.perf_counter() - start_time) * 1000 + logger.info( + f"[TIMING] Task not found in Redis after {elapsed:.1f}ms", + extra={ + "json_fields": { + **log_meta, + "elapsed_ms": elapsed, + "reason": "task_not_found", + } + }, + ) return None # Note: Redis client uses decode_responses=True, so keys are strings task_status = meta.get("status", "") task_user_id = meta.get("user_id", "") or None + log_meta["session_id"] = meta.get("session_id", "") # Validate ownership - if task has an owner, requester must match if task_user_id: if user_id != task_user_id: logger.warning( - f"User {user_id} denied access to task {task_id} " - f"owned by {task_user_id}" + f"[TIMING] Access denied: user {user_id} tried to access task owned by {task_user_id}", + extra={ + "json_fields": { + **log_meta, + "task_owner": task_user_id, + "reason": "access_denied", + } + }, ) return None @@ -225,7 +332,19 @@ async def subscribe_to_task( stream_key = _get_task_stream_key(task_id) # Step 1: Replay messages from Redis Stream + xread_start = time.perf_counter() messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000) + xread_time = (time.perf_counter() - xread_start) * 1000 + logger.info( + f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}", + extra={ + "json_fields": { + **log_meta, + "duration_ms": xread_time, + "task_status": task_status, + } + }, + ) replayed_count = 0 replay_last_id = last_message_id @@ -244,19 +363,48 @@ async def subscribe_to_task( except Exception as e: logger.warning(f"Failed to replay message: {e}") - logger.debug(f"Task {task_id}: replayed {replayed_count} messages") + logger.info( + f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}", + extra={ + "json_fields": { + **log_meta, + "n_messages_replayed": replayed_count, + "replay_last_id": replay_last_id, + } + }, + ) # Step 2: If task is still running, start stream listener for live updates if task_status == "running": + logger.info( + "[TIMING] Task still running, starting _stream_listener", + extra={"json_fields": {**log_meta, "task_status": task_status}}, + ) listener_task = asyncio.create_task( - _stream_listener(task_id, subscriber_queue, replay_last_id) + _stream_listener(task_id, subscriber_queue, replay_last_id, log_meta) ) # Track listener task for cleanup on unsubscribe _listener_tasks[id(subscriber_queue)] = (task_id, listener_task) else: # Task is completed/failed - add finish marker + logger.info( + f"[TIMING] Task already {task_status}, adding StreamFinish", + extra={"json_fields": {**log_meta, "task_status": task_status}}, + ) await subscriber_queue.put(StreamFinish()) + total_time = (time.perf_counter() - start_time) * 1000 + logger.info( + f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, " + f"n_messages_replayed={replayed_count}", + extra={ + "json_fields": { + **log_meta, + "total_time_ms": total_time, + "n_messages_replayed": replayed_count, + } + }, + ) return subscriber_queue @@ -264,6 +412,7 @@ async def _stream_listener( task_id: str, subscriber_queue: asyncio.Queue[StreamBaseResponse], last_replayed_id: str, + log_meta: dict | None = None, ) -> None: """Listen to Redis Stream for new messages using blocking XREAD. @@ -274,10 +423,27 @@ async def _stream_listener( task_id: Task ID to listen for subscriber_queue: Queue to deliver messages to last_replayed_id: Last message ID from replay (continue from here) + log_meta: Structured logging metadata """ + import time + + start_time = time.perf_counter() + + # Use provided log_meta or build minimal one + if log_meta is None: + log_meta = {"component": "StreamRegistry", "task_id": task_id} + + logger.info( + f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}", + extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}}, + ) + queue_id = id(subscriber_queue) # Track the last successfully delivered message ID for recovery hints last_delivered_id = last_replayed_id + messages_delivered = 0 + first_message_time = None + xread_count = 0 try: redis = await get_redis_async() @@ -287,9 +453,39 @@ async def _stream_listener( while True: # Block for up to 30 seconds waiting for new messages # This allows periodic checking if task is still running + xread_start = time.perf_counter() + xread_count += 1 messages = await redis.xread( {stream_key: current_id}, block=30000, count=100 ) + xread_time = (time.perf_counter() - xread_start) * 1000 + + if messages: + msg_count = sum(len(msgs) for _, msgs in messages) + logger.info( + f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms", + extra={ + "json_fields": { + **log_meta, + "xread_count": xread_count, + "n_messages": msg_count, + "duration_ms": xread_time, + } + }, + ) + elif xread_time > 1000: + # Only log timeouts (30s blocking) + logger.info( + f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms", + extra={ + "json_fields": { + **log_meta, + "xread_count": xread_count, + "duration_ms": xread_time, + "reason": "timeout", + } + }, + ) if not messages: # Timeout - check if task is still running @@ -326,10 +522,30 @@ async def _stream_listener( ) # Update last delivered ID on successful delivery last_delivered_id = current_id + messages_delivered += 1 + if first_message_time is None: + first_message_time = time.perf_counter() + elapsed = (first_message_time - start_time) * 1000 + logger.info( + f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}", + extra={ + "json_fields": { + **log_meta, + "elapsed_ms": elapsed, + "chunk_type": type(chunk).__name__, + } + }, + ) except asyncio.TimeoutError: logger.warning( - f"Subscriber queue full for task {task_id}, " - f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s" + f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s", + extra={ + "json_fields": { + **log_meta, + "timeout_s": QUEUE_PUT_TIMEOUT, + "reason": "queue_full", + } + }, ) # Send overflow error with recovery info try: @@ -351,15 +567,44 @@ async def _stream_listener( # Stop listening on finish if isinstance(chunk, StreamFinish): + total_time = (time.perf_counter() - start_time) * 1000 + logger.info( + f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}", + extra={ + "json_fields": { + **log_meta, + "total_time_ms": total_time, + "messages_delivered": messages_delivered, + } + }, + ) return except Exception as e: - logger.warning(f"Error processing stream message: {e}") + logger.warning( + f"Error processing stream message: {e}", + extra={"json_fields": {**log_meta, "error": str(e)}}, + ) except asyncio.CancelledError: - logger.debug(f"Stream listener cancelled for task {task_id}") + elapsed = (time.perf_counter() - start_time) * 1000 + logger.info( + f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}", + extra={ + "json_fields": { + **log_meta, + "elapsed_ms": elapsed, + "messages_delivered": messages_delivered, + "reason": "cancelled", + } + }, + ) raise # Re-raise to propagate cancellation except Exception as e: - logger.error(f"Stream listener error for task {task_id}: {e}") + elapsed = (time.perf_counter() - start_time) * 1000 + logger.error( + f"[TIMING] _stream_listener ERROR after {elapsed:.1f}ms: {e}", + extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}}, + ) # On error, send finish to unblock subscriber try: await asyncio.wait_for( @@ -368,10 +613,24 @@ async def _stream_listener( ) except (asyncio.TimeoutError, asyncio.QueueFull): logger.warning( - f"Could not deliver finish event for task {task_id} after error" + "Could not deliver finish event after error", + extra={"json_fields": log_meta}, ) finally: # Clean up listener task mapping on exit + total_time = (time.perf_counter() - start_time) * 1000 + logger.info( + f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, " + f"delivered={messages_delivered}, xread_count={xread_count}", + extra={ + "json_fields": { + **log_meta, + "total_time_ms": total_time, + "messages_delivered": messages_delivered, + "xread_count": xread_count, + } + }, + ) _listener_tasks.pop(queue_id, None) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/helpers.py b/autogpt_platform/backend/backend/api/features/chat/tools/helpers.py new file mode 100644 index 0000000000..cf53605ac0 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/helpers.py @@ -0,0 +1,29 @@ +"""Shared helpers for chat tools.""" + +from typing import Any + + +def get_inputs_from_schema( + input_schema: dict[str, Any], + exclude_fields: set[str] | None = None, +) -> list[dict[str, Any]]: + """Extract input field info from JSON schema.""" + if not isinstance(input_schema, dict): + return [] + + exclude = exclude_fields or set() + properties = input_schema.get("properties", {}) + required = set(input_schema.get("required", [])) + + return [ + { + "name": name, + "title": schema.get("title", name), + "type": schema.get("type", "string"), + "description": schema.get("description", ""), + "required": name in required, + "default": schema.get("default"), + } + for name, schema in properties.items() + if name not in exclude + ] diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py index 73d4cf81f2..a9f19bcf62 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py @@ -24,6 +24,7 @@ from backend.util.timezone_utils import ( ) from .base import BaseTool +from .helpers import get_inputs_from_schema from .models import ( AgentDetails, AgentDetailsResponse, @@ -261,7 +262,7 @@ class RunAgentTool(BaseTool): ), requirements={ "credentials": requirements_creds_list, - "inputs": self._get_inputs_list(graph.input_schema), + "inputs": get_inputs_from_schema(graph.input_schema), "execution_modes": self._get_execution_modes(graph), }, ), @@ -369,22 +370,6 @@ class RunAgentTool(BaseTool): session_id=session_id, ) - def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]: - """Extract inputs list from schema.""" - inputs_list = [] - if isinstance(input_schema, dict) and "properties" in input_schema: - for field_name, field_schema in input_schema["properties"].items(): - inputs_list.append( - { - "name": field_name, - "title": field_schema.get("title", field_name), - "type": field_schema.get("type", "string"), - "description": field_schema.get("description", ""), - "required": field_name in input_schema.get("required", []), - } - ) - return inputs_list - def _get_execution_modes(self, graph: GraphModel) -> list[str]: """Get available execution modes for the graph.""" trigger_info = graph.trigger_setup_info @@ -398,7 +383,7 @@ class RunAgentTool(BaseTool): suffix: str, ) -> str: """Build a message describing available inputs for an agent.""" - inputs_list = self._get_inputs_list(graph.input_schema) + inputs_list = get_inputs_from_schema(graph.input_schema) required_names = [i["name"] for i in inputs_list if i["required"]] optional_names = [i["name"] for i in inputs_list if not i["required"]] diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py index 590f81ff23..fc4a470fdd 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py @@ -12,14 +12,15 @@ from backend.api.features.chat.tools.find_block import ( COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES, ) -from backend.data.block import get_block +from backend.data.block import AnyBlockSchema, get_block from backend.data.execution import ExecutionContext -from backend.data.model import CredentialsMetaInput +from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput from backend.data.workspace import get_or_create_workspace from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.util.exceptions import BlockError from .base import BaseTool +from .helpers import get_inputs_from_schema from .models import ( BlockOutputResponse, ErrorResponse, @@ -28,7 +29,10 @@ from .models import ( ToolResponseBase, UserReadiness, ) -from .utils import build_missing_credentials_from_field_info +from .utils import ( + build_missing_credentials_from_field_info, + match_credentials_to_requirements, +) logger = logging.getLogger(__name__) @@ -77,91 +81,6 @@ class RunBlockTool(BaseTool): def requires_auth(self) -> bool: return True - async def _check_block_credentials( - self, - user_id: str, - block: Any, - input_data: dict[str, Any] | None = None, - ) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]: - """ - Check if user has required credentials for a block. - - Args: - user_id: User ID - block: Block to check credentials for - input_data: Input data for the block (used to determine provider via discriminator) - - Returns: - tuple[matched_credentials, missing_credentials] - """ - matched_credentials: dict[str, CredentialsMetaInput] = {} - missing_credentials: list[CredentialsMetaInput] = [] - input_data = input_data or {} - - # Get credential field info from block's input schema - credentials_fields_info = block.input_schema.get_credentials_fields_info() - - if not credentials_fields_info: - return matched_credentials, missing_credentials - - # Get user's available credentials - creds_manager = IntegrationCredentialsManager() - available_creds = await creds_manager.store.get_all_creds(user_id) - - for field_name, field_info in credentials_fields_info.items(): - effective_field_info = field_info - if field_info.discriminator and field_info.discriminator_mapping: - # Get discriminator from input, falling back to schema default - discriminator_value = input_data.get(field_info.discriminator) - if discriminator_value is None: - field = block.input_schema.model_fields.get( - field_info.discriminator - ) - if field and field.default is not PydanticUndefined: - discriminator_value = field.default - - if ( - discriminator_value - and discriminator_value in field_info.discriminator_mapping - ): - effective_field_info = field_info.discriminate(discriminator_value) - logger.debug( - f"Discriminated provider for {field_name}: " - f"{discriminator_value} -> {effective_field_info.provider}" - ) - - matching_cred = next( - ( - cred - for cred in available_creds - if cred.provider in effective_field_info.provider - and cred.type in effective_field_info.supported_types - ), - None, - ) - - if matching_cred: - matched_credentials[field_name] = CredentialsMetaInput( - id=matching_cred.id, - provider=matching_cred.provider, # type: ignore - type=matching_cred.type, - title=matching_cred.title, - ) - else: - # Create a placeholder for the missing credential - provider = next(iter(effective_field_info.provider), "unknown") - cred_type = next(iter(effective_field_info.supported_types), "api_key") - missing_credentials.append( - CredentialsMetaInput( - id=field_name, - provider=provider, # type: ignore - type=cred_type, # type: ignore - title=field_name.replace("_", " ").title(), - ) - ) - - return matched_credentials, missing_credentials - async def _execute( self, user_id: str | None, @@ -232,8 +151,8 @@ class RunBlockTool(BaseTool): logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}") creds_manager = IntegrationCredentialsManager() - matched_credentials, missing_credentials = await self._check_block_credentials( - user_id, block, input_data + matched_credentials, missing_credentials = ( + await self._resolve_block_credentials(user_id, block, input_data) ) if missing_credentials: @@ -362,29 +281,75 @@ class RunBlockTool(BaseTool): session_id=session_id, ) - def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]: + async def _resolve_block_credentials( + self, + user_id: str, + block: AnyBlockSchema, + input_data: dict[str, Any] | None = None, + ) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]: + """ + Resolve credentials for a block by matching user's available credentials. + + Args: + user_id: User ID + block: Block to resolve credentials for + input_data: Input data for the block (used to determine provider via discriminator) + + Returns: + tuple of (matched_credentials, missing_credentials) - matched credentials + are used for block execution, missing ones indicate setup requirements. + """ + input_data = input_data or {} + requirements = self._resolve_discriminated_credentials(block, input_data) + + if not requirements: + return {}, [] + + return await match_credentials_to_requirements(user_id, requirements) + + def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]: """Extract non-credential inputs from block schema.""" - inputs_list = [] schema = block.input_schema.jsonschema() - properties = schema.get("properties", {}) - required_fields = set(schema.get("required", [])) - - # Get credential field names to exclude credentials_fields = set(block.input_schema.get_credentials_fields().keys()) + return get_inputs_from_schema(schema, exclude_fields=credentials_fields) - for field_name, field_schema in properties.items(): - # Skip credential fields - if field_name in credentials_fields: - continue + def _resolve_discriminated_credentials( + self, + block: AnyBlockSchema, + input_data: dict[str, Any], + ) -> dict[str, CredentialsFieldInfo]: + """Resolve credential requirements, applying discriminator logic where needed.""" + credentials_fields_info = block.input_schema.get_credentials_fields_info() + if not credentials_fields_info: + return {} - inputs_list.append( - { - "name": field_name, - "title": field_schema.get("title", field_name), - "type": field_schema.get("type", "string"), - "description": field_schema.get("description", ""), - "required": field_name in required_fields, - } - ) + resolved: dict[str, CredentialsFieldInfo] = {} - return inputs_list + for field_name, field_info in credentials_fields_info.items(): + effective_field_info = field_info + + if field_info.discriminator and field_info.discriminator_mapping: + discriminator_value = input_data.get(field_info.discriminator) + if discriminator_value is None: + field = block.input_schema.model_fields.get( + field_info.discriminator + ) + if field and field.default is not PydanticUndefined: + discriminator_value = field.default + + if ( + discriminator_value + and discriminator_value in field_info.discriminator_mapping + ): + effective_field_info = field_info.discriminate(discriminator_value) + # For host-scoped credentials, add the discriminator value + # (e.g., URL) so _credential_is_for_host can match it + effective_field_info.discriminator_values.add(discriminator_value) + logger.debug( + f"Discriminated provider for {field_name}: " + f"{discriminator_value} -> {effective_field_info.provider}" + ) + + resolved[field_name] = effective_field_info + + return resolved diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py index cda0914809..80a842bf36 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py @@ -8,6 +8,7 @@ from backend.api.features.library import model as library_model from backend.api.features.store import db as store_db from backend.data.graph import GraphModel from backend.data.model import ( + Credentials, CredentialsFieldInfo, CredentialsMetaInput, HostScopedCredentials, @@ -223,6 +224,99 @@ async def get_or_create_library_agent( return library_agents[0] +async def match_credentials_to_requirements( + user_id: str, + requirements: dict[str, CredentialsFieldInfo], +) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]: + """ + Match user's credentials against a dictionary of credential requirements. + + This is the core matching logic shared by both graph and block credential matching. + """ + matched: dict[str, CredentialsMetaInput] = {} + missing: list[CredentialsMetaInput] = [] + + if not requirements: + return matched, missing + + available_creds = await get_user_credentials(user_id) + + for field_name, field_info in requirements.items(): + matching_cred = find_matching_credential(available_creds, field_info) + + if matching_cred: + try: + matched[field_name] = create_credential_meta_from_match(matching_cred) + except Exception as e: + logger.error( + f"Failed to create CredentialsMetaInput for field '{field_name}': " + f"provider={matching_cred.provider}, type={matching_cred.type}, " + f"credential_id={matching_cred.id}", + exc_info=True, + ) + provider = next(iter(field_info.provider), "unknown") + cred_type = next(iter(field_info.supported_types), "api_key") + missing.append( + CredentialsMetaInput( + id=field_name, + provider=provider, # type: ignore + type=cred_type, # type: ignore + title=f"{field_name} (validation failed: {e})", + ) + ) + else: + provider = next(iter(field_info.provider), "unknown") + cred_type = next(iter(field_info.supported_types), "api_key") + missing.append( + CredentialsMetaInput( + id=field_name, + provider=provider, # type: ignore + type=cred_type, # type: ignore + title=field_name.replace("_", " ").title(), + ) + ) + + return matched, missing + + +async def get_user_credentials(user_id: str) -> list[Credentials]: + """Get all available credentials for a user.""" + creds_manager = IntegrationCredentialsManager() + return await creds_manager.store.get_all_creds(user_id) + + +def find_matching_credential( + available_creds: list[Credentials], + field_info: CredentialsFieldInfo, +) -> Credentials | None: + """Find a credential that matches the required provider, type, scopes, and host.""" + for cred in available_creds: + if cred.provider not in field_info.provider: + continue + if cred.type not in field_info.supported_types: + continue + if cred.type == "oauth2" and not _credential_has_required_scopes( + cred, field_info + ): + continue + if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info): + continue + return cred + return None + + +def create_credential_meta_from_match( + matching_cred: Credentials, +) -> CredentialsMetaInput: + """Create a CredentialsMetaInput from a matched credential.""" + return CredentialsMetaInput( + id=matching_cred.id, + provider=matching_cred.provider, # type: ignore + type=matching_cred.type, + title=matching_cred.title, + ) + + async def match_user_credentials_to_graph( user_id: str, graph: GraphModel, @@ -331,8 +425,6 @@ def _credential_has_required_scopes( # If no scopes are required, any credential matches if not requirements.required_scopes: return True - - # Check that credential scopes are a superset of required scopes return set(credential.scopes).issuperset(requirements.required_scopes)