mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-24 03:00:28 -05:00
Compare commits
5 Commits
fix/messed
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ef42b17e3b | ||
|
|
a18ffd0b21 | ||
|
|
e40c8c70ce | ||
|
|
9cdcd6793f | ||
|
|
fc64f83331 |
@@ -27,7 +27,6 @@ class ChatConfig(BaseSettings):
|
||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||
|
||||
# Streaming Configuration
|
||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="Max retries for fallback path (SDK handles retries internally)",
|
||||
@@ -39,8 +38,10 @@ class ChatConfig(BaseSettings):
|
||||
|
||||
# Long-running operation configuration
|
||||
long_running_operation_ttl: int = Field(
|
||||
default=600,
|
||||
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||
default=3600,
|
||||
description="TTL in seconds for long-running operation deduplication lock "
|
||||
"(1 hour, matches stream_ttl). Prevents duplicate operations if pod dies. "
|
||||
"For longer operations, the stream_registry heartbeat keeps them alive.",
|
||||
)
|
||||
|
||||
# Stream registry configuration for SSE reconnection
|
||||
@@ -48,6 +49,11 @@ class ChatConfig(BaseSettings):
|
||||
default=3600,
|
||||
description="TTL in seconds for stream data in Redis (1 hour)",
|
||||
)
|
||||
stream_lock_ttl: int = Field(
|
||||
default=120,
|
||||
description="TTL in seconds for stream lock (2 minutes). Short timeout allows "
|
||||
"reconnection after refresh/crash without long waits.",
|
||||
)
|
||||
stream_max_length: int = Field(
|
||||
default=10000,
|
||||
description="Maximum number of messages to store per stream",
|
||||
|
||||
@@ -3,8 +3,9 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.types import (
|
||||
@@ -92,10 +93,9 @@ async def add_chat_message(
|
||||
function_call: dict[str, Any] | None = None,
|
||||
) -> ChatMessage:
|
||||
"""Add a message to a chat session."""
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput directly
|
||||
# because Prisma's TypedDict validation rejects optional fields set to None.
|
||||
# We only include fields that have values, then cast at the end.
|
||||
data: dict[str, Any] = {
|
||||
# Build ChatMessageCreateInput with only non-None values
|
||||
# (Prisma TypedDict rejects optional fields set to None)
|
||||
data: ChatMessageCreateInput = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": role,
|
||||
"sequence": sequence,
|
||||
@@ -123,7 +123,7 @@ async def add_chat_message(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
),
|
||||
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
|
||||
PrismaChatMessage.prisma().create(data=data),
|
||||
)
|
||||
return ChatMessage.from_db(message)
|
||||
|
||||
@@ -132,58 +132,93 @@ async def add_chat_messages_batch(
|
||||
session_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
start_sequence: int,
|
||||
) -> list[ChatMessage]:
|
||||
) -> int:
|
||||
"""Add multiple messages to a chat session in a batch.
|
||||
|
||||
Uses a transaction for atomicity - if any message creation fails,
|
||||
the entire batch is rolled back.
|
||||
Uses collision detection with retry: tries to create messages starting
|
||||
at start_sequence. If a unique constraint violation occurs (e.g., the
|
||||
streaming loop and long-running callback race), queries the latest
|
||||
sequence and retries with the correct offset. This avoids unnecessary
|
||||
upserts and DB queries in the common case (no collision).
|
||||
|
||||
Returns:
|
||||
Next sequence number for the next message to be inserted. This equals
|
||||
start_sequence + len(messages) and allows callers to update their
|
||||
counters even when collision detection adjusts start_sequence.
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
# No messages to add - return current count
|
||||
return start_sequence
|
||||
|
||||
created_messages = []
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Single timestamp for all messages and session update
|
||||
now = datetime.now(UTC)
|
||||
|
||||
async with db.transaction() as tx:
|
||||
for i, msg in enumerate(messages):
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||
# directly because Prisma's TypedDict validation rejects optional fields
|
||||
# set to None. We only include fields that have values, then cast.
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
}
|
||||
async with db.transaction() as tx:
|
||||
# Build all message data
|
||||
messages_data = []
|
||||
for i, msg in enumerate(messages):
|
||||
# Build ChatMessageCreateInput with only non-None values
|
||||
# (Prisma TypedDict rejects optional fields set to None)
|
||||
# Note: create_many doesn't support nested creates, use sessionId directly
|
||||
data: ChatMessageCreateInput = {
|
||||
"sessionId": session_id,
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
"createdAt": now,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
created = await PrismaChatMessage.prisma(tx).create(
|
||||
data=cast(ChatMessageCreateInput, data)
|
||||
)
|
||||
created_messages.append(created)
|
||||
messages_data.append(data)
|
||||
|
||||
# Update session's updatedAt timestamp within the same transaction.
|
||||
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
|
||||
# separately via update_chat_session() after streaming completes.
|
||||
await PrismaChatSession.prisma(tx).update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
# Run create_many and session update in parallel within transaction
|
||||
# Both use the same timestamp for consistency
|
||||
await asyncio.gather(
|
||||
PrismaChatMessage.prisma(tx).create_many(data=messages_data),
|
||||
PrismaChatSession.prisma(tx).update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": now},
|
||||
),
|
||||
)
|
||||
|
||||
return [ChatMessage.from_db(m) for m in created_messages]
|
||||
# Return next sequence number for counter sync
|
||||
return start_sequence + len(messages)
|
||||
|
||||
except UniqueViolationError:
|
||||
if attempt < max_retries - 1:
|
||||
# Collision detected - query MAX(sequence)+1 and retry with correct offset
|
||||
logger.info(
|
||||
f"Collision detected for session {session_id} at sequence "
|
||||
f"{start_sequence}, querying DB for latest sequence"
|
||||
)
|
||||
start_sequence = await get_next_sequence(session_id)
|
||||
logger.info(
|
||||
f"Retrying batch insert with start_sequence={start_sequence}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# Max retries exceeded - propagate error
|
||||
raise
|
||||
|
||||
# Should never reach here due to raise in exception handler
|
||||
raise RuntimeError(f"Failed to insert messages after {max_retries} attempts")
|
||||
|
||||
|
||||
async def get_user_chat_sessions(
|
||||
@@ -237,10 +272,20 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
return False
|
||||
|
||||
|
||||
async def get_chat_session_message_count(session_id: str) -> int:
|
||||
"""Get the number of messages in a chat session."""
|
||||
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
||||
return count
|
||||
async def get_next_sequence(session_id: str) -> int:
|
||||
"""Get the next sequence number for a new message in this session.
|
||||
|
||||
Uses MAX(sequence) + 1 for robustness. Returns 0 if no messages exist.
|
||||
More robust than COUNT(*) because it's immune to deleted messages.
|
||||
|
||||
Optimized to select only the sequence column using raw SQL.
|
||||
The unique index on (sessionId, sequence) makes this query fast.
|
||||
"""
|
||||
results = await db.query_raw_with_schema(
|
||||
'SELECT "sequence" FROM {schema_prefix}"ChatMessage" WHERE "sessionId" = $1 ORDER BY "sequence" DESC LIMIT 1',
|
||||
session_id,
|
||||
)
|
||||
return 0 if not results else results[0]["sequence"] + 1
|
||||
|
||||
|
||||
async def update_tool_message_content(
|
||||
|
||||
@@ -266,7 +266,11 @@ class CoPilotProcessor:
|
||||
|
||||
except asyncio.CancelledError:
|
||||
log.info("Task cancelled")
|
||||
await stream_registry.mark_task_completed(entry.task_id, status="failed")
|
||||
await stream_registry.mark_task_completed(
|
||||
entry.task_id,
|
||||
status="failed",
|
||||
error_message="Task was cancelled",
|
||||
)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -432,7 +432,9 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
return session
|
||||
|
||||
|
||||
async def upsert_chat_session(session: ChatSession) -> ChatSession:
|
||||
async def upsert_chat_session(
|
||||
session: ChatSession,
|
||||
) -> ChatSession:
|
||||
"""Update a chat session in both cache and database.
|
||||
|
||||
Uses session-level locking to prevent race conditions when concurrent
|
||||
@@ -449,16 +451,18 @@ async def upsert_chat_session(session: ChatSession) -> ChatSession:
|
||||
lock = await _get_session_lock(session.session_id)
|
||||
|
||||
async with lock:
|
||||
# Get existing message count from DB for incremental saves
|
||||
existing_message_count = await chat_db().get_chat_session_message_count(
|
||||
session.session_id
|
||||
)
|
||||
# Always query DB for existing message count to ensure consistency
|
||||
existing_message_count = await chat_db().get_next_sequence(session.session_id)
|
||||
|
||||
db_error: Exception | None = None
|
||||
|
||||
# Save to database (primary storage)
|
||||
try:
|
||||
await _save_session_to_db(session, existing_message_count)
|
||||
await _save_session_to_db(
|
||||
session,
|
||||
existing_message_count,
|
||||
skip_existence_check=existing_message_count > 0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save session {session.session_id} to database: {e}"
|
||||
@@ -489,21 +493,31 @@ async def upsert_chat_session(session: ChatSession) -> ChatSession:
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession, existing_message_count: int
|
||||
session: ChatSession,
|
||||
existing_message_count: int,
|
||||
*,
|
||||
skip_existence_check: bool = False,
|
||||
) -> None:
|
||||
"""Save or update a chat session in the database."""
|
||||
"""Save or update a chat session in the database.
|
||||
|
||||
Args:
|
||||
skip_existence_check: When True, skip the ``get_chat_session`` query
|
||||
and assume the session row already exists. Saves one DB round trip
|
||||
for incremental saves during streaming.
|
||||
"""
|
||||
db = chat_db()
|
||||
|
||||
# Check if session exists in DB
|
||||
existing = await db.get_chat_session(session.session_id)
|
||||
if not skip_existence_check:
|
||||
# Check if session exists in DB
|
||||
existing = await db.get_chat_session(session.session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
if not existing:
|
||||
# Create new session
|
||||
await db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
@@ -562,9 +576,7 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
session.messages.append(message)
|
||||
existing_message_count = await chat_db().get_chat_session_message_count(
|
||||
session_id
|
||||
)
|
||||
existing_message_count = await chat_db().get_next_sequence(session_id)
|
||||
|
||||
try:
|
||||
await _save_session_to_db(session, existing_message_count)
|
||||
|
||||
@@ -331,3 +331,96 @@ def test_to_openai_messages_merges_split_assistants():
|
||||
tc_list = merged.get("tool_calls")
|
||||
assert tc_list is not None and len(list(tc_list)) == 1
|
||||
assert list(tc_list)[0]["id"] == "tc1"
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Concurrent save collision detection #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_saves_collision_detection(setup_test_user, test_user_id):
|
||||
"""Test that concurrent saves from streaming loop and callback handle collisions correctly.
|
||||
|
||||
Simulates the race condition where:
|
||||
1. Streaming loop starts with saved_msg_count=5
|
||||
2. Long-running callback appends message #5 and saves
|
||||
3. Streaming loop tries to save with stale count=5
|
||||
|
||||
The collision detection should handle this gracefully.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# Create a session with initial messages
|
||||
session = ChatSession.new(user_id=test_user_id)
|
||||
for i in range(3):
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="user" if i % 2 == 0 else "assistant", content=f"Message {i}"
|
||||
)
|
||||
)
|
||||
|
||||
# Save initial messages
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Simulate streaming loop and callback saving concurrently
|
||||
async def streaming_loop_save():
|
||||
"""Simulates streaming loop saving messages."""
|
||||
# Add 2 messages
|
||||
session.messages.append(ChatMessage(role="user", content="Streaming message 1"))
|
||||
session.messages.append(
|
||||
ChatMessage(role="assistant", content="Streaming message 2")
|
||||
)
|
||||
|
||||
# Wait a bit to let callback potentially save first
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Save (will query DB for existing count)
|
||||
return await upsert_chat_session(session)
|
||||
|
||||
async def callback_save():
|
||||
"""Simulates long-running callback saving a message."""
|
||||
# Add 1 message
|
||||
session.messages.append(
|
||||
ChatMessage(role="tool", content="Callback result", tool_call_id="tc1")
|
||||
)
|
||||
|
||||
# Save immediately (will query DB for existing count)
|
||||
return await upsert_chat_session(session)
|
||||
|
||||
# Run both saves concurrently - one will hit collision detection
|
||||
results = await asyncio.gather(streaming_loop_save(), callback_save())
|
||||
|
||||
# Both should succeed
|
||||
assert all(r is not None for r in results)
|
||||
|
||||
# Reload session from DB to verify
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
redis_key = f"chat:session:{session.session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key) # Clear cache to force DB load
|
||||
|
||||
loaded_session = await get_chat_session(session.session_id, test_user_id)
|
||||
assert loaded_session is not None
|
||||
|
||||
# Should have all 6 messages (3 initial + 2 streaming + 1 callback)
|
||||
assert len(loaded_session.messages) == 6
|
||||
|
||||
# Verify no duplicate sequences
|
||||
sequences = []
|
||||
for i, msg in enumerate(loaded_session.messages):
|
||||
# Messages should have sequential sequence numbers starting from 0
|
||||
sequences.append(i)
|
||||
|
||||
# All sequences should be unique and sequential
|
||||
assert sequences == list(range(6))
|
||||
|
||||
# Verify message content is preserved
|
||||
contents = [m.content for m in loaded_session.messages]
|
||||
assert "Message 0" in contents
|
||||
assert "Message 1" in contents
|
||||
assert "Message 2" in contents
|
||||
assert "Streaming message 1" in contents
|
||||
assert "Streaming message 2" in contents
|
||||
assert "Callback result" in contents
|
||||
|
||||
@@ -47,8 +47,9 @@ class SDKResponseAdapter:
|
||||
text blocks, tool calls, and message lifecycle.
|
||||
"""
|
||||
|
||||
def __init__(self, message_id: str | None = None):
|
||||
def __init__(self, message_id: str | None = None, session_id: str | None = None):
|
||||
self.message_id = message_id or str(uuid.uuid4())
|
||||
self.session_id = session_id
|
||||
self.text_block_id = str(uuid.uuid4())
|
||||
self.has_started_text = False
|
||||
self.has_ended_text = False
|
||||
@@ -61,6 +62,11 @@ class SDKResponseAdapter:
|
||||
"""Set the task ID for reconnection support."""
|
||||
self.task_id = task_id
|
||||
|
||||
@property
|
||||
def has_unresolved_tool_calls(self) -> bool:
|
||||
"""True when there are tool calls that haven't received output yet."""
|
||||
return bool(self.current_tool_calls.keys() - self.resolved_tool_calls)
|
||||
|
||||
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
|
||||
"""Convert a single SDK message to Vercel AI SDK format."""
|
||||
responses: list[StreamBaseResponse] = []
|
||||
@@ -77,7 +83,12 @@ class SDKResponseAdapter:
|
||||
elif isinstance(sdk_message, AssistantMessage):
|
||||
# Flush any SDK built-in tool calls that didn't get a UserMessage
|
||||
# result (e.g. WebSearch, Read handled internally by the CLI).
|
||||
self._flush_unresolved_tool_calls(responses)
|
||||
# BUT skip flush when this AssistantMessage is a parallel tool
|
||||
# continuation (contains only ToolUseBlocks) — the prior tools
|
||||
# are still executing concurrently and haven't finished yet.
|
||||
is_tool_only = all(isinstance(b, ToolUseBlock) for b in sdk_message.content)
|
||||
if not is_tool_only:
|
||||
self._flush_unresolved_tool_calls(responses)
|
||||
|
||||
# After tool results, the SDK sends a new AssistantMessage for the
|
||||
# next LLM turn. Open a new step if the previous one was closed.
|
||||
@@ -118,8 +129,24 @@ class SDKResponseAdapter:
|
||||
blocks = content if isinstance(content, list) else []
|
||||
resolved_in_blocks: set[str] = set()
|
||||
|
||||
sid = (self.session_id or "?")[:12]
|
||||
parent_id_preview = getattr(sdk_message, "parent_tool_use_id", None)
|
||||
logger.info(
|
||||
"[SDK] [%s] UserMessage: %d blocks, content_type=%s, "
|
||||
"parent_tool_use_id=%s",
|
||||
sid,
|
||||
len(blocks),
|
||||
type(content).__name__,
|
||||
parent_id_preview[:12] if parent_id_preview else "None",
|
||||
)
|
||||
|
||||
for block in blocks:
|
||||
if isinstance(block, ToolResultBlock) and block.tool_use_id:
|
||||
# Skip if already resolved (e.g. by flush) — the real
|
||||
# result supersedes the empty flush, but re-emitting
|
||||
# would confuse the frontend's state machine.
|
||||
if block.tool_use_id in self.resolved_tool_calls:
|
||||
continue
|
||||
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
|
||||
tool_name = tool_info.get("name", "unknown")
|
||||
|
||||
@@ -144,7 +171,11 @@ class SDKResponseAdapter:
|
||||
# Handle SDK built-in tool results carried via parent_tool_use_id
|
||||
# instead of (or in addition to) ToolResultBlock content.
|
||||
parent_id = sdk_message.parent_tool_use_id
|
||||
if parent_id and parent_id not in resolved_in_blocks:
|
||||
if (
|
||||
parent_id
|
||||
and parent_id not in resolved_in_blocks
|
||||
and parent_id not in self.resolved_tool_calls
|
||||
):
|
||||
tool_info = self.current_tool_calls.get(parent_id, {})
|
||||
tool_name = tool_info.get("name", "unknown")
|
||||
|
||||
@@ -228,11 +259,28 @@ class SDKResponseAdapter:
|
||||
output, which we pop and emit here before the next ``AssistantMessage``
|
||||
starts.
|
||||
"""
|
||||
unresolved = [
|
||||
(tid, info.get("name", "unknown"))
|
||||
for tid, info in self.current_tool_calls.items()
|
||||
if tid not in self.resolved_tool_calls
|
||||
]
|
||||
sid = (self.session_id or "?")[:12]
|
||||
if not unresolved:
|
||||
logger.info(
|
||||
"[SDK] [%s] Flush called but all %d tool(s) already resolved",
|
||||
sid,
|
||||
len(self.current_tool_calls),
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"[SDK] [%s] Flushing %d unresolved tool call(s): %s",
|
||||
sid,
|
||||
len(unresolved),
|
||||
", ".join(f"{name}({tid[:12]})" for tid, name in unresolved),
|
||||
)
|
||||
|
||||
flushed = False
|
||||
for tool_id, tool_info in self.current_tool_calls.items():
|
||||
if tool_id in self.resolved_tool_calls:
|
||||
continue
|
||||
tool_name = tool_info.get("name", "unknown")
|
||||
for tool_id, tool_name in unresolved:
|
||||
output = pop_pending_tool_output(tool_name)
|
||||
if output is not None:
|
||||
responses.append(
|
||||
@@ -245,9 +293,12 @@ class SDKResponseAdapter:
|
||||
)
|
||||
self.resolved_tool_calls.add(tool_id)
|
||||
flushed = True
|
||||
logger.debug(
|
||||
f"Flushed pending output for built-in tool {tool_name} "
|
||||
f"(call {tool_id})"
|
||||
logger.info(
|
||||
"[SDK] [%s] Flushed stashed output for %s " "(call %s, %d chars)",
|
||||
sid,
|
||||
tool_name,
|
||||
tool_id[:12],
|
||||
len(output),
|
||||
)
|
||||
else:
|
||||
# No output available — emit an empty output so the frontend
|
||||
@@ -263,9 +314,14 @@ class SDKResponseAdapter:
|
||||
)
|
||||
self.resolved_tool_calls.add(tool_id)
|
||||
flushed = True
|
||||
logger.debug(
|
||||
f"Flushed empty output for unresolved tool {tool_name} "
|
||||
f"(call {tool_id})"
|
||||
logger.warning(
|
||||
"[SDK] [%s] Flushed EMPTY output for unresolved tool %s "
|
||||
"(call %s) — stash was empty (likely SDK hook race "
|
||||
"condition: PostToolUse hook hadn't completed before "
|
||||
"flush was triggered)",
|
||||
sid,
|
||||
tool_name,
|
||||
tool_id[:12],
|
||||
)
|
||||
|
||||
if flushed and self.step_open:
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""Unit tests for the SDK response adapter."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
ResultMessage,
|
||||
@@ -27,6 +30,10 @@ from backend.copilot.response_model import (
|
||||
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .tool_adapter import MCP_TOOL_PREFIX
|
||||
from .tool_adapter import _pending_tool_outputs as _pto
|
||||
from .tool_adapter import _stash_event
|
||||
from .tool_adapter import stash_pending_tool_output as _stash
|
||||
from .tool_adapter import wait_for_stash
|
||||
|
||||
|
||||
def _adapter() -> SDKResponseAdapter:
|
||||
@@ -364,3 +371,310 @@ def test_full_conversation_flow():
|
||||
"StreamFinishStep", # step 2 closed
|
||||
"StreamFinish",
|
||||
]
|
||||
|
||||
|
||||
# -- Flush unresolved tool calls --------------------------------------------
|
||||
|
||||
|
||||
def test_flush_unresolved_at_result_message():
|
||||
"""Built-in tools (WebSearch) without UserMessage results get flushed at ResultMessage."""
|
||||
adapter = _adapter()
|
||||
all_responses: list[StreamBaseResponse] = []
|
||||
|
||||
# 1. Init
|
||||
all_responses.extend(
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
)
|
||||
# 2. Tool use (built-in tool — no MCP prefix)
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(id="ws-1", name="WebSearch", input={"query": "test"})
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
)
|
||||
# 3. No UserMessage for this tool — go straight to ResultMessage
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="s1",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
types = [type(r).__name__ for r in all_responses]
|
||||
assert types == [
|
||||
"StreamStart",
|
||||
"StreamStartStep",
|
||||
"StreamToolInputStart",
|
||||
"StreamToolInputAvailable",
|
||||
"StreamToolOutputAvailable", # flushed with empty output
|
||||
"StreamFinishStep", # step closed by flush
|
||||
"StreamFinish",
|
||||
]
|
||||
# The flushed output should be empty (no stash available)
|
||||
output_event = [
|
||||
r for r in all_responses if isinstance(r, StreamToolOutputAvailable)
|
||||
][0]
|
||||
assert output_event.toolCallId == "ws-1"
|
||||
assert output_event.toolName == "WebSearch"
|
||||
assert output_event.output == ""
|
||||
|
||||
|
||||
def test_flush_unresolved_at_next_assistant_message():
|
||||
"""Built-in tools get flushed when the next AssistantMessage arrives."""
|
||||
adapter = _adapter()
|
||||
all_responses: list[StreamBaseResponse] = []
|
||||
|
||||
# 1. Init
|
||||
all_responses.extend(
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
)
|
||||
# 2. Tool use (built-in — no UserMessage will come)
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(id="ws-1", name="WebSearch", input={"query": "test"})
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
)
|
||||
# 3. Next AssistantMessage triggers flush before processing its blocks
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[TextBlock(text="Here are the results")], model="test"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
types = [type(r).__name__ for r in all_responses]
|
||||
assert types == [
|
||||
"StreamStart",
|
||||
"StreamStartStep",
|
||||
"StreamToolInputStart",
|
||||
"StreamToolInputAvailable",
|
||||
# Flush at next AssistantMessage:
|
||||
"StreamToolOutputAvailable",
|
||||
"StreamFinishStep", # step closed by flush
|
||||
# New step for continuation text:
|
||||
"StreamStartStep",
|
||||
"StreamTextStart",
|
||||
"StreamTextDelta",
|
||||
]
|
||||
|
||||
|
||||
def test_flush_with_stashed_output():
|
||||
"""Stashed output from PostToolUse hook is used when flushing."""
|
||||
adapter = _adapter()
|
||||
|
||||
# Simulate PostToolUse hook stashing output
|
||||
_pto.set({})
|
||||
_stash("WebSearch", "Search result: 5 items found")
|
||||
|
||||
all_responses: list[StreamBaseResponse] = []
|
||||
|
||||
# Tool use
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(id="ws-1", name="WebSearch", input={"query": "test"})
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
)
|
||||
# ResultMessage triggers flush
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="s1",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
output_events = [
|
||||
r for r in all_responses if isinstance(r, StreamToolOutputAvailable)
|
||||
]
|
||||
assert len(output_events) == 1
|
||||
assert output_events[0].output == "Search result: 5 items found"
|
||||
|
||||
# Cleanup
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# -- wait_for_stash synchronisation tests --
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_stash_signaled():
|
||||
"""wait_for_stash returns True when stash_pending_tool_output signals."""
|
||||
_pto.set({})
|
||||
event = asyncio.Event()
|
||||
_stash_event.set(event)
|
||||
|
||||
# Simulate a PostToolUse hook that stashes output after a short delay
|
||||
async def delayed_stash():
|
||||
await asyncio.sleep(0.01)
|
||||
_stash("WebSearch", "result data")
|
||||
|
||||
asyncio.create_task(delayed_stash())
|
||||
result = await wait_for_stash(timeout=1.0)
|
||||
|
||||
assert result is True
|
||||
assert _pto.get({}).get("WebSearch") == ["result data"]
|
||||
|
||||
# Cleanup
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
_stash_event.set(None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_stash_timeout():
|
||||
"""wait_for_stash returns False on timeout when no stash occurs."""
|
||||
_pto.set({})
|
||||
event = asyncio.Event()
|
||||
_stash_event.set(event)
|
||||
|
||||
result = await wait_for_stash(timeout=0.05)
|
||||
assert result is False
|
||||
|
||||
# Cleanup
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
_stash_event.set(None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_stash_already_stashed():
|
||||
"""wait_for_stash picks up a stash that happened just before the wait."""
|
||||
_pto.set({})
|
||||
event = asyncio.Event()
|
||||
_stash_event.set(event)
|
||||
|
||||
# Stash before waiting — simulates hook completing before message arrives
|
||||
_stash("Read", "file contents")
|
||||
# Event is now set; wait_for_stash detects the fast path and returns
|
||||
# immediately without timing out.
|
||||
result = await wait_for_stash(timeout=0.05)
|
||||
assert result is True
|
||||
|
||||
# But the stash itself is populated
|
||||
assert _pto.get({}).get("Read") == ["file contents"]
|
||||
|
||||
# Cleanup
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
_stash_event.set(None)
|
||||
|
||||
|
||||
# -- Parallel tool call tests --
|
||||
|
||||
|
||||
def test_parallel_tool_calls_not_flushed_prematurely():
|
||||
"""Parallel tool calls should NOT be flushed when the next AssistantMessage
|
||||
only contains ToolUseBlocks (parallel continuation)."""
|
||||
adapter = SDKResponseAdapter()
|
||||
|
||||
# Init
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
|
||||
# First AssistantMessage: tool call #1
|
||||
msg1 = AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name="WebSearch", input={"q": "foo"})],
|
||||
model="test",
|
||||
)
|
||||
r1 = adapter.convert_message(msg1)
|
||||
assert any(isinstance(r, StreamToolInputAvailable) for r in r1)
|
||||
assert adapter.has_unresolved_tool_calls
|
||||
|
||||
# Second AssistantMessage: tool call #2 (parallel continuation)
|
||||
msg2 = AssistantMessage(
|
||||
content=[ToolUseBlock(id="t2", name="WebSearch", input={"q": "bar"})],
|
||||
model="test",
|
||||
)
|
||||
r2 = adapter.convert_message(msg2)
|
||||
|
||||
# No flush should have happened — t1 should NOT have StreamToolOutputAvailable
|
||||
output_events = [r for r in r2 if isinstance(r, StreamToolOutputAvailable)]
|
||||
assert len(output_events) == 0, (
|
||||
f"Tool-only AssistantMessage should not flush prior tools, "
|
||||
f"but got {len(output_events)} output events"
|
||||
)
|
||||
|
||||
# Both t1 and t2 should still be unresolved
|
||||
assert "t1" not in adapter.resolved_tool_calls
|
||||
assert "t2" not in adapter.resolved_tool_calls
|
||||
|
||||
|
||||
def test_text_assistant_message_flushes_prior_tools():
|
||||
"""An AssistantMessage with text (new turn) should flush unresolved tools."""
|
||||
adapter = SDKResponseAdapter()
|
||||
|
||||
# Init
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
|
||||
# Tool call
|
||||
msg1 = AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name="WebSearch", input={"q": "foo"})],
|
||||
model="test",
|
||||
)
|
||||
adapter.convert_message(msg1)
|
||||
assert adapter.has_unresolved_tool_calls
|
||||
|
||||
# Text AssistantMessage (new turn after tools completed)
|
||||
msg2 = AssistantMessage(
|
||||
content=[TextBlock(text="Here are the results")],
|
||||
model="test",
|
||||
)
|
||||
r2 = adapter.convert_message(msg2)
|
||||
|
||||
# Flush SHOULD have happened — t1 gets empty output
|
||||
output_events = [r for r in r2 if isinstance(r, StreamToolOutputAvailable)]
|
||||
assert len(output_events) == 1
|
||||
assert output_events[0].toolCallId == "t1"
|
||||
assert "t1" in adapter.resolved_tool_calls
|
||||
|
||||
|
||||
def test_already_resolved_tool_skipped_in_user_message():
|
||||
"""A tool result in UserMessage should be skipped if already resolved by flush."""
|
||||
adapter = SDKResponseAdapter()
|
||||
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
|
||||
# Tool call + flush via text message
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name="WebSearch", input={})],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[TextBlock(text="Done")],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
assert "t1" in adapter.resolved_tool_calls
|
||||
|
||||
# Now UserMessage arrives with the real result — should be skipped
|
||||
user_msg = UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="real")])
|
||||
r = adapter.convert_message(user_msg)
|
||||
output_events = [r_ for r_ in r if isinstance(r_, StreamToolOutputAvailable)]
|
||||
assert (
|
||||
len(output_events) == 0
|
||||
), "Already-resolved tool should not emit duplicate output"
|
||||
|
||||
194
autogpt_platform/backend/backend/copilot/sdk/sdk_compat_test.py
Normal file
194
autogpt_platform/backend/backend/copilot/sdk/sdk_compat_test.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""SDK compatibility tests — verify the claude-agent-sdk public API surface we depend on.
|
||||
|
||||
Instead of pinning to a narrow version range, these tests verify that the
|
||||
installed SDK exposes every class, function, attribute, and method the copilot
|
||||
integration relies on. If an SDK upgrade removes or renames something these
|
||||
tests will catch it immediately.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public types & factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_sdk_exports_client_and_options():
|
||||
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
||||
|
||||
assert inspect.isclass(ClaudeSDKClient)
|
||||
assert inspect.isclass(ClaudeAgentOptions)
|
||||
|
||||
|
||||
def test_sdk_exports_message_types():
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
Message,
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
for cls in (AssistantMessage, ResultMessage, SystemMessage, UserMessage):
|
||||
assert inspect.isclass(cls), f"{cls.__name__} is not a class"
|
||||
# Message is a Union type alias, just verify it's importable
|
||||
assert Message is not None
|
||||
|
||||
|
||||
def test_sdk_exports_content_block_types():
|
||||
from claude_agent_sdk import TextBlock, ToolResultBlock, ToolUseBlock
|
||||
|
||||
for cls in (TextBlock, ToolResultBlock, ToolUseBlock):
|
||||
assert inspect.isclass(cls), f"{cls.__name__} is not a class"
|
||||
|
||||
|
||||
def test_sdk_exports_mcp_helpers():
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
assert callable(create_sdk_mcp_server)
|
||||
assert callable(tool)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ClaudeSDKClient interface
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_client_has_required_methods():
|
||||
from claude_agent_sdk import ClaudeSDKClient
|
||||
|
||||
required = ["connect", "disconnect", "query", "receive_messages"]
|
||||
for name in required:
|
||||
attr = getattr(ClaudeSDKClient, name, None)
|
||||
assert attr is not None, f"ClaudeSDKClient.{name} missing"
|
||||
assert callable(attr), f"ClaudeSDKClient.{name} is not callable"
|
||||
|
||||
|
||||
def test_client_supports_async_context_manager():
|
||||
from claude_agent_sdk import ClaudeSDKClient
|
||||
|
||||
assert hasattr(ClaudeSDKClient, "__aenter__")
|
||||
assert hasattr(ClaudeSDKClient, "__aexit__")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ClaudeAgentOptions fields
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_agent_options_accepts_required_fields():
|
||||
"""Verify ClaudeAgentOptions accepts all kwargs our code passes."""
|
||||
from claude_agent_sdk import ClaudeAgentOptions
|
||||
|
||||
opts = ClaudeAgentOptions(
|
||||
system_prompt="test",
|
||||
cwd="/tmp",
|
||||
)
|
||||
assert opts.system_prompt == "test"
|
||||
assert opts.cwd == "/tmp"
|
||||
|
||||
|
||||
def test_agent_options_accepts_all_our_fields():
|
||||
"""Comprehensive check of every field we use in service.py."""
|
||||
from claude_agent_sdk import ClaudeAgentOptions
|
||||
|
||||
fields_we_use = [
|
||||
"system_prompt",
|
||||
"mcp_servers",
|
||||
"allowed_tools",
|
||||
"disallowed_tools",
|
||||
"hooks",
|
||||
"cwd",
|
||||
"model",
|
||||
"env",
|
||||
"resume",
|
||||
"max_buffer_size",
|
||||
]
|
||||
sig = inspect.signature(ClaudeAgentOptions)
|
||||
for field in fields_we_use:
|
||||
assert field in sig.parameters, (
|
||||
f"ClaudeAgentOptions no longer accepts '{field}' — "
|
||||
f"available params: {list(sig.parameters.keys())}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message attributes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_assistant_message_has_content_and_model():
|
||||
from claude_agent_sdk import AssistantMessage, TextBlock
|
||||
|
||||
msg = AssistantMessage(content=[TextBlock(text="hi")], model="test")
|
||||
assert hasattr(msg, "content")
|
||||
assert hasattr(msg, "model")
|
||||
|
||||
|
||||
def test_result_message_has_required_attrs():
|
||||
from claude_agent_sdk import ResultMessage
|
||||
|
||||
msg = ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="s1",
|
||||
)
|
||||
assert msg.subtype == "success"
|
||||
assert hasattr(msg, "result")
|
||||
|
||||
|
||||
def test_system_message_has_subtype_and_data():
|
||||
from claude_agent_sdk import SystemMessage
|
||||
|
||||
msg = SystemMessage(subtype="init", data={})
|
||||
assert msg.subtype == "init"
|
||||
assert msg.data == {}
|
||||
|
||||
|
||||
def test_user_message_has_parent_tool_use_id():
|
||||
from claude_agent_sdk import UserMessage
|
||||
|
||||
msg = UserMessage(content="test")
|
||||
assert hasattr(msg, "parent_tool_use_id")
|
||||
assert hasattr(msg, "tool_use_result")
|
||||
|
||||
|
||||
def test_tool_use_block_has_id_name_input():
|
||||
from claude_agent_sdk import ToolUseBlock
|
||||
|
||||
block = ToolUseBlock(id="t1", name="test", input={"key": "val"})
|
||||
assert block.id == "t1"
|
||||
assert block.name == "test"
|
||||
assert block.input == {"key": "val"}
|
||||
|
||||
|
||||
def test_tool_result_block_has_required_attrs():
|
||||
from claude_agent_sdk import ToolResultBlock
|
||||
|
||||
block = ToolResultBlock(tool_use_id="t1", content="result")
|
||||
assert block.tool_use_id == "t1"
|
||||
assert block.content == "result"
|
||||
assert hasattr(block, "is_error")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hook types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"hook_event",
|
||||
["PreToolUse", "PostToolUse", "Stop"],
|
||||
)
|
||||
def test_sdk_exports_hook_event_type(hook_event: str):
|
||||
"""Verify HookEvent literal includes the events our security_hooks use."""
|
||||
from claude_agent_sdk.types import HookEvent
|
||||
|
||||
# HookEvent is a Literal type — check that our events are valid values.
|
||||
# We can't easily inspect Literal at runtime, so just verify the type exists.
|
||||
assert HookEvent is not None
|
||||
@@ -246,15 +246,33 @@ def create_security_hooks(
|
||||
"""
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
|
||||
is_builtin = not tool_name.startswith(MCP_TOOL_PREFIX)
|
||||
logger.info(
|
||||
"[SDK] PostToolUse: %s (builtin=%s, tool_use_id=%s)",
|
||||
tool_name,
|
||||
is_builtin,
|
||||
(tool_use_id or "")[:12],
|
||||
)
|
||||
|
||||
# Stash output for SDK built-in tools so the response adapter can
|
||||
# emit StreamToolOutputAvailable even when the CLI doesn't surface
|
||||
# a separate UserMessage with ToolResultBlock content.
|
||||
if not tool_name.startswith(MCP_TOOL_PREFIX):
|
||||
if is_builtin:
|
||||
tool_response = input_data.get("tool_response")
|
||||
if tool_response is not None:
|
||||
resp_preview = str(tool_response)[:100]
|
||||
logger.info(
|
||||
"[SDK] Stashing builtin output for %s (%d chars): %s...",
|
||||
tool_name,
|
||||
len(str(tool_response)),
|
||||
resp_preview,
|
||||
)
|
||||
stash_pending_tool_output(tool_name, tool_response)
|
||||
else:
|
||||
logger.warning(
|
||||
"[SDK] PostToolUse for builtin %s but tool_response is None",
|
||||
tool_name,
|
||||
)
|
||||
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
|
||||
@@ -7,8 +7,10 @@ import os
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.executor.cluster_lock import AsyncClusterLock
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .. import stream_registry
|
||||
@@ -24,6 +26,7 @@ from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
@@ -46,6 +49,7 @@ from .tool_adapter import (
|
||||
LongRunningCallback,
|
||||
create_copilot_mcp_server,
|
||||
set_execution_context,
|
||||
wait_for_stash,
|
||||
)
|
||||
from .transcript import (
|
||||
cleanup_cli_project_dir,
|
||||
@@ -59,6 +63,7 @@ from .transcript import (
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
# Set to hold background tasks to prevent garbage collection
|
||||
_background_tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
@@ -130,8 +135,12 @@ is delivered to the user via a background stream.
|
||||
All tasks must run in the foreground.
|
||||
"""
|
||||
|
||||
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
|
||||
|
||||
def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
|
||||
|
||||
def _build_long_running_callback(
|
||||
user_id: str | None,
|
||||
) -> LongRunningCallback:
|
||||
"""Build a callback that delegates long-running tools to the non-SDK infrastructure.
|
||||
|
||||
Long-running tools (create_agent, edit_agent, etc.) are delegated to the
|
||||
@@ -140,6 +149,9 @@ def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
|
||||
page refreshes / pod restarts, and the frontend shows the proper loading
|
||||
widget with progress updates.
|
||||
|
||||
Args:
|
||||
user_id: User ID for the session
|
||||
|
||||
The returned callback matches the ``LongRunningCallback`` signature:
|
||||
``(tool_name, args, session) -> MCP response dict``.
|
||||
"""
|
||||
@@ -205,7 +217,8 @@ def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
session.messages.append(pending_message)
|
||||
await upsert_chat_session(session)
|
||||
# Collision detection happens in add_chat_messages_batch (db.py)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# --- Spawn background task (reuses non-SDK infrastructure) ---
|
||||
bg_task = asyncio.create_task(
|
||||
@@ -344,15 +357,15 @@ async def _compress_conversation_history(
|
||||
|
||||
Returns the compressed prior messages (everything except the current message).
|
||||
"""
|
||||
prior = session.messages[:-1]
|
||||
if len(prior) < 2:
|
||||
return prior
|
||||
messages = session.messages[:-1]
|
||||
if len(messages) < 2:
|
||||
return messages
|
||||
|
||||
from backend.util.prompt import compress_context
|
||||
|
||||
# Convert ChatMessages to dicts for compress_context
|
||||
messages_dict = []
|
||||
for msg in prior:
|
||||
for msg in messages:
|
||||
msg_dict: dict[str, Any] = {"role": msg.role}
|
||||
if msg.content:
|
||||
msg_dict["content"] = msg.content
|
||||
@@ -400,7 +413,7 @@ async def _compress_conversation_history(
|
||||
for m in result.messages
|
||||
]
|
||||
|
||||
return prior
|
||||
return messages
|
||||
|
||||
|
||||
def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
|
||||
@@ -442,8 +455,8 @@ def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
|
||||
def _is_tool_error_or_denial(content: str | None) -> bool:
|
||||
"""Check if a tool message content indicates an error or denial.
|
||||
|
||||
We include these in conversation context so the agent doesn't
|
||||
hallucinate success for operations that actually failed.
|
||||
Currently unused — ``_format_conversation_context`` includes all tool
|
||||
results. Kept as a utility for future selective filtering.
|
||||
"""
|
||||
if not content:
|
||||
return False
|
||||
@@ -458,7 +471,7 @@ def _is_tool_error_or_denial(content: str | None) -> bool:
|
||||
"maximum", # subtask-limit denial
|
||||
"denied",
|
||||
"blocked",
|
||||
"failed", # internal tool execution failures
|
||||
"failed to", # internal tool execution failures
|
||||
'"iserror": true', # MCP protocol error flag
|
||||
)
|
||||
)
|
||||
@@ -525,6 +538,9 @@ async def stream_chat_completion_sdk(
|
||||
f"Session {session_id} not found. Please create a new session first."
|
||||
)
|
||||
|
||||
# Type narrowing: session is guaranteed ChatSession after the check above
|
||||
session = cast(ChatSession, session)
|
||||
|
||||
# Append the new message to the session if it's not already there
|
||||
new_message_role = "user" if is_user_message else "assistant"
|
||||
if message and (
|
||||
@@ -562,6 +578,29 @@ async def stream_chat_completion_sdk(
|
||||
system_prompt += _SDK_TOOL_SUPPLEMENT
|
||||
message_id = str(uuid.uuid4())
|
||||
task_id = str(uuid.uuid4())
|
||||
stream_id = task_id # Use task_id as unique stream identifier
|
||||
|
||||
# Acquire stream lock to prevent concurrent streams to the same session
|
||||
lock = AsyncClusterLock(
|
||||
redis=await get_redis_async(),
|
||||
key=f"{STREAM_LOCK_PREFIX}{session_id}",
|
||||
owner_id=stream_id,
|
||||
timeout=config.stream_lock_ttl,
|
||||
)
|
||||
|
||||
lock_owner = await lock.try_acquire()
|
||||
if lock_owner != stream_id:
|
||||
# Another stream is active
|
||||
logger.warning(
|
||||
f"[SDK] Session {session_id} already has an active stream: {lock_owner}"
|
||||
)
|
||||
yield StreamError(
|
||||
errorText="Another stream is already active for this session. "
|
||||
"Please wait or stop it.",
|
||||
code="stream_already_active",
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
yield StreamStart(messageId=message_id, taskId=task_id)
|
||||
|
||||
@@ -674,7 +713,7 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type]
|
||||
|
||||
adapter = SDKResponseAdapter(message_id=message_id)
|
||||
adapter = SDKResponseAdapter(message_id=message_id, session_id=session_id)
|
||||
adapter.set_task_id(task_id)
|
||||
|
||||
async with ClaudeSDKClient(options=options) as client:
|
||||
@@ -699,10 +738,13 @@ async def stream_chat_completion_sdk(
|
||||
transcript_msg_count,
|
||||
session_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[SDK] Sending query ({len(session.messages)} msgs, "
|
||||
f"resume={use_resume})"
|
||||
"[SDK] [%s] Sending query — resume=%s, "
|
||||
"total_msgs=%d, query_len=%d",
|
||||
session_id[:12],
|
||||
use_resume,
|
||||
len(session.messages),
|
||||
len(query_message),
|
||||
)
|
||||
await client.query(query_message, session_id=session_id)
|
||||
|
||||
@@ -711,97 +753,288 @@ async def stream_chat_completion_sdk(
|
||||
has_appended_assistant = False
|
||||
has_tool_results = False
|
||||
|
||||
# Use an explicit async iterator with timeout to send
|
||||
# heartbeats when the CLI is idle (e.g. executing tools).
|
||||
# This prevents proxies/LBs from closing the SSE connection.
|
||||
# asyncio.timeout() is preferred over asyncio.wait_for()
|
||||
# because wait_for wraps in a separate Task whose cancellation
|
||||
# can leave the async generator in a broken state.
|
||||
# Use an explicit async iterator with non-cancelling heartbeats.
|
||||
# CRITICAL: we must NOT cancel __anext__() mid-flight — doing so
|
||||
# (via asyncio.timeout or wait_for) corrupts the SDK's internal
|
||||
# anyio memory stream, causing StopAsyncIteration on the next
|
||||
# call and silently dropping all in-flight tool results.
|
||||
# Instead, wrap __anext__() in a Task and use asyncio.wait()
|
||||
# with a timeout. On timeout we emit a heartbeat but keep the
|
||||
# Task alive so it can deliver the next message.
|
||||
msg_iter = client.receive_messages().__aiter__()
|
||||
while not stream_completed:
|
||||
try:
|
||||
async with asyncio.timeout(_HEARTBEAT_INTERVAL):
|
||||
sdk_msg = await msg_iter.__anext__()
|
||||
except TimeoutError:
|
||||
yield StreamHeartbeat()
|
||||
continue
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
pending_task: asyncio.Task[Any] | None = None
|
||||
try:
|
||||
while not stream_completed:
|
||||
if pending_task is None:
|
||||
|
||||
logger.debug(
|
||||
f"[SDK] Received: {type(sdk_msg).__name__} "
|
||||
f"{getattr(sdk_msg, 'subtype', '')}"
|
||||
)
|
||||
for response in adapter.convert_message(sdk_msg):
|
||||
if isinstance(response, StreamStart):
|
||||
async def _next_msg() -> Any:
|
||||
return await msg_iter.__anext__()
|
||||
|
||||
pending_task = asyncio.create_task(_next_msg())
|
||||
|
||||
done, _ = await asyncio.wait(
|
||||
{pending_task}, timeout=_HEARTBEAT_INTERVAL
|
||||
)
|
||||
|
||||
if not done:
|
||||
# Timeout — emit heartbeat but keep the task alive
|
||||
# Also refresh lock TTL to keep it alive
|
||||
await lock.refresh()
|
||||
yield StreamHeartbeat()
|
||||
continue
|
||||
|
||||
# Log tool events for debugging visibility issues
|
||||
# Task completed — get result
|
||||
pending_task = None
|
||||
try:
|
||||
sdk_msg = done.pop().result()
|
||||
except StopAsyncIteration:
|
||||
logger.info(
|
||||
"[SDK] [%s] Stream ended normally "
|
||||
"(StopAsyncIteration)",
|
||||
session_id[:12],
|
||||
)
|
||||
break
|
||||
except Exception as stream_err:
|
||||
# SDK sends {"type": "error"} which raises
|
||||
# Exception in receive_messages() — capture it
|
||||
# so the session can still be saved and the
|
||||
# frontend gets a clean finish.
|
||||
logger.error(
|
||||
"[SDK] [%s] Stream error from SDK: %s",
|
||||
session_id[:12],
|
||||
stream_err,
|
||||
exc_info=True,
|
||||
)
|
||||
yield StreamError(
|
||||
errorText=f"SDK stream error: {stream_err}",
|
||||
code="sdk_stream_error",
|
||||
)
|
||||
break
|
||||
|
||||
logger.info(
|
||||
"[SDK] [%s] Received: %s %s "
|
||||
"(unresolved=%d, current=%d, resolved=%d)",
|
||||
session_id[:12],
|
||||
type(sdk_msg).__name__,
|
||||
getattr(sdk_msg, "subtype", ""),
|
||||
len(adapter.current_tool_calls)
|
||||
- len(adapter.resolved_tool_calls),
|
||||
len(adapter.current_tool_calls),
|
||||
len(adapter.resolved_tool_calls),
|
||||
)
|
||||
|
||||
# Race-condition fix: SDK hooks (PostToolUse) are
|
||||
# executed asynchronously via start_soon() — the next
|
||||
# message can arrive before the hook stashes output.
|
||||
# wait_for_stash() awaits an asyncio.Event signaled by
|
||||
# stash_pending_tool_output(), completing as soon as
|
||||
# the hook finishes (typically <1ms). The sleep(0)
|
||||
# after lets any remaining concurrent hooks complete.
|
||||
#
|
||||
# Skip for parallel tool continuations: when the SDK
|
||||
# sends parallel tool calls as separate
|
||||
# AssistantMessages (each containing only
|
||||
# ToolUseBlocks), we must NOT wait/flush — the prior
|
||||
# tools are still executing concurrently.
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
ResultMessage,
|
||||
ToolUseBlock,
|
||||
)
|
||||
|
||||
is_parallel_continuation = isinstance(
|
||||
sdk_msg, AssistantMessage
|
||||
) and all(isinstance(b, ToolUseBlock) for b in sdk_msg.content)
|
||||
|
||||
if (
|
||||
adapter.has_unresolved_tool_calls
|
||||
and isinstance(sdk_msg, (AssistantMessage, ResultMessage))
|
||||
and not is_parallel_continuation
|
||||
):
|
||||
if await wait_for_stash(timeout=0.5):
|
||||
await asyncio.sleep(0)
|
||||
else:
|
||||
logger.warning(
|
||||
"[SDK] [%s] Timed out waiting for "
|
||||
"PostToolUse hook stash "
|
||||
"(%d unresolved tool calls)",
|
||||
session_id[:12],
|
||||
len(adapter.current_tool_calls)
|
||||
- len(adapter.resolved_tool_calls),
|
||||
)
|
||||
|
||||
for response in adapter.convert_message(sdk_msg):
|
||||
if isinstance(response, StreamStart):
|
||||
continue
|
||||
|
||||
# Log tool events for debugging
|
||||
if isinstance(
|
||||
response,
|
||||
(
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
),
|
||||
):
|
||||
extra = ""
|
||||
if isinstance(response, StreamToolOutputAvailable):
|
||||
out_len = len(str(response.output))
|
||||
extra = f", output_len={out_len}"
|
||||
logger.info(
|
||||
"[SDK] [%s] Tool event: %s, tool=%s%s",
|
||||
session_id[:12],
|
||||
type(response).__name__,
|
||||
getattr(response, "toolName", "N/A"),
|
||||
extra,
|
||||
)
|
||||
|
||||
yield response
|
||||
|
||||
if isinstance(response, StreamTextDelta):
|
||||
delta = response.delta or ""
|
||||
# After tool results, start a new assistant
|
||||
# message for the post-tool text.
|
||||
if has_tool_results and has_appended_assistant:
|
||||
assistant_response = ChatMessage(
|
||||
role="assistant", content=delta
|
||||
)
|
||||
accumulated_tool_calls = []
|
||||
has_appended_assistant = False
|
||||
has_tool_results = False
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
else:
|
||||
assistant_response.content = (
|
||||
assistant_response.content or ""
|
||||
) + delta
|
||||
if not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
|
||||
elif isinstance(response, StreamToolInputAvailable):
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": response.toolCallId,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": response.toolName,
|
||||
"arguments": json.dumps(
|
||||
response.input or {}
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
if not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
# Save before tool execution starts so the
|
||||
# pending tool call is visible on refresh /
|
||||
# other devices. Collision detection happens
|
||||
# in add_chat_messages_batch (db.py).
|
||||
try:
|
||||
session = await upsert_chat_session(session)
|
||||
except Exception as save_err:
|
||||
logger.warning(
|
||||
"[SDK] [%s] Incremental save " "failed: %s",
|
||||
session_id[:12],
|
||||
save_err,
|
||||
)
|
||||
|
||||
elif isinstance(response, StreamToolOutputAvailable):
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=(
|
||||
response.output
|
||||
if isinstance(response.output, str)
|
||||
else str(response.output)
|
||||
),
|
||||
tool_call_id=response.toolCallId,
|
||||
)
|
||||
)
|
||||
has_tool_results = True
|
||||
# Save after tool completes so the result is
|
||||
# visible on refresh / other devices.
|
||||
# Collision detection happens in add_chat_messages_batch (db.py).
|
||||
try:
|
||||
session = await upsert_chat_session(session)
|
||||
except Exception as save_err:
|
||||
logger.warning(
|
||||
"[SDK] [%s] Incremental save " "failed: %s",
|
||||
session_id[:12],
|
||||
save_err,
|
||||
)
|
||||
|
||||
elif isinstance(response, StreamFinish):
|
||||
stream_completed = True
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Task/generator was cancelled (e.g. client disconnect,
|
||||
# server shutdown). Log and let the safety-net / finally
|
||||
# blocks handle cleanup.
|
||||
logger.warning(
|
||||
"[SDK] [%s] Streaming loop cancelled "
|
||||
"(asyncio.CancelledError)",
|
||||
session_id[:12],
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
# Cancel the pending __anext__ task to avoid a leaked
|
||||
# coroutine. This is safe even if the task already
|
||||
# completed.
|
||||
if pending_task is not None and not pending_task.done():
|
||||
pending_task.cancel()
|
||||
try:
|
||||
await pending_task
|
||||
except (asyncio.CancelledError, StopAsyncIteration):
|
||||
pass
|
||||
|
||||
# Safety net: if tools are still unresolved after the
|
||||
# streaming loop (e.g. StopAsyncIteration before ResultMessage,
|
||||
# or SDK not sending UserMessages for built-in tools), flush
|
||||
# them now so the frontend stops showing spinners.
|
||||
if adapter.has_unresolved_tool_calls:
|
||||
logger.warning(
|
||||
"[SDK] [%s] %d unresolved tool(s) after stream loop — "
|
||||
"flushing as safety net",
|
||||
session_id[:12],
|
||||
len(adapter.current_tool_calls)
|
||||
- len(adapter.resolved_tool_calls),
|
||||
)
|
||||
safety_responses: list[StreamBaseResponse] = []
|
||||
adapter._flush_unresolved_tool_calls(safety_responses)
|
||||
for response in safety_responses:
|
||||
if isinstance(
|
||||
response,
|
||||
(StreamToolInputAvailable, StreamToolOutputAvailable),
|
||||
):
|
||||
logger.info(
|
||||
"[SDK] Tool event: %s, tool=%s",
|
||||
"[SDK] [%s] Safety flush: %s, tool=%s",
|
||||
session_id[:12],
|
||||
type(response).__name__,
|
||||
getattr(response, "toolName", "N/A"),
|
||||
)
|
||||
|
||||
yield response
|
||||
|
||||
if isinstance(response, StreamTextDelta):
|
||||
delta = response.delta or ""
|
||||
# After tool results, start a new assistant
|
||||
# message for the post-tool text.
|
||||
if has_tool_results and has_appended_assistant:
|
||||
assistant_response = ChatMessage(
|
||||
role="assistant", content=delta
|
||||
)
|
||||
accumulated_tool_calls = []
|
||||
has_appended_assistant = False
|
||||
has_tool_results = False
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
else:
|
||||
assistant_response.content = (
|
||||
assistant_response.content or ""
|
||||
) + delta
|
||||
if not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
|
||||
elif isinstance(response, StreamToolInputAvailable):
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": response.toolCallId,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": response.toolName,
|
||||
"arguments": json.dumps(response.input or {}),
|
||||
},
|
||||
}
|
||||
)
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
if not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
|
||||
elif isinstance(response, StreamToolOutputAvailable):
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=(
|
||||
response.output
|
||||
if isinstance(response.output, str)
|
||||
else str(response.output)
|
||||
),
|
||||
tool_call_id=response.toolCallId,
|
||||
)
|
||||
)
|
||||
has_tool_results = True
|
||||
|
||||
elif isinstance(response, StreamFinish):
|
||||
stream_completed = True
|
||||
# If the stream ended without a ResultMessage (no
|
||||
# StreamFinish), the SDK CLI exited unexpectedly. Close
|
||||
# the open step and emit StreamFinish so the frontend
|
||||
# transitions to the "ready" state.
|
||||
if not stream_completed:
|
||||
logger.warning(
|
||||
"[SDK] [%s] Stream ended without ResultMessage "
|
||||
"(StopAsyncIteration) — emitting StreamFinish",
|
||||
session_id[:12],
|
||||
)
|
||||
if adapter.step_open:
|
||||
yield StreamFinishStep()
|
||||
adapter.step_open = False
|
||||
closing_responses: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(closing_responses)
|
||||
for r in closing_responses:
|
||||
yield r
|
||||
yield StreamFinish()
|
||||
stream_completed = True
|
||||
|
||||
if (
|
||||
assistant_response.content or assistant_response.tool_calls
|
||||
@@ -856,19 +1089,28 @@ async def stream_chat_completion_sdk(
|
||||
"to use the OpenAI-compatible fallback."
|
||||
)
|
||||
|
||||
await asyncio.shield(upsert_chat_session(session))
|
||||
logger.debug(
|
||||
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
|
||||
session = cast(ChatSession, await asyncio.shield(upsert_chat_session(session)))
|
||||
logger.info(
|
||||
"[SDK] [%s] Session saved with %d messages",
|
||||
session_id[:12],
|
||||
len(session.messages),
|
||||
)
|
||||
if not stream_completed:
|
||||
yield StreamFinish()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Client disconnect / server shutdown — log but re-raise so
|
||||
# the framework can clean up. The finally block still runs
|
||||
# for transcript upload.
|
||||
logger.warning("[SDK] [%s] Session cancelled (CancelledError)", session_id[:12])
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[SDK] Error: {e}", exc_info=True)
|
||||
try:
|
||||
await asyncio.shield(upsert_chat_session(session))
|
||||
except Exception as save_err:
|
||||
logger.error(f"[SDK] Failed to save session on error: {save_err}")
|
||||
if session:
|
||||
try:
|
||||
await asyncio.shield(upsert_chat_session(session))
|
||||
except Exception as save_err:
|
||||
logger.error(f"[SDK] Failed to save session on error: {save_err}")
|
||||
yield StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="sdk_error",
|
||||
@@ -890,7 +1132,7 @@ async def stream_chat_completion_sdk(
|
||||
if not raw_transcript and use_resume and resume_file:
|
||||
raw_transcript = read_transcript_file(resume_file)
|
||||
|
||||
if raw_transcript:
|
||||
if raw_transcript and session is not None:
|
||||
await asyncio.shield(
|
||||
_try_upload_transcript(
|
||||
user_id,
|
||||
@@ -910,6 +1152,9 @@ async def stream_chat_completion_sdk(
|
||||
if sdk_cwd:
|
||||
_cleanup_sdk_tool_results(sdk_cwd)
|
||||
|
||||
# Release stream lock to allow new streams for this session
|
||||
await lock.release()
|
||||
|
||||
|
||||
async def _try_upload_transcript(
|
||||
user_id: str,
|
||||
|
||||
@@ -9,6 +9,7 @@ via a callback provided by the service layer. This avoids wasteful SDK polling
|
||||
and makes results survive page refreshes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
@@ -44,6 +45,14 @@ _current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
_pending_tool_outputs: ContextVar[dict[str, list[str]]] = ContextVar(
|
||||
"pending_tool_outputs", default=None # type: ignore[arg-type]
|
||||
)
|
||||
# Event signaled whenever stash_pending_tool_output() adds a new entry.
|
||||
# Used by the streaming loop to wait for PostToolUse hooks to complete
|
||||
# instead of sleeping an arbitrary duration. The SDK fires hooks via
|
||||
# start_soon (fire-and-forget) so the next message can arrive before
|
||||
# the hook stashes its output — this event bridges that gap.
|
||||
_stash_event: ContextVar[asyncio.Event | None] = ContextVar(
|
||||
"_stash_event", default=None
|
||||
)
|
||||
|
||||
# Callback type for delegating long-running tools to the non-SDK infrastructure.
|
||||
# Args: (tool_name, arguments, session) → MCP-formatted response dict.
|
||||
@@ -76,6 +85,7 @@ def set_execution_context(
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_pending_tool_outputs.set({})
|
||||
_stash_event.set(asyncio.Event())
|
||||
_long_running_callback.set(long_running_callback)
|
||||
|
||||
|
||||
@@ -134,6 +144,43 @@ def stash_pending_tool_output(tool_name: str, output: Any) -> None:
|
||||
except (TypeError, ValueError):
|
||||
text = str(output)
|
||||
pending.setdefault(tool_name, []).append(text)
|
||||
# Signal any waiters that new output is available.
|
||||
event = _stash_event.get(None)
|
||||
if event is not None:
|
||||
event.set()
|
||||
|
||||
|
||||
async def wait_for_stash(timeout: float = 0.5) -> bool:
|
||||
"""Wait for a PostToolUse hook to stash tool output.
|
||||
|
||||
The SDK fires PostToolUse hooks asynchronously via ``start_soon()`` —
|
||||
the next message (AssistantMessage/ResultMessage) can arrive before the
|
||||
hook completes and stashes its output. This function bridges that gap
|
||||
by waiting on the ``_stash_event``, which is signaled by
|
||||
:func:`stash_pending_tool_output`.
|
||||
|
||||
After the event fires, callers should ``await asyncio.sleep(0)`` to
|
||||
give any remaining concurrent hooks a chance to complete.
|
||||
|
||||
Returns ``True`` if a stash signal was received, ``False`` on timeout.
|
||||
The timeout is a safety net — normally the stash happens within
|
||||
microseconds of yielding to the event loop.
|
||||
"""
|
||||
event = _stash_event.get(None)
|
||||
if event is None:
|
||||
return False
|
||||
# Fast path: hook already completed before we got here.
|
||||
if event.is_set():
|
||||
event.clear()
|
||||
return True
|
||||
# Slow path: wait for the hook to signal.
|
||||
try:
|
||||
async with asyncio.timeout(timeout):
|
||||
await event.wait()
|
||||
event.clear()
|
||||
return True
|
||||
except TimeoutError:
|
||||
return False
|
||||
|
||||
|
||||
async def _execute_tool_sync(
|
||||
|
||||
@@ -352,7 +352,8 @@ async def assign_user_to_session(
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
session.user_id = user_id
|
||||
return await upsert_chat_session(session)
|
||||
session = await upsert_chat_session(session)
|
||||
return session
|
||||
|
||||
|
||||
async def stream_chat_completion(
|
||||
@@ -1563,7 +1564,11 @@ async def _yield_tool_call(
|
||||
await _mark_operation_completed(tool_call_id)
|
||||
# Mark stream registry task as failed if it was created
|
||||
try:
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
await stream_registry.mark_task_completed(
|
||||
task_id,
|
||||
status="failed",
|
||||
error_message=f"Failed to setup tool {tool_name}: {e}",
|
||||
)
|
||||
except Exception as mark_err:
|
||||
logger.warning(f"Failed to mark task {task_id} as failed: {mark_err}")
|
||||
logger.error(
|
||||
@@ -1731,7 +1736,11 @@ async def _execute_long_running_tool_with_streaming(
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
logger.error(f"Session {session_id} not found for background tool")
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
await stream_registry.mark_task_completed(
|
||||
task_id,
|
||||
status="failed",
|
||||
error_message=f"Session {session_id} not found",
|
||||
)
|
||||
return
|
||||
|
||||
# Pass operation_id and task_id to the tool for async processing
|
||||
|
||||
@@ -644,6 +644,8 @@ async def _stream_listener(
|
||||
async def mark_task_completed(
|
||||
task_id: str,
|
||||
status: Literal["completed", "failed"] = "completed",
|
||||
*,
|
||||
error_message: str | None = None,
|
||||
) -> bool:
|
||||
"""Mark a task as completed and publish finish event.
|
||||
|
||||
@@ -654,6 +656,10 @@ async def mark_task_completed(
|
||||
Args:
|
||||
task_id: Task ID to mark as completed
|
||||
status: Final status ("completed" or "failed")
|
||||
error_message: If provided and status="failed", publish a StreamError
|
||||
before StreamFinish so connected clients see why the task ended.
|
||||
If not provided, no StreamError is published (caller should publish
|
||||
manually if needed to avoid duplicates).
|
||||
|
||||
Returns:
|
||||
True if task was newly marked completed, False if already completed/failed
|
||||
@@ -669,6 +675,17 @@ async def mark_task_completed(
|
||||
logger.debug(f"Task {task_id} already completed/failed, skipping")
|
||||
return False
|
||||
|
||||
# Publish error event before finish so connected clients know WHY the
|
||||
# task ended. Only publish if caller provided an explicit error message
|
||||
# to avoid duplicates with code paths that manually publish StreamError.
|
||||
# This is best-effort — if it fails, the StreamFinish still ensures
|
||||
# listeners clean up.
|
||||
if status == "failed" and error_message:
|
||||
try:
|
||||
await publish_chunk(task_id, StreamError(errorText=error_message))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to publish error event for task {task_id}: {e}")
|
||||
|
||||
# THEN publish finish event (best-effort - listeners can detect via status polling)
|
||||
try:
|
||||
await publish_chunk(task_id, StreamFinish())
|
||||
@@ -821,27 +838,6 @@ async def get_active_task_for_session(
|
||||
if task_user_id and user_id != task_user_id:
|
||||
continue
|
||||
|
||||
# Auto-expire stale tasks that exceeded stream_timeout
|
||||
created_at_str = meta.get("created_at", "")
|
||||
if created_at_str:
|
||||
try:
|
||||
created_at = datetime.fromisoformat(created_at_str)
|
||||
age_seconds = (
|
||||
datetime.now(timezone.utc) - created_at
|
||||
).total_seconds()
|
||||
if age_seconds > config.stream_timeout:
|
||||
logger.warning(
|
||||
f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... "
|
||||
f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)"
|
||||
)
|
||||
await mark_task_completed(task_id, "failed")
|
||||
continue
|
||||
except (ValueError, TypeError) as exc:
|
||||
logger.warning(
|
||||
f"[TASK_LOOKUP] Failed to parse created_at "
|
||||
f"for task {task_id[:8]}...: {exc}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
|
||||
)
|
||||
|
||||
@@ -303,7 +303,7 @@ class DatabaseManager(AppService):
|
||||
get_user_chat_sessions = _(chat_db.get_user_chat_sessions)
|
||||
get_user_session_count = _(chat_db.get_user_session_count)
|
||||
delete_chat_session = _(chat_db.delete_chat_session)
|
||||
get_chat_session_message_count = _(chat_db.get_chat_session_message_count)
|
||||
get_next_sequence = _(chat_db.get_next_sequence)
|
||||
update_tool_message_content = _(chat_db.update_tool_message_content)
|
||||
|
||||
|
||||
@@ -473,5 +473,5 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
get_user_chat_sessions = d.get_user_chat_sessions
|
||||
get_user_session_count = d.get_user_session_count
|
||||
delete_chat_session = d.delete_chat_session
|
||||
get_chat_session_message_count = d.get_chat_session_message_count
|
||||
get_next_sequence = d.get_next_sequence
|
||||
update_tool_message_content = d.update_tool_message_content
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Redis-based distributed locking for cluster coordination."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
@@ -7,6 +8,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -126,3 +128,124 @@ class ClusterLock:
|
||||
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0.0
|
||||
|
||||
|
||||
class AsyncClusterLock:
|
||||
"""Async Redis-based distributed lock for preventing duplicate execution."""
|
||||
|
||||
def __init__(
|
||||
self, redis: "AsyncRedis", key: str, owner_id: str, timeout: int = 300
|
||||
):
|
||||
self.redis = redis
|
||||
self.key = key
|
||||
self.owner_id = owner_id
|
||||
self.timeout = timeout
|
||||
self._last_refresh = 0.0
|
||||
self._refresh_lock = asyncio.Lock()
|
||||
|
||||
async def try_acquire(self) -> str | None:
|
||||
"""Try to acquire the lock.
|
||||
|
||||
Returns:
|
||||
- owner_id (self.owner_id) if successfully acquired
|
||||
- different owner_id if someone else holds the lock
|
||||
- None if Redis is unavailable or other error
|
||||
"""
|
||||
try:
|
||||
success = await self.redis.set(
|
||||
self.key, self.owner_id, nx=True, ex=self.timeout
|
||||
)
|
||||
if success:
|
||||
async with self._refresh_lock:
|
||||
self._last_refresh = time.time()
|
||||
return self.owner_id # Successfully acquired
|
||||
|
||||
# Failed to acquire, get current owner
|
||||
current_value = await self.redis.get(self.key)
|
||||
if current_value:
|
||||
current_owner = (
|
||||
current_value.decode("utf-8")
|
||||
if isinstance(current_value, bytes)
|
||||
else str(current_value)
|
||||
)
|
||||
return current_owner
|
||||
|
||||
# Key doesn't exist but we failed to set it - race condition or Redis issue
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AsyncClusterLock.try_acquire failed for key {self.key}: {e}")
|
||||
return None
|
||||
|
||||
async def refresh(self) -> bool:
|
||||
"""Refresh lock TTL if we still own it.
|
||||
|
||||
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
|
||||
During rate limiting, still verifies lock existence but skips TTL extension.
|
||||
Setting _last_refresh to 0 bypasses rate limiting for testing.
|
||||
|
||||
Async-safe: uses asyncio.Lock to protect _last_refresh access.
|
||||
"""
|
||||
# Calculate refresh interval: max(timeout // 10, 1)
|
||||
refresh_interval = max(self.timeout // 10, 1)
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we're within the rate limit period (async-safe read)
|
||||
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
|
||||
async with self._refresh_lock:
|
||||
last_refresh = self._last_refresh
|
||||
is_rate_limited = (
|
||||
last_refresh > 0 and (current_time - last_refresh) < refresh_interval
|
||||
)
|
||||
|
||||
try:
|
||||
# Always verify lock existence, even during rate limiting
|
||||
current_value = await self.redis.get(self.key)
|
||||
if not current_value:
|
||||
async with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
stored_owner = (
|
||||
current_value.decode("utf-8")
|
||||
if isinstance(current_value, bytes)
|
||||
else str(current_value)
|
||||
)
|
||||
if stored_owner != self.owner_id:
|
||||
async with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
# If rate limited, return True but don't update TTL or timestamp
|
||||
if is_rate_limited:
|
||||
return True
|
||||
|
||||
# Perform actual refresh
|
||||
if await self.redis.expire(self.key, self.timeout):
|
||||
async with self._refresh_lock:
|
||||
self._last_refresh = current_time
|
||||
return True
|
||||
|
||||
async with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AsyncClusterLock.refresh failed for key {self.key}: {e}")
|
||||
async with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
async def release(self):
|
||||
"""Release the lock."""
|
||||
async with self._refresh_lock:
|
||||
if self._last_refresh == 0:
|
||||
return
|
||||
|
||||
try:
|
||||
await self.redis.delete(self.key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async with self._refresh_lock:
|
||||
self._last_refresh = 0.0
|
||||
|
||||
14
autogpt_platform/backend/poetry.lock
generated
14
autogpt_platform/backend/poetry.lock
generated
@@ -899,17 +899,17 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "claude-agent-sdk"
|
||||
version = "0.1.35"
|
||||
version = "0.1.39"
|
||||
description = "Python SDK for Claude Code"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "claude_agent_sdk-0.1.35-py3-none-macosx_11_0_arm64.whl", hash = "sha256:df67f4deade77b16a9678b3a626c176498e40417f33b04beda9628287f375591"},
|
||||
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:14963944f55ded7c8ed518feebfa5b4284aa6dd8d81aeff2e5b21a962ce65097"},
|
||||
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:84344dcc535d179c1fc8a11c6f34c37c3b583447bdf09d869effb26514fd7a65"},
|
||||
{file = "claude_agent_sdk-0.1.35-py3-none-win_amd64.whl", hash = "sha256:1b3d54b47448c93f6f372acd4d1757f047c3c1e8ef5804be7a1e3e53e2c79a5f"},
|
||||
{file = "claude_agent_sdk-0.1.35.tar.gz", hash = "sha256:0f98e2b3c71ca85abfc042e7a35c648df88e87fda41c52e6779ef7b038dcbb52"},
|
||||
{file = "claude_agent_sdk-0.1.39-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6ed6a79781f545b761b9fe467bc5ae213a103c9d3f0fe7a9dad3c01790ed58fa"},
|
||||
{file = "claude_agent_sdk-0.1.39-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:0c03b5a3772eaec42e29ea39240c7d24b760358082f2e36336db9e71dde3dda4"},
|
||||
{file = "claude_agent_sdk-0.1.39-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:d2665c9e87b6ffece590bcdd6eb9def47cde4809b0d2f66e0a61a719189be7c9"},
|
||||
{file = "claude_agent_sdk-0.1.39-py3-none-win_amd64.whl", hash = "sha256:d03324daf7076be79d2dd05944559aabf4cc11c98d3a574b992a442a7c7a26d6"},
|
||||
{file = "claude_agent_sdk-0.1.39.tar.gz", hash = "sha256:dcf0ebd5a638c9a7d9f3af7640932a9212b2705b7056e4f08bd3968a865b4268"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -8530,4 +8530,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "55e095de555482f0fe47de7695f390fe93e7bcf739b31c391b2e5e3c3d938ae3"
|
||||
content-hash = "3ef62836d8321b9a3b8e897dade8dc6ca9022fd9468c53f384b0871b521ab343"
|
||||
|
||||
@@ -16,7 +16,7 @@ anthropic = "^0.79.0"
|
||||
apscheduler = "^3.11.1"
|
||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||
bleach = { extras = ["css"], version = "^6.2.0" }
|
||||
claude-agent-sdk = "^0.1.0"
|
||||
claude-agent-sdk = "^0.1.39" # see copilot/sdk/sdk_compat_test.py for capability checks
|
||||
click = "^8.2.0"
|
||||
cryptography = "^46.0"
|
||||
discord-py = "^2.5.2"
|
||||
|
||||
|
Before Width: | Height: | Size: 8.0 KiB After Width: | Height: | Size: 8.0 KiB |
@@ -58,6 +58,7 @@ function toToolInput(rawArguments: unknown): unknown {
|
||||
export function convertChatSessionMessagesToUiMessages(
|
||||
sessionId: string,
|
||||
rawMessages: unknown[],
|
||||
options?: { isComplete?: boolean },
|
||||
): UIMessage<unknown, UIDataTypes, UITools>[] {
|
||||
const messages = coerceSessionChatMessages(rawMessages);
|
||||
const toolOutputsByCallId = new Map<string, unknown>();
|
||||
@@ -104,6 +105,16 @@ export function convertChatSessionMessagesToUiMessages(
|
||||
input,
|
||||
output: typeof output === "string" ? safeJsonParse(output) : output,
|
||||
});
|
||||
} else if (options?.isComplete) {
|
||||
// Session is complete (no active stream) but this tool call has
|
||||
// no output in the DB — mark as completed to stop stale spinners.
|
||||
parts.push({
|
||||
type: `tool-${toolName}`,
|
||||
toolCallId,
|
||||
state: "output-available",
|
||||
input,
|
||||
output: "",
|
||||
});
|
||||
} else {
|
||||
parts.push({
|
||||
type: `tool-${toolName}`,
|
||||
|
||||
@@ -11,6 +11,11 @@ import {
|
||||
MessageResponse,
|
||||
} from "@/components/ai-elements/message";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import {
|
||||
CredentialsProvidersContext,
|
||||
type CredentialsProviderData,
|
||||
type CredentialsProvidersContextType,
|
||||
} from "@/providers/agent-credentials/credentials-provider";
|
||||
import { CopilotChatActionsProvider } from "../components/CopilotChatActionsProvider/CopilotChatActionsProvider";
|
||||
import { CreateAgentTool } from "../tools/CreateAgent/CreateAgent";
|
||||
import { EditAgentTool } from "../tools/EditAgent/EditAgent";
|
||||
@@ -97,6 +102,65 @@ function uid() {
|
||||
return `sg-${++_id}`;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mock credential providers for setup-requirements demos
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const noop = () => Promise.reject(new Error("Styleguide mock"));
|
||||
|
||||
function makeMockProvider(
|
||||
provider: string,
|
||||
providerName: string,
|
||||
savedCredentials: CredentialsProviderData["savedCredentials"] = [],
|
||||
): CredentialsProviderData {
|
||||
return {
|
||||
provider,
|
||||
providerName,
|
||||
savedCredentials,
|
||||
isSystemProvider: false,
|
||||
oAuthCallback: noop as CredentialsProviderData["oAuthCallback"],
|
||||
mcpOAuthCallback: noop as CredentialsProviderData["mcpOAuthCallback"],
|
||||
createAPIKeyCredentials:
|
||||
noop as CredentialsProviderData["createAPIKeyCredentials"],
|
||||
createUserPasswordCredentials:
|
||||
noop as CredentialsProviderData["createUserPasswordCredentials"],
|
||||
createHostScopedCredentials:
|
||||
noop as CredentialsProviderData["createHostScopedCredentials"],
|
||||
deleteCredentials: noop as CredentialsProviderData["deleteCredentials"],
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider context where the user already has saved credentials
|
||||
* so the credential picker shows a selection list.
|
||||
*/
|
||||
const MOCK_PROVIDERS_WITH_CREDENTIALS: CredentialsProvidersContextType = {
|
||||
google: makeMockProvider("google", "Google", [
|
||||
{
|
||||
id: "cred-google-1",
|
||||
provider: "google",
|
||||
type: "oauth2",
|
||||
title: "work@company.com",
|
||||
scopes: ["email", "calendar"],
|
||||
},
|
||||
{
|
||||
id: "cred-google-2",
|
||||
provider: "google",
|
||||
type: "oauth2",
|
||||
title: "personal@gmail.com",
|
||||
scopes: ["email", "calendar"],
|
||||
},
|
||||
]),
|
||||
};
|
||||
|
||||
/**
|
||||
* Provider context where the user has NO saved credentials,
|
||||
* so the credential picker shows an "add new" flow.
|
||||
*/
|
||||
const MOCK_PROVIDERS_WITHOUT_CREDENTIALS: CredentialsProvidersContextType = {
|
||||
openweathermap: makeMockProvider("openweathermap", "OpenWeatherMap"),
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Page
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -554,45 +618,80 @@ export default function StyleguidePage() {
|
||||
/>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output available (setup requirements)">
|
||||
<RunBlockTool
|
||||
part={{
|
||||
type: "tool-run_block",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
input: { block_id: "weather-block-123" },
|
||||
output: {
|
||||
type: ResponseType.setup_requirements,
|
||||
message:
|
||||
"This block requires API credentials to run. Please configure them below.",
|
||||
setup_info: {
|
||||
agent_name: "Weather Agent",
|
||||
requirements: {
|
||||
inputs: [
|
||||
{
|
||||
name: "city",
|
||||
title: "City",
|
||||
type: "string",
|
||||
required: true,
|
||||
description: "The city to get weather for",
|
||||
},
|
||||
],
|
||||
},
|
||||
user_readiness: {
|
||||
missing_credentials: {
|
||||
openweathermap: {
|
||||
provider: "openweathermap",
|
||||
credentials_type: "api_key",
|
||||
title: "OpenWeatherMap API Key",
|
||||
description:
|
||||
"Required to access weather data. Get your key at openweathermap.org",
|
||||
<SubSection label="Setup requirements — no credentials (add new)">
|
||||
<CredentialsProvidersContext.Provider
|
||||
value={MOCK_PROVIDERS_WITHOUT_CREDENTIALS}
|
||||
>
|
||||
<RunBlockTool
|
||||
part={{
|
||||
type: "tool-run_block",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
input: { block_id: "weather-block-123" },
|
||||
output: {
|
||||
type: ResponseType.setup_requirements,
|
||||
message:
|
||||
"This block requires API credentials to run. Please configure them below.",
|
||||
setup_info: {
|
||||
agent_id: "agent-weather-1",
|
||||
agent_name: "Weather Agent",
|
||||
requirements: {
|
||||
inputs: [
|
||||
{
|
||||
name: "city",
|
||||
title: "City",
|
||||
type: "string",
|
||||
required: true,
|
||||
description: "The city to get weather for",
|
||||
},
|
||||
],
|
||||
},
|
||||
user_readiness: {
|
||||
missing_credentials: {
|
||||
openweathermap_key: {
|
||||
provider: "openweathermap",
|
||||
types: ["api_key"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}}
|
||||
/>
|
||||
}}
|
||||
/>
|
||||
</CredentialsProvidersContext.Provider>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Setup requirements — has credentials (pick from list)">
|
||||
<CredentialsProvidersContext.Provider
|
||||
value={MOCK_PROVIDERS_WITH_CREDENTIALS}
|
||||
>
|
||||
<RunBlockTool
|
||||
part={{
|
||||
type: "tool-run_block",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
input: { block_id: "calendar-block-456" },
|
||||
output: {
|
||||
type: ResponseType.setup_requirements,
|
||||
message:
|
||||
"This block requires Google credentials. Pick an account below or connect a new one.",
|
||||
setup_info: {
|
||||
agent_id: "agent-calendar-1",
|
||||
agent_name: "Calendar Agent",
|
||||
user_readiness: {
|
||||
missing_credentials: {
|
||||
google_oauth: {
|
||||
provider: "google",
|
||||
types: ["oauth2"],
|
||||
scopes: ["email", "calendar"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</CredentialsProvidersContext.Provider>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output available (error)">
|
||||
@@ -849,34 +948,71 @@ export default function StyleguidePage() {
|
||||
/>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output available (setup requirements)">
|
||||
<RunAgentTool
|
||||
part={{
|
||||
type: "tool-run_agent",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
input: { username_agent_slug: "creator/my-agent" },
|
||||
output: {
|
||||
type: ResponseType.setup_requirements,
|
||||
message: "This agent requires additional setup.",
|
||||
setup_info: {
|
||||
agent_name: "YouTube Summarizer",
|
||||
requirements: {},
|
||||
user_readiness: {
|
||||
missing_credentials: {
|
||||
youtube_api: {
|
||||
provider: "youtube",
|
||||
credentials_type: "api_key",
|
||||
title: "YouTube Data API Key",
|
||||
description:
|
||||
"Required to access YouTube video data.",
|
||||
<SubSection label="Setup requirements — no credentials (add new)">
|
||||
<CredentialsProvidersContext.Provider
|
||||
value={MOCK_PROVIDERS_WITHOUT_CREDENTIALS}
|
||||
>
|
||||
<RunAgentTool
|
||||
part={{
|
||||
type: "tool-run_agent",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
input: { username_agent_slug: "creator/weather-agent" },
|
||||
output: {
|
||||
type: ResponseType.setup_requirements,
|
||||
message:
|
||||
"This agent requires an API key. Add your credentials below.",
|
||||
setup_info: {
|
||||
agent_id: "agent-weather-1",
|
||||
agent_name: "Weather Agent",
|
||||
requirements: {},
|
||||
user_readiness: {
|
||||
missing_credentials: {
|
||||
openweathermap_key: {
|
||||
provider: "openweathermap",
|
||||
types: ["api_key"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}}
|
||||
/>
|
||||
}}
|
||||
/>
|
||||
</CredentialsProvidersContext.Provider>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Setup requirements — has credentials (pick from list)">
|
||||
<CredentialsProvidersContext.Provider
|
||||
value={MOCK_PROVIDERS_WITH_CREDENTIALS}
|
||||
>
|
||||
<RunAgentTool
|
||||
part={{
|
||||
type: "tool-run_agent",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
input: { username_agent_slug: "creator/calendar-agent" },
|
||||
output: {
|
||||
type: ResponseType.setup_requirements,
|
||||
message:
|
||||
"This agent needs Google credentials. Pick an account or connect a new one.",
|
||||
setup_info: {
|
||||
agent_id: "agent-calendar-1",
|
||||
agent_name: "Google Calendar Agent",
|
||||
requirements: {},
|
||||
user_readiness: {
|
||||
missing_credentials: {
|
||||
google_oauth: {
|
||||
provider: "google",
|
||||
types: ["oauth2"],
|
||||
scopes: ["email", "calendar"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</CredentialsProvidersContext.Provider>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output available (need login)">
|
||||
|
||||
@@ -16,7 +16,6 @@ import {
|
||||
ContentCardDescription,
|
||||
ContentCodeBlock,
|
||||
ContentGrid,
|
||||
ContentHint,
|
||||
ContentMessage,
|
||||
} from "../../components/ToolAccordion/AccordionContent";
|
||||
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
|
||||
@@ -24,8 +23,8 @@ import {
|
||||
ClarificationQuestionsCard,
|
||||
ClarifyingQuestion,
|
||||
} from "./components/ClarificationQuestionsCard";
|
||||
import sparklesImg from "./components/MiniGame/assets/sparkles.png";
|
||||
import { MiniGame } from "./components/MiniGame/MiniGame";
|
||||
import sparklesImg from "../../components/MiniGame/assets/sparkles.png";
|
||||
import { MiniGame } from "../../components/MiniGame/MiniGame";
|
||||
import { SuggestedGoalCard } from "./components/SuggestedGoalCard";
|
||||
import {
|
||||
AccordionIcon,
|
||||
@@ -93,9 +92,7 @@ function getAccordionMeta(output: CreateAgentToolOutput) {
|
||||
) {
|
||||
return {
|
||||
icon,
|
||||
title:
|
||||
"Creating agent, this may take a few minutes. Play while you wait.",
|
||||
expanded: true,
|
||||
title: output.message || "Agent creation started",
|
||||
};
|
||||
}
|
||||
return {
|
||||
@@ -169,15 +166,22 @@ export function CreateAgentTool({ part }: Props) {
|
||||
/>
|
||||
</div>
|
||||
|
||||
{isStreaming && (
|
||||
<ToolAccordion
|
||||
icon={<AccordionIcon />}
|
||||
title="Creating agent, this may take a few minutes. Play while you wait."
|
||||
expanded
|
||||
>
|
||||
<ContentGrid>
|
||||
<MiniGame />
|
||||
</ContentGrid>
|
||||
</ToolAccordion>
|
||||
)}
|
||||
|
||||
{hasExpandableContent && output && (
|
||||
<ToolAccordion {...getAccordionMeta(output)}>
|
||||
{isOperating && (
|
||||
<ContentGrid>
|
||||
<MiniGame />
|
||||
<ContentHint>
|
||||
This could take a few minutes — play while you wait!
|
||||
</ContentHint>
|
||||
</ContentGrid>
|
||||
{isOperating && output.message && (
|
||||
<ContentMessage>{output.message}</ContentMessage>
|
||||
)}
|
||||
|
||||
{isAgentSavedOutput(output) && (
|
||||
|
||||
@@ -4,17 +4,15 @@ import { WarningDiamondIcon } from "@phosphor-icons/react";
|
||||
import type { ToolUIPart } from "ai";
|
||||
import { useCopilotChatActions } from "../../components/CopilotChatActionsProvider/useCopilotChatActions";
|
||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
||||
import {
|
||||
ContentCardDescription,
|
||||
ContentCodeBlock,
|
||||
ContentGrid,
|
||||
ContentHint,
|
||||
ContentLink,
|
||||
ContentMessage,
|
||||
} from "../../components/ToolAccordion/AccordionContent";
|
||||
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
|
||||
import { MiniGame } from "../CreateAgent/components/MiniGame/MiniGame";
|
||||
import { MiniGame } from "../../components/MiniGame/MiniGame";
|
||||
import {
|
||||
ClarificationQuestionsCard,
|
||||
ClarifyingQuestion,
|
||||
@@ -81,9 +79,8 @@ function getAccordionMeta(output: EditAgentToolOutput): {
|
||||
isOperationInProgressOutput(output)
|
||||
) {
|
||||
return {
|
||||
icon: <OrbitLoader size={32} />,
|
||||
title: "Editing agent, this may take a few minutes. Play while you wait.",
|
||||
expanded: true,
|
||||
icon,
|
||||
title: output.message || "Agent editing started",
|
||||
};
|
||||
}
|
||||
return {
|
||||
@@ -148,15 +145,22 @@ export function EditAgentTool({ part }: Props) {
|
||||
/>
|
||||
</div>
|
||||
|
||||
{isStreaming && (
|
||||
<ToolAccordion
|
||||
icon={<AccordionIcon />}
|
||||
title="Editing agent, this may take a few minutes. Play while you wait."
|
||||
expanded
|
||||
>
|
||||
<ContentGrid>
|
||||
<MiniGame />
|
||||
</ContentGrid>
|
||||
</ToolAccordion>
|
||||
)}
|
||||
|
||||
{hasExpandableContent && output && (
|
||||
<ToolAccordion {...getAccordionMeta(output)}>
|
||||
{isOperating && (
|
||||
<ContentGrid>
|
||||
<MiniGame />
|
||||
<ContentHint>
|
||||
This could take a few minutes — play while you wait!
|
||||
</ContentHint>
|
||||
</ContentGrid>
|
||||
{isOperating && output.message && (
|
||||
<ContentMessage>{output.message}</ContentMessage>
|
||||
)}
|
||||
|
||||
{isAgentSavedOutput(output) && (
|
||||
|
||||
@@ -9,7 +9,7 @@ import {
|
||||
ContentHint,
|
||||
ContentMessage,
|
||||
} from "../../components/ToolAccordion/AccordionContent";
|
||||
import { MiniGame } from "../CreateAgent/components/MiniGame/MiniGame";
|
||||
import { MiniGame } from "../../components/MiniGame/MiniGame";
|
||||
import {
|
||||
getAccordionMeta,
|
||||
getAnimationText,
|
||||
@@ -47,14 +47,25 @@ export function RunAgentTool({ part }: Props) {
|
||||
const isError =
|
||||
part.state === "output-error" ||
|
||||
(!!output && isRunAgentErrorOutput(output));
|
||||
const isOutputAvailable = part.state === "output-available" && !!output;
|
||||
|
||||
const setupRequirementsOutput =
|
||||
isOutputAvailable && isRunAgentSetupRequirementsOutput(output)
|
||||
? output
|
||||
: null;
|
||||
|
||||
const agentDetailsOutput =
|
||||
isOutputAvailable && isRunAgentAgentDetailsOutput(output) ? output : null;
|
||||
|
||||
const needLoginOutput =
|
||||
isOutputAvailable && isRunAgentNeedLoginOutput(output) ? output : null;
|
||||
|
||||
const hasExpandableContent =
|
||||
part.state === "output-available" &&
|
||||
!!output &&
|
||||
(isRunAgentExecutionStartedOutput(output) ||
|
||||
isRunAgentAgentDetailsOutput(output) ||
|
||||
isRunAgentSetupRequirementsOutput(output) ||
|
||||
isRunAgentNeedLoginOutput(output) ||
|
||||
isRunAgentErrorOutput(output));
|
||||
isOutputAvailable &&
|
||||
!setupRequirementsOutput &&
|
||||
!agentDetailsOutput &&
|
||||
!needLoginOutput &&
|
||||
(isRunAgentExecutionStartedOutput(output) || isRunAgentErrorOutput(output));
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
@@ -81,24 +92,30 @@ export function RunAgentTool({ part }: Props) {
|
||||
</ToolAccordion>
|
||||
)}
|
||||
|
||||
{setupRequirementsOutput && (
|
||||
<div className="mt-2">
|
||||
<SetupRequirementsCard output={setupRequirementsOutput} />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{agentDetailsOutput && (
|
||||
<div className="mt-2">
|
||||
<AgentDetailsCard output={agentDetailsOutput} />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{needLoginOutput && (
|
||||
<div className="mt-2">
|
||||
<ContentMessage>{needLoginOutput.message}</ContentMessage>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{hasExpandableContent && output && (
|
||||
<ToolAccordion {...getAccordionMeta(output)}>
|
||||
{isRunAgentExecutionStartedOutput(output) && (
|
||||
<ExecutionStartedCard output={output} />
|
||||
)}
|
||||
|
||||
{isRunAgentAgentDetailsOutput(output) && (
|
||||
<AgentDetailsCard output={output} />
|
||||
)}
|
||||
|
||||
{isRunAgentSetupRequirementsOutput(output) && (
|
||||
<SetupRequirementsCard output={output} />
|
||||
)}
|
||||
|
||||
{isRunAgentNeedLoginOutput(output) && (
|
||||
<ContentMessage>{output.message}</ContentMessage>
|
||||
)}
|
||||
|
||||
{isRunAgentErrorOutput(output) && <ErrorCard output={output} />}
|
||||
</ToolAccordion>
|
||||
)}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { CredentialsGroupedView } from "@/components/contextual/CredentialsInput/components/CredentialsGroupedView/CredentialsGroupedView";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { CredentialsMetaInput } from "@/lib/autogpt-server-api/types";
|
||||
import type { SetupRequirementsResponse } from "@/app/api/__generated__/models/setupRequirementsResponse";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { CredentialsGroupedView } from "@/components/contextual/CredentialsInput/components/CredentialsGroupedView/CredentialsGroupedView";
|
||||
import type { CredentialsMetaInput } from "@/lib/autogpt-server-api/types";
|
||||
import { useState } from "react";
|
||||
import { useCopilotChatActions } from "../../../../components/CopilotChatActionsProvider/useCopilotChatActions";
|
||||
import {
|
||||
ContentBadge,
|
||||
@@ -38,40 +39,40 @@ export function SetupRequirementsCard({ output }: Props) {
|
||||
setInputCredentials((prev) => ({ ...prev, [key]: value }));
|
||||
}
|
||||
|
||||
const isAllComplete =
|
||||
credentialFields.length > 0 &&
|
||||
const needsCredentials = credentialFields.length > 0;
|
||||
const isAllCredentialsComplete =
|
||||
needsCredentials &&
|
||||
[...requiredCredentials].every((key) => !!inputCredentials[key]);
|
||||
|
||||
const canProceed =
|
||||
!hasSent && (!needsCredentials || isAllCredentialsComplete);
|
||||
|
||||
function handleProceed() {
|
||||
setHasSent(true);
|
||||
onSend(
|
||||
"I've configured the required credentials. Please check if everything is ready and proceed with running the agent.",
|
||||
);
|
||||
const message = needsCredentials
|
||||
? "I've configured the required credentials. Please check if everything is ready and proceed with running the agent."
|
||||
: "Please proceed with running the agent.";
|
||||
onSend(message);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="grid gap-2">
|
||||
<ContentMessage>{output.message}</ContentMessage>
|
||||
|
||||
{credentialFields.length > 0 && (
|
||||
{needsCredentials && (
|
||||
<div className="rounded-2xl border bg-background p-3">
|
||||
<CredentialsGroupedView
|
||||
credentialFields={credentialFields}
|
||||
requiredCredentials={requiredCredentials}
|
||||
inputCredentials={inputCredentials}
|
||||
inputValues={{}}
|
||||
onCredentialChange={handleCredentialChange}
|
||||
/>
|
||||
{isAllComplete && !hasSent && (
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
className="mt-3 w-full"
|
||||
onClick={handleProceed}
|
||||
>
|
||||
Proceed
|
||||
</Button>
|
||||
)}
|
||||
<Text variant="small" className="w-fit border-b text-zinc-500">
|
||||
Agent credentials
|
||||
</Text>
|
||||
<div className="mt-6">
|
||||
<CredentialsGroupedView
|
||||
credentialFields={credentialFields}
|
||||
requiredCredentials={requiredCredentials}
|
||||
inputCredentials={inputCredentials}
|
||||
inputValues={{}}
|
||||
onCredentialChange={handleCredentialChange}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -100,6 +101,18 @@ export function SetupRequirementsCard({ output }: Props) {
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{(needsCredentials || expectedInputs.length > 0) && (
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
className="mt-4 w-fit"
|
||||
disabled={!canProceed}
|
||||
onClick={handleProceed}
|
||||
>
|
||||
Proceed
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -39,12 +39,19 @@ export function RunBlockTool({ part }: Props) {
|
||||
const isError =
|
||||
part.state === "output-error" ||
|
||||
(!!output && isRunBlockErrorOutput(output));
|
||||
const setupRequirementsOutput =
|
||||
part.state === "output-available" &&
|
||||
output &&
|
||||
isRunBlockSetupRequirementsOutput(output)
|
||||
? output
|
||||
: null;
|
||||
|
||||
const hasExpandableContent =
|
||||
part.state === "output-available" &&
|
||||
!!output &&
|
||||
!setupRequirementsOutput &&
|
||||
(isRunBlockBlockOutput(output) ||
|
||||
isRunBlockDetailsOutput(output) ||
|
||||
isRunBlockSetupRequirementsOutput(output) ||
|
||||
isRunBlockErrorOutput(output));
|
||||
|
||||
return (
|
||||
@@ -57,6 +64,12 @@ export function RunBlockTool({ part }: Props) {
|
||||
/>
|
||||
</div>
|
||||
|
||||
{setupRequirementsOutput && (
|
||||
<div className="mt-2">
|
||||
<SetupRequirementsCard output={setupRequirementsOutput} />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{hasExpandableContent && output && (
|
||||
<ToolAccordion {...getAccordionMeta(output)}>
|
||||
{isRunBlockBlockOutput(output) && <BlockOutputCard output={output} />}
|
||||
@@ -65,10 +78,6 @@ export function RunBlockTool({ part }: Props) {
|
||||
<BlockDetailsCard output={output} />
|
||||
)}
|
||||
|
||||
{isRunBlockSetupRequirementsOutput(output) && (
|
||||
<SetupRequirementsCard output={output} />
|
||||
)}
|
||||
|
||||
{isRunBlockErrorOutput(output) && <ErrorCard output={output} />}
|
||||
</ToolAccordion>
|
||||
)}
|
||||
|
||||
@@ -6,15 +6,9 @@ import { Text } from "@/components/atoms/Text/Text";
|
||||
import { CredentialsGroupedView } from "@/components/contextual/CredentialsInput/components/CredentialsGroupedView/CredentialsGroupedView";
|
||||
import { FormRenderer } from "@/components/renderers/InputRenderer/FormRenderer";
|
||||
import type { CredentialsMetaInput } from "@/lib/autogpt-server-api/types";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import { useState } from "react";
|
||||
import { useCopilotChatActions } from "../../../../components/CopilotChatActionsProvider/useCopilotChatActions";
|
||||
import {
|
||||
ContentBadge,
|
||||
ContentCardDescription,
|
||||
ContentCardTitle,
|
||||
ContentMessage,
|
||||
} from "../../../../components/ToolAccordion/AccordionContent";
|
||||
import { ContentMessage } from "../../../../components/ToolAccordion/AccordionContent";
|
||||
import {
|
||||
buildExpectedInputsSchema,
|
||||
coerceCredentialFields,
|
||||
@@ -31,10 +25,8 @@ export function SetupRequirementsCard({ output }: Props) {
|
||||
const [inputCredentials, setInputCredentials] = useState<
|
||||
Record<string, CredentialsMetaInput | undefined>
|
||||
>({});
|
||||
const [hasSentCredentials, setHasSentCredentials] = useState(false);
|
||||
|
||||
const [showInputForm, setShowInputForm] = useState(false);
|
||||
const [inputValues, setInputValues] = useState<Record<string, unknown>>({});
|
||||
const [hasSent, setHasSent] = useState(false);
|
||||
|
||||
const { credentialFields, requiredCredentials } = coerceCredentialFields(
|
||||
output.setup_info.user_readiness?.missing_credentials,
|
||||
@@ -50,27 +42,49 @@ export function SetupRequirementsCard({ output }: Props) {
|
||||
setInputCredentials((prev) => ({ ...prev, [key]: value }));
|
||||
}
|
||||
|
||||
const needsCredentials = credentialFields.length > 0;
|
||||
const isAllCredentialsComplete =
|
||||
credentialFields.length > 0 &&
|
||||
needsCredentials &&
|
||||
[...requiredCredentials].every((key) => !!inputCredentials[key]);
|
||||
|
||||
function handleProceedCredentials() {
|
||||
setHasSentCredentials(true);
|
||||
onSend(
|
||||
"I've configured the required credentials. Please re-run the block now.",
|
||||
);
|
||||
}
|
||||
const needsInputs = inputSchema !== null;
|
||||
const requiredInputNames = expectedInputs
|
||||
.filter((i) => i.required)
|
||||
.map((i) => i.name);
|
||||
const isAllInputsComplete =
|
||||
needsInputs &&
|
||||
requiredInputNames.every((name) => {
|
||||
const v = inputValues[name];
|
||||
return v !== undefined && v !== null && v !== "";
|
||||
});
|
||||
|
||||
function handleRunWithInputs() {
|
||||
const nonEmpty = Object.fromEntries(
|
||||
Object.entries(inputValues).filter(
|
||||
([, v]) => v !== undefined && v !== null && v !== "",
|
||||
),
|
||||
);
|
||||
onSend(
|
||||
`Run the block with these inputs: ${JSON.stringify(nonEmpty, null, 2)}`,
|
||||
);
|
||||
setShowInputForm(false);
|
||||
const canRun =
|
||||
!hasSent &&
|
||||
(!needsCredentials || isAllCredentialsComplete) &&
|
||||
(!needsInputs || isAllInputsComplete);
|
||||
|
||||
function handleRun() {
|
||||
setHasSent(true);
|
||||
|
||||
const parts: string[] = [];
|
||||
if (needsCredentials) {
|
||||
parts.push("I've configured the required credentials.");
|
||||
}
|
||||
|
||||
if (needsInputs) {
|
||||
const nonEmpty = Object.fromEntries(
|
||||
Object.entries(inputValues).filter(
|
||||
([, v]) => v !== undefined && v !== null && v !== "",
|
||||
),
|
||||
);
|
||||
parts.push(
|
||||
`Run the block with these inputs: ${JSON.stringify(nonEmpty, null, 2)}`,
|
||||
);
|
||||
} else {
|
||||
parts.push("Please re-run the block now.");
|
||||
}
|
||||
|
||||
onSend(parts.join(" "));
|
||||
setInputValues({});
|
||||
}
|
||||
|
||||
@@ -78,119 +92,54 @@ export function SetupRequirementsCard({ output }: Props) {
|
||||
<div className="grid gap-2">
|
||||
<ContentMessage>{output.message}</ContentMessage>
|
||||
|
||||
{credentialFields.length > 0 && (
|
||||
{needsCredentials && (
|
||||
<div className="rounded-2xl border bg-background p-3">
|
||||
<CredentialsGroupedView
|
||||
credentialFields={credentialFields}
|
||||
requiredCredentials={requiredCredentials}
|
||||
inputCredentials={inputCredentials}
|
||||
inputValues={{}}
|
||||
onCredentialChange={handleCredentialChange}
|
||||
/>
|
||||
{isAllCredentialsComplete && !hasSentCredentials && (
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
className="mt-3 w-full"
|
||||
onClick={handleProceedCredentials}
|
||||
>
|
||||
Proceed
|
||||
</Button>
|
||||
)}
|
||||
<Text variant="small" className="w-fit border-b text-zinc-500">
|
||||
Block credentials
|
||||
</Text>
|
||||
<div className="mt-6">
|
||||
<CredentialsGroupedView
|
||||
credentialFields={credentialFields}
|
||||
requiredCredentials={requiredCredentials}
|
||||
inputCredentials={inputCredentials}
|
||||
inputValues={{}}
|
||||
onCredentialChange={handleCredentialChange}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{inputSchema && (
|
||||
<div className="flex gap-2 pt-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
className="w-fit"
|
||||
onClick={() => setShowInputForm((prev) => !prev)}
|
||||
>
|
||||
{showInputForm ? "Hide inputs" : "Fill in inputs"}
|
||||
</Button>
|
||||
<div className="rounded-2xl border bg-background p-3 pt-4">
|
||||
<Text variant="small" className="w-fit border-b text-zinc-500">
|
||||
Block inputs
|
||||
</Text>
|
||||
<FormRenderer
|
||||
jsonSchema={inputSchema}
|
||||
className="mb-3 mt-3"
|
||||
handleChange={(v) => setInputValues(v.formData ?? {})}
|
||||
uiSchema={{
|
||||
"ui:submitButtonOptions": { norender: true },
|
||||
}}
|
||||
initialValues={inputValues}
|
||||
formContext={{
|
||||
showHandles: false,
|
||||
size: "small",
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<AnimatePresence initial={false}>
|
||||
{showInputForm && inputSchema && (
|
||||
<motion.div
|
||||
initial={{ height: 0, opacity: 0, filter: "blur(6px)" }}
|
||||
animate={{ height: "auto", opacity: 1, filter: "blur(0px)" }}
|
||||
exit={{ height: 0, opacity: 0, filter: "blur(6px)" }}
|
||||
transition={{
|
||||
height: { type: "spring", bounce: 0.15, duration: 0.5 },
|
||||
opacity: { duration: 0.25 },
|
||||
filter: { duration: 0.2 },
|
||||
}}
|
||||
className="overflow-hidden"
|
||||
style={{ willChange: "height, opacity, filter" }}
|
||||
>
|
||||
<div className="rounded-2xl border bg-background p-3 pt-4">
|
||||
<Text variant="body-medium">Block inputs</Text>
|
||||
<FormRenderer
|
||||
jsonSchema={inputSchema}
|
||||
handleChange={(v) => setInputValues(v.formData ?? {})}
|
||||
uiSchema={{
|
||||
"ui:submitButtonOptions": { norender: true },
|
||||
}}
|
||||
initialValues={inputValues}
|
||||
formContext={{
|
||||
showHandles: false,
|
||||
size: "small",
|
||||
}}
|
||||
/>
|
||||
<div className="-mt-8 flex gap-2">
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
className="w-fit"
|
||||
onClick={handleRunWithInputs}
|
||||
>
|
||||
Run
|
||||
</Button>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
className="w-fit"
|
||||
onClick={() => {
|
||||
setShowInputForm(false);
|
||||
setInputValues({});
|
||||
}}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
|
||||
{expectedInputs.length > 0 && !inputSchema && (
|
||||
<div className="rounded-2xl border bg-background p-3">
|
||||
<ContentCardTitle className="text-xs">
|
||||
Expected inputs
|
||||
</ContentCardTitle>
|
||||
<div className="mt-2 grid gap-2">
|
||||
{expectedInputs.map((input) => (
|
||||
<div key={input.name} className="rounded-xl border p-2">
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<ContentCardTitle className="text-xs">
|
||||
{input.title}
|
||||
</ContentCardTitle>
|
||||
<ContentBadge>
|
||||
{input.required ? "Required" : "Optional"}
|
||||
</ContentBadge>
|
||||
</div>
|
||||
<ContentCardDescription className="mt-1">
|
||||
{input.name} • {input.type}
|
||||
{input.description ? ` \u2022 ${input.description}` : ""}
|
||||
</ContentCardDescription>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
{(needsCredentials || needsInputs) && (
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
className="w-fit"
|
||||
disabled={!canRun}
|
||||
onClick={handleRun}
|
||||
>
|
||||
Proceed
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -40,16 +40,6 @@ export function useChatSession() {
|
||||
}
|
||||
}, [sessionId, queryClient]);
|
||||
|
||||
// Memoize so the effect in useCopilotPage doesn't infinite-loop on a new
|
||||
// array reference every render. Re-derives only when query data changes.
|
||||
const hydratedMessages = useMemo(() => {
|
||||
if (sessionQuery.data?.status !== 200 || !sessionId) return undefined;
|
||||
return convertChatSessionMessagesToUiMessages(
|
||||
sessionId,
|
||||
sessionQuery.data.data.messages ?? [],
|
||||
);
|
||||
}, [sessionQuery.data, sessionId]);
|
||||
|
||||
// Expose active_stream info so the caller can trigger manual resume
|
||||
// after hydration completes (rather than relying on AI SDK's built-in
|
||||
// resume which fires before hydration).
|
||||
@@ -58,6 +48,19 @@ export function useChatSession() {
|
||||
return !!sessionQuery.data.data.active_stream;
|
||||
}, [sessionQuery.data]);
|
||||
|
||||
// Memoize so the effect in useCopilotPage doesn't infinite-loop on a new
|
||||
// array reference every render. Re-derives only when query data changes.
|
||||
// When the session is complete (no active stream), mark dangling tool
|
||||
// calls as completed so stale spinners don't persist after refresh.
|
||||
const hydratedMessages = useMemo(() => {
|
||||
if (sessionQuery.data?.status !== 200 || !sessionId) return undefined;
|
||||
return convertChatSessionMessagesToUiMessages(
|
||||
sessionId,
|
||||
sessionQuery.data.data.messages ?? [],
|
||||
{ isComplete: !hasActiveStream },
|
||||
);
|
||||
}, [sessionQuery.data, sessionId, hasActiveStream]);
|
||||
|
||||
const { mutateAsync: createSessionMutation, isPending: isCreatingSession } =
|
||||
usePostV2CreateSession({
|
||||
mutation: {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import {
|
||||
getGetV2GetSessionQueryKey,
|
||||
getGetV2ListSessionsQueryKey,
|
||||
postV2CancelSessionTask,
|
||||
useDeleteV2DeleteSession,
|
||||
@@ -187,11 +188,35 @@ export function useCopilotPage() {
|
||||
});
|
||||
}, [hydratedMessages, setMessages, status]);
|
||||
|
||||
// Ref: tracks whether we've already resumed for a given session.
|
||||
// Reset when the stream ends so re-resume is possible if the backend
|
||||
// task is still running (SSE dropped but executor didn't finish).
|
||||
const hasResumedRef = useRef<string | null>(null);
|
||||
|
||||
// When the stream ends (or drops), invalidate the session cache so the
|
||||
// next hydration fetches fresh messages from the backend. Without this,
|
||||
// staleTime: Infinity means the cache keeps the pre-stream data forever,
|
||||
// and any messages added during streaming are lost on remount/navigation.
|
||||
const prevStatusRef = useRef(status);
|
||||
useEffect(() => {
|
||||
const prev = prevStatusRef.current;
|
||||
prevStatusRef.current = status;
|
||||
|
||||
const wasActive = prev === "streaming" || prev === "submitted";
|
||||
const isIdle = status === "ready" || status === "error";
|
||||
if (wasActive && isIdle && sessionId) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||
});
|
||||
// Allow re-resume if the backend task is still running.
|
||||
hasResumedRef.current = null;
|
||||
}
|
||||
}, [status, sessionId, queryClient]);
|
||||
|
||||
// Resume an active stream AFTER hydration completes.
|
||||
// The backend returns active_stream info when a task is still running.
|
||||
// We wait for hydration so the AI SDK has the conversation history
|
||||
// before the resumed stream appends the in-progress assistant message.
|
||||
const hasResumedRef = useRef<string | null>(null);
|
||||
useEffect(() => {
|
||||
if (!hasActiveStream || !sessionId) return;
|
||||
if (!hydratedMessages || hydratedMessages.length === 0) return;
|
||||
@@ -202,18 +227,6 @@ export function useCopilotPage() {
|
||||
resumeStream();
|
||||
}, [hasActiveStream, sessionId, hydratedMessages, status, resumeStream]);
|
||||
|
||||
// When the stream finishes, resolve any tool parts still showing spinners.
|
||||
// This can happen if the backend didn't emit StreamToolOutputAvailable for
|
||||
// a tool call before sending StreamFinish (e.g. SDK built-in tools).
|
||||
const prevStatusRef = useRef(status);
|
||||
useEffect(() => {
|
||||
const prev = prevStatusRef.current;
|
||||
prevStatusRef.current = status;
|
||||
if (prev === "streaming" && status === "ready") {
|
||||
setMessages((msgs) => resolveInProgressTools(msgs, "completed"));
|
||||
}
|
||||
}, [status, setMessages]);
|
||||
|
||||
// Poll session endpoint when a long-running tool (create_agent, edit_agent)
|
||||
// is in progress. When the backend completes, the session data will contain
|
||||
// the final tool output — this hook detects the change and updates messages.
|
||||
|
||||
@@ -119,7 +119,7 @@ export function CredentialsFlatView({
|
||||
) : (
|
||||
!readOnly && (
|
||||
<Button
|
||||
variant="secondary"
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={onAddCredential}
|
||||
className="w-fit"
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
import { RJSFSchema } from "@rjsf/utils";
|
||||
import { preprocessInputSchema } from "./utils/input-schema-pre-processor";
|
||||
import { useMemo } from "react";
|
||||
import { customValidator } from "./utils/custom-validator";
|
||||
import Form from "./registry";
|
||||
import { ExtendedFormContextType } from "./types";
|
||||
import { customValidator } from "./utils/custom-validator";
|
||||
import { generateUiSchemaForCustomFields } from "./utils/generate-ui-schema";
|
||||
import { preprocessInputSchema } from "./utils/input-schema-pre-processor";
|
||||
|
||||
type FormRendererProps = {
|
||||
jsonSchema: RJSFSchema;
|
||||
@@ -12,15 +13,17 @@ type FormRendererProps = {
|
||||
uiSchema: any;
|
||||
initialValues: any;
|
||||
formContext: ExtendedFormContextType;
|
||||
className?: string;
|
||||
};
|
||||
|
||||
export const FormRenderer = ({
|
||||
export function FormRenderer({
|
||||
jsonSchema,
|
||||
handleChange,
|
||||
uiSchema,
|
||||
initialValues,
|
||||
formContext,
|
||||
}: FormRendererProps) => {
|
||||
className,
|
||||
}: FormRendererProps) {
|
||||
const preprocessedSchema = useMemo(() => {
|
||||
return preprocessInputSchema(jsonSchema);
|
||||
}, [jsonSchema]);
|
||||
@@ -31,7 +34,10 @@ export const FormRenderer = ({
|
||||
}, [preprocessedSchema, uiSchema]);
|
||||
|
||||
return (
|
||||
<div className={"mb-6 mt-4"} data-tutorial-id="input-handles">
|
||||
<div
|
||||
className={cn("mb-6 mt-4", className)}
|
||||
data-tutorial-id="input-handles"
|
||||
>
|
||||
<Form
|
||||
formContext={formContext}
|
||||
idPrefix="agpt"
|
||||
@@ -45,4 +51,4 @@ export const FormRenderer = ({
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
@@ -218,6 +218,17 @@ If you initially installed Docker with Hyper-V, you **don’t need to reinstall*
|
||||
|
||||
For more details, refer to [Docker's official documentation](https://docs.docker.com/desktop/windows/wsl/).
|
||||
|
||||
### ⚠️ Podman Not Supported
|
||||
|
||||
AutoGPT requires **Docker** (Docker Desktop or Docker Engine). **Podman and podman-compose are not supported** and may cause path resolution issues, particularly on Windows.
|
||||
|
||||
If you see errors like:
|
||||
```text
|
||||
Error: the specified Containerfile or Dockerfile does not exist, ..\..\autogpt_platform\backend\Dockerfile
|
||||
```
|
||||
|
||||
This indicates you're using Podman instead of Docker. Please install [Docker Desktop](https://docs.docker.com/desktop/) and use `docker compose` instead of `podman-compose`.
|
||||
|
||||
|
||||
## Development
|
||||
|
||||
|
||||
Reference in New Issue
Block a user