Compare commits

..

5 Commits

Author SHA1 Message Date
dependabot[bot]
179df7a726 chore(backend/deps): bump e2b-code-interpreter
Bumps [e2b-code-interpreter](https://github.com/e2b-dev/code-interpreter) from 1.5.2 to 2.4.1.
- [Release notes](https://github.com/e2b-dev/code-interpreter/releases)
- [Commits](https://github.com/e2b-dev/code-interpreter/compare/@e2b/code-interpreter-python@1.5.2...@e2b/code-interpreter-python@2.4.1)

---
updated-dependencies:
- dependency-name: e2b-code-interpreter
  dependency-version: 2.4.1
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-02-09 21:12:23 +00:00
Reinier van der Leer
6467f6734f debug(backend/chat): Add timing logging to chat stream generation mechanism (#12019)
[SECRT-1912: Investigate & eliminate chat session start
latency](https://linear.app/autogpt/issue/SECRT-1912)

### Changes 🏗️

- Add timing logs to `backend.api.features.chat` in `routes.py`,
`service.py`, and `stream_registry.py`
- Remove unneeded DB join in `create_chat_session`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - CI checks
2026-02-09 14:05:29 +00:00
Otto
5a30d11416 refactor(copilot): Code cleanup and deduplication (#11950)
## Summary

Code cleanup of the AI Copilot codebase - rebased onto latest dev.

## Changes

### New Files
- `backend/util/validation.py` - UUID validation helpers
- `backend/api/features/chat/tools/helpers.py` - Shared tool utilities

### Credential Matching Consolidation  
- Added shared utilities to `utils.py`
- Refactored `run_block._check_block_credentials()` with discriminator
support
- Extracted `_resolve_discriminated_credentials()` for multi-provider
handling

### Routes Cleanup
- Extracted `_create_stream_generator()` and `SSE_RESPONSE_HEADERS`

### Tool Files Cleanup
- Updated `run_agent.py` and `run_block.py` to use shared helpers

**WIP** - This PR will be updated incrementally.
2026-02-09 13:43:55 +00:00
Bently
1f4105e8f9 fix(frontend): Handle object values in FileInput component (#11948)
Fixes
[#11800](https://github.com/Significant-Gravitas/AutoGPT/issues/11800)

## Problem
The FileInput component crashed with `TypeError: e.startsWith is not a
function` when the value was an object (from external API) instead of a
string.

## Example Input Object
When using the external API
(`/external-api/v1/graphs/{id}/execute/{version}`), file inputs can be
passed as objects:

```json
{
  "node_input": {
    "input_image": {
      "name": "image.jpeg",
      "type": "image/jpeg",
      "size": 131147,
      "data": "/9j/4QAW..."
    }
  }
}
```

## Changes
- Updated `getFileLabelFromValue()` to handle object format: `{ name,
type, size, data }`
- Added type guards for string vs object values
- Graceful fallback for edge cases (null, undefined, empty object)

## Test cases verified
- Object with name: returns filename
- Object with type only: extracts and formats MIME type
- String data URI: parses correctly
- String file path: extracts extension
- Edge cases: returns "File" fallback
2026-02-09 10:25:08 +00:00
Bently
caf9ff34e6 fix(backend): Handle stale RabbitMQ channels on connection drop (#11929)
### Changes 🏗️

Fixes
[**AUTOGPT-SERVER-1TN**](https://autoagpt.sentry.io/issues/?query=AUTOGPT-SERVER-1TN)
(~39K events since Feb 2025) and related connection issues
**6JC/6JD/6JE/6JF** (~6K combined).

#### Problem

When the RabbitMQ TCP connection drops (network blip, server restart,
etc.):

1. `connect_robust` (aio_pika) automatically reconnects the underlying
AMQP connection
2. But `AsyncRabbitMQ._channel` still references the **old dead
channel**
3. `is_ready` checks `not self._channel.is_closed` — but the channel
object doesn't know the transport is gone
4. `publish_message` tries to use the stale channel →
`ChannelInvalidStateError: No active transport in channel`
5. `@func_retry` retries 5 times, but each retry hits the same stale
channel (it passes `is_ready`)

This means every connection drop generates errors until the process is
restarted.

#### Fix

**New `_ensure_channel()` helper** that resets stale channels before
reconnecting, so `connect()` creates a fresh one instead of
short-circuiting on `is_connected`.

**Explicit `ChannelInvalidStateError` handling in `publish_message`:**
1. First attempt uses `_ensure_channel()` (handles normal staleness)
2. If publish throws `ChannelInvalidStateError`, does a full reconnect
(resets both `_channel` and `_connection`) and retries once
3. `@func_retry` provides additional retry resilience on top

**Simplified `get_channel()`** to use the same resilient helper.

**1 file changed, 62 insertions, 24 deletions.**

#### Impact
- Eliminates ~39K `ChannelInvalidStateError` Sentry events
- RabbitMQ operations self-heal after connection drops without process
restart
- Related transport EOF errors (6JC/6JD/6JE/6JF) should also reduce
2026-02-09 10:24:08 +00:00
14 changed files with 992 additions and 991 deletions

View File

@@ -45,10 +45,7 @@ async def create_chat_session(
successfulAgentRuns=SafeJson({}),
successfulAgentSchedules=SafeJson({}),
)
return await PrismaChatSession.prisma().create(
data=data,
include={"Messages": True},
)
return await PrismaChatSession.prisma().create(data=data)
async def update_chat_session(

View File

@@ -266,12 +266,38 @@ async def stream_chat_post(
"""
import asyncio
import time
stream_start_time = time.perf_counter()
# Base log metadata (task_id added after creation)
log_meta = {"component": "ChatStream", "session_id": session_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
f"user={user_id}, message_len={len(request.message)}",
extra={"json_fields": log_meta},
)
session = await _validate_and_get_session(session_id, user_id)
logger.info(
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - stream_start_time) * 1000,
}
},
)
# Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4())
log_meta["task_id"] = task_id
task_create_start = time.perf_counter()
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
@@ -280,14 +306,46 @@ async def stream_chat_post(
tool_name="chat",
operation_id=operation_id,
)
logger.info(
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
}
},
)
# Background task that runs the AI generation independently of SSE connection
async def run_ai_generation():
import time as time_module
gen_start_time = time_module.perf_counter()
logger.info(
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
extra={"json_fields": log_meta},
)
first_chunk_time, ttfc = None, None
chunk_count = 0
try:
# Emit a start event with task_id for reconnection
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
await stream_registry.publish_chunk(task_id, start_chunk)
logger.info(
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time)*1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
* 1000,
}
},
)
logger.info(
"[TIMING] Calling stream_chat_completion",
extra={"json_fields": log_meta},
)
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
@@ -296,54 +354,202 @@ async def stream_chat_post(
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
chunk_count += 1
if first_chunk_time is None:
first_chunk_time = time_module.perf_counter()
ttfc = first_chunk_time - gen_start_time
logger.info(
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"chunk_type": type(chunk).__name__,
"time_to_first_chunk_ms": ttfc * 1000,
}
},
)
# Write to Redis (subscribers will receive via XREAD)
await stream_registry.publish_chunk(task_id, chunk)
# Mark task as completed
gen_end_time = time_module.perf_counter()
total_time = (gen_end_time - gen_start_time) * 1000
logger.info(
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
f"task={task_id}, session={session_id}, "
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"time_to_first_chunk_ms": (
ttfc * 1000 if ttfc is not None else None
),
"n_chunks": chunk_count,
}
},
)
await stream_registry.mark_task_completed(task_id, "completed")
except Exception as e:
elapsed = time_module.perf_counter() - gen_start_time
logger.error(
f"Error in background AI generation for session {session_id}: {e}"
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed * 1000,
"error": str(e),
}
},
)
await stream_registry.mark_task_completed(task_id, "failed")
# Start the AI generation in a background task
bg_task = asyncio.create_task(run_ai_generation())
await stream_registry.set_task_asyncio_task(task_id, bg_task)
setup_time = (time.perf_counter() - stream_start_time) * 1000
logger.info(
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
)
# SSE endpoint that subscribes to the task's stream
async def event_generator() -> AsyncGenerator[str, None]:
import time as time_module
event_gen_start = time_module.perf_counter()
logger.info(
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
f"user={user_id}",
extra={"json_fields": log_meta},
)
subscriber_queue = None
first_chunk_yielded = False
chunks_yielded = 0
try:
# Subscribe to the task stream (this replays existing messages + live updates)
subscribe_start = time_module.perf_counter()
logger.info(
"[TIMING] Calling subscribe_to_task",
extra={"json_fields": log_meta},
)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id="0-0", # Get all messages from the beginning
)
subscribe_time = (time_module.perf_counter() - subscribe_start) * 1000
logger.info(
f"[TIMING] subscribe_to_task completed in {subscribe_time:.1f}ms, "
f"queue_ok={subscriber_queue is not None}",
extra={
"json_fields": {
**log_meta,
"duration_ms": subscribe_time,
"queue_obtained": subscriber_queue is not None,
}
},
)
if subscriber_queue is None:
logger.info(
"[TIMING] subscriber_queue is None, yielding finish",
extra={"json_fields": log_meta},
)
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return
# Read from the subscriber queue and yield to SSE
logger.info(
"[TIMING] Starting to read from subscriber_queue",
extra={"json_fields": log_meta},
)
while True:
try:
queue_wait_start = time_module.perf_counter()
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
queue_wait_time = (
time_module.perf_counter() - queue_wait_start
) * 1000
chunks_yielded += 1
if not first_chunk_yielded:
first_chunk_yielded = True
elapsed = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
f"type={type(chunk).__name__}, "
f"wait={queue_wait_time:.1f}ms",
extra={
"json_fields": {
**log_meta,
"chunk_type": type(chunk).__name__,
"elapsed_ms": elapsed * 1000,
"queue_wait_ms": queue_wait_time,
}
},
)
elif chunks_yielded % 50 == 0:
logger.info(
f"[TIMING] Chunk #{chunks_yielded}, "
f"type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"chunk_number": chunks_yielded,
"chunk_type": type(chunk).__name__,
}
},
)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
f"n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"chunks_yielded": chunks_yielded,
"total_time_ms": total_time * 1000,
}
},
)
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
logger.info(
f"[TIMING] Heartbeat timeout, chunks_so_far={chunks_yielded}",
extra={
"json_fields": {**log_meta, "chunks_so_far": chunks_yielded}
},
)
yield StreamHeartbeat().to_sse()
except GeneratorExit:
logger.info(
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"chunks_yielded": chunks_yielded,
"reason": "client_disconnect",
}
},
)
pass # Client disconnected - background task continues
except Exception as e:
logger.error(f"Error in SSE stream for task {task_id}: {e}")
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
logger.error(
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
extra={
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
},
)
finally:
# Unsubscribe when client disconnects or stream ends to prevent resource leak
if subscriber_queue is not None:
@@ -357,6 +563,18 @@ async def stream_chat_post(
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time * 1000,
"chunks_yielded": chunks_yielded,
}
},
)
yield "data: [DONE]\n\n"
return StreamingResponse(
@@ -425,7 +643,7 @@ async def stream_chat_get(
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"n_chunks": chunk_count,
"first_chunk_type": first_chunk_type,
},
)

View File

@@ -371,21 +371,45 @@ async def stream_chat_completion(
ValueError: If max_context_messages is exceeded
"""
completion_start = time.monotonic()
# Build log metadata for structured logging
log_meta = {"component": "ChatService", "session_id": session_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, "
f"message_len={len(message) if message else 0}, is_user={is_user_message}",
extra={
"json_fields": {
**log_meta,
"message_len": len(message) if message else 0,
"is_user_message": is_user_message,
}
},
)
# Only fetch from Redis if session not provided (initial call)
if session is None:
fetch_start = time.monotonic()
session = await get_chat_session(session_id, user_id)
fetch_time = (time.monotonic() - fetch_start) * 1000
logger.info(
f"Fetched session from Redis: {session.session_id if session else 'None'}, "
f"message_count={len(session.messages) if session else 0}"
f"[TIMING] get_chat_session took {fetch_time:.1f}ms, "
f"n_messages={len(session.messages) if session else 0}",
extra={
"json_fields": {
**log_meta,
"duration_ms": fetch_time,
"n_messages": len(session.messages) if session else 0,
}
},
)
else:
logger.info(
f"Using provided session object: {session.session_id}, "
f"message_count={len(session.messages)}"
f"[TIMING] Using provided session, messages={len(session.messages)}",
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
)
if not session:
@@ -406,17 +430,25 @@ async def stream_chat_completion(
# Track user message in PostHog
if is_user_message:
posthog_start = time.monotonic()
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(message),
)
posthog_time = (time.monotonic() - posthog_start) * 1000
logger.info(
f"[TIMING] track_user_message took {posthog_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": posthog_time}},
)
logger.info(
f"Upserting session: {session.session_id} with user id {session.user_id}, "
f"message_count={len(session.messages)}"
)
upsert_start = time.monotonic()
session = await upsert_chat_session(session)
upsert_time = (time.monotonic() - upsert_start) * 1000
logger.info(
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": upsert_time}},
)
assert session, "Session not found"
# Generate title for new sessions on first user message (non-blocking)
@@ -454,7 +486,13 @@ async def stream_chat_completion(
asyncio.create_task(_update_title())
# Build system prompt with business understanding
prompt_start = time.monotonic()
system_prompt, understanding = await _build_system_prompt(user_id)
prompt_time = (time.monotonic() - prompt_start) * 1000
logger.info(
f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": prompt_time}},
)
# Initialize variables for streaming
assistant_response = ChatMessage(
@@ -483,9 +521,18 @@ async def stream_chat_completion(
text_block_id = str(uuid_module.uuid4())
# Yield message start
setup_time = (time.monotonic() - completion_start) * 1000
logger.info(
f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms",
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
)
yield StreamStart(messageId=message_id)
try:
logger.info(
"[TIMING] Calling _stream_chat_chunks",
extra={"json_fields": log_meta},
)
async for chunk in _stream_chat_chunks(
session=session,
tools=tools,
@@ -893,9 +940,21 @@ async def _stream_chat_chunks(
SSE formatted JSON response objects
"""
import time as time_module
stream_chunks_start = time_module.perf_counter()
model = config.model
logger.info("Starting pure chat stream")
# Build log metadata for structured logging
log_meta = {"component": "ChatService", "session_id": session.session_id}
if session.user_id:
log_meta["user_id"] = session.user_id
logger.info(
f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, "
f"user={session.user_id}, n_messages={len(session.messages)}",
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
)
messages = session.to_openai_messages()
if system_prompt:
@@ -906,12 +965,18 @@ async def _stream_chat_chunks(
messages = [system_message] + messages
# Apply context window management
context_start = time_module.perf_counter()
context_result = await _manage_context_window(
messages=messages,
model=model,
api_key=config.api_key,
base_url=config.base_url,
)
context_time = (time_module.perf_counter() - context_start) * 1000
logger.info(
f"[TIMING] _manage_context_window took {context_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": context_time}},
)
if context_result.error:
if "System prompt dropped" in context_result.error:
@@ -946,9 +1011,19 @@ async def _stream_chat_chunks(
while retry_count <= MAX_RETRIES:
try:
elapsed = (time_module.perf_counter() - stream_chunks_start) * 1000
retry_info = (
f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
)
logger.info(
f"Creating OpenAI chat completion stream..."
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"retry_count": retry_count,
}
},
)
# Build extra_body for OpenRouter tracing and PostHog analytics
@@ -965,6 +1040,7 @@ async def _stream_chat_chunks(
:128
] # OpenRouter limit
api_call_start = time_module.perf_counter()
stream = await client.chat.completions.create(
model=model,
messages=cast(list[ChatCompletionMessageParam], messages),
@@ -974,6 +1050,11 @@ async def _stream_chat_chunks(
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
extra_body=extra_body,
)
api_init_time = (time_module.perf_counter() - api_call_start) * 1000
logger.info(
f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": api_init_time}},
)
# Variables to accumulate tool calls
tool_calls: list[dict[str, Any]] = []
@@ -984,10 +1065,13 @@ async def _stream_chat_chunks(
# Track if we've started the text block
text_started = False
first_content_chunk = True
chunk_count = 0
# Process the stream
chunk: ChatCompletionChunk
async for chunk in stream:
chunk_count += 1
if chunk.usage:
yield StreamUsage(
promptTokens=chunk.usage.prompt_tokens,
@@ -1010,6 +1094,23 @@ async def _stream_chat_chunks(
if not text_started and text_block_id:
yield StreamTextStart(id=text_block_id)
text_started = True
# Log timing for first content chunk
if first_content_chunk:
first_content_chunk = False
ttfc = (
time_module.perf_counter() - api_call_start
) * 1000
logger.info(
f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms "
f"(since API call), n_chunks={chunk_count}",
extra={
"json_fields": {
**log_meta,
"time_to_first_chunk_ms": ttfc,
"n_chunks": chunk_count,
}
},
)
# Stream the text delta
text_response = StreamTextDelta(
id=text_block_id or "",
@@ -1066,7 +1167,21 @@ async def _stream_chat_chunks(
toolName=tool_calls[idx]["function"]["name"],
)
emitted_start_for_idx.add(idx)
logger.info(f"Stream complete. Finish reason: {finish_reason}")
stream_duration = time_module.perf_counter() - api_call_start
logger.info(
f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, "
f"duration={stream_duration:.2f}s, "
f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}",
extra={
"json_fields": {
**log_meta,
"stream_duration_ms": stream_duration * 1000,
"finish_reason": finish_reason,
"n_chunks": chunk_count,
"n_tool_calls": len(tool_calls),
}
},
)
# Yield all accumulated tool calls after the stream is complete
# This ensures all tool call arguments have been fully received
@@ -1086,6 +1201,12 @@ async def _stream_chat_chunks(
# Re-raise to trigger retry logic in the parent function
raise
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
logger.info(
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
f"session={session.session_id}, user={session.user_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
)
yield StreamFinish()
return
except Exception as e:

View File

@@ -104,6 +104,24 @@ async def create_task(
Returns:
The created ActiveTask instance (metadata only)
"""
import time
start_time = time.perf_counter()
# Build log metadata for structured logging
log_meta = {
"component": "StreamRegistry",
"task_id": task_id,
"session_id": session_id,
}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
extra={"json_fields": log_meta},
)
task = ActiveTask(
task_id=task_id,
session_id=session_id,
@@ -114,10 +132,18 @@ async def create_task(
)
# Store metadata in Redis
redis_start = time.perf_counter()
redis = await get_redis_async()
redis_time = (time.perf_counter() - redis_start) * 1000
logger.info(
f"[TIMING] get_redis_async took {redis_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
)
meta_key = _get_task_meta_key(task_id)
op_key = _get_operation_mapping_key(operation_id)
hset_start = time.perf_counter()
await redis.hset( # type: ignore[misc]
meta_key,
mapping={
@@ -131,12 +157,22 @@ async def create_task(
"created_at": task.created_at.isoformat(),
},
)
hset_time = (time.perf_counter() - hset_start) * 1000
logger.info(
f"[TIMING] redis.hset took {hset_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
)
await redis.expire(meta_key, config.stream_ttl)
# Create operation_id -> task_id mapping for webhook lookups
await redis.set(op_key, task_id, ex=config.stream_ttl)
logger.debug(f"Created task {task_id} for session {session_id}")
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
)
return task
@@ -156,26 +192,60 @@ async def publish_chunk(
Returns:
The Redis Stream message ID
"""
import time
start_time = time.perf_counter()
chunk_type = type(chunk).__name__
chunk_json = chunk.model_dump_json()
message_id = "0-0"
# Build log metadata
log_meta = {
"component": "StreamRegistry",
"task_id": task_id,
"chunk_type": chunk_type,
}
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
# Write to Redis Stream for persistence and real-time delivery
xadd_start = time.perf_counter()
raw_id = await redis.xadd(
stream_key,
{"data": chunk_json},
maxlen=config.stream_max_length,
)
xadd_time = (time.perf_counter() - xadd_start) * 1000
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
# Set TTL on stream to match task metadata TTL
await redis.expire(stream_key, config.stream_ttl)
total_time = (time.perf_counter() - start_time) * 1000
# Only log timing for significant chunks or slow operations
if (
chunk_type
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
or total_time > 50
):
logger.info(
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"xadd_time_ms": xadd_time,
"message_id": message_id,
}
},
)
except Exception as e:
elapsed = (time.perf_counter() - start_time) * 1000
logger.error(
f"Failed to publish chunk for task {task_id}: {e}",
f"[TIMING] Failed to publish chunk {chunk_type} after {elapsed:.1f}ms: {e}",
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
exc_info=True,
)
@@ -200,24 +270,61 @@ async def subscribe_to_task(
An asyncio Queue that will receive stream chunks, or None if task not found
or user doesn't have access
"""
import time
start_time = time.perf_counter()
# Build log metadata
log_meta = {"component": "StreamRegistry", "task_id": task_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
)
redis_start = time.perf_counter()
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
hgetall_time = (time.perf_counter() - redis_start) * 1000
logger.info(
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
)
if not meta:
logger.debug(f"Task {task_id} not found in Redis")
elapsed = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"reason": "task_not_found",
}
},
)
return None
# Note: Redis client uses decode_responses=True, so keys are strings
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
log_meta["session_id"] = meta.get("session_id", "")
# Validate ownership - if task has an owner, requester must match
if task_user_id:
if user_id != task_user_id:
logger.warning(
f"User {user_id} denied access to task {task_id} "
f"owned by {task_user_id}"
f"[TIMING] Access denied: user {user_id} tried to access task owned by {task_user_id}",
extra={
"json_fields": {
**log_meta,
"task_owner": task_user_id,
"reason": "access_denied",
}
},
)
return None
@@ -225,7 +332,19 @@ async def subscribe_to_task(
stream_key = _get_task_stream_key(task_id)
# Step 1: Replay messages from Redis Stream
xread_start = time.perf_counter()
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
xread_time = (time.perf_counter() - xread_start) * 1000
logger.info(
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
extra={
"json_fields": {
**log_meta,
"duration_ms": xread_time,
"task_status": task_status,
}
},
)
replayed_count = 0
replay_last_id = last_message_id
@@ -244,19 +363,48 @@ async def subscribe_to_task(
except Exception as e:
logger.warning(f"Failed to replay message: {e}")
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
logger.info(
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
extra={
"json_fields": {
**log_meta,
"n_messages_replayed": replayed_count,
"replay_last_id": replay_last_id,
}
},
)
# Step 2: If task is still running, start stream listener for live updates
if task_status == "running":
logger.info(
"[TIMING] Task still running, starting _stream_listener",
extra={"json_fields": {**log_meta, "task_status": task_status}},
)
listener_task = asyncio.create_task(
_stream_listener(task_id, subscriber_queue, replay_last_id)
_stream_listener(task_id, subscriber_queue, replay_last_id, log_meta)
)
# Track listener task for cleanup on unsubscribe
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
else:
# Task is completed/failed - add finish marker
logger.info(
f"[TIMING] Task already {task_status}, adding StreamFinish",
extra={"json_fields": {**log_meta, "task_status": task_status}},
)
await subscriber_queue.put(StreamFinish())
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
f"n_messages_replayed={replayed_count}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"n_messages_replayed": replayed_count,
}
},
)
return subscriber_queue
@@ -264,6 +412,7 @@ async def _stream_listener(
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
last_replayed_id: str,
log_meta: dict | None = None,
) -> None:
"""Listen to Redis Stream for new messages using blocking XREAD.
@@ -274,10 +423,27 @@ async def _stream_listener(
task_id: Task ID to listen for
subscriber_queue: Queue to deliver messages to
last_replayed_id: Last message ID from replay (continue from here)
log_meta: Structured logging metadata
"""
import time
start_time = time.perf_counter()
# Use provided log_meta or build minimal one
if log_meta is None:
log_meta = {"component": "StreamRegistry", "task_id": task_id}
logger.info(
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
)
queue_id = id(subscriber_queue)
# Track the last successfully delivered message ID for recovery hints
last_delivered_id = last_replayed_id
messages_delivered = 0
first_message_time = None
xread_count = 0
try:
redis = await get_redis_async()
@@ -287,9 +453,39 @@ async def _stream_listener(
while True:
# Block for up to 30 seconds waiting for new messages
# This allows periodic checking if task is still running
xread_start = time.perf_counter()
xread_count += 1
messages = await redis.xread(
{stream_key: current_id}, block=30000, count=100
)
xread_time = (time.perf_counter() - xread_start) * 1000
if messages:
msg_count = sum(len(msgs) for _, msgs in messages)
logger.info(
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
extra={
"json_fields": {
**log_meta,
"xread_count": xread_count,
"n_messages": msg_count,
"duration_ms": xread_time,
}
},
)
elif xread_time > 1000:
# Only log timeouts (30s blocking)
logger.info(
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
extra={
"json_fields": {
**log_meta,
"xread_count": xread_count,
"duration_ms": xread_time,
"reason": "timeout",
}
},
)
if not messages:
# Timeout - check if task is still running
@@ -326,10 +522,30 @@ async def _stream_listener(
)
# Update last delivered ID on successful delivery
last_delivered_id = current_id
messages_delivered += 1
if first_message_time is None:
first_message_time = time.perf_counter()
elapsed = (first_message_time - start_time) * 1000
logger.info(
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"chunk_type": type(chunk).__name__,
}
},
)
except asyncio.TimeoutError:
logger.warning(
f"Subscriber queue full for task {task_id}, "
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s",
extra={
"json_fields": {
**log_meta,
"timeout_s": QUEUE_PUT_TIMEOUT,
"reason": "queue_full",
}
},
)
# Send overflow error with recovery info
try:
@@ -351,15 +567,44 @@ async def _stream_listener(
# Stop listening on finish
if isinstance(chunk, StreamFinish):
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"messages_delivered": messages_delivered,
}
},
)
return
except Exception as e:
logger.warning(f"Error processing stream message: {e}")
logger.warning(
f"Error processing stream message: {e}",
extra={"json_fields": {**log_meta, "error": str(e)}},
)
except asyncio.CancelledError:
logger.debug(f"Stream listener cancelled for task {task_id}")
elapsed = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"messages_delivered": messages_delivered,
"reason": "cancelled",
}
},
)
raise # Re-raise to propagate cancellation
except Exception as e:
logger.error(f"Stream listener error for task {task_id}: {e}")
elapsed = (time.perf_counter() - start_time) * 1000
logger.error(
f"[TIMING] _stream_listener ERROR after {elapsed:.1f}ms: {e}",
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
)
# On error, send finish to unblock subscriber
try:
await asyncio.wait_for(
@@ -368,10 +613,24 @@ async def _stream_listener(
)
except (asyncio.TimeoutError, asyncio.QueueFull):
logger.warning(
f"Could not deliver finish event for task {task_id} after error"
"Could not deliver finish event after error",
extra={"json_fields": log_meta},
)
finally:
# Clean up listener task mapping on exit
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
f"delivered={messages_delivered}, xread_count={xread_count}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"messages_delivered": messages_delivered,
"xread_count": xread_count,
}
},
)
_listener_tasks.pop(queue_id, None)

View File

@@ -1,251 +0,0 @@
"""
Detect and save embedded binary data in block outputs.
Scans stdout_logs and other string outputs for embedded base64 patterns,
saves detected binary content to workspace, and replaces the base64 with
workspace:// references. This reduces LLM output token usage by ~97% for
file generation tasks.
Primary use case: ExecuteCodeBlock prints base64 to stdout, which appears
in stdout_logs. Without this processor, the LLM would re-type the entire
base64 string when saving files.
"""
import base64
import binascii
import hashlib
import logging
import re
import uuid
from typing import Any, Optional
from backend.util.file import sanitize_filename
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager
logger = logging.getLogger(__name__)
# Minimum decoded size to process (filters out small base64 strings)
MIN_DECODED_SIZE = 1024 # 1KB
# Pattern to find base64 chunks in text (at least 100 chars to be worth checking)
# Matches continuous base64 characters (with optional whitespace for line wrapping),
# optionally ending with = padding
EMBEDDED_BASE64_PATTERN = re.compile(r"[A-Za-z0-9+/\s]{100,}={0,2}")
# Magic numbers for binary file detection
MAGIC_SIGNATURES = [
(b"\x89PNG\r\n\x1a\n", "png"),
(b"\xff\xd8\xff", "jpg"),
(b"%PDF-", "pdf"),
(b"GIF87a", "gif"),
(b"GIF89a", "gif"),
(b"RIFF", "webp"), # Also check content[8:12] == b'WEBP'
]
async def process_binary_outputs(
outputs: dict[str, list[Any]],
workspace_manager: WorkspaceManager,
block_name: str,
) -> dict[str, list[Any]]:
"""
Scan all string values in outputs for embedded base64 binary content.
Save detected binaries to workspace and replace with references.
Args:
outputs: Block execution outputs (dict of output_name -> list of values)
workspace_manager: WorkspaceManager instance with session scoping
block_name: Name of the block (used in generated filenames)
Returns:
Processed outputs with embedded base64 replaced by workspace references
"""
cache: dict[str, str] = {} # content_hash -> workspace_ref
processed: dict[str, list[Any]] = {}
for name, items in outputs.items():
processed_items = []
for item in items:
processed_items.append(
await _process_value(item, workspace_manager, block_name, cache)
)
processed[name] = processed_items
return processed
async def _process_value(
value: Any,
wm: WorkspaceManager,
block: str,
cache: dict[str, str],
) -> Any:
"""Recursively process a value, detecting embedded base64 in strings."""
if isinstance(value, dict):
result = {}
for k, v in value.items():
result[k] = await _process_value(v, wm, block, cache)
return result
if isinstance(value, list):
return [await _process_value(v, wm, block, cache) for v in value]
if isinstance(value, str) and len(value) > MIN_DECODED_SIZE:
return await _extract_and_replace_base64(value, wm, block, cache)
return value
async def _extract_and_replace_base64(
text: str,
wm: WorkspaceManager,
block: str,
cache: dict[str, str],
) -> str:
"""
Find embedded base64 in text, save binaries, replace with references.
Scans for base64 patterns, validates each as binary via magic numbers,
saves valid binaries to workspace, and replaces the base64 portion
(plus any surrounding markers) with the workspace reference.
"""
result = text
offset = 0
for match in EMBEDDED_BASE64_PATTERN.finditer(text):
b64_str = match.group(0)
# Try to decode and validate
detection = _decode_and_validate(b64_str)
if detection is None:
continue
content, ext = detection
# Save to workspace
ref = await _save_binary(content, ext, wm, block, cache)
if ref is None:
continue
# Calculate replacement bounds (include surrounding markers if present)
start, end = match.start(), match.end()
start, end = _expand_to_markers(text, start, end)
# Apply replacement with offset adjustment
adj_start = start + offset
adj_end = end + offset
result = result[:adj_start] + ref + result[adj_end:]
offset += len(ref) - (end - start)
return result
def _decode_and_validate(b64_str: str) -> Optional[tuple[bytes, str]]:
"""
Decode base64 and validate it's a known binary format.
Tries multiple 4-byte aligned offsets to handle cases where marker text
(e.g., "START" from "PDF_BASE64_START") bleeds into the regex match.
Base64 works in 4-char chunks, so we only check aligned offsets.
Returns (content, extension) if valid binary, None otherwise.
"""
# Strip whitespace for RFC 2045 line-wrapped base64
normalized = re.sub(r"\s+", "", b64_str)
# Try offsets 0, 4, 8, ... up to 32 chars (handles markers up to ~24 chars)
# This handles cases like "STARTJVBERi0..." where "START" bleeds into match
for char_offset in range(0, min(33, len(normalized)), 4):
candidate = normalized[char_offset:]
try:
content = base64.b64decode(candidate, validate=True)
except (ValueError, binascii.Error):
continue
# Must meet minimum size
if len(content) < MIN_DECODED_SIZE:
continue
# Check magic numbers
for magic, ext in MAGIC_SIGNATURES:
if content.startswith(magic):
# Special case for WebP: RIFF container, verify "WEBP" at offset 8
if magic == b"RIFF":
if len(content) < 12 or content[8:12] != b"WEBP":
continue
return content, ext
return None
def _expand_to_markers(text: str, start: int, end: int) -> tuple[int, int]:
"""
Expand replacement bounds to include surrounding markers if present.
Handles patterns like:
- ---BASE64_START---\\n{base64}\\n---BASE64_END---
- [BASE64]{base64}[/BASE64]
- Or just the raw base64
"""
# Common marker patterns to strip (order matters - check longer patterns first)
start_markers = [
"PDF_BASE64_START",
"---BASE64_START---\n",
"---BASE64_START---",
"[BASE64]\n",
"[BASE64]",
]
end_markers = [
"PDF_BASE64_END",
"\n---BASE64_END---",
"---BASE64_END---",
"\n[/BASE64]",
"[/BASE64]",
]
# Check for start markers
for marker in start_markers:
marker_start = start - len(marker)
if marker_start >= 0 and text[marker_start:start] == marker:
start = marker_start
break
# Check for end markers
for marker in end_markers:
marker_end = end + len(marker)
if marker_end <= len(text) and text[end:marker_end] == marker:
end = marker_end
break
return start, end
async def _save_binary(
content: bytes,
ext: str,
wm: WorkspaceManager,
block: str,
cache: dict[str, str],
) -> Optional[str]:
"""
Save binary content to workspace with deduplication.
Returns workspace://file-id reference, or None on failure.
"""
content_hash = hashlib.sha256(content).hexdigest()
if content_hash in cache:
return cache[content_hash]
try:
safe_block = sanitize_filename(block)[:20].lower()
filename = f"{safe_block}_{uuid.uuid4().hex[:12]}.{ext}"
# Scan for viruses before saving
await scan_content_safe(content, filename=filename)
file = await wm.write_file(content, filename)
ref = f"workspace://{file.id}"
cache[content_hash] = ref
return ref
except Exception as e:
logger.warning("Failed to save binary output: %s", e)
return None

View File

@@ -0,0 +1,29 @@
"""Shared helpers for chat tools."""
from typing import Any
def get_inputs_from_schema(
input_schema: dict[str, Any],
exclude_fields: set[str] | None = None,
) -> list[dict[str, Any]]:
"""Extract input field info from JSON schema."""
if not isinstance(input_schema, dict):
return []
exclude = exclude_fields or set()
properties = input_schema.get("properties", {})
required = set(input_schema.get("required", []))
return [
{
"name": name,
"title": schema.get("title", name),
"type": schema.get("type", "string"),
"description": schema.get("description", ""),
"required": name in required,
"default": schema.get("default"),
}
for name, schema in properties.items()
if name not in exclude
]

View File

@@ -24,6 +24,7 @@ from backend.util.timezone_utils import (
)
from .base import BaseTool
from .helpers import get_inputs_from_schema
from .models import (
AgentDetails,
AgentDetailsResponse,
@@ -261,7 +262,7 @@ class RunAgentTool(BaseTool):
),
requirements={
"credentials": requirements_creds_list,
"inputs": self._get_inputs_list(graph.input_schema),
"inputs": get_inputs_from_schema(graph.input_schema),
"execution_modes": self._get_execution_modes(graph),
},
),
@@ -369,22 +370,6 @@ class RunAgentTool(BaseTool):
session_id=session_id,
)
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
"""Extract inputs list from schema."""
inputs_list = []
if isinstance(input_schema, dict) and "properties" in input_schema:
for field_name, field_schema in input_schema["properties"].items():
inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in input_schema.get("required", []),
}
)
return inputs_list
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
"""Get available execution modes for the graph."""
trigger_info = graph.trigger_setup_info
@@ -398,7 +383,7 @@ class RunAgentTool(BaseTool):
suffix: str,
) -> str:
"""Build a message describing available inputs for an agent."""
inputs_list = self._get_inputs_list(graph.input_schema)
inputs_list = get_inputs_from_schema(graph.input_schema)
required_names = [i["name"] for i in inputs_list if i["required"]]
optional_names = [i["name"] for i in inputs_list if not i["required"]]

View File

@@ -12,16 +12,15 @@ from backend.api.features.chat.tools.find_block import (
COPILOT_EXCLUDED_BLOCK_IDS,
COPILOT_EXCLUDED_BLOCK_TYPES,
)
from backend.data.block import get_block
from backend.data.block import AnyBlockSchema, get_block
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsMetaInput
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
from backend.util.workspace import WorkspaceManager
from .base import BaseTool
from .binary_output_processor import process_binary_outputs
from .helpers import get_inputs_from_schema
from .models import (
BlockOutputResponse,
ErrorResponse,
@@ -30,7 +29,10 @@ from .models import (
ToolResponseBase,
UserReadiness,
)
from .utils import build_missing_credentials_from_field_info
from .utils import (
build_missing_credentials_from_field_info,
match_credentials_to_requirements,
)
logger = logging.getLogger(__name__)
@@ -79,91 +81,6 @@ class RunBlockTool(BaseTool):
def requires_auth(self) -> bool:
return True
async def _check_block_credentials(
self,
user_id: str,
block: Any,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Check if user has required credentials for a block.
Args:
user_id: User ID
block: Block to check credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple[matched_credentials, missing_credentials]
"""
matched_credentials: dict[str, CredentialsMetaInput] = {}
missing_credentials: list[CredentialsMetaInput] = []
input_data = input_data or {}
# Get credential field info from block's input schema
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return matched_credentials, missing_credentials
# Get user's available credentials
creds_manager = IntegrationCredentialsManager()
available_creds = await creds_manager.store.get_all_creds(user_id)
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
# Get discriminator from input, falling back to schema default
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in effective_field_info.provider
and cred.type in effective_field_info.supported_types
),
None,
)
if matching_cred:
matched_credentials[field_name] = CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
else:
# Create a placeholder for the missing credential
provider = next(iter(effective_field_info.provider), "unknown")
cred_type = next(iter(effective_field_info.supported_types), "api_key")
missing_credentials.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched_credentials, missing_credentials
async def _execute(
self,
user_id: str | None,
@@ -234,8 +151,8 @@ class RunBlockTool(BaseTool):
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = await self._check_block_credentials(
user_id, block, input_data
matched_credentials, missing_credentials = (
await self._resolve_block_credentials(user_id, block, input_data)
)
if missing_credentials:
@@ -340,16 +257,6 @@ class RunBlockTool(BaseTool):
):
outputs[output_name].append(output_data)
# Post-process outputs to save binary content to workspace
workspace_manager = WorkspaceManager(
user_id=user_id,
workspace_id=workspace.id,
session_id=session.session_id,
)
outputs = await process_binary_outputs(
dict(outputs), workspace_manager, block.name
)
return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully",
block_id=block_id,
@@ -374,29 +281,75 @@ class RunBlockTool(BaseTool):
session_id=session_id,
)
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
async def _resolve_block_credentials(
self,
user_id: str,
block: AnyBlockSchema,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Resolve credentials for a block by matching user's available credentials.
Args:
user_id: User ID
block: Block to resolve credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple of (matched_credentials, missing_credentials) - matched credentials
are used for block execution, missing ones indicate setup requirements.
"""
input_data = input_data or {}
requirements = self._resolve_discriminated_credentials(block, input_data)
if not requirements:
return {}, []
return await match_credentials_to_requirements(user_id, requirements)
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
"""Extract non-credential inputs from block schema."""
inputs_list = []
schema = block.input_schema.jsonschema()
properties = schema.get("properties", {})
required_fields = set(schema.get("required", []))
# Get credential field names to exclude
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
for field_name, field_schema in properties.items():
# Skip credential fields
if field_name in credentials_fields:
continue
def _resolve_discriminated_credentials(
self,
block: AnyBlockSchema,
input_data: dict[str, Any],
) -> dict[str, CredentialsFieldInfo]:
"""Resolve credential requirements, applying discriminator logic where needed."""
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return {}
inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in required_fields,
}
)
resolved: dict[str, CredentialsFieldInfo] = {}
return inputs_list
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
# For host-scoped credentials, add the discriminator value
# (e.g., URL) so _credential_is_for_host can match it
effective_field_info.discriminator_values.add(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
resolved[field_name] = effective_field_info
return resolved

View File

@@ -1,518 +0,0 @@
"""Tests for embedded binary detection in block outputs."""
import base64
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from .binary_output_processor import (
_decode_and_validate,
_expand_to_markers,
process_binary_outputs,
)
@pytest.fixture
def mock_workspace_manager():
"""Create a mock workspace manager that returns predictable file IDs."""
wm = MagicMock()
async def mock_write_file(content, filename):
file = MagicMock()
file.id = f"file-{filename[:10]}"
return file
wm.write_file = AsyncMock(side_effect=mock_write_file)
return wm
def _make_pdf_base64(size: int = 2000) -> str:
"""Create a valid PDF base64 string of specified size."""
pdf_content = b"%PDF-1.4 " + b"x" * size
return base64.b64encode(pdf_content).decode()
def _make_png_base64(size: int = 2000) -> str:
"""Create a valid PNG base64 string of specified size."""
png_content = b"\x89PNG\r\n\x1a\n" + b"\x00" * size
return base64.b64encode(png_content).decode()
# =============================================================================
# Decode and Validate Tests
# =============================================================================
class TestDecodeAndValidate:
"""Tests for _decode_and_validate function."""
def test_detects_pdf_magic_number(self):
"""Should detect valid PDF by magic number."""
pdf_b64 = _make_pdf_base64()
result = _decode_and_validate(pdf_b64)
assert result is not None
content, ext = result
assert ext == "pdf"
assert content.startswith(b"%PDF-")
def test_detects_png_magic_number(self):
"""Should detect valid PNG by magic number."""
png_b64 = _make_png_base64()
result = _decode_and_validate(png_b64)
assert result is not None
content, ext = result
assert ext == "png"
def test_detects_jpeg_magic_number(self):
"""Should detect valid JPEG by magic number."""
jpeg_content = b"\xff\xd8\xff\xe0" + b"\x00" * 2000
jpeg_b64 = base64.b64encode(jpeg_content).decode()
result = _decode_and_validate(jpeg_b64)
assert result is not None
_, ext = result
assert ext == "jpg"
def test_detects_gif_magic_number(self):
"""Should detect valid GIF by magic number."""
gif_content = b"GIF89a" + b"\x00" * 2000
gif_b64 = base64.b64encode(gif_content).decode()
result = _decode_and_validate(gif_b64)
assert result is not None
_, ext = result
assert ext == "gif"
def test_detects_webp_magic_number(self):
"""Should detect valid WebP by magic number."""
webp_content = b"RIFF\x00\x00\x00\x00WEBP" + b"\x00" * 2000
webp_b64 = base64.b64encode(webp_content).decode()
result = _decode_and_validate(webp_b64)
assert result is not None
_, ext = result
assert ext == "webp"
def test_rejects_small_content(self):
"""Should reject content smaller than threshold."""
small_pdf = b"%PDF-1.4 small"
small_b64 = base64.b64encode(small_pdf).decode()
result = _decode_and_validate(small_b64)
assert result is None
def test_rejects_no_magic_number(self):
"""Should reject content without recognized magic number."""
random_content = b"This is just random text" * 100
random_b64 = base64.b64encode(random_content).decode()
result = _decode_and_validate(random_b64)
assert result is None
def test_rejects_invalid_base64(self):
"""Should reject invalid base64."""
result = _decode_and_validate("not-valid-base64!!!")
assert result is None
def test_rejects_riff_without_webp(self):
"""Should reject RIFF files that aren't WebP (e.g., WAV)."""
wav_content = b"RIFF\x00\x00\x00\x00WAVE" + b"\x00" * 2000
wav_b64 = base64.b64encode(wav_content).decode()
result = _decode_and_validate(wav_b64)
assert result is None
def test_handles_line_wrapped_base64(self):
"""Should handle RFC 2045 line-wrapped base64."""
pdf_content = b"%PDF-1.4 " + b"x" * 2000
pdf_b64 = base64.b64encode(pdf_content).decode()
# Simulate line wrapping at 76 chars
wrapped = "\n".join(pdf_b64[i : i + 76] for i in range(0, len(pdf_b64), 76))
result = _decode_and_validate(wrapped)
assert result is not None
content, ext = result
assert ext == "pdf"
assert content == pdf_content
# =============================================================================
# Marker Expansion Tests
# =============================================================================
class TestExpandToMarkers:
"""Tests for _expand_to_markers function."""
def test_expands_base64_start_end_markers(self):
"""Should expand to include ---BASE64_START--- and ---BASE64_END---."""
text = "prefix\n---BASE64_START---\nABCDEF\n---BASE64_END---\nsuffix"
# Base64 "ABCDEF" is at position 26-32
start, end = _expand_to_markers(text, 26, 32)
assert text[start:end] == "---BASE64_START---\nABCDEF\n---BASE64_END---"
def test_expands_bracket_markers(self):
"""Should expand to include [BASE64] and [/BASE64] markers."""
text = "prefix[BASE64]ABCDEF[/BASE64]suffix"
# Base64 is at position 14-20
start, end = _expand_to_markers(text, 14, 20)
assert text[start:end] == "[BASE64]ABCDEF[/BASE64]"
def test_no_expansion_without_markers(self):
"""Should not expand if no markers present."""
text = "prefix ABCDEF suffix"
start, end = _expand_to_markers(text, 7, 13)
assert start == 7
assert end == 13
# =============================================================================
# Process Binary Outputs Tests
# =============================================================================
@pytest.fixture
def mock_scan():
"""Patch virus scanner for tests."""
with patch(
"backend.api.features.chat.tools.binary_output_processor.scan_content_safe",
new_callable=AsyncMock,
) as mock:
yield mock
class TestProcessBinaryOutputs:
"""Tests for process_binary_outputs function."""
@pytest.mark.asyncio
async def test_detects_embedded_pdf_in_stdout_logs(
self, mock_workspace_manager, mock_scan
):
"""Should detect and replace embedded PDF in stdout_logs."""
pdf_b64 = _make_pdf_base64()
stdout = f"PDF generated!\n---BASE64_START---\n{pdf_b64}\n---BASE64_END---\n"
outputs = {"stdout_logs": [stdout]}
result = await process_binary_outputs(
outputs, mock_workspace_manager, "ExecuteCodeBlock"
)
# Should contain workspace reference, not base64
assert "workspace://" in result["stdout_logs"][0]
assert pdf_b64 not in result["stdout_logs"][0]
assert "PDF generated!" in result["stdout_logs"][0]
mock_workspace_manager.write_file.assert_called_once()
@pytest.mark.asyncio
async def test_detects_embedded_png_without_markers(
self, mock_workspace_manager, mock_scan
):
"""Should detect embedded PNG even without markers."""
png_b64 = _make_png_base64()
stdout = f"Image created: {png_b64} done"
outputs = {"stdout_logs": [stdout]}
result = await process_binary_outputs(
outputs, mock_workspace_manager, "ExecuteCodeBlock"
)
assert "workspace://" in result["stdout_logs"][0]
assert "Image created:" in result["stdout_logs"][0]
assert "done" in result["stdout_logs"][0]
@pytest.mark.asyncio
async def test_preserves_small_strings(self, mock_workspace_manager, mock_scan):
"""Should not process small strings."""
outputs = {"stdout_logs": ["small output"]}
result = await process_binary_outputs(
outputs, mock_workspace_manager, "TestBlock"
)
assert result["stdout_logs"][0] == "small output"
mock_workspace_manager.write_file.assert_not_called()
@pytest.mark.asyncio
async def test_preserves_non_binary_large_strings(
self, mock_workspace_manager, mock_scan
):
"""Should preserve large strings that don't contain valid binary."""
large_text = "A" * 5000 # Large string - decodes to nulls, no magic number
outputs = {"stdout_logs": [large_text]}
result = await process_binary_outputs(
outputs, mock_workspace_manager, "TestBlock"
)
assert result["stdout_logs"][0] == large_text
mock_workspace_manager.write_file.assert_not_called()
@pytest.mark.asyncio
async def test_deduplicates_identical_content(
self, mock_workspace_manager, mock_scan
):
"""Should save identical content only once."""
pdf_b64 = _make_pdf_base64()
stdout1 = f"First: {pdf_b64}"
stdout2 = f"Second: {pdf_b64}"
outputs = {"stdout_logs": [stdout1, stdout2]}
result = await process_binary_outputs(
outputs, mock_workspace_manager, "TestBlock"
)
# Both should have references
assert "workspace://" in result["stdout_logs"][0]
assert "workspace://" in result["stdout_logs"][1]
# But only one write
assert mock_workspace_manager.write_file.call_count == 1
@pytest.mark.asyncio
async def test_handles_multiple_binaries_in_one_string(
self, mock_workspace_manager, mock_scan
):
"""Should handle multiple embedded binaries in a single string."""
pdf_b64 = _make_pdf_base64()
png_b64 = _make_png_base64()
stdout = f"PDF: {pdf_b64}\nPNG: {png_b64}"
outputs = {"stdout_logs": [stdout]}
result = await process_binary_outputs(
outputs, mock_workspace_manager, "TestBlock"
)
# Should have two workspace references
assert result["stdout_logs"][0].count("workspace://") == 2
assert mock_workspace_manager.write_file.call_count == 2
@pytest.mark.asyncio
async def test_processes_nested_structures(self, mock_workspace_manager, mock_scan):
"""Should recursively process nested dicts and lists."""
pdf_b64 = _make_pdf_base64()
outputs = {"result": [{"nested": {"deep": f"data: {pdf_b64}"}}]}
result = await process_binary_outputs(
outputs, mock_workspace_manager, "TestBlock"
)
assert "workspace://" in result["result"][0]["nested"]["deep"]
@pytest.mark.asyncio
async def test_graceful_degradation_on_save_failure(
self, mock_workspace_manager, mock_scan
):
"""Should preserve original on save failure."""
mock_workspace_manager.write_file = AsyncMock(
side_effect=Exception("Storage error")
)
pdf_b64 = _make_pdf_base64()
stdout = f"PDF: {pdf_b64}"
outputs = {"stdout_logs": [stdout]}
result = await process_binary_outputs(
outputs, mock_workspace_manager, "TestBlock"
)
# Should keep original since save failed
assert pdf_b64 in result["stdout_logs"][0]
# =============================================================================
# Offset Loop Tests (handling marker bleed-in)
# =============================================================================
class TestOffsetLoopHandling:
"""Tests for the offset-aligned decoding that handles marker bleed-in."""
def test_handles_4char_aligned_prefix(self):
"""Should detect base64 when a 4-char aligned prefix bleeds into match.
When 'TEST' (4 chars, aligned) bleeds in, offset 4 finds valid base64.
"""
pdf_content = b"%PDF-1.4 " + b"x" * 2000
pdf_b64 = base64.b64encode(pdf_content).decode()
# 4-char prefix (aligned)
with_prefix = f"TEST{pdf_b64}"
result = _decode_and_validate(with_prefix)
assert result is not None
content, ext = result
assert ext == "pdf"
assert content == pdf_content
def test_handles_8char_aligned_prefix(self):
"""Should detect base64 when an 8-char prefix bleeds into match."""
pdf_content = b"%PDF-1.4 " + b"x" * 2000
pdf_b64 = base64.b64encode(pdf_content).decode()
# 8-char prefix (aligned)
with_prefix = f"TESTTEST{pdf_b64}"
result = _decode_and_validate(with_prefix)
assert result is not None
content, ext = result
assert ext == "pdf"
def test_handles_misaligned_prefix(self):
"""Should handle misaligned prefix by finding a valid aligned offset.
'START' is 5 chars (misaligned). The loop tries offsets 0, 4, 8...
Since characters 0-4 include 'START' which is invalid base64 on its own,
we need the full PDF base64 to eventually decode correctly at some offset.
"""
pdf_content = b"%PDF-1.4 " + b"x" * 2000
pdf_b64 = base64.b64encode(pdf_content).decode()
# 5-char prefix - misaligned, but offset 4 should start mid-'START'
# and offset 8 will be past the prefix
with_prefix = f"START{pdf_b64}"
result = _decode_and_validate(with_prefix)
# Should find valid PDF at some offset (8 in this case)
assert result is not None
_, ext = result
assert ext == "pdf"
def test_handles_pdf_base64_start_marker_bleed(self):
"""Should handle PDF_BASE64_START marker bleeding into regex match.
This is the real-world case: regex matches 'STARTJVBERi0...' because
'START' chars are in the base64 alphabet. Offset loop skips past it.
PDF_BASE64_START is 16 chars (4-aligned), so offset 16 finds valid base64.
"""
pdf_content = b"%PDF-1.4 " + b"x" * 2000
pdf_b64 = base64.b64encode(pdf_content).decode()
# Simulate regex capturing 'PDF_BASE64_START' + base64 together
# This happens when there's no delimiter between marker and content
with_full_marker = f"PDF_BASE64_START{pdf_b64}"
result = _decode_and_validate(with_full_marker)
assert result is not None
_, ext = result
assert ext == "pdf"
def test_clean_base64_works_at_offset_zero(self):
"""Should detect clean base64 at offset 0 without issues."""
pdf_content = b"%PDF-1.4 " + b"x" * 2000
pdf_b64 = base64.b64encode(pdf_content).decode()
result = _decode_and_validate(pdf_b64)
assert result is not None
content, ext = result
assert ext == "pdf"
assert content == pdf_content
# =============================================================================
# PDF Marker Tests
# =============================================================================
class TestPdfMarkerExpansion:
"""Tests for PDF_BASE64_START/END marker handling."""
def test_expands_pdf_base64_start_marker(self):
"""Should expand to include PDF_BASE64_START marker."""
text = "prefixPDF_BASE64_STARTABCDEF"
# Base64 'ABCDEF' is at position 22-28
start, end = _expand_to_markers(text, 22, 28)
assert text[start:end] == "PDF_BASE64_STARTABCDEF"
def test_expands_pdf_base64_end_marker(self):
"""Should expand to include PDF_BASE64_END marker."""
text = "ABCDEFPDF_BASE64_ENDsuffix"
# Base64 'ABCDEF' is at position 0-6
start, end = _expand_to_markers(text, 0, 6)
assert text[start:end] == "ABCDEFPDF_BASE64_END"
def test_expands_both_pdf_markers(self):
"""Should expand to include both PDF_BASE64_START and END."""
text = "xPDF_BASE64_STARTABCDEFPDF_BASE64_ENDy"
# Base64 'ABCDEF' is at position 17-23
start, end = _expand_to_markers(text, 17, 23)
assert text[start:end] == "PDF_BASE64_STARTABCDEFPDF_BASE64_END"
def test_partial_marker_not_expanded(self):
"""Should not expand if only partial marker present."""
text = "BASE64_STARTABCDEF" # Missing 'PDF_' prefix
start, end = _expand_to_markers(text, 12, 18)
# Should not expand since it's not the full marker
assert start == 12
assert end == 18
@pytest.mark.asyncio
async def test_full_pipeline_with_pdf_markers(self, mock_workspace_manager):
"""Test full pipeline with PDF_BASE64_START/END markers."""
pdf_b64 = _make_pdf_base64()
stdout = f"Output: PDF_BASE64_START{pdf_b64}PDF_BASE64_END done"
outputs = {"stdout_logs": [stdout]}
with patch(
"backend.api.features.chat.tools.binary_output_processor.scan_content_safe",
new_callable=AsyncMock,
):
result = await process_binary_outputs(
outputs, mock_workspace_manager, "TestBlock"
)
# Should have workspace reference
assert "workspace://" in result["stdout_logs"][0]
# Markers should be consumed along with base64
assert "PDF_BASE64_START" not in result["stdout_logs"][0]
assert "PDF_BASE64_END" not in result["stdout_logs"][0]
# Surrounding text preserved
assert "Output:" in result["stdout_logs"][0]
assert "done" in result["stdout_logs"][0]
# =============================================================================
# Virus Scanning Tests
# =============================================================================
class TestVirusScanning:
"""Tests for virus scanning integration."""
@pytest.mark.asyncio
async def test_calls_virus_scanner_before_save(self, mock_workspace_manager):
"""Should call scan_content_safe before writing file."""
pdf_b64 = _make_pdf_base64()
stdout = f"PDF: {pdf_b64}"
outputs = {"stdout_logs": [stdout]}
with patch(
"backend.api.features.chat.tools.binary_output_processor.scan_content_safe",
new_callable=AsyncMock,
) as mock_scan:
result = await process_binary_outputs(
outputs, mock_workspace_manager, "TestBlock"
)
# Verify scanner was called
mock_scan.assert_called_once()
# Verify file was written after scan
mock_workspace_manager.write_file.assert_called_once()
# Verify result has workspace reference
assert "workspace://" in result["stdout_logs"][0]
@pytest.mark.asyncio
async def test_virus_scan_failure_preserves_original(self, mock_workspace_manager):
"""Should preserve original if virus scan fails."""
pdf_b64 = _make_pdf_base64()
stdout = f"PDF: {pdf_b64}"
outputs = {"stdout_logs": [stdout]}
with patch(
"backend.api.features.chat.tools.binary_output_processor.scan_content_safe",
new_callable=AsyncMock,
side_effect=Exception("Virus detected"),
):
result = await process_binary_outputs(
outputs, mock_workspace_manager, "TestBlock"
)
# Should keep original since scan failed
assert pdf_b64 in result["stdout_logs"][0]
# File should not have been written
mock_workspace_manager.write_file.assert_not_called()

View File

@@ -8,6 +8,7 @@ from backend.api.features.library import model as library_model
from backend.api.features.store import db as store_db
from backend.data.graph import GraphModel
from backend.data.model import (
Credentials,
CredentialsFieldInfo,
CredentialsMetaInput,
HostScopedCredentials,
@@ -223,6 +224,99 @@ async def get_or_create_library_agent(
return library_agents[0]
async def match_credentials_to_requirements(
user_id: str,
requirements: dict[str, CredentialsFieldInfo],
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Match user's credentials against a dictionary of credential requirements.
This is the core matching logic shared by both graph and block credential matching.
"""
matched: dict[str, CredentialsMetaInput] = {}
missing: list[CredentialsMetaInput] = []
if not requirements:
return matched, missing
available_creds = await get_user_credentials(user_id)
for field_name, field_info in requirements.items():
matching_cred = find_matching_credential(available_creds, field_info)
if matching_cred:
try:
matched[field_name] = create_credential_meta_from_match(matching_cred)
except Exception as e:
logger.error(
f"Failed to create CredentialsMetaInput for field '{field_name}': "
f"provider={matching_cred.provider}, type={matching_cred.type}, "
f"credential_id={matching_cred.id}",
exc_info=True,
)
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=f"{field_name} (validation failed: {e})",
)
)
else:
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched, missing
async def get_user_credentials(user_id: str) -> list[Credentials]:
"""Get all available credentials for a user."""
creds_manager = IntegrationCredentialsManager()
return await creds_manager.store.get_all_creds(user_id)
def find_matching_credential(
available_creds: list[Credentials],
field_info: CredentialsFieldInfo,
) -> Credentials | None:
"""Find a credential that matches the required provider, type, scopes, and host."""
for cred in available_creds:
if cred.provider not in field_info.provider:
continue
if cred.type not in field_info.supported_types:
continue
if cred.type == "oauth2" and not _credential_has_required_scopes(
cred, field_info
):
continue
if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info):
continue
return cred
return None
def create_credential_meta_from_match(
matching_cred: Credentials,
) -> CredentialsMetaInput:
"""Create a CredentialsMetaInput from a matched credential."""
return CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
async def match_user_credentials_to_graph(
user_id: str,
graph: GraphModel,
@@ -331,8 +425,6 @@ def _credential_has_required_scopes(
# If no scopes are required, any credential matches
if not requirements.required_scopes:
return True
# Check that credential scopes are a superset of required scopes
return set(credential.scopes).issuperset(requirements.required_scopes)

View File

@@ -1,3 +1,4 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from enum import Enum
@@ -225,6 +226,10 @@ class SyncRabbitMQ(RabbitMQBase):
class AsyncRabbitMQ(RabbitMQBase):
"""Asynchronous RabbitMQ client"""
def __init__(self, config: RabbitMQConfig):
super().__init__(config)
self._reconnect_lock: asyncio.Lock | None = None
@property
def is_connected(self) -> bool:
return bool(self._connection and not self._connection.is_closed)
@@ -235,7 +240,17 @@ class AsyncRabbitMQ(RabbitMQBase):
@conn_retry("AsyncRabbitMQ", "Acquiring async connection")
async def connect(self):
if self.is_connected:
if self.is_connected and self._channel and not self._channel.is_closed:
return
if (
self.is_connected
and self._connection
and (self._channel is None or self._channel.is_closed)
):
self._channel = await self._connection.channel()
await self._channel.set_qos(prefetch_count=1)
await self.declare_infrastructure()
return
self._connection = await aio_pika.connect_robust(
@@ -291,24 +306,46 @@ class AsyncRabbitMQ(RabbitMQBase):
exchange, routing_key=queue.routing_key or queue.name
)
@func_retry
async def publish_message(
@property
def _lock(self) -> asyncio.Lock:
if self._reconnect_lock is None:
self._reconnect_lock = asyncio.Lock()
return self._reconnect_lock
async def _ensure_channel(self) -> aio_pika.abc.AbstractChannel:
"""Get a valid channel, reconnecting if the current one is stale.
Uses a lock to prevent concurrent reconnection attempts from racing.
"""
if self.is_ready:
return self._channel # type: ignore # is_ready guarantees non-None
async with self._lock:
# Double-check after acquiring lock
if self.is_ready:
return self._channel # type: ignore
self._channel = None
await self.connect()
if self._channel is None:
raise RuntimeError("Channel should be established after connect")
return self._channel
async def _publish_once(
self,
routing_key: str,
message: str,
exchange: Optional[Exchange] = None,
persistent: bool = True,
) -> None:
if not self.is_ready:
await self.connect()
if self._channel is None:
raise RuntimeError("Channel should be established after connect")
channel = await self._ensure_channel()
if exchange:
exchange_obj = await self._channel.get_exchange(exchange.name)
exchange_obj = await channel.get_exchange(exchange.name)
else:
exchange_obj = self._channel.default_exchange
exchange_obj = channel.default_exchange
await exchange_obj.publish(
aio_pika.Message(
@@ -322,9 +359,23 @@ class AsyncRabbitMQ(RabbitMQBase):
routing_key=routing_key,
)
@func_retry
async def publish_message(
self,
routing_key: str,
message: str,
exchange: Optional[Exchange] = None,
persistent: bool = True,
) -> None:
try:
await self._publish_once(routing_key, message, exchange, persistent)
except aio_pika.exceptions.ChannelInvalidStateError:
logger.warning(
"RabbitMQ channel invalid, forcing reconnect and retrying publish"
)
async with self._lock:
self._channel = None
await self._publish_once(routing_key, message, exchange, persistent)
async def get_channel(self) -> aio_pika.abc.AbstractChannel:
if not self.is_ready:
await self.connect()
if self._channel is None:
raise RuntimeError("Channel should be established after connect")
return self._channel
return await self._ensure_channel()

View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
[[package]]
name = "aio-pika"
@@ -374,7 +374,7 @@ description = "LTS Port of Python audioop"
optional = false
python-versions = ">=3.13"
groups = ["main"]
markers = "python_version >= \"3.13\""
markers = "python_version == \"3.13\""
files = [
{file = "audioop_lts-0.2.2-cp313-abi3-macosx_10_13_universal2.whl", hash = "sha256:fd3d4602dc64914d462924a08c1a9816435a2155d74f325853c1f1ac3b2d9800"},
{file = "audioop_lts-0.2.2-cp313-abi3-macosx_10_13_x86_64.whl", hash = "sha256:550c114a8df0aafe9a05442a1162dfc8fec37e9af1d625ae6060fed6e756f303"},
@@ -474,7 +474,7 @@ description = "Backport of asyncio.Runner, a context manager that controls event
optional = false
python-versions = "<3.11,>=3.8"
groups = ["main"]
markers = "python_version < \"3.11\""
markers = "python_version == \"3.10\""
files = [
{file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"},
{file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"},
@@ -487,7 +487,7 @@ description = "Backport of CPython tarfile module"
optional = false
python-versions = ">=3.8"
groups = ["main"]
markers = "python_version <= \"3.11\""
markers = "python_version < \"3.12\""
files = [
{file = "backports.tarfile-1.2.0-py3-none-any.whl", hash = "sha256:77e284d754527b01fb1e6fa8a1afe577858ebe4e9dad8919e34c862cb399bc34"},
{file = "backports_tarfile-1.2.0.tar.gz", hash = "sha256:d75e02c268746e1b8144c278978b6e98e85de6ad16f8e4b0844a154557eca991"},
@@ -563,6 +563,18 @@ webencodings = "*"
[package.extras]
css = ["tinycss2 (>=1.1.0,<1.5)"]
[[package]]
name = "bracex"
version = "2.6"
description = "Bash style brace expander."
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "bracex-2.6-py3-none-any.whl", hash = "sha256:0b0049264e7340b3ec782b5cb99beb325f36c3782a32e36e876452fd49a09952"},
{file = "bracex-2.6.tar.gz", hash = "sha256:98f1347cd77e22ee8d967a30ad4e310b233f7754dbf31ff3fceb76145ba47dc7"},
]
[[package]]
name = "browserbase"
version = "1.4.0"
@@ -659,7 +671,6 @@ description = "Foreign Function Interface for Python calling C code."
optional = false
python-versions = ">=3.9"
groups = ["main"]
markers = "platform_python_implementation != \"PyPy\" or sys_platform == \"darwin\""
files = [
{file = "cffi-2.0.0-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:0cf2d91ecc3fcc0625c2c530fe004f82c110405f101548512cce44322fa8ac44"},
{file = "cffi-2.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f73b96c41e3b2adedc34a7356e64c8eb96e03a3782b535e043a986276ce12a49"},
@@ -1149,6 +1160,18 @@ idna = ["idna (>=3.10)"]
trio = ["trio (>=0.30)"]
wmi = ["wmi (>=1.5.1) ; platform_system == \"Windows\""]
[[package]]
name = "dockerfile-parse"
version = "2.0.1"
description = "Python library for Dockerfile manipulation"
optional = false
python-versions = ">=3.6"
groups = ["main"]
files = [
{file = "dockerfile-parse-2.0.1.tar.gz", hash = "sha256:3184ccdc513221983e503ac00e1aa504a2aa8f84e5de673c46b0b6eee99ec7bc"},
{file = "dockerfile_parse-2.0.1-py2.py3-none-any.whl", hash = "sha256:bdffd126d2eb26acf1066acb54cb2e336682e1d72b974a40894fac76a4df17f6"},
]
[[package]]
name = "docstring-parser"
version = "0.17.0"
@@ -1235,40 +1258,43 @@ pgp = ["gpg"]
[[package]]
name = "e2b"
version = "1.11.1"
version = "2.13.2"
description = "E2B SDK that give agents cloud environments"
optional = false
python-versions = "<4.0,>=3.9"
python-versions = "<4.0,>=3.10"
groups = ["main"]
files = [
{file = "e2b-1.11.1-py3-none-any.whl", hash = "sha256:1ecb123873788472731c101939a494ab852cbcce0f913df6f7ecb194ae932130"},
{file = "e2b-1.11.1.tar.gz", hash = "sha256:7f7b6f238208d0a23353bb0da01f91a924321b57c61b176506862cbc1493ce8c"},
{file = "e2b-2.13.2-py3-none-any.whl", hash = "sha256:d91d5293bc0dd1917c72a6e6b35e86513607be2666a14ae18c57b921e7864de4"},
{file = "e2b-2.13.2.tar.gz", hash = "sha256:c0e81a3920091874fdf73c0b8f376b28766212db9f1cea5d8bd56a2e95d2436c"},
]
[package.dependencies]
attrs = ">=23.2.0"
dockerfile-parse = ">=2.0.1,<3.0.0"
httpcore = ">=1.0.5,<2.0.0"
httpx = ">=0.27.0,<1.0.0"
packaging = ">=24.1"
protobuf = ">=4.21.0"
python-dateutil = ">=2.8.2"
rich = ">=14.0.0"
typing-extensions = ">=4.1.0"
wcmatch = ">=10.1,<11.0"
[[package]]
name = "e2b-code-interpreter"
version = "1.5.2"
version = "2.4.1"
description = "E2B Code Interpreter - Stateful code execution"
optional = false
python-versions = "<4.0,>=3.9"
groups = ["main"]
files = [
{file = "e2b_code_interpreter-1.5.2-py3-none-any.whl", hash = "sha256:5c3188d8f25226b28fef4b255447cc6a4c36afb748bdd5180b45be486d5169f3"},
{file = "e2b_code_interpreter-1.5.2.tar.gz", hash = "sha256:3bd6ea70596290e85aaf0a2f19f28bf37a5e73d13086f5e6a0080bb591c5a547"},
{file = "e2b_code_interpreter-2.4.1-py3-none-any.whl", hash = "sha256:15d35f025b4a15033e119f2e12e7ac65657ad2b5a013fa9149e74581fbee778a"},
{file = "e2b_code_interpreter-2.4.1.tar.gz", hash = "sha256:4b15014ee0d0dfcdc3072e1f409cbb87ca48f48d53d75629b7257e5513b9e7dd"},
]
[package.dependencies]
attrs = ">=21.3.0"
e2b = ">=1.5.4,<2.0.0"
e2b = ">=2.7.0,<3.0.0"
httpx = ">=0.20.0,<1.0.0"
[[package]]
@@ -1338,7 +1364,7 @@ description = "Backport of PEP 654 (exception groups)"
optional = false
python-versions = ">=3.7"
groups = ["main", "dev"]
markers = "python_version < \"3.11\""
markers = "python_version == \"3.10\""
files = [
{file = "exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598"},
{file = "exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219"},
@@ -1820,16 +1846,16 @@ files = [
google-auth = ">=2.14.1,<3.0.0"
googleapis-common-protos = ">=1.56.2,<2.0.0"
grpcio = [
{version = ">=1.33.2,<2.0.0", optional = true, markers = "extra == \"grpc\""},
{version = ">=1.49.1,<2.0.0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
{version = ">=1.33.2,<2.0.0", optional = true, markers = "extra == \"grpc\""},
]
grpcio-status = [
{version = ">=1.33.2,<2.0.0", optional = true, markers = "extra == \"grpc\""},
{version = ">=1.49.1,<2.0.0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
{version = ">=1.33.2,<2.0.0", optional = true, markers = "extra == \"grpc\""},
]
proto-plus = [
{version = ">=1.22.3,<2.0.0"},
{version = ">=1.25.0,<2.0.0", markers = "python_version >= \"3.13\""},
{version = ">=1.22.3,<2.0.0"},
]
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
requests = ">=2.18.0,<3.0.0"
@@ -1940,8 +1966,8 @@ google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0", extras
google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0"
grpcio = ">=1.33.2,<2.0.0"
proto-plus = [
{version = ">=1.22.3,<2.0.0"},
{version = ">=1.25.0,<2.0.0", markers = "python_version >= \"3.13\""},
{version = ">=1.22.3,<2.0.0"},
]
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
@@ -2001,9 +2027,9 @@ google-cloud-core = ">=2.0.0,<3.0.0"
grpc-google-iam-v1 = ">=0.12.4,<1.0.0"
opentelemetry-api = ">=1.9.0"
proto-plus = [
{version = ">=1.22.0,<2.0.0"},
{version = ">=1.22.2,<2.0.0", markers = "python_version >= \"3.11\""},
{version = ">=1.25.0,<2.0.0", markers = "python_version >= \"3.13\""},
{version = ">=1.22.2,<2.0.0", markers = "python_version >= \"3.11\" and python_version < \"3.13\""},
{version = ">=1.22.0,<2.0.0", markers = "python_version < \"3.11\""},
]
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
@@ -3802,7 +3828,7 @@ description = "Fundamental package for array computing in Python"
optional = false
python-versions = ">=3.10"
groups = ["main"]
markers = "python_version < \"3.11\""
markers = "python_version == \"3.10\""
files = [
{file = "numpy-2.2.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b412caa66f72040e6d268491a59f2c43bf03eb6c96dd8f0307829feb7fa2b6fb"},
{file = "numpy-2.2.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8e41fd67c52b86603a91c1a505ebaef50b3314de0213461c7a6e99c9a3beff90"},
@@ -4287,9 +4313,9 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
]
python-dateutil = ">=2.8.2"
pytz = ">=2020.1"
@@ -4532,8 +4558,8 @@ pinecone-plugin-interface = ">=0.0.7,<0.0.8"
python-dateutil = ">=2.5.3"
typing-extensions = ">=3.7.4"
urllib3 = [
{version = ">=1.26.0", markers = "python_version >= \"3.8\" and python_version < \"3.12\""},
{version = ">=1.26.5", markers = "python_version >= \"3.12\" and python_version < \"4.0\""},
{version = ">=1.26.0", markers = "python_version >= \"3.8\" and python_version < \"3.12\""},
]
[package.extras]
@@ -5361,7 +5387,7 @@ description = "C parser in Python"
optional = false
python-versions = ">=3.10"
groups = ["main"]
markers = "(platform_python_implementation != \"PyPy\" or sys_platform == \"darwin\") and implementation_name != \"PyPy\""
markers = "implementation_name != \"PyPy\""
files = [
{file = "pycparser-3.0-py3-none-any.whl", hash = "sha256:b727414169a36b7d524c1c3e31839a521725078d7b2ff038656844266160a992"},
{file = "pycparser-3.0.tar.gz", hash = "sha256:600f49d217304a5902ac3c37e1281c9fe94e4d0489de643a9504c5cdfdfc6b29"},
@@ -6130,10 +6156,10 @@ files = [
grpcio = ">=1.41.0"
httpx = {version = ">=0.20.0", extras = ["http2"]}
numpy = [
{version = ">=1.21,<2.3.0", markers = "python_version == \"3.10\""},
{version = ">=1.21", markers = "python_version == \"3.11\""},
{version = ">=2.1.0", markers = "python_version == \"3.13\""},
{version = ">=1.21", markers = "python_version == \"3.11\""},
{version = ">=1.26", markers = "python_version == \"3.12\""},
{version = ">=1.21,<2.3.0", markers = "python_version == \"3.10\""},
]
portalocker = ">=2.7.0,<4.0"
protobuf = ">=3.20.0"
@@ -7317,7 +7343,7 @@ description = "A lil' TOML parser"
optional = false
python-versions = ">=3.8"
groups = ["main", "dev"]
markers = "python_version < \"3.11\""
markers = "python_version == \"3.10\""
files = [
{file = "tomli-2.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b5ef256a3fd497d4973c11bf142e9ed78b150d36f5773f1ca6088c230ffc5867"},
{file = "tomli-2.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5572e41282d5268eb09a697c89a7bee84fae66511f87533a6f88bd2f7b652da9"},
@@ -7841,6 +7867,21 @@ files = [
[package.dependencies]
anyio = ">=3.0.0"
[[package]]
name = "wcmatch"
version = "10.1"
description = "Wildcard/glob file name matcher."
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "wcmatch-10.1-py3-none-any.whl", hash = "sha256:5848ace7dbb0476e5e55ab63c6bbd529745089343427caa5537f230cc01beb8a"},
{file = "wcmatch-10.1.tar.gz", hash = "sha256:f11f94208c8c8484a16f4f48638a85d771d9513f4ab3f37595978801cb9465af"},
]
[package.dependencies]
bracex = ">=2.1.1"
[[package]]
name = "webencodings"
version = "0.5.1"
@@ -8440,4 +8481,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.14"
content-hash = "14686ee0e2dc446a75d0db145b08dc410dc31c357e25085bb0f9b0174711c4b1"
content-hash = "f1f229017d133bab1cb5b787a93f8d6d652c2712e07d4966358e725a57e35e80"

View File

@@ -19,7 +19,7 @@ bleach = { extras = ["css"], version = "^6.2.0" }
click = "^8.2.0"
cryptography = "^46.0"
discord-py = "^2.5.2"
e2b-code-interpreter = "^1.5.2"
e2b-code-interpreter = "^2.4.1"
elevenlabs = "^1.50.0"
fastapi = "^0.128.5"
feedparser = "^6.0.11"

View File

@@ -104,7 +104,31 @@ export function FileInput(props: Props) {
return false;
}
const getFileLabelFromValue = (val: string) => {
const getFileLabelFromValue = (val: unknown): string => {
// Handle object format from external API: { name, type, size, data }
if (val && typeof val === "object") {
const obj = val as Record<string, unknown>;
if (typeof obj.name === "string") {
return getFileLabel(
obj.name,
typeof obj.type === "string" ? obj.type : "",
);
}
if (typeof obj.type === "string") {
const mimeParts = obj.type.split("/");
if (mimeParts.length > 1) {
return `${mimeParts[1].toUpperCase()} file`;
}
return `${obj.type} file`;
}
return "File";
}
// Handle string values (data URIs or file paths)
if (typeof val !== "string") {
return "File";
}
if (val.startsWith("data:")) {
const matches = val.match(/^data:([^;]+);/);
if (matches?.[1]) {