fix(copilot): remove task_id concept entirely, use session_id for all streaming

- Remove separate task_id creation for long-running tools
- Update stream_registry to use session_id as primary identifier
- Update all stream_registry calls across codebase to use session_id
- Keep ActiveTask.task_id field for backwards compatibility (equals session_id)
- Fix mini game showing forever by ensuring results reach correct stream
- Remove _task_id parameter from stream_chat_completion
- Update processor, service, sdk/service, routes, completion handlers

Root cause: Long-running tools were creating separate task_ids and publishing
to different streams than the frontend was subscribed to. Now everything uses
session_id, ensuring results reach the frontend properly.

Files modified:
- backend/copilot/stream_registry.py
- backend/api/features/chat/routes.py
- backend/copilot/executor/processor.py
- backend/copilot/service.py
- backend/copilot/sdk/service.py
- backend/copilot/completion_handler.py
- backend/copilot/completion_consumer.py
- frontend EditAgent/CreateAgent components (remove duplicate loader)
- frontend ChatMessagesContainer (remove unused imports)

Tests: 139/140 passed (1 SDK initialization failure unrelated to changes)
This commit is contained in:
Zamil Majdy
2026-02-22 17:35:39 +07:00
parent 5cca95b78c
commit dd10e1b339
100 changed files with 1690 additions and 1260 deletions

View File

@@ -304,7 +304,7 @@ async def get_session(
# Since we filtered out the cached assistant message, the client needs
# the full stream to reconstruct the response.
active_stream_info = ActiveStreamInfo(
task_id=active_task.task_id,
task_id=active_task.session_id,
last_message_id="0-0",
operation_id=active_task.operation_id,
tool_name=active_task.tool_name,
@@ -342,7 +342,7 @@ async def cancel_session_task(
if not active_task:
return CancelTaskResponse(cancelled=False, reason="no_active_task")
task_id = active_task.task_id
task_id = active_task.session_id
await enqueue_cancel_task(task_id)
logger.info(
f"[CANCEL] Published cancel for task ...{task_id[-8:]} "
@@ -357,7 +357,7 @@ async def cancel_session_task(
while waited < max_wait:
await asyncio.sleep(poll_interval)
waited += poll_interval
task = await stream_registry.get_task(task_id)
task = await stream_registry.get_task(session_id)
if task is None or task.status != "running":
logger.info(
f"[CANCEL] Task ...{task_id[-8:]} confirmed stopped "
@@ -444,13 +444,13 @@ async def stream_chat_post(
logger.info(f"[STREAM] User message saved for session {session_id}")
# Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4())
# Note: task_id = session_id (no longer generating random UUIDs)
task_id = session_id # For backwards compatibility in responses
operation_id = str(uuid_module.uuid4())
log_meta["task_id"] = task_id
task_create_start = time.perf_counter()
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id="chat_stream", # Not a tool call, but needed for the model
@@ -499,7 +499,7 @@ async def stream_chat_post(
try:
# Subscribe to the task stream (this replays existing messages + live updates)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
last_message_id="0-0", # Get all messages from the beginning
)
@@ -585,7 +585,7 @@ async def stream_chat_post(
if subscriber_queue is not None:
try:
await stream_registry.unsubscribe_from_task(
task_id, subscriber_queue
session_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
@@ -651,7 +651,7 @@ async def resume_session_stream(
return Response(status_code=204)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=active_task.task_id,
session_id=session_id,
user_id=user_id,
last_message_id="0-0", # Full replay so useChat rebuilds the message
)
@@ -690,11 +690,11 @@ async def resume_session_stream(
finally:
try:
await stream_registry.unsubscribe_from_task(
active_task.task_id, subscriber_queue
session_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
f"Error unsubscribing from task {active_task.session_id}: {unsub_err}",
exc_info=True,
)
logger.info(
@@ -778,7 +778,9 @@ async def stream_task(
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
"""
# Check task existence and expiry before subscribing
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
task, error_code = await stream_registry.get_task_with_expiry_info(
session_id=task_id
)
if error_code == "TASK_EXPIRED":
raise HTTPException(
@@ -810,7 +812,7 @@ async def stream_task(
# Get subscriber queue from stream registry
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
session_id=task_id,
user_id=user_id,
last_message_id=last_message_id,
)
@@ -887,7 +889,7 @@ async def get_task_status(
Raises:
NotFoundError: If task_id is not found or user doesn't have access.
"""
task = await stream_registry.get_task(task_id)
task = await stream_registry.get_task(session_id=task_id)
if task is None:
raise NotFoundError(f"Task {task_id} not found.")
@@ -897,7 +899,7 @@ async def get_task_status(
raise NotFoundError(f"Task {task_id} not found.")
return {
"task_id": task.task_id,
"task_id": task.session_id,
"session_id": task.session_id,
"status": task.status,
"tool_name": task.tool_name,
@@ -957,7 +959,7 @@ async def complete_operation(
logger.info(
f"Received completion webhook for operation {operation_id} "
f"(task_id={task.task_id}, success={request.success})"
f"(task_id={task.session_id}, success={request.success})"
)
if request.success:
@@ -965,7 +967,7 @@ async def complete_operation(
else:
await process_operation_failure(task, request.error)
return {"status": "ok", "task_id": task.task_id}
return {"status": "ok", "task_id": task.session_id}
# ========== Configuration ==========

View File

@@ -57,7 +57,7 @@ async def postmark_webhook_handler(
webhook: Annotated[
PostmarkWebhook,
Body(discriminator="RecordType"),
]
],
):
logger.info(f"Received webhook from Postmark: {webhook}")
match webhook:

View File

@@ -164,7 +164,7 @@ class BlockHandler(ContentHandler):
block_ids = list(all_blocks.keys())
# Query for existing embeddings
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
placeholders = ",".join([f"${i + 1}" for i in range(len(block_ids))])
existing_result = await query_raw_with_schema(
f"""
SELECT "contentId"
@@ -265,7 +265,7 @@ class BlockHandler(ContentHandler):
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
block_ids = enabled_block_ids
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
placeholders = ",".join([f"${i + 1}" for i in range(len(block_ids))])
embedded_result = await query_raw_with_schema(
f"""
@@ -508,7 +508,7 @@ class DocumentationHandler(ContentHandler):
]
# Check which ones have embeddings
placeholders = ",".join([f"${i+1}" for i in range(len(section_content_ids))])
placeholders = ",".join([f"${i + 1}" for i in range(len(section_content_ids))])
existing_result = await query_raw_with_schema(
f"""
SELECT "contentId"

View File

@@ -47,7 +47,7 @@ def mock_storage_client(mocker):
async def test_upload_media_success(mock_settings, mock_storage_client):
# Create test JPEG data with valid signature
test_data = b"\xFF\xD8\xFF" + b"test data"
test_data = b"\xff\xd8\xff" + b"test data"
test_file = fastapi.UploadFile(
filename="laptop.jpeg",
@@ -85,7 +85,7 @@ async def test_upload_media_missing_credentials(monkeypatch):
test_file = fastapi.UploadFile(
filename="laptop.jpeg",
file=io.BytesIO(b"\xFF\xD8\xFF" + b"test data"), # Valid JPEG signature
file=io.BytesIO(b"\xff\xd8\xff" + b"test data"), # Valid JPEG signature
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
)
@@ -110,7 +110,7 @@ async def test_upload_media_video_type(mock_settings, mock_storage_client):
async def test_upload_media_file_too_large(mock_settings, mock_storage_client):
large_data = b"\xFF\xD8\xFF" + b"x" * (
large_data = b"\xff\xd8\xff" + b"x" * (
50 * 1024 * 1024 + 1
) # 50MB + 1 byte with valid JPEG signature
test_file = fastapi.UploadFile(

View File

@@ -499,10 +499,12 @@ async def test_upload_file_success(test_user_id: str):
)
# Mock dependencies
with patch("backend.api.features.v1.scan_content_safe") as mock_scan, patch(
"backend.api.features.v1.get_cloud_storage_handler"
) as mock_handler_getter:
with (
patch("backend.api.features.v1.scan_content_safe") as mock_scan,
patch(
"backend.api.features.v1.get_cloud_storage_handler"
) as mock_handler_getter,
):
mock_scan.return_value = None
mock_handler = AsyncMock()
mock_handler.store_file.return_value = "gcs://test-bucket/uploads/123/test.txt"
@@ -551,10 +553,12 @@ async def test_upload_file_no_filename(test_user_id: str):
),
)
with patch("backend.api.features.v1.scan_content_safe") as mock_scan, patch(
"backend.api.features.v1.get_cloud_storage_handler"
) as mock_handler_getter:
with (
patch("backend.api.features.v1.scan_content_safe") as mock_scan,
patch(
"backend.api.features.v1.get_cloud_storage_handler"
) as mock_handler_getter,
):
mock_scan.return_value = None
mock_handler = AsyncMock()
mock_handler.store_file.return_value = (
@@ -632,10 +636,12 @@ async def test_upload_file_cloud_storage_failure(test_user_id: str):
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
)
with patch("backend.api.features.v1.scan_content_safe") as mock_scan, patch(
"backend.api.features.v1.get_cloud_storage_handler"
) as mock_handler_getter:
with (
patch("backend.api.features.v1.scan_content_safe") as mock_scan,
patch(
"backend.api.features.v1.get_cloud_storage_handler"
) as mock_handler_getter,
):
mock_scan.return_value = None
mock_handler = AsyncMock()
mock_handler.store_file.side_effect = RuntimeError("Storage error!")
@@ -679,10 +685,12 @@ async def test_upload_file_gcs_not_configured_fallback(test_user_id: str):
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
)
with patch("backend.api.features.v1.scan_content_safe") as mock_scan, patch(
"backend.api.features.v1.get_cloud_storage_handler"
) as mock_handler_getter:
with (
patch("backend.api.features.v1.scan_content_safe") as mock_scan,
patch(
"backend.api.features.v1.get_cloud_storage_handler"
) as mock_handler_getter,
):
mock_scan.return_value = None
mock_handler = AsyncMock()
mock_handler.config.gcs_bucket_name = "" # Simulate no GCS bucket configured

View File

@@ -457,7 +457,8 @@ async def test_api_key_with_unicode_characters_normalization_attack(mock_request
"""Test that Unicode normalization doesn't bypass validation."""
# Create auth with composed Unicode character
auth = APIKeyAuthenticator(
header_name="X-API-Key", expected_token="café" # é is composed
header_name="X-API-Key",
expected_token="café", # é is composed
)
# Try with decomposed version (c + a + f + e + ´)
@@ -522,8 +523,8 @@ async def test_api_keys_with_newline_variations(mock_request):
"valid\r\ntoken", # Windows newline
"valid\rtoken", # Mac newline
"valid\x85token", # NEL (Next Line)
"valid\x0Btoken", # Vertical Tab
"valid\x0Ctoken", # Form Feed
"valid\x0btoken", # Vertical Tab
"valid\x0ctoken", # Form Feed
]
for api_key in newline_variations:

View File

@@ -44,9 +44,12 @@ def test_websocket_server_uses_cors_helper(mocker) -> None:
"backend.api.ws_api.build_cors_params", return_value=cors_params
)
with override_config(
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
), override_config(settings, "app_env", AppEnvironment.LOCAL):
with (
override_config(
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
),
override_config(settings, "app_env", AppEnvironment.LOCAL),
):
WebsocketServer().run()
build_cors.assert_called_once_with(
@@ -65,9 +68,12 @@ def test_websocket_server_uses_cors_helper(mocker) -> None:
def test_websocket_server_blocks_localhost_in_production(mocker) -> None:
mocker.patch("backend.api.ws_api.uvicorn.run")
with override_config(
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
), override_config(settings, "app_env", AppEnvironment.PRODUCTION):
with (
override_config(
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
),
override_config(settings, "app_env", AppEnvironment.PRODUCTION),
):
with pytest.raises(ValueError):
WebsocketServer().run()

View File

@@ -179,7 +179,9 @@ class AIImageGeneratorBlock(Block):
],
test_mock={
# Return a data URI directly so store_media_file doesn't need to download
"_run_client": lambda *args, **kwargs: "data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
"_run_client": lambda *args, **kwargs: (
"data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
)
},
)

View File

@@ -142,7 +142,9 @@ class AIMusicGeneratorBlock(Block):
),
],
test_mock={
"run_model": lambda api_key, music_gen_model_version, prompt, duration, temperature, top_k, top_p, classifier_free_guidance, output_format, normalization_strategy: "https://replicate.com/output/generated-audio-url.wav",
"run_model": lambda api_key, music_gen_model_version, prompt, duration, temperature, top_k, top_p, classifier_free_guidance, output_format, normalization_strategy: (
"https://replicate.com/output/generated-audio-url.wav"
),
},
test_credentials=TEST_CREDENTIALS,
)

View File

@@ -69,12 +69,18 @@ class PostToBlueskyBlock(Block):
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
yield (
"error",
"Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY.",
)
return
# Validate character limit for Bluesky
if len(input_data.post) > 300:
yield "error", f"Post text exceeds Bluesky's 300 character limit ({len(input_data.post)} characters)"
yield (
"error",
f"Post text exceeds Bluesky's 300 character limit ({len(input_data.post)} characters)",
)
return
# Validate media constraints for Bluesky

View File

@@ -131,7 +131,10 @@ class PostToFacebookBlock(Block):
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
yield (
"error",
"Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY.",
)
return
# Convert datetime to ISO format if provided

View File

@@ -120,12 +120,18 @@ class PostToGMBBlock(Block):
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
yield (
"error",
"Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY.",
)
return
# Validate GMB constraints
if len(input_data.media_urls) > 1:
yield "error", "Google My Business supports only one image or video per post"
yield (
"error",
"Google My Business supports only one image or video per post",
)
return
# Validate offer coupon code length

View File

@@ -123,16 +123,25 @@ class PostToInstagramBlock(Block):
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
yield (
"error",
"Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY.",
)
return
# Validate Instagram constraints
if len(input_data.post) > 2200:
yield "error", f"Instagram post text exceeds 2,200 character limit ({len(input_data.post)} characters)"
yield (
"error",
f"Instagram post text exceeds 2,200 character limit ({len(input_data.post)} characters)",
)
return
if len(input_data.media_urls) > 10:
yield "error", "Instagram supports a maximum of 10 images/videos in a carousel"
yield (
"error",
"Instagram supports a maximum of 10 images/videos in a carousel",
)
return
if len(input_data.collaborators) > 3:
@@ -147,7 +156,10 @@ class PostToInstagramBlock(Block):
]
if any(reel_options) and not all(reel_options):
yield "error", "When posting a reel, all reel options must be set: share_reels_feed, audio_name, and either thumbnail or thumbnail_offset"
yield (
"error",
"When posting a reel, all reel options must be set: share_reels_feed, audio_name, and either thumbnail or thumbnail_offset",
)
return
# Count hashtags and mentions
@@ -155,11 +167,17 @@ class PostToInstagramBlock(Block):
mention_count = input_data.post.count("@")
if hashtag_count > 30:
yield "error", f"Instagram allows maximum 30 hashtags ({hashtag_count} found)"
yield (
"error",
f"Instagram allows maximum 30 hashtags ({hashtag_count} found)",
)
return
if mention_count > 3:
yield "error", f"Instagram allows maximum 3 @mentions ({mention_count} found)"
yield (
"error",
f"Instagram allows maximum 3 @mentions ({mention_count} found)",
)
return
# Convert datetime to ISO format if provided
@@ -191,7 +209,10 @@ class PostToInstagramBlock(Block):
# Validate alt text length
for i, alt in enumerate(input_data.alt_text):
if len(alt) > 1000:
yield "error", f"Alt text {i+1} exceeds 1,000 character limit ({len(alt)} characters)"
yield (
"error",
f"Alt text {i + 1} exceeds 1,000 character limit ({len(alt)} characters)",
)
return
instagram_options["altText"] = input_data.alt_text
@@ -206,13 +227,19 @@ class PostToInstagramBlock(Block):
try:
tag_obj = InstagramUserTag(**tag)
except Exception as e:
yield "error", f"Invalid user tag: {e}, tages need to be a dictionary with a 3 items: username (str), x (float) and y (float)"
yield (
"error",
f"Invalid user tag: {e}, tages need to be a dictionary with a 3 items: username (str), x (float) and y (float)",
)
return
tag_dict: dict[str, float | str] = {"username": tag_obj.username}
if tag_obj.x is not None and tag_obj.y is not None:
# Validate coordinates
if not (0.0 <= tag_obj.x <= 1.0) or not (0.0 <= tag_obj.y <= 1.0):
yield "error", f"User tag coordinates must be between 0.0 and 1.0 (user: {tag_obj.username})"
yield (
"error",
f"User tag coordinates must be between 0.0 and 1.0 (user: {tag_obj.username})",
)
return
tag_dict["x"] = tag_obj.x
tag_dict["y"] = tag_obj.y

View File

@@ -123,12 +123,18 @@ class PostToLinkedInBlock(Block):
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
yield (
"error",
"Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY.",
)
return
# Validate LinkedIn constraints
if len(input_data.post) > 3000:
yield "error", f"LinkedIn post text exceeds 3,000 character limit ({len(input_data.post)} characters)"
yield (
"error",
f"LinkedIn post text exceeds 3,000 character limit ({len(input_data.post)} characters)",
)
return
if len(input_data.media_urls) > 9:
@@ -136,13 +142,19 @@ class PostToLinkedInBlock(Block):
return
if input_data.document_title and len(input_data.document_title) > 400:
yield "error", f"LinkedIn document title exceeds 400 character limit ({len(input_data.document_title)} characters)"
yield (
"error",
f"LinkedIn document title exceeds 400 character limit ({len(input_data.document_title)} characters)",
)
return
# Validate visibility option
valid_visibility = ["public", "connections", "loggedin"]
if input_data.visibility not in valid_visibility:
yield "error", f"LinkedIn visibility must be one of: {', '.join(valid_visibility)}"
yield (
"error",
f"LinkedIn visibility must be one of: {', '.join(valid_visibility)}",
)
return
# Check for document extensions

View File

@@ -103,20 +103,32 @@ class PostToPinterestBlock(Block):
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
yield (
"error",
"Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY.",
)
return
# Validate Pinterest constraints
if len(input_data.post) > 500:
yield "error", f"Pinterest pin description exceeds 500 character limit ({len(input_data.post)} characters)"
yield (
"error",
f"Pinterest pin description exceeds 500 character limit ({len(input_data.post)} characters)",
)
return
if len(input_data.pin_title) > 100:
yield "error", f"Pinterest pin title exceeds 100 character limit ({len(input_data.pin_title)} characters)"
yield (
"error",
f"Pinterest pin title exceeds 100 character limit ({len(input_data.pin_title)} characters)",
)
return
if len(input_data.link) > 2048:
yield "error", f"Pinterest link URL exceeds 2048 character limit ({len(input_data.link)} characters)"
yield (
"error",
f"Pinterest link URL exceeds 2048 character limit ({len(input_data.link)} characters)",
)
return
if len(input_data.media_urls) == 0:
@@ -141,7 +153,10 @@ class PostToPinterestBlock(Block):
# Validate alt text length
for i, alt in enumerate(input_data.alt_text):
if len(alt) > 500:
yield "error", f"Pinterest alt text {i+1} exceeds 500 character limit ({len(alt)} characters)"
yield (
"error",
f"Pinterest alt text {i + 1} exceeds 500 character limit ({len(alt)} characters)",
)
return
# Convert datetime to ISO format if provided

View File

@@ -73,7 +73,10 @@ class PostToSnapchatBlock(Block):
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
yield (
"error",
"Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY.",
)
return
# Validate Snapchat constraints
@@ -88,7 +91,10 @@ class PostToSnapchatBlock(Block):
# Validate story type
valid_story_types = ["story", "saved_story", "spotlight"]
if input_data.story_type not in valid_story_types:
yield "error", f"Snapchat story type must be one of: {', '.join(valid_story_types)}"
yield (
"error",
f"Snapchat story type must be one of: {', '.join(valid_story_types)}",
)
return
# Convert datetime to ISO format if provided

View File

@@ -68,7 +68,10 @@ class PostToTelegramBlock(Block):
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
yield (
"error",
"Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY.",
)
return
# Validate Telegram constraints

View File

@@ -61,22 +61,34 @@ class PostToThreadsBlock(Block):
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
yield (
"error",
"Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY.",
)
return
# Validate Threads constraints
if len(input_data.post) > 500:
yield "error", f"Threads post text exceeds 500 character limit ({len(input_data.post)} characters)"
yield (
"error",
f"Threads post text exceeds 500 character limit ({len(input_data.post)} characters)",
)
return
if len(input_data.media_urls) > 20:
yield "error", "Threads supports a maximum of 20 images/videos in a carousel"
yield (
"error",
"Threads supports a maximum of 20 images/videos in a carousel",
)
return
# Count hashtags (only 1 allowed)
hashtag_count = input_data.post.count("#")
if hashtag_count > 1:
yield "error", f"Threads allows only 1 hashtag per post ({hashtag_count} found)"
yield (
"error",
f"Threads allows only 1 hashtag per post ({hashtag_count} found)",
)
return
# Convert datetime to ISO format if provided

View File

@@ -123,16 +123,25 @@ class PostToTikTokBlock(Block):
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
yield (
"error",
"Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY.",
)
return
# Validate TikTok constraints
if len(input_data.post) > 2200:
yield "error", f"TikTok post text exceeds 2,200 character limit ({len(input_data.post)} characters)"
yield (
"error",
f"TikTok post text exceeds 2,200 character limit ({len(input_data.post)} characters)",
)
return
if not input_data.media_urls:
yield "error", "TikTok requires at least one media URL (either 1 video or up to 35 images)"
yield (
"error",
"TikTok requires at least one media URL (either 1 video or up to 35 images)",
)
return
# Check for video vs image constraints
@@ -150,7 +159,10 @@ class PostToTikTokBlock(Block):
)
if has_video and has_images:
yield "error", "TikTok does not support mixing video and images in the same post"
yield (
"error",
"TikTok does not support mixing video and images in the same post",
)
return
if has_video and len(input_data.media_urls) > 1:
@@ -163,13 +175,19 @@ class PostToTikTokBlock(Block):
# Validate image cover index
if has_images and input_data.image_cover_index >= len(input_data.media_urls):
yield "error", f"Image cover index {input_data.image_cover_index} is out of range (max: {len(input_data.media_urls) - 1})"
yield (
"error",
f"Image cover index {input_data.image_cover_index} is out of range (max: {len(input_data.media_urls) - 1})",
)
return
# Check for PNG files (not supported)
has_png = any(url.lower().endswith(".png") for url in input_data.media_urls)
if has_png:
yield "error", "TikTok does not support PNG files. Please use JPG, JPEG, or WEBP for images."
yield (
"error",
"TikTok does not support PNG files. Please use JPG, JPEG, or WEBP for images.",
)
return
# Convert datetime to ISO format if provided

View File

@@ -126,16 +126,25 @@ class PostToXBlock(Block):
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
yield (
"error",
"Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY.",
)
return
# Validate X constraints
if not input_data.long_post and len(input_data.post) > 280:
yield "error", f"X post text exceeds 280 character limit ({len(input_data.post)} characters). Enable 'long_post' for Premium accounts."
yield (
"error",
f"X post text exceeds 280 character limit ({len(input_data.post)} characters). Enable 'long_post' for Premium accounts.",
)
return
if input_data.long_post and len(input_data.post) > 25000:
yield "error", f"X long post text exceeds 25,000 character limit ({len(input_data.post)} characters)"
yield (
"error",
f"X long post text exceeds 25,000 character limit ({len(input_data.post)} characters)",
)
return
if len(input_data.media_urls) > 4:
@@ -149,14 +158,20 @@ class PostToXBlock(Block):
return
if input_data.poll_duration < 1 or input_data.poll_duration > 10080:
yield "error", "X poll duration must be between 1 and 10,080 minutes (7 days)"
yield (
"error",
"X poll duration must be between 1 and 10,080 minutes (7 days)",
)
return
# Validate alt text
if input_data.alt_text:
for i, alt in enumerate(input_data.alt_text):
if len(alt) > 1000:
yield "error", f"X alt text {i+1} exceeds 1,000 character limit ({len(alt)} characters)"
yield (
"error",
f"X alt text {i + 1} exceeds 1,000 character limit ({len(alt)} characters)",
)
return
# Validate subtitle settings
@@ -168,7 +183,10 @@ class PostToXBlock(Block):
return
if len(input_data.subtitle_name) > 150:
yield "error", f"Subtitle name exceeds 150 character limit ({len(input_data.subtitle_name)} characters)"
yield (
"error",
f"Subtitle name exceeds 150 character limit ({len(input_data.subtitle_name)} characters)",
)
return
# Convert datetime to ISO format if provided

View File

@@ -149,7 +149,10 @@ class PostToYouTubeBlock(Block):
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
yield (
"error",
"Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY.",
)
return
# Validate YouTube constraints
@@ -158,11 +161,17 @@ class PostToYouTubeBlock(Block):
return
if len(input_data.title) > 100:
yield "error", f"YouTube title exceeds 100 character limit ({len(input_data.title)} characters)"
yield (
"error",
f"YouTube title exceeds 100 character limit ({len(input_data.title)} characters)",
)
return
if len(input_data.post) > 5000:
yield "error", f"YouTube description exceeds 5,000 character limit ({len(input_data.post)} characters)"
yield (
"error",
f"YouTube description exceeds 5,000 character limit ({len(input_data.post)} characters)",
)
return
# Check for forbidden characters
@@ -186,7 +195,10 @@ class PostToYouTubeBlock(Block):
# Validate visibility option
valid_visibility = ["private", "public", "unlisted"]
if input_data.visibility not in valid_visibility:
yield "error", f"YouTube visibility must be one of: {', '.join(valid_visibility)}"
yield (
"error",
f"YouTube visibility must be one of: {', '.join(valid_visibility)}",
)
return
# Validate thumbnail URL format
@@ -202,12 +214,18 @@ class PostToYouTubeBlock(Block):
if input_data.tags:
total_tag_length = sum(len(tag) for tag in input_data.tags)
if total_tag_length > 500:
yield "error", f"YouTube tags total length exceeds 500 characters ({total_tag_length} characters)"
yield (
"error",
f"YouTube tags total length exceeds 500 characters ({total_tag_length} characters)",
)
return
for tag in input_data.tags:
if len(tag) < 2:
yield "error", f"YouTube tag '{tag}' is too short (minimum 2 characters)"
yield (
"error",
f"YouTube tag '{tag}' is too short (minimum 2 characters)",
)
return
# Validate subtitle URL
@@ -225,12 +243,18 @@ class PostToYouTubeBlock(Block):
return
if input_data.subtitle_name and len(input_data.subtitle_name) > 150:
yield "error", f"YouTube subtitle name exceeds 150 character limit ({len(input_data.subtitle_name)} characters)"
yield (
"error",
f"YouTube subtitle name exceeds 150 character limit ({len(input_data.subtitle_name)} characters)",
)
return
# Validate publish_at format if provided
if input_data.publish_at and input_data.schedule_date:
yield "error", "Cannot use both 'publish_at' and 'schedule_date'. Use 'publish_at' for YouTube-controlled publishing."
yield (
"error",
"Cannot use both 'publish_at' and 'schedule_date'. Use 'publish_at' for YouTube-controlled publishing.",
)
return
# Convert datetime to ISO format if provided (only if not using publish_at)

View File

@@ -59,10 +59,13 @@ class FileStoreBlock(Block):
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
yield "file_out", await store_media_file(
file=input_data.file_in,
execution_context=execution_context,
return_format=return_format,
yield (
"file_out",
await store_media_file(
file=input_data.file_in,
execution_context=execution_context,
return_format=return_format,
),
)

View File

@@ -110,8 +110,10 @@ class DataForSeoKeywordSuggestionsBlock(Block):
test_output=[
(
"suggestion",
lambda x: hasattr(x, "keyword")
and x.keyword == "digital marketing strategy",
lambda x: (
hasattr(x, "keyword")
and x.keyword == "digital marketing strategy"
),
),
("suggestions", lambda x: isinstance(x, list) and len(x) == 1),
("total_count", 1),

View File

@@ -137,47 +137,71 @@ class SendEmailBlock(Block):
)
yield "status", status
except socket.gaierror:
yield "error", (
f"Cannot connect to SMTP server '{input_data.config.smtp_server}'. "
"Please verify the server address is correct."
yield (
"error",
(
f"Cannot connect to SMTP server '{input_data.config.smtp_server}'. "
"Please verify the server address is correct."
),
)
except socket.timeout:
yield "error", (
f"Connection timeout to '{input_data.config.smtp_server}' "
f"on port {input_data.config.smtp_port}. "
"The server may be down or unreachable."
yield (
"error",
(
f"Connection timeout to '{input_data.config.smtp_server}' "
f"on port {input_data.config.smtp_port}. "
"The server may be down or unreachable."
),
)
except ConnectionRefusedError:
yield "error", (
f"Connection refused to '{input_data.config.smtp_server}' "
f"on port {input_data.config.smtp_port}. "
"Common SMTP ports are: 587 (TLS), 465 (SSL), 25 (plain). "
"Please verify the port is correct."
yield (
"error",
(
f"Connection refused to '{input_data.config.smtp_server}' "
f"on port {input_data.config.smtp_port}. "
"Common SMTP ports are: 587 (TLS), 465 (SSL), 25 (plain). "
"Please verify the port is correct."
),
)
except smtplib.SMTPNotSupportedError:
yield "error", (
f"STARTTLS not supported by server '{input_data.config.smtp_server}'. "
"Try using port 465 for SSL or port 25 for unencrypted connection."
yield (
"error",
(
f"STARTTLS not supported by server '{input_data.config.smtp_server}'. "
"Try using port 465 for SSL or port 25 for unencrypted connection."
),
)
except ssl.SSLError as e:
yield "error", (
f"SSL/TLS error when connecting to '{input_data.config.smtp_server}': {str(e)}. "
"The server may require a different security protocol."
yield (
"error",
(
f"SSL/TLS error when connecting to '{input_data.config.smtp_server}': {str(e)}. "
"The server may require a different security protocol."
),
)
except smtplib.SMTPAuthenticationError:
yield "error", (
"Authentication failed. Please verify your username and password are correct."
yield (
"error",
(
"Authentication failed. Please verify your username and password are correct."
),
)
except smtplib.SMTPRecipientsRefused:
yield "error", (
f"Recipient email address '{input_data.to_email}' was rejected by the server. "
"Please verify the email address is valid."
yield (
"error",
(
f"Recipient email address '{input_data.to_email}' was rejected by the server. "
"Please verify the email address is valid."
),
)
except smtplib.SMTPSenderRefused:
yield "error", (
"Sender email address defined in the credentials that where used"
"was rejected by the server. "
"Please verify your account is authorized to send emails."
yield (
"error",
(
"Sender email address defined in the credentials that where used"
"was rejected by the server. "
"Please verify your account is authorized to send emails."
),
)
except smtplib.SMTPDataError as e:
yield "error", f"Email data rejected by server: {str(e)}"

View File

@@ -490,7 +490,9 @@ class GetLinkedinProfilePictureBlock(Block):
],
test_credentials=TEST_CREDENTIALS,
test_mock={
"_get_profile_picture": lambda *args, **kwargs: "https://media.licdn.com/dms/image/C4D03AQFj-xjuXrLFSQ/profile-displayphoto-shrink_800_800/0/1576881858598?e=1686787200&v=beta&t=zrQC76QwsfQQIWthfOnrKRBMZ5D-qIAvzLXLmWgYvTk",
"_get_profile_picture": lambda *args, **kwargs: (
"https://media.licdn.com/dms/image/C4D03AQFj-xjuXrLFSQ/profile-displayphoto-shrink_800_800/0/1576881858598?e=1686787200&v=beta&t=zrQC76QwsfQQIWthfOnrKRBMZ5D-qIAvzLXLmWgYvTk"
),
},
)

View File

@@ -319,7 +319,7 @@ class CostDollars(BaseModel):
# Helper functions for payload processing
def process_text_field(
text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None]
text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None],
) -> Optional[Union[bool, Dict[str, Any]]]:
"""Process text field for API payload."""
if text is None:
@@ -400,7 +400,7 @@ def process_contents_settings(contents: Optional[ContentSettings]) -> Dict[str,
def process_context_field(
context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None]
context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None],
) -> Optional[Union[bool, Dict[str, int]]]:
"""Process context field for API payload."""
if context is None:

View File

@@ -566,8 +566,9 @@ class ExaUpdateWebsetBlock(Block):
yield "status", status_str
yield "external_id", sdk_webset.external_id
yield "metadata", sdk_webset.metadata or {}
yield "updated_at", (
sdk_webset.updated_at.isoformat() if sdk_webset.updated_at else ""
yield (
"updated_at",
(sdk_webset.updated_at.isoformat() if sdk_webset.updated_at else ""),
)
@@ -706,11 +707,13 @@ class ExaGetWebsetBlock(Block):
yield "enrichments", enrichments_data
yield "monitors", monitors_data
yield "metadata", sdk_webset.metadata or {}
yield "created_at", (
sdk_webset.created_at.isoformat() if sdk_webset.created_at else ""
yield (
"created_at",
(sdk_webset.created_at.isoformat() if sdk_webset.created_at else ""),
)
yield "updated_at", (
sdk_webset.updated_at.isoformat() if sdk_webset.updated_at else ""
yield (
"updated_at",
(sdk_webset.updated_at.isoformat() if sdk_webset.updated_at else ""),
)

View File

@@ -523,16 +523,20 @@ class ExaWaitForEnrichmentBlock(Block):
items_enriched = 0
if input_data.sample_results and status == "completed":
sample_data, items_enriched = (
await self._get_sample_enrichments(
input_data.webset_id, input_data.enrichment_id, aexa
)
(
sample_data,
items_enriched,
) = await self._get_sample_enrichments(
input_data.webset_id, input_data.enrichment_id, aexa
)
yield "enrichment_id", input_data.enrichment_id
yield "final_status", status
yield "items_enriched", items_enriched
yield "enrichment_title", enrichment.title or enrichment.description or ""
yield (
"enrichment_title",
enrichment.title or enrichment.description or "",
)
yield "elapsed_time", elapsed
if input_data.sample_results:
yield "sample_data", sample_data

View File

@@ -127,7 +127,9 @@ class AIImageEditorBlock(Block):
],
test_mock={
# Use data URI to avoid HTTP requests during tests
"run_model": lambda *args, **kwargs: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
"run_model": lambda *args, **kwargs: (
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
),
},
test_credentials=TEST_CREDENTIALS,
)

View File

@@ -798,7 +798,9 @@ class GithubUnassignIssueBlock(Block):
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Issue unassigned successfully")],
test_mock={
"unassign_issue": lambda *args, **kwargs: "Issue unassigned successfully"
"unassign_issue": lambda *args, **kwargs: (
"Issue unassigned successfully"
)
},
)

View File

@@ -261,7 +261,9 @@ class GithubReadPullRequestBlock(Block):
"This is the body of the pull request.",
"username",
),
"read_pr_changes": lambda *args, **kwargs: "List of changes made in the pull request.",
"read_pr_changes": lambda *args, **kwargs: (
"List of changes made in the pull request."
),
},
)
@@ -365,7 +367,9 @@ class GithubAssignPRReviewerBlock(Block):
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Reviewer assigned successfully")],
test_mock={
"assign_reviewer": lambda *args, **kwargs: "Reviewer assigned successfully"
"assign_reviewer": lambda *args, **kwargs: (
"Reviewer assigned successfully"
)
},
)
@@ -432,7 +436,9 @@ class GithubUnassignPRReviewerBlock(Block):
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Reviewer unassigned successfully")],
test_mock={
"unassign_reviewer": lambda *args, **kwargs: "Reviewer unassigned successfully"
"unassign_reviewer": lambda *args, **kwargs: (
"Reviewer unassigned successfully"
)
},
)

View File

@@ -341,14 +341,17 @@ class GoogleDocsCreateBlock(Block):
)
doc_id = result["document_id"]
doc_url = result["document_url"]
yield "document", GoogleDriveFile(
id=doc_id,
name=input_data.title,
mimeType="application/vnd.google-apps.document",
url=doc_url,
iconUrl="https://www.gstatic.com/images/branding/product/1x/docs_48dp.png",
isFolder=False,
_credentials_id=input_data.credentials.id,
yield (
"document",
GoogleDriveFile(
id=doc_id,
name=input_data.title,
mimeType="application/vnd.google-apps.document",
url=doc_url,
iconUrl="https://www.gstatic.com/images/branding/product/1x/docs_48dp.png",
isFolder=False,
_credentials_id=input_data.credentials.id,
),
)
yield "document_id", doc_id
yield "document_url", doc_url
@@ -815,7 +818,10 @@ class GoogleDocsGetMetadataBlock(Block):
yield "title", result["title"]
yield "document_id", input_data.document.id
yield "revision_id", result["revision_id"]
yield "document_url", f"https://docs.google.com/document/d/{input_data.document.id}/edit"
yield (
"document_url",
f"https://docs.google.com/document/d/{input_data.document.id}/edit",
)
yield "document", _make_document_output(input_data.document)
except Exception as e:
yield "error", f"Failed to get metadata: {str(e)}"

View File

@@ -278,11 +278,13 @@ class GmailBase(Block, ABC):
"""Download attachment content when email body is stored as attachment."""
try:
attachment = await asyncio.to_thread(
lambda: service.users()
.messages()
.attachments()
.get(userId="me", messageId=msg_id, id=attachment_id)
.execute()
lambda: (
service.users()
.messages()
.attachments()
.get(userId="me", messageId=msg_id, id=attachment_id)
.execute()
)
)
return attachment.get("data")
except Exception:
@@ -304,11 +306,13 @@ class GmailBase(Block, ABC):
async def download_attachment(self, service, message_id: str, attachment_id: str):
attachment = await asyncio.to_thread(
lambda: service.users()
.messages()
.attachments()
.get(userId="me", messageId=message_id, id=attachment_id)
.execute()
lambda: (
service.users()
.messages()
.attachments()
.get(userId="me", messageId=message_id, id=attachment_id)
.execute()
)
)
file_data = base64.urlsafe_b64decode(attachment["data"].encode("UTF-8"))
return file_data
@@ -466,10 +470,12 @@ class GmailReadBlock(GmailBase):
else "full"
)
msg = await asyncio.to_thread(
lambda: service.users()
.messages()
.get(userId="me", id=message["id"], format=format_type)
.execute()
lambda: (
service.users()
.messages()
.get(userId="me", id=message["id"], format=format_type)
.execute()
)
)
headers = {
@@ -602,10 +608,12 @@ class GmailSendBlock(GmailBase):
)
raw_message = await create_mime_message(input_data, execution_context)
sent_message = await asyncio.to_thread(
lambda: service.users()
.messages()
.send(userId="me", body={"raw": raw_message})
.execute()
lambda: (
service.users()
.messages()
.send(userId="me", body={"raw": raw_message})
.execute()
)
)
return {"id": sent_message["id"], "status": "sent"}
@@ -699,8 +707,13 @@ class GmailCreateDraftBlock(GmailBase):
input_data,
execution_context,
)
yield "result", GmailDraftResult(
id=result["id"], message_id=result["message"]["id"], status="draft_created"
yield (
"result",
GmailDraftResult(
id=result["id"],
message_id=result["message"]["id"],
status="draft_created",
),
)
async def _create_draft(
@@ -713,10 +726,12 @@ class GmailCreateDraftBlock(GmailBase):
raw_message = await create_mime_message(input_data, execution_context)
draft = await asyncio.to_thread(
lambda: service.users()
.drafts()
.create(userId="me", body={"message": {"raw": raw_message}})
.execute()
lambda: (
service.users()
.drafts()
.create(userId="me", body={"message": {"raw": raw_message}})
.execute()
)
)
return draft
@@ -840,10 +855,12 @@ class GmailAddLabelBlock(GmailBase):
async def _add_label(self, service, message_id: str, label_name: str) -> dict:
label_id = await self._get_or_create_label(service, label_name)
result = await asyncio.to_thread(
lambda: service.users()
.messages()
.modify(userId="me", id=message_id, body={"addLabelIds": [label_id]})
.execute()
lambda: (
service.users()
.messages()
.modify(userId="me", id=message_id, body={"addLabelIds": [label_id]})
.execute()
)
)
if not result.get("labelIds"):
return {
@@ -857,10 +874,12 @@ class GmailAddLabelBlock(GmailBase):
label_id = await self._get_label_id(service, label_name)
if not label_id:
label = await asyncio.to_thread(
lambda: service.users()
.labels()
.create(userId="me", body={"name": label_name})
.execute()
lambda: (
service.users()
.labels()
.create(userId="me", body={"name": label_name})
.execute()
)
)
label_id = label["id"]
return label_id
@@ -927,10 +946,14 @@ class GmailRemoveLabelBlock(GmailBase):
label_id = await self._get_label_id(service, label_name)
if label_id:
result = await asyncio.to_thread(
lambda: service.users()
.messages()
.modify(userId="me", id=message_id, body={"removeLabelIds": [label_id]})
.execute()
lambda: (
service.users()
.messages()
.modify(
userId="me", id=message_id, body={"removeLabelIds": [label_id]}
)
.execute()
)
)
if not result.get("labelIds"):
return {
@@ -1048,10 +1071,12 @@ class GmailGetThreadBlock(GmailBase):
else "full"
)
thread = await asyncio.to_thread(
lambda: service.users()
.threads()
.get(userId="me", id=thread_id, format=format_type)
.execute()
lambda: (
service.users()
.threads()
.get(userId="me", id=thread_id, format=format_type)
.execute()
)
)
parsed_messages = []
@@ -1106,23 +1131,25 @@ async def _build_reply_message(
"""
# Get parent message for reply context
parent = await asyncio.to_thread(
lambda: service.users()
.messages()
.get(
userId="me",
id=input_data.parentMessageId,
format="metadata",
metadataHeaders=[
"Subject",
"References",
"Message-ID",
"From",
"To",
"Cc",
"Reply-To",
],
lambda: (
service.users()
.messages()
.get(
userId="me",
id=input_data.parentMessageId,
format="metadata",
metadataHeaders=[
"Subject",
"References",
"Message-ID",
"From",
"To",
"Cc",
"Reply-To",
],
)
.execute()
)
.execute()
)
# Build headers dictionary, preserving all values for duplicate headers
@@ -1346,10 +1373,12 @@ class GmailReplyBlock(GmailBase):
# Send the message
return await asyncio.to_thread(
lambda: service.users()
.messages()
.send(userId="me", body={"threadId": thread_id, "raw": raw})
.execute()
lambda: (
service.users()
.messages()
.send(userId="me", body={"threadId": thread_id, "raw": raw})
.execute()
)
)
@@ -1459,18 +1488,20 @@ class GmailDraftReplyBlock(GmailBase):
# Create draft with proper thread association
draft = await asyncio.to_thread(
lambda: service.users()
.drafts()
.create(
userId="me",
body={
"message": {
"threadId": thread_id,
"raw": raw,
}
},
lambda: (
service.users()
.drafts()
.create(
userId="me",
body={
"message": {
"threadId": thread_id,
"raw": raw,
}
},
)
.execute()
)
.execute()
)
return draft
@@ -1642,10 +1673,12 @@ class GmailForwardBlock(GmailBase):
# Get the original message
original = await asyncio.to_thread(
lambda: service.users()
.messages()
.get(userId="me", id=input_data.messageId, format="full")
.execute()
lambda: (
service.users()
.messages()
.get(userId="me", id=input_data.messageId, format="full")
.execute()
)
)
headers = {
@@ -1735,8 +1768,10 @@ To: {original_to}
# Send the forwarded message
raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
return await asyncio.to_thread(
lambda: service.users()
.messages()
.send(userId="me", body={"raw": raw})
.execute()
lambda: (
service.users()
.messages()
.send(userId="me", body={"raw": raw})
.execute()
)
)

View File

@@ -345,14 +345,17 @@ class GoogleSheetsReadBlock(Block):
)
yield "result", data
# Output the GoogleDriveFile for chaining (preserves credentials_id)
yield "spreadsheet", GoogleDriveFile(
id=spreadsheet_id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{spreadsheet_id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=spreadsheet_id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{spreadsheet_id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", _handle_sheets_api_error(str(e), "read")
@@ -466,9 +469,12 @@ class GoogleSheetsWriteBlock(Block):
if validation_error:
# Customize message for write operations on CSV files
if "CSV file" in validation_error:
yield "error", validation_error.replace(
"Please use a CSV reader block instead, or",
"CSV files are read-only through Google Drive. Please",
yield (
"error",
validation_error.replace(
"Please use a CSV reader block instead, or",
"CSV files are read-only through Google Drive. Please",
),
)
else:
yield "error", validation_error
@@ -485,14 +491,17 @@ class GoogleSheetsWriteBlock(Block):
)
yield "result", result
# Output the GoogleDriveFile for chaining (preserves credentials_id)
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", _handle_sheets_api_error(str(e), "write")
@@ -614,14 +623,17 @@ class GoogleSheetsAppendRowBlock(Block):
input_data.value_input_option,
)
yield "result", result
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to append row: {str(e)}"
@@ -744,14 +756,17 @@ class GoogleSheetsClearBlock(Block):
)
yield "result", result
# Output the GoogleDriveFile for chaining (preserves credentials_id)
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to clear Google Sheet range: {str(e)}"
@@ -854,14 +869,17 @@ class GoogleSheetsMetadataBlock(Block):
)
yield "result", result
# Output the GoogleDriveFile for chaining (preserves credentials_id)
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to get spreadsheet metadata: {str(e)}"
@@ -984,14 +1002,17 @@ class GoogleSheetsManageSheetBlock(Block):
)
yield "result", result
# Output the GoogleDriveFile for chaining (preserves credentials_id)
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to manage sheet: {str(e)}"
@@ -1141,14 +1162,17 @@ class GoogleSheetsBatchOperationsBlock(Block):
)
yield "result", result
# Output the GoogleDriveFile for chaining (preserves credentials_id)
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to perform batch operations: {str(e)}"
@@ -1306,14 +1330,17 @@ class GoogleSheetsFindReplaceBlock(Block):
)
yield "result", result
# Output the GoogleDriveFile for chaining (preserves credentials_id)
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to find/replace in Google Sheet: {str(e)}"
@@ -1488,14 +1515,17 @@ class GoogleSheetsFindBlock(Block):
yield "locations", result["locations"]
yield "result", {"success": True}
# Output the GoogleDriveFile for chaining (preserves credentials_id)
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to find text in Google Sheet: {str(e)}"
@@ -1754,14 +1784,17 @@ class GoogleSheetsFormatBlock(Block):
else:
yield "result", result
# Output the GoogleDriveFile for chaining (preserves credentials_id)
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to format Google Sheet cells: {str(e)}"
@@ -1928,14 +1961,17 @@ class GoogleSheetsCreateSpreadsheetBlock(Block):
spreadsheet_id = result["spreadsheetId"]
spreadsheet_url = result["spreadsheetUrl"]
# Output the GoogleDriveFile for chaining (includes credentials_id)
yield "spreadsheet", GoogleDriveFile(
id=spreadsheet_id,
name=result.get("title", input_data.title),
mimeType="application/vnd.google-apps.spreadsheet",
url=spreadsheet_url,
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.credentials.id, # Preserve credentials for chaining
yield (
"spreadsheet",
GoogleDriveFile(
id=spreadsheet_id,
name=result.get("title", input_data.title),
mimeType="application/vnd.google-apps.spreadsheet",
url=spreadsheet_url,
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.credentials.id, # Preserve credentials for chaining
),
)
yield "spreadsheet_id", spreadsheet_id
yield "spreadsheet_url", spreadsheet_url
@@ -2113,14 +2149,17 @@ class GoogleSheetsUpdateCellBlock(Block):
yield "result", result
# Output the GoogleDriveFile for chaining (preserves credentials_id)
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", _handle_sheets_api_error(str(e), "update")
@@ -2379,14 +2418,17 @@ class GoogleSheetsFilterRowsBlock(Block):
yield "rows", result["rows"]
yield "row_indices", result["row_indices"]
yield "count", result["count"]
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to filter rows: {str(e)}"
@@ -2596,14 +2638,17 @@ class GoogleSheetsLookupRowBlock(Block):
yield "row_dict", result["row_dict"]
yield "row_index", result["row_index"]
yield "found", result["found"]
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to lookup row: {str(e)}"
@@ -2817,14 +2862,17 @@ class GoogleSheetsDeleteRowsBlock(Block):
)
yield "result", {"success": True}
yield "deleted_count", result["deleted_count"]
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to delete rows: {str(e)}"
@@ -2995,14 +3043,17 @@ class GoogleSheetsGetColumnBlock(Block):
yield "values", result["values"]
yield "count", result["count"]
yield "column_index", result["column_index"]
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to get column: {str(e)}"
@@ -3176,14 +3227,17 @@ class GoogleSheetsSortBlock(Block):
input_data.has_header,
)
yield "result", result
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to sort sheet: {str(e)}"
@@ -3439,14 +3493,17 @@ class GoogleSheetsGetUniqueValuesBlock(Block):
yield "values", result["values"]
yield "counts", result["counts"]
yield "total_unique", result["total_unique"]
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to get unique values: {str(e)}"
@@ -3620,14 +3677,17 @@ class GoogleSheetsInsertRowBlock(Block):
input_data.value_input_option,
)
yield "result", result
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to insert row: {str(e)}"
@@ -3793,14 +3853,17 @@ class GoogleSheetsAddColumnBlock(Block):
yield "result", {"success": True}
yield "column_letter", result["column_letter"]
yield "column_index", result["column_index"]
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to add column: {str(e)}"
@@ -3998,14 +4061,17 @@ class GoogleSheetsGetRowCountBlock(Block):
yield "data_rows", result["data_rows"]
yield "last_row", result["last_row"]
yield "column_count", result["column_count"]
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to get row count: {str(e)}"
@@ -4176,14 +4242,17 @@ class GoogleSheetsRemoveDuplicatesBlock(Block):
yield "result", {"success": True}
yield "removed_count", result["removed_count"]
yield "remaining_rows", result["remaining_rows"]
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to remove duplicates: {str(e)}"
@@ -4426,14 +4495,17 @@ class GoogleSheetsUpdateRowBlock(Block):
input_data.dict_values,
)
yield "result", result
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to update row: {str(e)}"
@@ -4615,14 +4687,17 @@ class GoogleSheetsGetRowBlock(Block):
)
yield "row", result["row"]
yield "row_dict", result["row_dict"]
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to get row: {str(e)}"
@@ -4753,14 +4828,17 @@ class GoogleSheetsDeleteColumnBlock(Block):
input_data.column,
)
yield "result", result
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to delete column: {str(e)}"
@@ -4931,14 +5009,17 @@ class GoogleSheetsCreateNamedRangeBlock(Block):
)
yield "result", {"success": True}
yield "named_range_id", result["named_range_id"]
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to create named range: {str(e)}"
@@ -5104,14 +5185,17 @@ class GoogleSheetsListNamedRangesBlock(Block):
)
yield "named_ranges", result["named_ranges"]
yield "count", result["count"]
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to list named ranges: {str(e)}"
@@ -5264,14 +5348,17 @@ class GoogleSheetsAddDropdownBlock(Block):
input_data.show_dropdown,
)
yield "result", result
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to add dropdown: {str(e)}"
@@ -5436,14 +5523,17 @@ class GoogleSheetsCopyToSpreadsheetBlock(Block):
yield "result", {"success": True}
yield "new_sheet_id", result["new_sheet_id"]
yield "new_sheet_name", result["new_sheet_name"]
yield "spreadsheet", GoogleDriveFile(
id=input_data.source_spreadsheet.id,
name=input_data.source_spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.source_spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.source_spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.source_spreadsheet.id,
name=input_data.source_spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.source_spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.source_spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to copy sheet: {str(e)}"
@@ -5588,14 +5678,17 @@ class GoogleSheetsProtectRangeBlock(Block):
)
yield "result", {"success": True}
yield "protection_id", result["protection_id"]
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to protect range: {str(e)}"
@@ -5752,14 +5845,17 @@ class GoogleSheetsExportCsvBlock(Block):
)
yield "csv_data", result["csv_data"]
yield "row_count", result["row_count"]
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to export CSV: {str(e)}"
@@ -5895,14 +5991,17 @@ class GoogleSheetsImportCsvBlock(Block):
)
yield "result", {"success": True}
yield "rows_imported", result["rows_imported"]
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to import CSV: {str(e)}"
@@ -6032,14 +6131,17 @@ class GoogleSheetsAddNoteBlock(Block):
input_data.note,
)
yield "result", {"success": True}
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to add note: {str(e)}"
@@ -6185,14 +6287,17 @@ class GoogleSheetsGetNotesBlock(Block):
notes = result["notes"]
yield "notes", notes
yield "count", len(notes)
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to get notes: {str(e)}"
@@ -6347,14 +6452,17 @@ class GoogleSheetsShareSpreadsheetBlock(Block):
)
yield "result", {"success": True}
yield "share_link", result["share_link"]
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to share spreadsheet: {str(e)}"
@@ -6491,14 +6599,17 @@ class GoogleSheetsSetPublicAccessBlock(Block):
)
yield "result", {"success": True, "is_public": result["is_public"]}
yield "share_link", result["share_link"]
yield "spreadsheet", GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
yield (
"spreadsheet",
GoogleDriveFile(
id=input_data.spreadsheet.id,
name=input_data.spreadsheet.name,
mimeType="application/vnd.google-apps.spreadsheet",
url=f"https://docs.google.com/spreadsheets/d/{input_data.spreadsheet.id}/edit",
iconUrl="https://www.gstatic.com/images/branding/product/1x/sheets_48dp.png",
isFolder=False,
_credentials_id=input_data.spreadsheet.credentials_id,
),
)
except Exception as e:
yield "error", f"Failed to set public access: {str(e)}"

View File

@@ -195,8 +195,12 @@ class IdeogramModelBlock(Block):
),
],
test_mock={
"run_model": lambda api_key, model_name, prompt, seed, aspect_ratio, magic_prompt_option, style_type, negative_prompt, color_palette_name, custom_colors: "https://ideogram.ai/api/images/test-generated-image-url.png",
"upscale_image": lambda api_key, image_url: "https://ideogram.ai/api/images/test-upscaled-image-url.png",
"run_model": lambda api_key, model_name, prompt, seed, aspect_ratio, magic_prompt_option, style_type, negative_prompt, color_palette_name, custom_colors: (
"https://ideogram.ai/api/images/test-generated-image-url.png"
),
"upscale_image": lambda api_key, image_url: (
"https://ideogram.ai/api/images/test-upscaled-image-url.png"
),
},
test_credentials=TEST_CREDENTIALS,
)

View File

@@ -211,8 +211,11 @@ class AgentOutputBlock(Block):
if input_data.format:
try:
formatter = TextFormatter(autoescape=input_data.escape_html)
yield "output", formatter.format_string(
input_data.format, {input_data.name: input_data.value}
yield (
"output",
formatter.format_string(
input_data.format, {input_data.name: input_data.value}
),
)
except Exception as e:
yield "output", f"Error: {e}, {input_data.value}"
@@ -475,10 +478,13 @@ class AgentFileInputBlock(AgentInputBlock):
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
yield "result", await store_media_file(
file=input_data.value,
execution_context=execution_context,
return_format=return_format,
yield (
"result",
await store_media_file(
file=input_data.value,
execution_context=execution_context,
return_format=return_format,
),
)

View File

@@ -128,10 +128,16 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
try:
content = await self.get_request(url, json=False, headers=headers)
except HTTPClientError as e:
yield "error", f"Client error ({e.status_code}) fetching {input_data.url}: {e}"
yield (
"error",
f"Client error ({e.status_code}) fetching {input_data.url}: {e}",
)
return
except HTTPServerError as e:
yield "error", f"Server error ({e.status_code}) fetching {input_data.url}: {e}"
yield (
"error",
f"Server error ({e.status_code}) fetching {input_data.url}: {e}",
)
return
except Exception as e:
yield "error", f"Failed to fetch {input_data.url}: {e}"

View File

@@ -75,7 +75,6 @@ class LinearClient:
response_data = response.json()
if "errors" in response_data:
error_messages = [
error.get("message", "") for error in response_data["errors"]
]

View File

@@ -692,7 +692,6 @@ async def llm_call(
reasoning=reasoning,
)
elif provider == "anthropic":
an_tools = convert_openai_tool_fmt_to_anthropic(tools)
system_messages = [p["content"] for p in prompt if p["role"] == "system"]

View File

@@ -267,9 +267,12 @@ class MCPToolBlock(Block):
if required:
missing = required - set(input_data.tool_arguments.keys())
if missing:
yield "error", (
f"Missing required argument(s): {', '.join(sorted(missing))}. "
f"Please fill in all required fields marked with * in the block form."
yield (
"error",
(
f"Missing required argument(s): {', '.join(sorted(missing))}. "
f"Please fill in all required fields marked with * in the block form."
),
)
return

View File

@@ -18,11 +18,7 @@ class TestSSEParsing:
"""Tests for SSE (text/event-stream) response parsing."""
def test_parse_sse_simple(self):
sse = (
"event: message\n"
'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n'
"\n"
)
sse = 'event: message\ndata: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n\n'
body = MCPClient._parse_sse_response(sse)
assert body["result"] == {"tools": []}
assert body["id"] == 1

View File

@@ -75,11 +75,14 @@ class PersistInformationBlock(Block):
storage_key = get_storage_key(input_data.key, input_data.scope, graph_id)
# Store the data
yield "value", await self._store_data(
user_id=user_id,
node_exec_id=node_exec_id,
key=storage_key,
data=input_data.value,
yield (
"value",
await self._store_data(
user_id=user_id,
node_exec_id=node_exec_id,
key=storage_key,
data=input_data.value,
),
)
async def _store_data(

View File

@@ -160,10 +160,13 @@ class PineconeQueryBlock(Block):
combined_text = "\n\n".join(texts)
# Return both the raw matches and combined text
yield "results", {
"matches": results["matches"],
"combined_text": combined_text,
}
yield (
"results",
{
"matches": results["matches"],
"combined_text": combined_text,
},
)
yield "combined_results", combined_text
except Exception as e:

View File

@@ -309,10 +309,13 @@ class PostRedditCommentBlock(Block):
async def run(
self, input_data: Input, *, credentials: RedditCredentials, **kwargs
) -> BlockOutput:
yield "comment_id", self.reply_post(
credentials,
post_id=input_data.post_id,
comment=input_data.comment,
yield (
"comment_id",
self.reply_post(
credentials,
post_id=input_data.post_id,
comment=input_data.comment,
),
)
yield "post_id", input_data.post_id

View File

@@ -141,7 +141,9 @@ class ReplicateFluxAdvancedModelBlock(Block):
),
],
test_mock={
"run_model": lambda api_key, model_name, prompt, seed, steps, guidance, interval, aspect_ratio, output_format, output_quality, safety_tolerance: "https://replicate.com/output/generated-image-url.jpg",
"run_model": lambda api_key, model_name, prompt, seed, steps, guidance, interval, aspect_ratio, output_format, output_quality, safety_tolerance: (
"https://replicate.com/output/generated-image-url.jpg"
),
},
test_credentials=TEST_CREDENTIALS,
)

View File

@@ -48,7 +48,7 @@ class Slant3DBlockBase(Block):
raise ValueError(
f"""Invalid color profile combination {color_tag}.
Valid colors for {profile.value} are:
{','.join([filament['colorTag'].replace(profile.value.lower(), '') for filament in response['filaments'] if filament['profile'] == profile.value])}
{",".join([filament["colorTag"].replace(profile.value.lower(), "") for filament in response["filaments"] if filament["profile"] == profile.value])}
"""
)
return color_tag

View File

@@ -933,7 +933,10 @@ class SmartDecisionMakerBlock(Block):
credentials, input_data, iteration_prompt, tool_functions
)
except Exception as e:
yield "error", f"LLM call failed in agent mode iteration {iteration}: {str(e)}"
yield (
"error",
f"LLM call failed in agent mode iteration {iteration}: {str(e)}",
)
return
# Process tool calls
@@ -973,7 +976,10 @@ class SmartDecisionMakerBlock(Block):
if max_iterations < 0:
yield "finished", f"Agent mode completed after {iteration} iterations"
else:
yield "finished", f"Agent mode completed after {max_iterations} iterations (limit reached)"
yield (
"finished",
f"Agent mode completed after {max_iterations} iterations (limit reached)",
)
yield "conversations", current_prompt
async def run(

View File

@@ -180,20 +180,22 @@ class AddLeadToCampaignBlock(Block):
),
],
test_mock={
"add_leads_to_campaign": lambda campaign_id, lead_list, credentials: AddLeadsToCampaignResponse(
ok=True,
upload_count=1,
already_added_to_campaign=0,
duplicate_count=0,
invalid_email_count=0,
is_lead_limit_exhausted=False,
lead_import_stopped_count=0,
error="",
total_leads=1,
block_count=0,
invalid_emails=[],
unsubscribed_leads=[],
bounce_count=0,
"add_leads_to_campaign": lambda campaign_id, lead_list, credentials: (
AddLeadsToCampaignResponse(
ok=True,
upload_count=1,
already_added_to_campaign=0,
duplicate_count=0,
invalid_email_count=0,
is_lead_limit_exhausted=False,
lead_import_stopped_count=0,
error="",
total_leads=1,
block_count=0,
invalid_emails=[],
unsubscribed_leads=[],
bounce_count=0,
)
)
},
)
@@ -295,9 +297,11 @@ class SaveCampaignSequencesBlock(Block):
),
],
test_mock={
"save_campaign_sequences": lambda campaign_id, sequences, credentials: SaveSequencesResponse(
ok=True,
message="Sequences saved successfully",
"save_campaign_sequences": lambda campaign_id, sequences, credentials: (
SaveSequencesResponse(
ok=True,
message="Sequences saved successfully",
)
)
},
)

View File

@@ -219,17 +219,19 @@ async def test_smart_decision_maker_tracks_llm_stats():
# Mock the _create_tool_node_signatures method to avoid database calls
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response,
), patch.object(
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
with (
patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response,
),
patch.object(
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
),
):
# Create test input
input_data = SmartDecisionMakerBlock.Input(
prompt="Should I continue with this task?",
@@ -322,17 +324,19 @@ async def test_smart_decision_maker_parameter_validation():
mock_response_with_typo.reasoning = None
mock_response_with_typo.raw_response = {"role": "assistant", "content": None}
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response_with_typo,
) as mock_llm_call, patch.object(
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=mock_tool_functions,
with (
patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response_with_typo,
) as mock_llm_call,
patch.object(
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=mock_tool_functions,
),
):
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.DEFAULT_LLM_MODEL,
@@ -389,17 +393,19 @@ async def test_smart_decision_maker_parameter_validation():
mock_response_missing_required.reasoning = None
mock_response_missing_required.raw_response = {"role": "assistant", "content": None}
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response_missing_required,
), patch.object(
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=mock_tool_functions,
with (
patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response_missing_required,
),
patch.object(
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=mock_tool_functions,
),
):
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.DEFAULT_LLM_MODEL,
@@ -449,17 +455,19 @@ async def test_smart_decision_maker_parameter_validation():
mock_response_valid.reasoning = None
mock_response_valid.raw_response = {"role": "assistant", "content": None}
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response_valid,
), patch.object(
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=mock_tool_functions,
with (
patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response_valid,
),
patch.object(
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=mock_tool_functions,
),
):
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.DEFAULT_LLM_MODEL,
@@ -513,17 +521,19 @@ async def test_smart_decision_maker_parameter_validation():
mock_response_all_params.reasoning = None
mock_response_all_params.raw_response = {"role": "assistant", "content": None}
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response_all_params,
), patch.object(
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=mock_tool_functions,
with (
patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response_all_params,
),
patch.object(
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=mock_tool_functions,
),
):
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.DEFAULT_LLM_MODEL,
@@ -634,13 +644,14 @@ async def test_smart_decision_maker_raw_response_conversion():
# Mock llm_call to return different responses on different calls
with patch(
"backend.blocks.llm.llm_call", new_callable=AsyncMock
) as mock_llm_call, patch.object(
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=mock_tool_functions,
with (
patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm_call,
patch.object(
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=mock_tool_functions,
),
):
# First call returns response that will trigger retry due to validation error
# Second call returns successful response
@@ -710,15 +721,18 @@ async def test_smart_decision_maker_raw_response_conversion():
"I'll help you with that." # Ollama returns string
)
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response_ollama,
), patch.object(
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[], # No tools for this test
with (
patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response_ollama,
),
patch.object(
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[], # No tools for this test
),
):
input_data = SmartDecisionMakerBlock.Input(
prompt="Simple prompt",
@@ -766,15 +780,18 @@ async def test_smart_decision_maker_raw_response_conversion():
"content": "Test response",
} # Dict format
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response_dict,
), patch.object(
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
with (
patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response_dict,
),
patch.object(
SmartDecisionMakerBlock,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
),
):
input_data = SmartDecisionMakerBlock.Input(
prompt="Another test",
@@ -890,18 +907,21 @@ async def test_smart_decision_maker_agent_mode():
# No longer need mock_execute_node since we use execution_processor.on_node_execution
with patch("backend.blocks.llm.llm_call", llm_call_mock), patch.object(
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
), patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
return_value=mock_db_client,
), patch(
"backend.executor.manager.async_update_node_execution_status",
new_callable=AsyncMock,
), patch(
"backend.integrations.creds_manager.IntegrationCredentialsManager"
with (
patch("backend.blocks.llm.llm_call", llm_call_mock),
patch.object(
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
),
patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
return_value=mock_db_client,
),
patch(
"backend.executor.manager.async_update_node_execution_status",
new_callable=AsyncMock,
),
patch("backend.integrations.creds_manager.IntegrationCredentialsManager"),
):
# Create a mock execution context
mock_execution_context = ExecutionContext(
@@ -1009,14 +1029,16 @@ async def test_smart_decision_maker_traditional_mode_default():
}
]
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response,
), patch.object(
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
with (
patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response,
),
patch.object(
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
),
):
# Test default behavior (traditional mode)
input_data = SmartDecisionMakerBlock.Input(
prompt="Test prompt",

View File

@@ -41,7 +41,8 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields():
# Generate function signature
signature = await SmartDecisionMakerBlock._create_block_function_signature(
mock_node, mock_links # type: ignore
mock_node,
mock_links, # type: ignore
)
# Verify the signature was created successfully
@@ -98,7 +99,8 @@ async def test_smart_decision_maker_handles_dynamic_list_fields():
# Generate function signature
signature = await SmartDecisionMakerBlock._create_block_function_signature(
mock_node, mock_links # type: ignore
mock_node,
mock_links, # type: ignore
)
# Verify dynamic list fields are handled properly

View File

@@ -314,11 +314,14 @@ async def test_output_yielding_with_dynamic_fields():
mock_llm.return_value = mock_response
# Mock the database manager to avoid HTTP calls during tool execution
with patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
) as mock_db_manager, patch.object(
block, "_create_tool_node_signatures", new_callable=AsyncMock
) as mock_sig:
with (
patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
) as mock_db_manager,
patch.object(
block, "_create_tool_node_signatures", new_callable=AsyncMock
) as mock_sig,
):
# Set up the mock database manager
mock_db_client = AsyncMock()
mock_db_manager.return_value = mock_db_client

View File

@@ -275,24 +275,30 @@ class GetCurrentDateBlock(Block):
test_output=[
(
"date",
lambda t: abs(
datetime.now().date() - datetime.strptime(t, "%Y-%m-%d").date()
)
<= timedelta(days=8), # 7 days difference + 1 day error margin.
lambda t: (
abs(
datetime.now().date()
- datetime.strptime(t, "%Y-%m-%d").date()
)
<= timedelta(days=8)
), # 7 days difference + 1 day error margin.
),
(
"date",
lambda t: abs(
datetime.now().date() - datetime.strptime(t, "%m/%d/%Y").date()
)
<= timedelta(days=8),
lambda t: (
abs(
datetime.now().date()
- datetime.strptime(t, "%m/%d/%Y").date()
)
<= timedelta(days=8)
),
# 7 days difference + 1 day error margin.
),
(
"date",
lambda t: len(t) == 10
and t[4] == "-"
and t[7] == "-", # ISO date format YYYY-MM-DD
lambda t: (
len(t) == 10 and t[4] == "-" and t[7] == "-"
), # ISO date format YYYY-MM-DD
),
],
)
@@ -380,25 +386,32 @@ class GetCurrentDateAndTimeBlock(Block):
test_output=[
(
"date_time",
lambda t: abs(
datetime.now(tz=ZoneInfo("UTC"))
- datetime.strptime(t + "+00:00", "%Y-%m-%d %H:%M:%S%z")
)
< timedelta(seconds=10), # 10 seconds error margin.
lambda t: (
abs(
datetime.now(tz=ZoneInfo("UTC"))
- datetime.strptime(t + "+00:00", "%Y-%m-%d %H:%M:%S%z")
)
< timedelta(seconds=10)
), # 10 seconds error margin.
),
(
"date_time",
lambda t: abs(
datetime.now().date() - datetime.strptime(t, "%Y/%m/%d").date()
)
<= timedelta(days=1), # Date format only, no time component
lambda t: (
abs(
datetime.now().date()
- datetime.strptime(t, "%Y/%m/%d").date()
)
<= timedelta(days=1)
), # Date format only, no time component
),
(
"date_time",
lambda t: abs(
datetime.now(tz=ZoneInfo("UTC")) - datetime.fromisoformat(t)
)
< timedelta(seconds=10), # 10 seconds error margin for ISO format.
lambda t: (
abs(
datetime.now(tz=ZoneInfo("UTC")) - datetime.fromisoformat(t)
)
< timedelta(seconds=10)
), # 10 seconds error margin for ISO format.
),
],
)

View File

@@ -160,7 +160,7 @@ class TodoistCreateProjectBlock(Block):
test_input={"credentials": TEST_CREDENTIALS_INPUT, "name": "Test Project"},
test_credentials=TEST_CREDENTIALS,
test_output=[("success", True)],
test_mock={"create_project": lambda *args, **kwargs: (True)},
test_mock={"create_project": lambda *args, **kwargs: True},
)
@staticmethod
@@ -346,7 +346,7 @@ class TodoistUpdateProjectBlock(Block):
},
test_credentials=TEST_CREDENTIALS,
test_output=[("success", True)],
test_mock={"update_project": lambda *args, **kwargs: (True)},
test_mock={"update_project": lambda *args, **kwargs: True},
)
@staticmethod
@@ -426,7 +426,7 @@ class TodoistDeleteProjectBlock(Block):
},
test_credentials=TEST_CREDENTIALS,
test_output=[("success", True)],
test_mock={"delete_project": lambda *args, **kwargs: (True)},
test_mock={"delete_project": lambda *args, **kwargs: True},
)
@staticmethod

View File

@@ -285,7 +285,7 @@ class TodoistDeleteSectionBlock(Block):
test_input={"credentials": TEST_CREDENTIALS_INPUT, "section_id": "7025"},
test_credentials=TEST_CREDENTIALS,
test_output=[("success", True)],
test_mock={"delete_section": lambda *args, **kwargs: (True)},
test_mock={"delete_section": lambda *args, **kwargs: True},
)
@staticmethod

View File

@@ -580,7 +580,7 @@ class TodoistReopenTaskBlock(Block):
test_output=[
("success", True),
],
test_mock={"reopen_task": lambda *args, **kwargs: (True)},
test_mock={"reopen_task": lambda *args, **kwargs: True},
)
@staticmethod
@@ -632,7 +632,7 @@ class TodoistDeleteTaskBlock(Block):
test_output=[
("success", True),
],
test_mock={"delete_task": lambda *args, **kwargs: (True)},
test_mock={"delete_task": lambda *args, **kwargs: True},
)
@staticmethod

View File

@@ -256,7 +256,6 @@ class ListFieldsFilter(BaseModel):
# --------- [Input Types] -------------
class TweetExpansionInputs(BlockSchemaInput):
expansions: ExpansionFilter | None = SchemaField(
description="Choose what extra information you want to get with your tweets. For example:\n- Select 'Media_Keys' to get media details\n- Select 'Author_User_ID' to get user information\n- Select 'Place_ID' to get location details",
placeholder="Pick the extra information you want to see",

View File

@@ -232,7 +232,7 @@ class TwitterCreateListBlock(Block):
("list_id", "1234567890"),
("url", "https://twitter.com/i/lists/1234567890"),
],
test_mock={"create_list": lambda *args, **kwargs: ("1234567890")},
test_mock={"create_list": lambda *args, **kwargs: "1234567890"},
)
@staticmethod

View File

@@ -159,7 +159,6 @@ class TwitterGetTweetBlock(Block):
**kwargs,
) -> BlockOutput:
try:
tweet_data, included, meta, user_id, user_name = self.get_tweet(
credentials,
input_data.tweet_id,

View File

@@ -44,7 +44,8 @@ class VideoNarrationBlock(Block):
)
script: str = SchemaField(description="Narration script text")
voice_id: str = SchemaField(
description="ElevenLabs voice ID", default="21m00Tcm4TlvDq8ikWAM" # Rachel
description="ElevenLabs voice ID",
default="21m00Tcm4TlvDq8ikWAM", # Rachel
)
model_id: Literal[
"eleven_multilingual_v2",

View File

@@ -94,7 +94,9 @@ class TranscribeYoutubeVideoBlock(Block):
{"text": "Never gonna give you up"},
{"text": "Never gonna let you down"},
],
"format_transcript": lambda transcript: "Never gonna give you up\nNever gonna let you down",
"format_transcript": lambda transcript: (
"Never gonna give you up\nNever gonna let you down"
),
},
)

View File

@@ -140,20 +140,22 @@ class ValidateEmailsBlock(Block):
)
],
test_mock={
"validate_email": lambda email, ip_address, credentials: ZBValidateResponse(
data={
"address": email,
"status": ZBValidateStatus.valid,
"sub_status": ZBValidateSubStatus.allowed,
"account": "test",
"domain": "test.com",
"did_you_mean": None,
"domain_age_days": None,
"free_email": False,
"mx_found": False,
"mx_record": None,
"smtp_provider": None,
}
"validate_email": lambda email, ip_address, credentials: (
ZBValidateResponse(
data={
"address": email,
"status": ZBValidateStatus.valid,
"sub_status": ZBValidateSubStatus.allowed,
"account": "test",
"domain": "test.com",
"did_you_mean": None,
"domain_age_days": None,
"free_email": False,
"mx_found": False,
"mx_record": None,
"smtp_provider": None,
}
)
)
},
)

View File

@@ -172,7 +172,7 @@ async def add_test_data(db):
"storeListingId": listing.id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"name": f"Test Agent {i+1}",
"name": f"Test Agent {i + 1}",
"subHeading": faker.catch_phrase(),
"description": faker.paragraph(nb_sentences=5),
"imageUrls": [faker.image_url()],
@@ -245,9 +245,7 @@ async def compare_counts(before, after):
print("🔍 Agent run changes:")
before_runs = before["agent_runs"].get("total_runs") or 0
after_runs = after["agent_runs"].get("total_runs") or 0
print(
f" Total runs: {before_runs}{after_runs} " f"(+{after_runs - before_runs})"
)
print(f" Total runs: {before_runs}{after_runs} (+{after_runs - before_runs})")
# Compare reviews
print("\n🔍 Review changes:")

View File

@@ -147,7 +147,7 @@ def format_sql_insert(creds: dict) -> str:
sql = f"""
-- ============================================================
-- OAuth Application: {creds['name']}
-- OAuth Application: {creds["name"]}
-- Generated: {now_iso} UTC
-- ============================================================
@@ -167,14 +167,14 @@ INSERT INTO "OAuthApplication" (
"isActive"
)
VALUES (
'{creds['id']}',
'{creds["id"]}',
NOW(),
NOW(),
'{creds['name']}',
{f"'{creds['description']}'" if creds['description'] else 'NULL'},
'{creds['client_id']}',
'{creds['client_secret_hash']}',
'{creds['client_secret_salt']}',
'{creds["name"]}',
{f"'{creds['description']}'" if creds["description"] else "NULL"},
'{creds["client_id"]}',
'{creds["client_secret_hash"]}',
'{creds["client_secret_salt"]}',
ARRAY{redirect_uris_pg}::TEXT[],
ARRAY{grant_types_pg}::TEXT[],
ARRAY{scopes_pg}::"APIKeyPermission"[],
@@ -186,8 +186,8 @@ VALUES (
-- ⚠️ IMPORTANT: Save these credentials securely!
-- ============================================================
--
-- Client ID: {creds['client_id']}
-- Client Secret: {creds['client_secret_plaintext']}
-- Client ID: {creds["client_id"]}
-- Client Secret: {creds["client_secret_plaintext"]}
--
-- ⚠️ The client secret is shown ONLY ONCE!
-- ⚠️ Store it securely and share only with the application developer.
@@ -200,7 +200,7 @@ VALUES (
-- To verify the application was created:
-- SELECT "clientId", name, scopes, "redirectUris", "isActive"
-- FROM "OAuthApplication"
-- WHERE "clientId" = '{creds['client_id']}';
-- WHERE "clientId" = '{creds["client_id"]}';
"""
return sql

View File

@@ -261,15 +261,15 @@ class ChatCompletionConsumer:
return
logger.info(
f"[COMPLETION] Found task: task_id={task.task_id}, "
f"[COMPLETION] Found task: task_id={task.session_id}, "
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
)
# Guard against empty task fields
if not task.task_id or not task.session_id or not task.tool_call_id:
if not task.session_id or not task.session_id or not task.tool_call_id:
logger.error(
f"[COMPLETION] Task has empty critical fields! "
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
f"task_id={task.session_id!r}, session_id={task.session_id!r}, "
f"tool_call_id={task.tool_call_id!r}"
)
return

View File

@@ -223,7 +223,7 @@ async def process_operation_success(
# Publish result to stream registry
await stream_registry.publish_chunk(
task.task_id,
task.session_id,
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
@@ -244,14 +244,14 @@ async def process_operation_success(
except ToolMessageUpdateError:
# DB update failed - mark task as failed to avoid inconsistent state
logger.error(
f"[COMPLETION] DB update failed for task {task.task_id}, "
f"[COMPLETION] DB update failed for task {task.session_id}, "
"marking as failed instead of completed"
)
await stream_registry.publish_chunk(
task.task_id,
task.session_id,
StreamError(errorText="Failed to save operation result to database"),
)
await stream_registry.mark_task_completed(task.task_id, status="failed")
await stream_registry.mark_task_completed(task.session_id, status="failed")
raise
# Generate LLM continuation with streaming
@@ -259,7 +259,6 @@ async def process_operation_success(
await chat_service._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
except Exception as e:
logger.error(
@@ -268,14 +267,14 @@ async def process_operation_success(
)
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
await stream_registry.mark_task_completed(task.session_id, status="completed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
f"[COMPLETION] Successfully processed completion for task {task.session_id}"
)
@@ -296,7 +295,7 @@ async def process_operation_failure(
# Publish error to stream registry
await stream_registry.publish_chunk(
task.task_id,
task.session_id,
StreamError(errorText=error_msg),
)
@@ -315,15 +314,17 @@ async def process_operation_failure(
except ToolMessageUpdateError:
# DB update failed - log but continue with cleanup
logger.error(
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
f"[COMPLETION] DB update failed while processing failure for task {task.session_id}, "
"continuing with cleanup"
)
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
await stream_registry.mark_task_completed(task.session_id, status="failed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")
logger.info(
f"[COMPLETION] Processed failure for task {task.session_id}: {error_msg}"
)

View File

@@ -151,7 +151,7 @@ class CoPilotProcessor:
"""
log = CoPilotLogMetadata(
logging.getLogger(__name__),
task_id=entry.task_id,
task_id=entry.session_id,
session_id=entry.session_id,
user_id=entry.user_id,
)
@@ -240,14 +240,16 @@ class CoPilotProcessor:
if cancel.is_set():
log.info("Cancelled during streaming")
await stream_registry.publish_chunk(
entry.task_id, StreamError(errorText="Operation cancelled")
entry.session_id, StreamError(errorText="Operation cancelled")
)
await stream_registry.publish_chunk(
entry.task_id, StreamFinishStep()
entry.session_id, StreamFinishStep()
)
await stream_registry.publish_chunk(
entry.session_id, StreamFinish()
)
await stream_registry.publish_chunk(entry.task_id, StreamFinish())
await stream_registry.mark_task_completed(
entry.task_id, status="failed"
entry.session_id, status="failed"
)
return
@@ -258,16 +260,18 @@ class CoPilotProcessor:
last_refresh = current_time
# Publish chunk to stream registry
await stream_registry.publish_chunk(entry.task_id, chunk)
await stream_registry.publish_chunk(entry.session_id, chunk)
# Mark task as completed
await stream_registry.mark_task_completed(entry.task_id, status="completed")
await stream_registry.mark_task_completed(
entry.session_id, status="completed"
)
log.info("Task completed successfully")
except asyncio.CancelledError:
log.info("Task cancelled")
await stream_registry.mark_task_completed(
entry.task_id,
entry.session_id,
status="failed",
error_message="Task was cancelled",
)
@@ -275,17 +279,17 @@ class CoPilotProcessor:
except Exception as e:
log.error(f"Task failed: {e}")
await self._mark_task_failed(entry.task_id, str(e))
await self._mark_task_failed(entry.session_id, str(e))
raise
async def _mark_task_failed(self, task_id: str, error_message: str):
async def _mark_task_failed(self, session_id: str, error_message: str):
"""Mark a task as failed and publish error to stream registry."""
try:
await stream_registry.publish_chunk(
task_id, StreamError(errorText=error_message)
session_id, StreamError(errorText=error_message)
)
await stream_registry.publish_chunk(task_id, StreamFinishStep())
await stream_registry.publish_chunk(task_id, StreamFinish())
await stream_registry.mark_task_completed(task_id, status="failed")
await stream_registry.publish_chunk(session_id, StreamFinishStep())
await stream_registry.publish_chunk(session_id, StreamFinish())
await stream_registry.mark_task_completed(session_id, status="failed")
except Exception as e:
logger.error(f"Failed to mark task {task_id} as failed: {e}")
logger.error(f"Failed to mark task {session_id} as failed: {e}")

View File

@@ -294,7 +294,7 @@ class SDKResponseAdapter:
self.resolved_tool_calls.add(tool_id)
flushed = True
logger.info(
"[SDK] [%s] Flushed stashed output for %s " "(call %s, %d chars)",
"[SDK] [%s] Flushed stashed output for %s (call %s, %d chars)",
sid,
tool_name,
tool_id[:12],

View File

@@ -13,7 +13,6 @@ from backend.data.redis_client import get_redis_async
from backend.executor.cluster_lock import AsyncClusterLock
from backend.util.exceptions import NotFoundError
from .. import stream_registry
from ..config import ChatConfig
from ..model import (
ChatMessage,
@@ -213,7 +212,6 @@ def _build_long_running_callback(
tool_name: str, args: dict[str, Any], session: ChatSession
) -> dict[str, Any]:
operation_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
session_id = session.session_id
# CRITICAL: Find the tool_use_id from the latest assistant message.
@@ -230,23 +228,13 @@ def _build_long_running_callback(
f"using generated ID: {tool_call_id}"
)
# --- Register task in Redis for SSE reconnection ---
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
operation_id=operation_id,
)
# --- Execute tool synchronously and WAIT for completion ---
# The callback blocks here, waiting for agent generation to complete.
# Meanwhile, the frontend mini-game shows via SSE events from stream_registry.
# Results are published to stream_registry via session_id for SSE reconnection.
# Claude only receives the FINAL result after generation is done.
logger.info(
f"[SDK] Executing {tool_name} synchronously and blocking until completion "
f"(operation_id={operation_id}, task_id={task_id})"
f"(operation_id={operation_id}, session_id={session_id})"
)
# Execute synchronously - this handles both sync and async (202) tool responses
@@ -255,7 +243,6 @@ def _build_long_running_callback(
parameters=args,
tool_call_id=tool_call_id,
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
user_id=user_id,
)
@@ -762,8 +749,7 @@ async def stream_chat_completion_sdk(
session_id,
)
logger.info(
"[SDK] [%s] Sending query — resume=%s, "
"total_msgs=%d, query_len=%d",
"[SDK] [%s] Sending query — resume=%s, total_msgs=%d, query_len=%d",
session_id[:12],
use_resume,
len(session.messages),
@@ -812,8 +798,7 @@ async def stream_chat_completion_sdk(
sdk_msg = done.pop().result()
except StopAsyncIteration:
logger.info(
"[SDK] [%s] Stream ended normally "
"(StopAsyncIteration)",
"[SDK] [%s] Stream ended normally (StopAsyncIteration)",
session_id[:12],
)
break
@@ -958,7 +943,7 @@ async def stream_chat_completion_sdk(
session = await upsert_chat_session(session)
except Exception as save_err:
logger.warning(
"[SDK] [%s] Incremental save " "failed: %s",
"[SDK] [%s] Incremental save failed: %s",
session_id[:12],
save_err,
)
@@ -983,7 +968,7 @@ async def stream_chat_completion_sdk(
session = await upsert_chat_session(session)
except Exception as save_err:
logger.warning(
"[SDK] [%s] Incremental save " "failed: %s",
"[SDK] [%s] Incremental save failed: %s",
session_id[:12],
save_err,
)
@@ -996,8 +981,7 @@ async def stream_chat_completion_sdk(
# server shutdown). Log and let the safety-net / finally
# blocks handle cleanup.
logger.warning(
"[SDK] [%s] Streaming loop cancelled "
"(asyncio.CancelledError)",
"[SDK] [%s] Streaming loop cancelled (asyncio.CancelledError)",
session_id[:12],
)
raise
@@ -1077,7 +1061,7 @@ async def stream_chat_completion_sdk(
elif captured_transcript.path:
raw_transcript = read_transcript_file(captured_transcript.path)
logger.debug(
"[SDK] Transcript source: stop hook (%s), " "read result: %s",
"[SDK] Transcript source: stop hook (%s), read result: %s",
captured_transcript.path,
f"{len(raw_transcript)}B" if raw_transcript else "None",
)

View File

@@ -43,7 +43,8 @@ _current_session: ContextVar[ChatSession | None] = ContextVar(
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
_pending_tool_outputs: ContextVar[dict[str, list[str]]] = ContextVar(
"pending_tool_outputs", default=None # type: ignore[arg-type]
"pending_tool_outputs",
default=None, # type: ignore[arg-type]
)
# Event signaled whenever stash_pending_tool_output() adds a new entry.
# Used by the streaming loop to wait for PostToolUse hooks to complete

View File

@@ -361,7 +361,6 @@ async def stream_chat_completion(
_continuation_message_id: (
str | None
) = None, # Internal: reuse message ID for tool call continuations
_task_id: str | None = None, # Internal: task ID for SSE reconnection support
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Main entry point for streaming chat completions with database handling.
@@ -543,7 +542,7 @@ async def stream_chat_completion(
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
)
if not is_continuation:
yield StreamStart(messageId=message_id, taskId=_task_id)
yield StreamStart(messageId=message_id, taskId=session.session_id)
# Emit start-step before each LLM call (AI SDK uses this to add step boundaries)
yield StreamStartStep()
@@ -793,7 +792,6 @@ async def stream_chat_completion(
session=session,
context=context,
_continuation_message_id=message_id, # Reuse message ID since start was already sent
_task_id=_task_id,
):
yield chunk
return # Exit after retry to avoid double-saving in finally block
@@ -868,7 +866,6 @@ async def stream_chat_completion(
context=context,
tool_call_response=str(tool_response_messages),
_continuation_message_id=message_id, # Reuse message ID to avoid duplicates
_task_id=_task_id,
):
yield chunk
@@ -1450,19 +1447,8 @@ async def _yield_tool_call(
)
return
# Generate operation ID and task ID
# Generate operation ID for tracking
operation_id = str(uuid_module.uuid4())
task_id = str(uuid_module.uuid4())
# Create task in stream registry for SSE reconnection support
await stream_registry.create_task(
task_id=task_id,
session_id=session.session_id,
user_id=session.user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
operation_id=operation_id,
)
# Save tool_call to session before execution
async def _save_tool_call() -> None:
@@ -1472,7 +1458,7 @@ async def _yield_tool_call(
await _with_optional_lock(session_lock, _save_tool_call)
logger.info(
f"Starting synchronous execution of {tool_name} "
f"(operation_id={operation_id}, task_id={task_id})"
f"(operation_id={operation_id}, session_id={session.session_id})"
)
# Execute tool SYNCHRONOUSLY - blocks until complete
@@ -1485,7 +1471,6 @@ async def _yield_tool_call(
parameters=arguments,
tool_call_id=tool_call_id,
operation_id=operation_id,
task_id=task_id,
session_id=session.session_id,
user_id=session.user_id,
)
@@ -1508,7 +1493,7 @@ async def _yield_tool_call(
logger.error(f"Long-running tool {tool_name} failed: {e}", exc_info=True)
# Mark task as failed
await stream_registry.mark_task_completed(
task_id,
session.session_id,
status="failed",
error_message=str(e),
)
@@ -1535,7 +1520,9 @@ async def _yield_tool_call(
)
# Mark task as completed and clean up
await stream_registry.mark_task_completed(task_id, status="completed")
await stream_registry.mark_task_completed(
session.session_id, status="completed"
)
await _mark_operation_completed(tool_call_id)
yield StreamToolOutputAvailable(
@@ -1603,7 +1590,6 @@ async def _execute_long_running_tool_with_streaming(
parameters: dict[str, Any],
tool_call_id: str,
operation_id: str,
task_id: str,
session_id: str,
user_id: str | None,
) -> str | None:
@@ -1613,7 +1599,7 @@ async def _execute_long_running_tool_with_streaming(
polls Redis until the completion consumer processes the result, then returns it.
Progress is published to stream_registry for SSE reconnection - clients can
reconnect via GET /chat/tasks/{task_id}/stream if they disconnect.
reconnect via GET /chat/sessions/{session_id}/stream if they disconnect.
Returns:
The tool result as a JSON string, or None on error.
@@ -1624,14 +1610,14 @@ async def _execute_long_running_tool_with_streaming(
if not session:
logger.error(f"Session {session_id} not found for background tool")
await stream_registry.mark_task_completed(
task_id,
session_id,
status="failed",
error_message=f"Session {session_id} not found",
)
return
# Execute the tool synchronously (do NOT pass operation_id/task_id to force sync mode)
# The operation_id/task_id are only for our internal task tracking, not for the tool
# Execute the tool synchronously (do NOT pass operation_id/session_id to force sync mode)
# The operation_id/session_id are only for our internal task tracking, not for the tool
result = await execute_tool(
tool_name=tool_name,
parameters=parameters, # No enrichment - forces synchronous execution
@@ -1642,7 +1628,7 @@ async def _execute_long_running_tool_with_streaming(
# Tool executed synchronously (no async/webhook mode)
# Publish tool result to stream registry
await stream_registry.publish_chunk(task_id, result)
await stream_registry.publish_chunk(session_id, result)
# Serialize result
result_str = (
@@ -1653,11 +1639,10 @@ async def _execute_long_running_tool_with_streaming(
logger.info(
f"Tool {tool_name} completed synchronously for session {session_id} "
f"(task_id={task_id})"
f"(session_id={session_id})"
)
# Mark task as completed and clean up
await stream_registry.mark_task_completed(task_id, status="completed")
# Mark operation as completed (but don't complete the main task - that's done by processor)
await _mark_operation_completed(tool_call_id)
# Return the result to Claude
@@ -1671,11 +1656,11 @@ async def _execute_long_running_tool_with_streaming(
# Publish error to stream registry followed by finish event
await stream_registry.publish_chunk(
task_id,
session_id,
StreamError(errorText=str(e)),
)
await stream_registry.publish_chunk(task_id, StreamFinishStep())
await stream_registry.publish_chunk(task_id, StreamFinish())
await stream_registry.publish_chunk(session_id, StreamFinishStep())
await stream_registry.publish_chunk(session_id, StreamFinish())
await _update_pending_operation(
session_id=session_id,
@@ -1684,7 +1669,7 @@ async def _execute_long_running_tool_with_streaming(
)
# Mark task as failed in stream registry
await stream_registry.mark_task_completed(task_id, status="failed")
await stream_registry.mark_task_completed(session_id, status="failed")
finally:
# Clean up operation lock
await _mark_operation_completed(tool_call_id)
@@ -1982,7 +1967,6 @@ def _sanitize_error_body(
async def _generate_llm_continuation_with_streaming(
session_id: str,
user_id: str | None,
task_id: str,
) -> None:
"""Generate an LLM response with streaming to the stream registry.
@@ -2040,9 +2024,13 @@ async def _generate_llm_continuation_with_streaming(
text_block_id = str(uuid_module.uuid4())
# Publish start event
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
await stream_registry.publish_chunk(task_id, StreamStartStep())
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
await stream_registry.publish_chunk(
session_id, StreamStart(messageId=message_id)
)
await stream_registry.publish_chunk(session_id, StreamStartStep())
await stream_registry.publish_chunk(
session_id, StreamTextStart(id=text_block_id)
)
# Stream the response
stream = await client.chat.completions.create(
@@ -2059,13 +2047,13 @@ async def _generate_llm_continuation_with_streaming(
assistant_content += delta
# Publish delta to stream registry
await stream_registry.publish_chunk(
task_id,
session_id,
StreamTextDelta(id=text_block_id, delta=delta),
)
# Publish end events
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
await stream_registry.publish_chunk(task_id, StreamFinishStep())
await stream_registry.publish_chunk(session_id, StreamTextEnd(id=text_block_id))
await stream_registry.publish_chunk(session_id, StreamFinishStep())
if assistant_content:
# Reload session from DB to avoid race condition with user messages
@@ -2091,7 +2079,7 @@ async def _generate_llm_continuation_with_streaming(
logger.info(
f"Generated streaming LLM continuation for session {session_id} "
f"(task_id={task_id}), response length: {len(assistant_content)}"
f"(session_id={session_id}), response length: {len(assistant_content)}"
)
else:
logger.warning(
@@ -2104,8 +2092,8 @@ async def _generate_llm_continuation_with_streaming(
)
# Publish error to stream registry followed by finish event
await stream_registry.publish_chunk(
task_id,
session_id,
StreamError(errorText=f"Failed to generate response: {e}"),
)
await stream_registry.publish_chunk(task_id, StreamFinishStep())
await stream_registry.publish_chunk(task_id, StreamFinish())
await stream_registry.publish_chunk(session_id, StreamFinishStep())
await stream_registry.publish_chunk(session_id, StreamFinish())

View File

@@ -34,7 +34,7 @@ config = ChatConfig()
_local_tasks: dict[str, asyncio.Task] = {}
# Track listener tasks per subscriber queue for cleanup
# Maps queue id() to (task_id, asyncio.Task) for proper cleanup on unsubscribe
# Maps queue id() to (session_id, asyncio.Task) for proper cleanup on unsubscribe
_listener_tasks: dict[int, tuple[str, asyncio.Task]] = {}
# Timeout for putting chunks into subscriber queues (seconds)
@@ -57,7 +57,7 @@ return 0
class ActiveTask:
"""Represents an active streaming task (metadata only, no in-memory queues)."""
task_id: str
task_id: str # For backwards compatibility, equals session_id
session_id: str
user_id: str | None
tool_call_id: str
@@ -68,14 +68,14 @@ class ActiveTask:
asyncio_task: asyncio.Task | None = None
def _get_task_meta_key(task_id: str) -> str:
"""Get Redis key for task metadata."""
return f"{config.task_meta_prefix}{task_id}"
def _get_task_meta_key(session_id: str) -> str:
"""Get Redis key for task metadata (now keyed by session_id)."""
return f"{config.task_meta_prefix}{session_id}"
def _get_task_stream_key(task_id: str) -> str:
"""Get Redis key for task message stream."""
return f"{config.task_stream_prefix}{task_id}"
def _get_task_stream_key(session_id: str) -> str:
"""Get Redis key for task message stream (now keyed by session_id)."""
return f"{config.task_stream_prefix}{session_id}"
def _get_operation_mapping_key(operation_id: str) -> str:
@@ -84,18 +84,16 @@ def _get_operation_mapping_key(operation_id: str) -> str:
async def create_task(
task_id: str,
session_id: str,
user_id: str | None,
tool_call_id: str,
tool_name: str,
operation_id: str,
) -> ActiveTask:
"""Create a new streaming task in Redis.
"""Create a new streaming task in Redis (keyed by session_id).
Args:
task_id: Unique identifier for the task
session_id: Chat session ID
session_id: Chat session ID (used as task identifier)
user_id: User ID (may be None for anonymous)
tool_call_id: Tool call ID from the LLM
tool_name: Name of the tool being executed
@@ -111,19 +109,19 @@ async def create_task(
# Build log metadata for structured logging
log_meta = {
"component": "StreamRegistry",
"task_id": task_id,
"session_id": session_id,
}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
f"[TIMING] create_task STARTED, session={session_id}, user={user_id}",
extra={"json_fields": log_meta},
)
# Use session_id as task_id (no separate task_id needed)
task = ActiveTask(
task_id=task_id,
task_id=session_id, # task_id = session_id
session_id=session_id,
user_id=user_id,
tool_call_id=tool_call_id,
@@ -140,14 +138,13 @@ async def create_task(
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
)
meta_key = _get_task_meta_key(task_id)
meta_key = _get_task_meta_key(session_id)
op_key = _get_operation_mapping_key(operation_id)
hset_start = time.perf_counter()
await redis.hset( # type: ignore[misc]
meta_key,
mapping={
"task_id": task_id,
"session_id": session_id,
"user_id": user_id or "",
"tool_call_id": tool_call_id,
@@ -165,12 +162,12 @@ async def create_task(
await redis.expire(meta_key, config.stream_ttl)
# Create operation_id -> task_id mapping for webhook lookups
await redis.set(op_key, task_id, ex=config.stream_ttl)
# Create operation_id -> session_id mapping for webhook lookups
await redis.set(op_key, session_id, ex=config.stream_ttl)
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; session={session_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
)
@@ -178,7 +175,7 @@ async def create_task(
async def publish_chunk(
task_id: str,
session_id: str,
chunk: StreamBaseResponse,
) -> str:
"""Publish a chunk to Redis Stream.
@@ -186,7 +183,7 @@ async def publish_chunk(
All delivery is via Redis Streams - no in-memory state.
Args:
task_id: Task ID to publish to
session_id: Session ID to publish to
chunk: The stream response chunk to publish
Returns:
@@ -202,13 +199,13 @@ async def publish_chunk(
# Build log metadata
log_meta = {
"component": "StreamRegistry",
"task_id": task_id,
"session_id": session_id,
"chunk_type": chunk_type,
}
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
stream_key = _get_task_stream_key(session_id)
# Write to Redis Stream for persistence and real-time delivery
xadd_start = time.perf_counter()
@@ -260,7 +257,7 @@ async def publish_chunk(
async def subscribe_to_task(
task_id: str,
session_id: str,
user_id: str | None,
last_message_id: str = "0-0",
) -> asyncio.Queue[StreamBaseResponse] | None:
@@ -269,7 +266,7 @@ async def subscribe_to_task(
This is fully stateless - uses Redis Stream for replay and pub/sub for live updates.
Args:
task_id: Task ID to subscribe to
session_id: Session ID to subscribe to
user_id: User ID for ownership validation
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
@@ -282,18 +279,18 @@ async def subscribe_to_task(
start_time = time.perf_counter()
# Build log metadata
log_meta = {"component": "StreamRegistry", "task_id": task_id}
log_meta = {"component": "StreamRegistry", "session_id": session_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
f"[TIMING] subscribe_to_task STARTED, task={session_id}, user={user_id}, last_msg={last_message_id}",
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
)
redis_start = time.perf_counter()
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta_key = _get_task_meta_key(session_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
hgetall_time = (time.perf_counter() - redis_start) * 1000
logger.info(
@@ -336,7 +333,7 @@ async def subscribe_to_task(
return None
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
stream_key = _get_task_stream_key(task_id)
stream_key = _get_task_stream_key(session_id)
# Step 1: Replay messages from Redis Stream
xread_start = time.perf_counter()
@@ -388,10 +385,10 @@ async def subscribe_to_task(
extra={"json_fields": {**log_meta, "task_status": task_status}},
)
listener_task = asyncio.create_task(
_stream_listener(task_id, subscriber_queue, replay_last_id, log_meta)
_stream_listener(session_id, subscriber_queue, replay_last_id, log_meta)
)
# Track listener task for cleanup on unsubscribe
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
_listener_tasks[id(subscriber_queue)] = (session_id, listener_task)
else:
# Task is completed/failed - add finish marker
logger.info(
@@ -402,7 +399,7 @@ async def subscribe_to_task(
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={session_id}, "
f"n_messages_replayed={replayed_count}",
extra={
"json_fields": {
@@ -416,7 +413,7 @@ async def subscribe_to_task(
async def _stream_listener(
task_id: str,
session_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
last_replayed_id: str,
log_meta: dict | None = None,
@@ -427,7 +424,7 @@ async def _stream_listener(
when messages are published during the gap between replay and subscription.
Args:
task_id: Task ID to listen for
session_id: Session ID to listen for
subscriber_queue: Queue to deliver messages to
last_replayed_id: Last message ID from replay (continue from here)
log_meta: Structured logging metadata
@@ -438,10 +435,10 @@ async def _stream_listener(
# Use provided log_meta or build minimal one
if log_meta is None:
log_meta = {"component": "StreamRegistry", "task_id": task_id}
log_meta = {"component": "StreamRegistry", "session_id": session_id}
logger.info(
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
f"[TIMING] _stream_listener STARTED, task={session_id}, last_id={last_replayed_id}",
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
)
@@ -454,7 +451,7 @@ async def _stream_listener(
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
stream_key = _get_task_stream_key(session_id)
current_id = last_replayed_id
while True:
@@ -496,7 +493,7 @@ async def _stream_listener(
if not messages:
# Timeout - check if task is still running
meta_key = _get_task_meta_key(task_id)
meta_key = _get_task_meta_key(session_id)
status = await redis.hget(meta_key, "status") # type: ignore[misc]
if status and status != "running":
try:
@@ -506,7 +503,7 @@ async def _stream_listener(
)
except asyncio.TimeoutError:
logger.warning(
f"Timeout delivering finish event for task {task_id}"
f"Timeout delivering finish event for task {session_id}"
)
break
continue
@@ -568,7 +565,7 @@ async def _stream_listener(
except asyncio.QueueFull:
# Queue is completely stuck, nothing more we can do
logger.error(
f"Cannot deliver overflow error for task {task_id}, "
f"Cannot deliver overflow error for task {session_id}, "
"queue completely blocked"
)
@@ -627,7 +624,7 @@ async def _stream_listener(
# Clean up listener task mapping on exit
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] _stream_listener FINISHED in {total_time / 1000:.1f}s; task={task_id}, "
f"[TIMING] _stream_listener FINISHED in {total_time / 1000:.1f}s; task={session_id}, "
f"delivered={messages_delivered}, xread_count={xread_count}",
extra={
"json_fields": {
@@ -642,7 +639,7 @@ async def _stream_listener(
async def mark_task_completed(
task_id: str,
session_id: str,
status: Literal["completed", "failed"] = "completed",
*,
error_message: str | None = None,
@@ -654,7 +651,7 @@ async def mark_task_completed(
Status is updated first (source of truth), then finish event is published (best-effort).
Args:
task_id: Task ID to mark as completed
session_id: Session ID to mark as completed
status: Final status ("completed" or "failed")
error_message: If provided and status="failed", publish a StreamError
before StreamFinish so connected clients see why the task ended.
@@ -665,14 +662,14 @@ async def mark_task_completed(
True if task was newly marked completed, False if already completed/failed
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta_key = _get_task_meta_key(session_id)
# Atomic compare-and-swap: only update if status is "running"
# This prevents race conditions when multiple callers try to complete simultaneously
result = await redis.eval(COMPLETE_TASK_SCRIPT, 1, meta_key, status) # type: ignore[misc]
if result == 0:
logger.debug(f"Task {task_id} already completed/failed, skipping")
logger.debug(f"Task {session_id} already completed/failed, skipping")
return False
# Publish error event before finish so connected clients know WHY the
@@ -682,21 +679,21 @@ async def mark_task_completed(
# listeners clean up.
if status == "failed" and error_message:
try:
await publish_chunk(task_id, StreamError(errorText=error_message))
await publish_chunk(session_id, StreamError(errorText=error_message))
except Exception as e:
logger.warning(f"Failed to publish error event for task {task_id}: {e}")
logger.warning(f"Failed to publish error event for task {session_id}: {e}")
# THEN publish finish event (best-effort - listeners can detect via status polling)
try:
await publish_chunk(task_id, StreamFinish())
await publish_chunk(session_id, StreamFinish())
except Exception as e:
logger.error(
f"Failed to publish finish event for task {task_id}: {e}. "
f"Failed to publish finish event for task {session_id}: {e}. "
"Listeners will detect completion via status polling."
)
# Clean up local task reference if exists
_local_tasks.pop(task_id, None)
_local_tasks.pop(session_id, None)
return True
@@ -713,26 +710,28 @@ async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
"""
redis = await get_redis_async()
op_key = _get_operation_mapping_key(operation_id)
task_id = await redis.get(op_key)
session_id = await redis.get(op_key)
if not task_id:
if not session_id:
return None
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
return await get_task(task_id_str)
session_id_str = (
session_id.decode() if isinstance(session_id, bytes) else session_id
)
return await get_task(session_id_str)
async def get_task(task_id: str) -> ActiveTask | None:
async def get_task(session_id: str) -> ActiveTask | None:
"""Get a task by its ID from Redis.
Args:
task_id: Task ID to look up
session_id: Session ID to look up
Returns:
ActiveTask if found, None otherwise
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta_key = _get_task_meta_key(session_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
@@ -751,7 +750,7 @@ async def get_task(task_id: str) -> ActiveTask | None:
async def get_task_with_expiry_info(
task_id: str,
session_id: str,
) -> tuple[ActiveTask | None, str | None]:
"""Get a task by its ID with expiration detection.
@@ -761,14 +760,14 @@ async def get_task_with_expiry_info(
- "TASK_NOT_FOUND" if neither exists
Args:
task_id: Task ID to look up
session_id: Session ID to look up
Returns:
Tuple of (ActiveTask or None, error_code or None)
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
stream_key = _get_task_stream_key(task_id)
meta_key = _get_task_meta_key(session_id)
stream_key = _get_task_stream_key(session_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
@@ -831,7 +830,6 @@ async def get_active_task_for_session(
task_session_id = meta.get("session_id", "")
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
task_id = meta.get("task_id", "")
if task_session_id == session_id and task_status == "running":
# Validate ownership - if task has an owner, requester must match
@@ -849,11 +847,11 @@ async def get_active_task_for_session(
).total_seconds()
if age_seconds > 600: # 10 minutes
logger.warning(
f"[STALE_TASK] Auto-completing stale task {task_id[:8]}... "
f"[STALE_TASK] Auto-completing stale session {session_id[:8]}... "
f"(running for {age_seconds:.0f}s)"
)
await mark_task_completed(
task_id,
session_id,
status="failed",
error_message=f"Task timed out after {age_seconds:.0f}s",
)
@@ -861,12 +859,10 @@ async def get_active_task_for_session(
except (ValueError, TypeError) as e:
logger.warning(f"Failed to parse created_at: {e}")
logger.info(
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
)
logger.info(f"[TASK_LOOKUP] Found running session {session_id[:8]}...")
# Get the last message ID from Redis Stream
stream_key = _get_task_stream_key(task_id)
stream_key = _get_task_stream_key(session_id)
last_id = "0-0"
try:
messages = await redis.xrevrange(stream_key, count=1)
@@ -878,8 +874,8 @@ async def get_active_task_for_session(
return (
ActiveTask(
task_id=task_id,
session_id=task_session_id,
task_id=session_id,
session_id=session_id,
user_id=task_user_id,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
@@ -952,20 +948,20 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
return None
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
async def set_task_asyncio_task(session_id: str, asyncio_task: asyncio.Task) -> None:
"""Track the asyncio.Task for a task (local reference only).
This is just for cleanup purposes - the task state is in Redis.
Args:
task_id: Task ID
session_id: Session ID
asyncio_task: The asyncio Task to track
"""
_local_tasks[task_id] = asyncio_task
_local_tasks[session_id] = asyncio_task
async def unsubscribe_from_task(
task_id: str,
session_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
) -> None:
"""Clean up when a subscriber disconnects.
@@ -974,7 +970,7 @@ async def unsubscribe_from_task(
to prevent resource leaks.
Args:
task_id: Task ID
session_id: Session ID
subscriber_queue: The subscriber's queue used to look up the listener task
"""
queue_id = id(subscriber_queue)
@@ -982,21 +978,21 @@ async def unsubscribe_from_task(
if listener_entry is None:
logger.debug(
f"No listener task found for task {task_id} queue {queue_id} "
f"No listener task found for task {session_id} queue {queue_id} "
"(may have already completed)"
)
return
stored_task_id, listener_task = listener_entry
stored_session_id, listener_task = listener_entry
if stored_task_id != task_id:
if stored_session_id != session_id:
logger.warning(
f"Task ID mismatch in unsubscribe: expected {task_id}, "
f"found {stored_task_id}"
f"Session ID mismatch in unsubscribe: expected {session_id}, "
f"found {stored_session_id}"
)
if listener_task.done():
logger.debug(f"Listener task for task {task_id} already completed")
logger.debug(f"Listener task for task {session_id} already completed")
return
# Cancel the listener task
@@ -1010,9 +1006,11 @@ async def unsubscribe_from_task(
pass
except asyncio.TimeoutError:
logger.warning(
f"Timeout waiting for listener task cancellation for task {task_id}"
f"Timeout waiting for listener task cancellation for task {session_id}"
)
except Exception as e:
logger.error(f"Error during listener task cancellation for task {task_id}: {e}")
logger.error(
f"Error during listener task cancellation for task {session_id}: {e}"
)
logger.debug(f"Successfully unsubscribed from task {task_id}")
logger.debug(f"Successfully unsubscribed from task {session_id}")

View File

@@ -10,7 +10,6 @@ from .agent_generator import (
decompose_goal,
enrich_library_agents_from_steps,
generate_agent,
get_all_relevant_agents_for_generation,
get_user_message_for_error,
save_agent_to_library,
)

View File

@@ -9,7 +9,6 @@ from .agent_generator import (
AgentGeneratorNotConfiguredError,
generate_agent_patch,
get_agent_as_json,
get_all_relevant_agents_for_generation,
get_user_message_for_error,
save_agent_to_library,
)

View File

@@ -366,12 +366,15 @@ class TestFindBlockFiltering:
return_value=(search_results, len(search_results))
)
with patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
), patch(
"backend.copilot.tools.find_block.get_block",
side_effect=lambda bid: mock_blocks.get(bid),
with (
patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
),
patch(
"backend.copilot.tools.find_block.get_block",
side_effect=lambda bid: mock_blocks.get(bid),
),
):
tool = FindBlockTool()
response = await tool._execute(

View File

@@ -160,9 +160,10 @@ class RunBlockTool(BaseTool):
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = (
await self._resolve_block_credentials(user_id, block, input_data)
)
(
matched_credentials,
missing_credentials,
) = await self._resolve_block_credentials(user_id, block, input_data)
# Get block schemas for details/validation
try:

View File

@@ -431,7 +431,7 @@ class UserCreditBase(ABC):
current_balance, _ = await self._get_credits(user_id)
if current_balance >= ceiling_balance:
raise ValueError(
f"You already have enough balance of ${current_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
f"You already have enough balance of ${current_balance / 100}, top-up is not required when you already have at least ${ceiling_balance / 100}"
)
# Single unified atomic operation for all transaction types using UserBalance
@@ -570,7 +570,7 @@ class UserCreditBase(ABC):
if amount < 0 and fail_insufficient_credits:
current_balance, _ = await self._get_credits(user_id)
raise InsufficientBalanceError(
message=f"Insufficient balance of ${current_balance/100}, where this will cost ${abs(amount)/100}",
message=f"Insufficient balance of ${current_balance / 100}, where this will cost ${abs(amount) / 100}",
user_id=user_id,
balance=current_balance,
amount=amount,
@@ -581,7 +581,6 @@ class UserCreditBase(ABC):
class UserCredit(UserCreditBase):
async def _send_refund_notification(
self,
notification_request: RefundRequestData,
@@ -733,7 +732,7 @@ class UserCredit(UserCreditBase):
)
if request.amount <= 0 or request.amount > transaction.amount:
raise AssertionError(
f"Invalid amount to deduct ${request.amount/100} from ${transaction.amount/100} top-up"
f"Invalid amount to deduct ${request.amount / 100} from ${transaction.amount / 100} top-up"
)
balance, _ = await self._add_transaction(
@@ -787,12 +786,12 @@ class UserCredit(UserCreditBase):
# If the user has enough balance, just let them win the dispute.
if balance - amount >= settings.config.refund_credit_tolerance_threshold:
logger.warning(f"Accepting dispute from {user_id} for ${amount/100}")
logger.warning(f"Accepting dispute from {user_id} for ${amount / 100}")
dispute.close()
return
logger.warning(
f"Adding extra info for dispute from {user_id} for ${amount/100}"
f"Adding extra info for dispute from {user_id} for ${amount / 100}"
)
# Retrieve recent transaction history to support our evidence.
# This provides a concise timeline that shows service usage and proper credit application.

View File

@@ -507,7 +507,7 @@ async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestSer
sorted_timings = sorted(timings.items(), key=lambda x: x[1]["start"])
print("\nExecution order by start time:")
for i, (label, timing) in enumerate(sorted_timings):
print(f" {i+1}. {label}: {timing['start']:.4f} -> {timing['end']:.4f}")
print(f" {i + 1}. {label}: {timing['start']:.4f} -> {timing['end']:.4f}")
# Check for overlap (true concurrency) vs serialization
overlaps = []
@@ -546,7 +546,7 @@ async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestSer
print("\nDatabase transaction order (by createdAt):")
for i, tx in enumerate(transactions):
print(
f" {i+1}. Amount {tx.amount}, Running balance: {tx.runningBalance}, Created: {tx.createdAt}"
f" {i + 1}. Amount {tx.amount}, Running balance: {tx.runningBalance}, Created: {tx.createdAt}"
)
# Verify running balances are chronologically consistent (ordered by createdAt)
@@ -707,7 +707,7 @@ async def test_prove_database_locking_behavior(server: SpinTestServer):
for i, result in enumerate(sorted_results):
print(
f" {i+1}. {result['label']}: DB operation took {result['db_duration']:.4f}s"
f" {i + 1}. {result['label']}: DB operation took {result['db_duration']:.4f}s"
)
# Check if any operations overlapped at the database level

View File

@@ -581,7 +581,6 @@ class GraphModel(Graph, GraphMeta):
field_name,
field_info,
) in node.block.input_schema.get_credentials_fields_info().items():
discriminator = field_info.discriminator
if not discriminator:
node_credential_data.append((field_info, (node.id, field_name)))

View File

@@ -472,7 +472,6 @@ class UserMetadataRaw(TypedDict, total=False):
class UserIntegrations(BaseModel):
class ManagedCredentials(BaseModel):
"""Integration credentials managed by us, rather than by the user"""

View File

@@ -156,8 +156,7 @@ async def create_workspace_file(
)
logger.info(
f"Created workspace file {file.id} at path {path} "
f"in workspace {workspace_id}"
f"Created workspace file {file.id} at path {path} in workspace {workspace_id}"
)
return WorkspaceFile.from_db(file)

View File

@@ -379,8 +379,9 @@ class TestLLMCall:
from backend.blocks.llm import AIStructuredResponseGeneratorBlock
from backend.data.model import APIKeyCredentials
with patch("backend.blocks.llm.llm_call") as mock_llm_call, patch(
"backend.blocks.llm.secrets.token_hex", return_value="test123"
with (
patch("backend.blocks.llm.llm_call") as mock_llm_call,
patch("backend.blocks.llm.secrets.token_hex", return_value="test123"),
):
mock_llm_call.return_value = LLMResponse(
raw_response={},
@@ -442,8 +443,9 @@ class TestLLMCall:
from backend.blocks.llm import AIStructuredResponseGeneratorBlock
from backend.data.model import APIKeyCredentials
with patch("backend.blocks.llm.llm_call") as mock_llm_call, patch(
"backend.blocks.llm.secrets.token_hex", return_value="test123"
with (
patch("backend.blocks.llm.llm_call") as mock_llm_call,
patch("backend.blocks.llm.secrets.token_hex", return_value="test123"),
):
# Return invalid JSON that will fail validation (missing required field)
mock_llm_call.return_value = LLMResponse(
@@ -515,17 +517,21 @@ class TestGenerateActivityStatusForExecution:
mock_graph.links = []
mock_db_client.get_graph.return_value = mock_graph
with patch(
"backend.executor.activity_status_generator.get_block"
) as mock_get_block, patch(
"backend.executor.activity_status_generator.Settings"
) as mock_settings, patch(
"backend.executor.activity_status_generator.AIStructuredResponseGeneratorBlock"
) as mock_structured_block, patch(
"backend.executor.activity_status_generator.is_feature_enabled",
return_value=True,
with (
patch(
"backend.executor.activity_status_generator.get_block"
) as mock_get_block,
patch(
"backend.executor.activity_status_generator.Settings"
) as mock_settings,
patch(
"backend.executor.activity_status_generator.AIStructuredResponseGeneratorBlock"
) as mock_structured_block,
patch(
"backend.executor.activity_status_generator.is_feature_enabled",
return_value=True,
),
):
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
@@ -533,10 +539,13 @@ class TestGenerateActivityStatusForExecution:
mock_instance = mock_structured_block.return_value
async def mock_run(*args, **kwargs):
yield "response", {
"activity_status": "I analyzed your data and provided the requested insights.",
"correctness_score": 0.85,
}
yield (
"response",
{
"activity_status": "I analyzed your data and provided the requested insights.",
"correctness_score": 0.85,
},
)
mock_instance.run = mock_run
@@ -586,11 +595,14 @@ class TestGenerateActivityStatusForExecution:
"""Test activity status generation with no API key."""
mock_db_client = AsyncMock()
with patch(
"backend.executor.activity_status_generator.Settings"
) as mock_settings, patch(
"backend.executor.activity_status_generator.is_feature_enabled",
return_value=True,
with (
patch(
"backend.executor.activity_status_generator.Settings"
) as mock_settings,
patch(
"backend.executor.activity_status_generator.is_feature_enabled",
return_value=True,
),
):
mock_settings.return_value.secrets.openai_internal_api_key = ""
@@ -612,11 +624,14 @@ class TestGenerateActivityStatusForExecution:
mock_db_client = AsyncMock()
mock_db_client.get_node_executions.side_effect = Exception("Database error")
with patch(
"backend.executor.activity_status_generator.Settings"
) as mock_settings, patch(
"backend.executor.activity_status_generator.is_feature_enabled",
return_value=True,
with (
patch(
"backend.executor.activity_status_generator.Settings"
) as mock_settings,
patch(
"backend.executor.activity_status_generator.is_feature_enabled",
return_value=True,
),
):
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
@@ -641,17 +656,21 @@ class TestGenerateActivityStatusForExecution:
mock_db_client.get_graph_metadata.return_value = None # No metadata
mock_db_client.get_graph.return_value = None # No graph
with patch(
"backend.executor.activity_status_generator.get_block"
) as mock_get_block, patch(
"backend.executor.activity_status_generator.Settings"
) as mock_settings, patch(
"backend.executor.activity_status_generator.AIStructuredResponseGeneratorBlock"
) as mock_structured_block, patch(
"backend.executor.activity_status_generator.is_feature_enabled",
return_value=True,
with (
patch(
"backend.executor.activity_status_generator.get_block"
) as mock_get_block,
patch(
"backend.executor.activity_status_generator.Settings"
) as mock_settings,
patch(
"backend.executor.activity_status_generator.AIStructuredResponseGeneratorBlock"
) as mock_structured_block,
patch(
"backend.executor.activity_status_generator.is_feature_enabled",
return_value=True,
),
):
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
@@ -659,10 +678,13 @@ class TestGenerateActivityStatusForExecution:
mock_instance = mock_structured_block.return_value
async def mock_run(*args, **kwargs):
yield "response", {
"activity_status": "Agent completed execution.",
"correctness_score": 0.8,
}
yield (
"response",
{
"activity_status": "Agent completed execution.",
"correctness_score": 0.8,
},
)
mock_instance.run = mock_run
@@ -704,17 +726,21 @@ class TestIntegration:
expected_activity = "I processed user input but failed during final output generation due to system error."
with patch(
"backend.executor.activity_status_generator.get_block"
) as mock_get_block, patch(
"backend.executor.activity_status_generator.Settings"
) as mock_settings, patch(
"backend.executor.activity_status_generator.AIStructuredResponseGeneratorBlock"
) as mock_structured_block, patch(
"backend.executor.activity_status_generator.is_feature_enabled",
return_value=True,
with (
patch(
"backend.executor.activity_status_generator.get_block"
) as mock_get_block,
patch(
"backend.executor.activity_status_generator.Settings"
) as mock_settings,
patch(
"backend.executor.activity_status_generator.AIStructuredResponseGeneratorBlock"
) as mock_structured_block,
patch(
"backend.executor.activity_status_generator.is_feature_enabled",
return_value=True,
),
):
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
@@ -722,10 +748,13 @@ class TestIntegration:
mock_instance = mock_structured_block.return_value
async def mock_run(*args, **kwargs):
yield "response", {
"activity_status": expected_activity,
"correctness_score": 0.3, # Low score since there was a failure
}
yield (
"response",
{
"activity_status": expected_activity,
"correctness_score": 0.3, # Low score since there was a failure
},
)
mock_instance.run = mock_run

View File

@@ -20,7 +20,6 @@ logger = logging.getLogger(__name__)
class AutoModManager:
def __init__(self):
self.config = self._load_config()

View File

@@ -35,16 +35,14 @@ async def test_handle_insufficient_funds_sends_discord_alert_first_time(
amount=-714, # Attempting to spend $7.14
)
with patch(
"backend.executor.manager.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
with (
patch("backend.executor.manager.queue_notification") as mock_queue_notif,
patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client,
patch("backend.executor.manager.settings") as mock_settings,
patch("backend.executor.manager.redis") as mock_redis_module,
):
# Setup mocks
mock_client = MagicMock()
mock_get_client.return_value = mock_client
@@ -109,16 +107,14 @@ async def test_handle_insufficient_funds_skips_duplicate_notifications(
amount=-714,
)
with patch(
"backend.executor.manager.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
with (
patch("backend.executor.manager.queue_notification") as mock_queue_notif,
patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client,
patch("backend.executor.manager.settings") as mock_settings,
patch("backend.executor.manager.redis") as mock_redis_module,
):
# Setup mocks
mock_client = MagicMock()
mock_get_client.return_value = mock_client
@@ -166,14 +162,14 @@ async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
amount=-714,
)
with patch("backend.executor.manager.queue_notification"), patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
with (
patch("backend.executor.manager.queue_notification"),
patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client,
patch("backend.executor.manager.settings") as mock_settings,
patch("backend.executor.manager.redis") as mock_redis_module,
):
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_settings.config.frontend_base_url = "https://test.com"
@@ -228,7 +224,6 @@ async def test_clear_insufficient_funds_notifications(server: SpinTestServer):
user_id = "test-user-123"
with patch("backend.executor.manager.redis") as mock_redis_module:
mock_redis_client = MagicMock()
# get_redis_async is an async function, so we need AsyncMock for it
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
@@ -264,7 +259,6 @@ async def test_clear_insufficient_funds_notifications_no_keys(server: SpinTestSe
user_id = "test-user-no-notifications"
with patch("backend.executor.manager.redis") as mock_redis_module:
mock_redis_client = MagicMock()
# get_redis_async is an async function, so we need AsyncMock for it
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
@@ -291,7 +285,6 @@ async def test_clear_insufficient_funds_notifications_handles_redis_error(
user_id = "test-user-redis-error"
with patch("backend.executor.manager.redis") as mock_redis_module:
# Mock get_redis_async to raise an error
mock_redis_module.get_redis_async = AsyncMock(
side_effect=Exception("Redis connection failed")
@@ -320,16 +313,14 @@ async def test_handle_insufficient_funds_continues_on_redis_error(
amount=-714,
)
with patch(
"backend.executor.manager.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
with (
patch("backend.executor.manager.queue_notification") as mock_queue_notif,
patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client,
patch("backend.executor.manager.settings") as mock_settings,
patch("backend.executor.manager.redis") as mock_redis_module,
):
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_settings.config.frontend_base_url = "https://test.com"
@@ -369,10 +360,10 @@ async def test_add_transaction_clears_notifications_on_grant(server: SpinTestSer
user_id = "test-user-grant-clear"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
with (
patch("backend.data.credit.query_raw_with_schema") as mock_query,
patch("backend.executor.manager.redis") as mock_redis_module,
):
# Mock the query to return a successful transaction
mock_query.return_value = [{"balance": 1000, "transactionKey": "test-tx-key"}]
@@ -411,10 +402,10 @@ async def test_add_transaction_clears_notifications_on_top_up(server: SpinTestSe
user_id = "test-user-topup-clear"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
with (
patch("backend.data.credit.query_raw_with_schema") as mock_query,
patch("backend.executor.manager.redis") as mock_redis_module,
):
# Mock the query to return a successful transaction
mock_query.return_value = [{"balance": 2000, "transactionKey": "test-tx-key-2"}]
@@ -449,10 +440,10 @@ async def test_add_transaction_skips_clearing_for_inactive_transaction(
user_id = "test-user-inactive"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
with (
patch("backend.data.credit.query_raw_with_schema") as mock_query,
patch("backend.executor.manager.redis") as mock_redis_module,
):
# Mock the query to return a successful transaction
mock_query.return_value = [{"balance": 500, "transactionKey": "test-tx-key-3"}]
@@ -485,10 +476,10 @@ async def test_add_transaction_skips_clearing_for_usage_transaction(
user_id = "test-user-usage"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
with (
patch("backend.data.credit.query_raw_with_schema") as mock_query,
patch("backend.executor.manager.redis") as mock_redis_module,
):
# Mock the query to return a successful transaction
mock_query.return_value = [{"balance": 400, "transactionKey": "test-tx-key-4"}]
@@ -519,10 +510,11 @@ async def test_enable_transaction_clears_notifications(server: SpinTestServer):
user_id = "test-user-enable"
with patch("backend.data.credit.CreditTransaction") as mock_credit_tx, patch(
"backend.data.credit.query_raw_with_schema"
) as mock_query, patch("backend.executor.manager.redis") as mock_redis_module:
with (
patch("backend.data.credit.CreditTransaction") as mock_credit_tx,
patch("backend.data.credit.query_raw_with_schema") as mock_query,
patch("backend.executor.manager.redis") as mock_redis_module,
):
# Mock finding the pending transaction
mock_transaction = MagicMock()
mock_transaction.amount = 1000

View File

@@ -18,14 +18,13 @@ async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
transaction_cost = 600 # $6 transaction
# Mock dependencies
with patch(
"backend.executor.manager.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings:
with (
patch("backend.executor.manager.queue_notification") as mock_queue_notif,
patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client,
patch("backend.executor.manager.settings") as mock_settings,
):
# Setup mocks
mock_client = MagicMock()
mock_get_client.return_value = mock_client
@@ -77,14 +76,13 @@ async def test_handle_low_balance_no_notification_when_not_crossing(
)
# Mock dependencies
with patch(
"backend.executor.manager.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings:
with (
patch("backend.executor.manager.queue_notification") as mock_queue_notif,
patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client,
patch("backend.executor.manager.settings") as mock_settings,
):
# Setup mocks
mock_client = MagicMock()
mock_get_client.return_value = mock_client
@@ -120,14 +118,13 @@ async def test_handle_low_balance_no_duplicate_when_already_below(
)
# Mock dependencies
with patch(
"backend.executor.manager.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings:
with (
patch("backend.executor.manager.queue_notification") as mock_queue_notif,
patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client,
patch("backend.executor.manager.settings") as mock_settings,
):
# Setup mocks
mock_client = MagicMock()
mock_get_client.return_value = mock_client

View File

@@ -301,7 +301,7 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
assert len(graph_exec.node_executions) == 8
# The last 3 executions will be a+b=4+5=9
for i, exec_data in enumerate(graph_exec.node_executions[-3:]):
logger.info(f"Checking execution {i+1} of last 3: {exec_data}")
logger.info(f"Checking execution {i + 1} of last 3: {exec_data}")
assert exec_data.status == execution.ExecutionStatus.COMPLETED
assert exec_data.output_data == {"result": [9]}
logger.info("Completed test_static_input_link_on_graph")

View File

@@ -348,8 +348,7 @@ async def _validate_node_input_credentials(
and node.id not in credential_errors
):
logger.info(
f"Node #{node.id}: optional credentials not configured, "
"running without"
f"Node #{node.id}: optional credentials not configured, running without"
)
return credential_errors, nodes_to_skip
@@ -411,9 +410,10 @@ async def validate_graph_with_credentials(
)
# Get credential input/availability/validation errors and nodes to skip
node_credential_input_errors, nodes_to_skip = (
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
)
(
node_credential_input_errors,
nodes_to_skip,
) = await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
# Merge credential errors with structural errors
for node_id, field_errors in node_credential_input_errors.items():
@@ -561,13 +561,14 @@ async def validate_and_construct_node_execution_input(
nodes_input_masks or {},
)
starting_nodes_input, nodes_to_skip = (
await _construct_starting_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=graph_inputs,
nodes_input_masks=nodes_input_masks,
)
(
starting_nodes_input,
nodes_to_skip,
) = await _construct_starting_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=graph_inputs,
nodes_input_masks=nodes_input_masks,
)
return graph, starting_nodes_input, nodes_input_masks, nodes_to_skip
@@ -858,16 +859,19 @@ async def add_graph_execution(
)
# Create new execution
graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip = (
await validate_and_construct_node_execution_input(
graph_id=graph_id,
user_id=user_id,
graph_inputs=inputs or {},
graph_version=graph_version,
graph_credentials_inputs=graph_credentials_inputs,
nodes_input_masks=nodes_input_masks,
is_sub_graph=parent_exec_id is not None,
)
(
graph,
starting_nodes_input,
compiled_nodes_input_masks,
nodes_to_skip,
) = await validate_and_construct_node_execution_input(
graph_id=graph_id,
user_id=user_id,
graph_inputs=inputs or {},
graph_version=graph_version,
graph_credentials_inputs=graph_credentials_inputs,
nodes_input_masks=nodes_input_masks,
is_sub_graph=parent_exec_id is not None,
)
graph_exec = await edb.create_graph_execution(

View File

@@ -146,8 +146,7 @@ class IntegrationCredentialsManager:
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
if oauth_handler.needs_refresh(credentials):
logger.debug(
f"Refreshing '{credentials.provider}' "
f"credentials #{credentials.id}"
f"Refreshing '{credentials.provider}' credentials #{credentials.id}"
)
_lock = None
if lock:

View File

@@ -77,18 +77,23 @@ class TestNotificationErrorHandling:
self, notification_manager, sample_batch_event
):
"""Test that 406 inactive recipient error stops ALL processing for that user."""
with patch("backend.notifications.notifications.logger"), patch(
"backend.notifications.notifications.set_user_email_verification",
new_callable=AsyncMock,
) as mock_set_verification, patch(
"backend.notifications.notifications.disable_all_user_notifications",
new_callable=AsyncMock,
) as mock_disable_all, patch(
"backend.notifications.notifications.get_database_manager_async_client"
) as mock_db_client, patch(
"backend.notifications.notifications.generate_unsubscribe_link"
) as mock_unsub_link:
with (
patch("backend.notifications.notifications.logger"),
patch(
"backend.notifications.notifications.set_user_email_verification",
new_callable=AsyncMock,
) as mock_set_verification,
patch(
"backend.notifications.notifications.disable_all_user_notifications",
new_callable=AsyncMock,
) as mock_disable_all,
patch(
"backend.notifications.notifications.get_database_manager_async_client"
) as mock_db_client,
patch(
"backend.notifications.notifications.generate_unsubscribe_link"
) as mock_unsub_link,
):
# Create batch of 5 notifications
notifications = []
for i in range(5):
@@ -169,12 +174,15 @@ class TestNotificationErrorHandling:
self, notification_manager, sample_batch_event
):
"""Test that 422 error permanently removes the malformed notification from batch and continues with others."""
with patch("backend.notifications.notifications.logger") as mock_logger, patch(
"backend.notifications.notifications.get_database_manager_async_client"
) as mock_db_client, patch(
"backend.notifications.notifications.generate_unsubscribe_link"
) as mock_unsub_link:
with (
patch("backend.notifications.notifications.logger") as mock_logger,
patch(
"backend.notifications.notifications.get_database_manager_async_client"
) as mock_db_client,
patch(
"backend.notifications.notifications.generate_unsubscribe_link"
) as mock_unsub_link,
):
# Create batch of 5 notifications
notifications = []
for i in range(5):
@@ -272,12 +280,15 @@ class TestNotificationErrorHandling:
self, notification_manager, sample_batch_event
):
"""Test that oversized notifications are permanently removed from batch but others continue."""
with patch("backend.notifications.notifications.logger") as mock_logger, patch(
"backend.notifications.notifications.get_database_manager_async_client"
) as mock_db_client, patch(
"backend.notifications.notifications.generate_unsubscribe_link"
) as mock_unsub_link:
with (
patch("backend.notifications.notifications.logger") as mock_logger,
patch(
"backend.notifications.notifications.get_database_manager_async_client"
) as mock_db_client,
patch(
"backend.notifications.notifications.generate_unsubscribe_link"
) as mock_unsub_link,
):
# Create batch of 5 notifications
notifications = []
for i in range(5):
@@ -382,12 +393,15 @@ class TestNotificationErrorHandling:
self, notification_manager, sample_batch_event
):
"""Test that generic API errors keep notifications in batch for retry while others continue."""
with patch("backend.notifications.notifications.logger") as mock_logger, patch(
"backend.notifications.notifications.get_database_manager_async_client"
) as mock_db_client, patch(
"backend.notifications.notifications.generate_unsubscribe_link"
) as mock_unsub_link:
with (
patch("backend.notifications.notifications.logger") as mock_logger,
patch(
"backend.notifications.notifications.get_database_manager_async_client"
) as mock_db_client,
patch(
"backend.notifications.notifications.generate_unsubscribe_link"
) as mock_unsub_link,
):
# Create batch of 5 notifications
notifications = []
for i in range(5):
@@ -499,12 +513,15 @@ class TestNotificationErrorHandling:
self, notification_manager, sample_batch_event
):
"""Test successful batch processing where all notifications are sent without errors."""
with patch("backend.notifications.notifications.logger") as mock_logger, patch(
"backend.notifications.notifications.get_database_manager_async_client"
) as mock_db_client, patch(
"backend.notifications.notifications.generate_unsubscribe_link"
) as mock_unsub_link:
with (
patch("backend.notifications.notifications.logger") as mock_logger,
patch(
"backend.notifications.notifications.get_database_manager_async_client"
) as mock_db_client,
patch(
"backend.notifications.notifications.generate_unsubscribe_link"
) as mock_unsub_link,
):
# Create batch of 5 notifications
notifications = []
for i in range(5):

View File

@@ -6,7 +6,7 @@ Usage: from backend.sdk import *
This module provides:
- All block base classes and types
- All credential and authentication components
- All credential and authentication components
- All cost tracking components
- All webhook components
- All utility functions

View File

@@ -1,7 +1,7 @@
"""
Integration between SDK provider costs and the execution cost system.
This module provides the glue between provider-defined base costs and the
This module provides the glue between provider-defined base costs and the
BLOCK_COSTS configuration used by the execution system.
"""

View File

@@ -91,7 +91,6 @@ class AutoRegistry:
not hasattr(provider.webhook_manager, "PROVIDER_NAME")
or provider.webhook_manager.PROVIDER_NAME is None
):
# This works because ProviderName has _missing_ method
provider.webhook_manager.PROVIDER_NAME = ProviderName(provider.name)
cls._webhook_managers[provider.name] = provider.webhook_manager

View File

@@ -3,7 +3,7 @@ Utilities for handling dynamic field names and delimiters in the AutoGPT Platfor
Dynamic fields allow graphs to connect complex data structures using special delimiters:
- _#_ for dictionary keys (e.g., "values_#_name" → values["name"])
- _$_ for list indices (e.g., "items_$_0" → items[0])
- _$_ for list indices (e.g., "items_$_0" → items[0])
- _@_ for object attributes (e.g., "obj_@_attr" → obj.attr)
This module provides utilities for:

View File

@@ -33,14 +33,11 @@ class TestFileCloudIntegration:
cloud_path = "gcs://test-bucket/uploads/456/source.txt"
cloud_content = b"cloud file content"
with patch(
"backend.util.file.get_cloud_storage_handler"
) as mock_handler_getter, patch(
"backend.util.file.scan_content_safe"
) as mock_scan, patch(
"backend.util.file.Path"
) as mock_path_class:
with (
patch("backend.util.file.get_cloud_storage_handler") as mock_handler_getter,
patch("backend.util.file.scan_content_safe") as mock_scan,
patch("backend.util.file.Path") as mock_path_class,
):
# Mock cloud storage handler
mock_handler = MagicMock()
mock_handler.is_cloud_path.return_value = True
@@ -110,18 +107,13 @@ class TestFileCloudIntegration:
cloud_path = "gcs://test-bucket/uploads/456/image.png"
cloud_content = b"\\x89PNG\\r\\n\\x1a\\n\\x00\\x00\\x00\\rIHDR" # PNG header
with patch(
"backend.util.file.get_cloud_storage_handler"
) as mock_handler_getter, patch(
"backend.util.file.scan_content_safe"
) as mock_scan, patch(
"backend.util.file.get_mime_type"
) as mock_mime, patch(
"backend.util.file.base64.b64encode"
) as mock_b64, patch(
"backend.util.file.Path"
) as mock_path_class:
with (
patch("backend.util.file.get_cloud_storage_handler") as mock_handler_getter,
patch("backend.util.file.scan_content_safe") as mock_scan,
patch("backend.util.file.get_mime_type") as mock_mime,
patch("backend.util.file.base64.b64encode") as mock_b64,
patch("backend.util.file.Path") as mock_path_class,
):
# Mock cloud storage handler
mock_handler = MagicMock()
mock_handler.is_cloud_path.return_value = True
@@ -169,18 +161,13 @@ class TestFileCloudIntegration:
graph_exec_id = "test-exec-123"
data_uri = "data:text/plain;base64,SGVsbG8gd29ybGQ="
with patch(
"backend.util.file.get_cloud_storage_handler"
) as mock_handler_getter, patch(
"backend.util.file.scan_content_safe"
) as mock_scan, patch(
"backend.util.file.base64.b64decode"
) as mock_b64decode, patch(
"backend.util.file.uuid.uuid4"
) as mock_uuid, patch(
"backend.util.file.Path"
) as mock_path_class:
with (
patch("backend.util.file.get_cloud_storage_handler") as mock_handler_getter,
patch("backend.util.file.scan_content_safe") as mock_scan,
patch("backend.util.file.base64.b64decode") as mock_b64decode,
patch("backend.util.file.uuid.uuid4") as mock_uuid,
patch("backend.util.file.Path") as mock_path_class,
):
# Mock cloud storage handler
mock_handler = MagicMock()
mock_handler.is_cloud_path.return_value = False
@@ -230,7 +217,6 @@ class TestFileCloudIntegration:
with patch(
"backend.util.file.get_cloud_storage_handler"
) as mock_handler_getter:
# Mock cloud storage handler to raise error
mock_handler = AsyncMock()
mock_handler.is_cloud_path.return_value = True
@@ -255,14 +241,11 @@ class TestFileCloudIntegration:
local_file = "test_video.mp4"
file_content = b"fake video content"
with patch(
"backend.util.file.get_cloud_storage_handler"
) as mock_handler_getter, patch(
"backend.util.file.scan_content_safe"
) as mock_scan, patch(
"backend.util.file.Path"
) as mock_path_class:
with (
patch("backend.util.file.get_cloud_storage_handler") as mock_handler_getter,
patch("backend.util.file.scan_content_safe") as mock_scan,
patch("backend.util.file.Path") as mock_path_class,
):
# Mock cloud storage handler - not a cloud path
mock_handler = MagicMock()
mock_handler.is_cloud_path.return_value = False
@@ -307,14 +290,11 @@ class TestFileCloudIntegration:
local_file = "infected.exe"
file_content = b"malicious content"
with patch(
"backend.util.file.get_cloud_storage_handler"
) as mock_handler_getter, patch(
"backend.util.file.scan_content_safe"
) as mock_scan, patch(
"backend.util.file.Path"
) as mock_path_class:
with (
patch("backend.util.file.get_cloud_storage_handler") as mock_handler_getter,
patch("backend.util.file.scan_content_safe") as mock_scan,
patch("backend.util.file.Path") as mock_path_class,
):
# Mock cloud storage handler - not a cloud path
mock_handler = MagicMock()
mock_handler.is_cloud_path.return_value = False

View File

@@ -500,7 +500,6 @@ class Requests:
json=json,
**kwargs,
) as response:
if self.raise_for_status:
try:
response.raise_for_status()

View File

@@ -563,7 +563,6 @@ def get_service_client(
self._connection_failure_count >= 3
and current_time - self._last_client_reset > 30
):
logger.warning(
f"Connection failures detected ({self._connection_failure_count}), recreating HTTP clients"
)

View File

@@ -155,7 +155,6 @@ class TestDynamicClientConnectionHealing:
self._connection_failure_count >= 3
and current_time - self._last_client_reset > 30
):
# Clear cached clients to force recreation on next access
if hasattr(self, "sync_client"):
delattr(self, "sync_client")

View File

@@ -222,9 +222,9 @@ class TestSafeJson:
problematic_data = {
"null_byte": "data with \x00 null",
"bell_char": "data with \x07 bell",
"form_feed": "data with \x0C feed",
"escape_char": "data with \x1B escape",
"delete_char": "data with \x7F delete",
"form_feed": "data with \x0c feed",
"escape_char": "data with \x1b escape",
"delete_char": "data with \x7f delete",
}
# SafeJson should successfully process data with control characters
@@ -235,9 +235,9 @@ class TestSafeJson:
result_data = result.data
assert "\x00" not in str(result_data) # null byte removed
assert "\x07" not in str(result_data) # bell removed
assert "\x0C" not in str(result_data) # form feed removed
assert "\x1B" not in str(result_data) # escape removed
assert "\x7F" not in str(result_data) # delete removed
assert "\x0c" not in str(result_data) # form feed removed
assert "\x1b" not in str(result_data) # escape removed
assert "\x7f" not in str(result_data) # delete removed
# Test that safe whitespace characters are preserved
safe_data = {
@@ -263,7 +263,7 @@ class TestSafeJson:
def test_web_scraping_content_sanitization(self):
"""Test sanitization of typical web scraping content with null characters."""
# Simulate web content that might contain null bytes from SearchTheWebBlock
web_content = "Article title\x00Hidden null\x01Start of heading\x08Backspace\x0CForm feed content\x1FUnit separator\x7FDelete char"
web_content = "Article title\x00Hidden null\x01Start of heading\x08Backspace\x0cForm feed content\x1fUnit separator\x7fDelete char"
result = SafeJson(web_content)
assert isinstance(result, Json)
@@ -273,9 +273,9 @@ class TestSafeJson:
assert "\x00" not in sanitized_content
assert "\x01" not in sanitized_content
assert "\x08" not in sanitized_content
assert "\x0C" not in sanitized_content
assert "\x1F" not in sanitized_content
assert "\x7F" not in sanitized_content
assert "\x0c" not in sanitized_content
assert "\x1f" not in sanitized_content
assert "\x7f" not in sanitized_content
# Verify the content is still readable
assert "Article title" in sanitized_content
@@ -391,7 +391,7 @@ class TestSafeJson:
mixed_content = {
"safe_and_unsafe": "Good text\twith tab\x00NULL BYTE\nand newline\x08BACKSPACE",
"file_path_with_null": "C:\\temp\\file\x00.txt",
"json_with_controls": '{"text": "data\x01\x0C\x1F"}',
"json_with_controls": '{"text": "data\x01\x0c\x1f"}',
}
result = SafeJson(mixed_content)
@@ -419,13 +419,13 @@ class TestSafeJson:
# Create data with various problematic escape sequences that could cause JSON parsing errors
problematic_output_data = {
"web_content": "Article text\x00with null\x01and control\x08chars\x0C\x1F\x7F",
"web_content": "Article text\x00with null\x01and control\x08chars\x0c\x1f\x7f",
"file_path": "C:\\Users\\test\\file\x00.txt",
"json_like_string": '{"text": "data\x00\x08\x1F"}',
"json_like_string": '{"text": "data\x00\x08\x1f"}',
"escaped_sequences": "Text with \\u0000 and \\u0008 sequences",
"mixed_content": "Normal text\tproperly\nformatted\rwith\x00invalid\x08chars\x1Fmixed",
"mixed_content": "Normal text\tproperly\nformatted\rwith\x00invalid\x08chars\x1fmixed",
"large_text": "A" * 35000
+ "\x00\x08\x1F"
+ "\x00\x08\x1f"
+ "B" * 5000, # Large text like in the error
}
@@ -446,9 +446,9 @@ class TestSafeJson:
assert "\x00" not in str(web_content)
assert "\x01" not in str(web_content)
assert "\x08" not in str(web_content)
assert "\x0C" not in str(web_content)
assert "\x1F" not in str(web_content)
assert "\x7F" not in str(web_content)
assert "\x0c" not in str(web_content)
assert "\x1f" not in str(web_content)
assert "\x7f" not in str(web_content)
# Check that legitimate content is preserved
assert "Article text" in str(web_content)
@@ -467,7 +467,7 @@ class TestSafeJson:
assert "B" * 1000 in str(large_text) # B's preserved
assert "\x00" not in str(large_text) # Control chars removed
assert "\x08" not in str(large_text)
assert "\x1F" not in str(large_text)
assert "\x1f" not in str(large_text)
# Most importantly: ensure the result can be JSON-serialized without errors
# This would have failed with the old approach
@@ -602,7 +602,7 @@ class TestSafeJson:
model = SamplePydanticModel(
name="Test\x00User", # Has null byte
age=30,
metadata={"info": "data\x08with\x0Ccontrols"},
metadata={"info": "data\x08with\x0ccontrols"},
)
data = {"credential": model}
@@ -616,7 +616,7 @@ class TestSafeJson:
json_string = json.dumps(result.data)
assert "\x00" not in json_string
assert "\x08" not in json_string
assert "\x0C" not in json_string
assert "\x0c" not in json_string
assert "TestUser" in json_string # Name preserved minus null byte
def test_deeply_nested_pydantic_models_control_char_sanitization(self):
@@ -639,16 +639,16 @@ class TestSafeJson:
# Create test data with control characters at every nesting level
inner = InnerModel(
deep_string="Deepest\x00Level\x08Control\x0CChars", # Multiple control chars at deepest level
deep_string="Deepest\x00Level\x08Control\x0cChars", # Multiple control chars at deepest level
metadata={
"nested_key": "Nested\x1FValue\x7FDelete"
"nested_key": "Nested\x1fValue\x7fDelete"
}, # Control chars in nested dict
)
middle = MiddleModel(
middle_string="Middle\x01StartOfHeading\x1FUnitSeparator",
middle_string="Middle\x01StartOfHeading\x1fUnitSeparator",
inner=inner,
data="Some\x0BVerticalTab\x0EShiftOut",
data="Some\x0bVerticalTab\x0eShiftOut",
)
outer = OuterModel(outer_string="Outer\x00Null\x07Bell", middle=middle)
@@ -659,7 +659,7 @@ class TestSafeJson:
"nested_model": outer,
"list_with_strings": [
"List\x00Item1",
"List\x0CItem2\x1F",
"List\x0cItem2\x1f",
{"dict_in_list": "Dict\x08Value"},
],
}
@@ -684,10 +684,10 @@ class TestSafeJson:
"\x06",
"\x07",
"\x08",
"\x0B",
"\x0C",
"\x0E",
"\x0F",
"\x0b",
"\x0c",
"\x0e",
"\x0f",
"\x10",
"\x11",
"\x12",
@@ -698,13 +698,13 @@ class TestSafeJson:
"\x17",
"\x18",
"\x19",
"\x1A",
"\x1B",
"\x1C",
"\x1D",
"\x1E",
"\x1F",
"\x7F",
"\x1a",
"\x1b",
"\x1c",
"\x1d",
"\x1e",
"\x1f",
"\x7f",
]
for char in control_chars:

View File

@@ -10,9 +10,8 @@ import {
MessageResponse,
} from "@/components/ai-elements/message";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { toast } from "@/components/molecules/Toast/use-toast";
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
import { useEffect, useRef, useState } from "react";
import { useEffect, useState } from "react";
import { CreateAgentTool } from "../../tools/CreateAgent/CreateAgent";
import { EditAgentTool } from "../../tools/EditAgent/EditAgent";
import {
@@ -129,7 +128,6 @@ export const ChatMessagesContainer = ({
headerSlot,
}: ChatMessagesContainerProps) => {
const [thinkingPhrase, setThinkingPhrase] = useState(getRandomPhrase);
const lastToastTimeRef = useRef(0);
useEffect(() => {
if (status === "submitted") {

View File

@@ -152,7 +152,7 @@ export function CreateAgentTool({ part }: Props) {
return (
<div className="py-2">
{!hasExpandableContent && (
{isOperating && (
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon isStreaming={isStreaming} isError={isError} />
<MorphingTextAnimation

View File

@@ -4,7 +4,6 @@ import { WarningDiamondIcon } from "@phosphor-icons/react";
import type { ToolUIPart } from "ai";
import { useCopilotChatActions } from "../../components/CopilotChatActionsProvider/useCopilotChatActions";
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
import {
ContentCardDescription,
ContentCodeBlock,
@@ -57,7 +56,7 @@ function getAccordionMeta(output: EditAgentToolOutput | null): {
if (!output) {
return {
icon: <OrbitLoader size={32} />,
icon,
title: "Editing agent, this may take a few minutes. Play while you wait.",
expanded: true,
};
@@ -83,7 +82,7 @@ function getAccordionMeta(output: EditAgentToolOutput | null): {
}
if (isOperationInProgressOutput(output)) {
return {
icon: <OrbitLoader size={32} />,
icon,
title: "Editing agent, this may take a few minutes. Play while you wait.",
expanded: true,
};
@@ -133,7 +132,7 @@ export function EditAgentTool({ part }: Props) {
return (
<div className="py-2">
{!hasExpandableContent && (
{isOperating && (
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon isStreaming={isStreaming} isError={isError} />
<MorphingTextAnimation