Compare commits

..

13 Commits

Author SHA1 Message Date
Otto
ed07f02738 fix(copilot): edit_agent updates existing agent instead of creating duplicate (#11981)
## Summary

When editing an agent via CoPilot's `edit_agent` tool, the code was
always creating a new `LibraryAgent` entry instead of updating the
existing one to point to the new graph version. This caused duplicate
agents to appear in the user's library.

## Changes

In `save_agent_to_library()`:
- When `is_update=True`, now checks if there's an existing library agent
for the graph using `get_library_agent_by_graph_id()`
- If found, uses `update_agent_version_in_library()` to update the
existing library agent to point to the new version
- Falls back to creating a new library agent if no existing one is found
(e.g., if editing a graph that wasn't added to library yet)

## Testing

- Verified lint/format checks pass
- Plan reviewed and approved by Staff Engineer Plan Reviewer agent

## Related

Fixes [SECRT-1857](https://linear.app/autogpt/issue/SECRT-1857)

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-02-05 15:02:26 +00:00
Swifty
b121030c94 feat(frontend): Add progress indicator during agent generation [SECRT-1883] (#11974)
## Summary
- Add asymptotic progress bar that appears during long-running chat
tasks
- Progress bar shows after 10 seconds with "Working on it..." label and
percentage
- Uses half-life formula: ~50% at 30s, ~75% at 60s, ~87.5% at 90s, etc.
- Creates the classic "game loading bar" effect that never reaches 100%



https://github.com/user-attachments/assets/3c59289e-793c-4a08-b3fc-69e1eef28b1f



## Test plan
- [x] Start a chat that triggers agent generation
- [x] Wait 10+ seconds for the progress bar to appear
- [x] Verify progress bar is centered with label and percentage
- [x] Verify progress follows expected timing (~50% at 30s)
- [x] Verify progress bar disappears when task completes

---------

Co-authored-by: Otto <otto@agpt.co>
2026-02-05 15:37:51 +01:00
Swifty
c22c18374d feat(frontend): Add ready-to-test prompt after agent creation [SECRT-1882] (#11975)
## Summary
- Add special UI prompt when agent is successfully created in chat
- Show "Agent Created Successfully" with agent name
- Provide two action buttons:
- **Run with example values**: Sends chat message asking AI to run with
placeholders
- **Run with my inputs**: Opens RunAgentModal for custom input
configuration
- After run/schedule, automatically send chat message with execution
details for AI monitoring



https://github.com/user-attachments/assets/b11e118c-de59-4b79-a629-8bd0d52d9161



## Test plan
- [x] Create an agent through chat
- [x] Verify "Agent Created Successfully" prompt appears
- [x] Click "Run with example values" - verify chat message is sent
- [x] Click "Run with my inputs" - verify RunAgentModal opens
- [x] Fill inputs and run - verify chat message with execution ID is
sent
- [x] Fill inputs and schedule - verify chat message with schedule
details is sent

---------

Co-authored-by: Otto <otto@agpt.co>
2026-02-05 15:37:31 +01:00
Swifty
e40233a3ac fix(backend/chat): Guide find_agent users toward action with CTAs (#11976)
When users search for agents, guide them toward creating custom agents
if no results are found or after showing results. This improves user
engagement by offering a clear next step.

### Changes 🏗️

- Updated `agent_search.py` to add CTAs in search responses
- Added messaging to inform users they can create custom agents based on
their needs
- Applied to both "no results found" and "agents found" scenarios

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Search for agents in marketplace with matching results
  - [x] Search for agents in marketplace with no results
  - [x] Search for agents in library with matching results  
  - [x] Search for agents in library with no results
  - [x] Verify CTA message appears in all cases

---------

Co-authored-by: Otto <otto@agpt.co>
2026-02-05 15:36:55 +01:00
Swifty
3ae5eabf9d fix(backend/chat): Use latest prompt label in non-production environments (#11977)
In non-production environments, the chat service now fetches prompts
with the `latest` label instead of the default production-labeled
prompt. This makes it easier to test and iterate on prompt changes in
dev/staging without needing to promote them to production first.

### Changes 🏗️

- Updated `_get_system_prompt_template()` in chat service to pass
`label="latest"` when `app_env` is not `PRODUCTION`
- Production environments continue using the default behavior
(production-labeled prompts)

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Verified that in non-production environments, prompts with
`latest` label are fetched
- [x] Verified that production environments still use the default
(production) labeled prompts

Co-authored-by: Otto <otto@agpt.co>
2026-02-05 14:54:39 +01:00
Otto
a077ba9f03 fix(platform): YouTube block yields only error on failure (#11980)
## Summary

Fixes [SECRT-1889](https://linear.app/autogpt/issue/SECRT-1889): The
YouTube transcription block was yielding both `video_id` and `error`
when the transcript fetch failed.

## Problem

The block yielded `video_id` immediately upon extracting it from the
URL, before attempting to fetch the transcript. If the transcript fetch
failed, both outputs were present.

```python
# Before
video_id = self.extract_video_id(input_data.youtube_url)
yield "video_id", video_id  # ← Yielded before transcript attempt

transcript = self.get_transcript(video_id, credentials)  # ← Could fail here
```

## Solution

Wrap the entire operation in try/except and only yield outputs after all
operations succeed:

```python
# After
try:
    video_id = self.extract_video_id(input_data.youtube_url)
    transcript = self.get_transcript(video_id, credentials)
    transcript_text = self.format_transcript(transcript=transcript)

    # Only yield after all operations succeed
    yield "video_id", video_id
    yield "transcript", transcript_text
except Exception as e:
    yield "error", str(e)
```

This follows the established pattern in other blocks (e.g.,
`ai_image_generator_block.py`).

## Testing

- All 10 unit tests pass (`test/blocks/test_youtube.py`)
- Lint/format checks pass

Co-authored-by: Toran Bruce Richards <toran.richards@gmail.com>
2026-02-05 11:51:32 +00:00
Bently
5401d54eaa fix(backend): Handle StreamHeartbeat in CoPilot stream handler (#11928)
### Changes 🏗️

Fixes **AUTOGPT-SERVER-7JA** (123 events since Jan 27, 2026).

#### Problem

`StreamHeartbeat` was added to keep SSE connections alive during
long-running tool executions (yielded every 15s while waiting). However,
the main `stream_chat_completion` handler's `elif` chain didn't have a
case for it:

```
StreamTextStart →  handled
StreamTextDelta →  handled
StreamTextEnd →  handled
StreamToolInputStart →  handled
StreamToolInputAvailable →  handled
StreamToolOutputAvailable →  handled
StreamFinish →  handled
StreamError →  handled
StreamUsage →  handled
StreamHeartbeat →  fell through to 'Unknown chunk type' error
```

This meant every heartbeat during tool execution generated a Sentry
error instead of keeping the connection alive.

#### Fix

Add `StreamHeartbeat` to the `elif` chain and yield it through. The
route handler already calls `to_sse()` on all yielded chunks, and
`StreamHeartbeat.to_sse()` correctly returns `: heartbeat\n\n` (SSE
comment format, ignored by clients but keeps proxies/load balancers
happy).

**1 file changed, 3 insertions.**
2026-02-05 12:04:46 +01:00
Otto
5ac89d7c0b fix(test): fix timing bug in test_block_credit_reset (#11978)
## Summary
Fixes the flaky `test_block_credit_reset` test that was failing on
multiple PRs with `assert 0 == 1000`.

## Root Cause
The test calls `disable_test_user_transactions()` which sets `updatedAt`
to 35 days ago from the **actual current time**. It then mocks
`time_now` to January 1st.

**The bug**: If the test runs in early February, 35 days ago is January
— the **same month** as the mocked `time_now`. The credit refill logic
only triggers when the balance snapshot is from a *different* month, so
no refill happens and the balance stays at 0.

## Fix
After calling `disable_test_user_transactions()`, explicitly set
`updatedAt` to December of the previous year. This ensures it's always
in a different month than the mocked `month1` (January), regardless of
when the test runs.

## Testing
CI will verify the fix.
2026-02-05 11:56:26 +01:00
Otto
4f908d5cb3 fix(platform): Improve Linear Search Block [SECRT-1880] (#11967)
## Summary

Implements [SECRT-1880](https://linear.app/autogpt/issue/SECRT-1880) -
Improve Linear Search Block

## Changes

### Models (`models.py`)
- Added `State` model with `id`, `name`, and `type` fields for workflow
state information
- Added `state: State | None` field to `Issue` model

### API Client (`_api.py`)
- Updated `try_search_issues()` to:
- Add `max_results` parameter (default 10, was ~50) to reduce token
usage
  - Add `team_id` parameter for team filtering
- Return `createdAt`, `state`, `project`, and `assignee` fields in
results
- Fixed `try_get_team_by_name()` to return descriptive error message
when team not found instead of crashing with `IndexError`

### Block (`issues.py`)
- Added `max_results` input parameter (1-100, default 10)
- Added `team_name` input parameter for optional team filtering
- Added `error` output field for graceful error handling
- Added categories (`PRODUCTIVITY`, `ISSUE_TRACKING`)
- Updated test fixtures to include new fields

## Breaking Changes

| Change | Before | After | Mitigation |
|--------|--------|-------|------------|
| Default result count | ~50 | 10 | Users can set `max_results` up to
100 if needed |

## Non-Breaking Changes

- `state` field added to `Issue` (optional, defaults to `None`)
- `max_results` param added (has default value)
- `team_name` param added (optional, defaults to `None`)
- `error` output added (follows established pattern from GitHub blocks)

## Testing

- [x] Format/lint checks pass
- [x] Unit test fixtures updated

Resolves SECRT-1880

---------

Co-authored-by: Toran Bruce Richards <toran.richards@gmail.com>
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Toran Bruce Richards <Torantulino@users.noreply.github.com>
2026-02-04 22:54:46 +00:00
Reinier van der Leer
c1aa684743 fix(platform/chat): Filter host-scoped credentials for run_agent tool (#11905)
- Fixes [SECRT-1851: \[Copilot\] `run_agent` tool doesn't filter
host-scoped credentials](https://linear.app/autogpt/issue/SECRT-1851)
- Follow-up to #11881

### Changes 🏗️

- Filter host-scoped credentials for `run_agent` tool
- Tighten validation on host input field in `HostScopedCredentialsModal`
- Use netloc (w/ port) rather than just hostname (w/o port) as host
scope

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - Create graph that requires host-scoped credentials to work
  - Create host-scoped credentials with a *different* host
  - Try to have Copilot run the graph
  - [x] -> no matching credentials available
  - Create new credentials
  - [x] -> works

---------

Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2026-02-04 16:27:14 +00:00
Otto
7e5b84cc5c fix(copilot): update homepage copy to focus on problem discovery (#11956)
## Summary
Update the CoPilot homepage to shift from "what do you want to
automate?" to "tell me about your problems." This lowers the barrier to
engagement by letting users describe their work frustrations instead of
requiring them to identify automations themselves.

## Changes
| Element | Before | After |
|---------|--------|-------|
| Headline | "What do you want to automate?" | "Tell me about your work
— I'll find what to automate." |
| Placeholder | "You can search or just ask - e.g. 'create a blog post
outline'" | "What's your role and what eats up most of your day? e.g.
'I'm a real estate agent and I hate...'" |
| Button 1 | "Show me what I can automate" | "I don't know where to
start, just ask me stuff" |
| Button 2 | "Design a custom workflow" | "I do the same thing every
week and it's killing me" |
| Button 3 | "Help me with content creation" | "Help me find where I'm
wasting my time" |
| Container | max-w-2xl | max-w-3xl |

> **Note on container width:** The `max-w-2xl` → `max-w-3xl` change is
just to keep the longer headline on one line. This works but may not be
the ideal solution — @lluis-xai should advise on the proper approach.

## Why This Matters
The current UX assumes users know what they want to automate. In
reality, most users know what frustrates them but can't identify
automations. The current screen blocks Otto from starting the discovery
conversation that leads to useful recommendations.

## Files Changed
- `autogpt_platform/frontend/src/app/(platform)/copilot/page.tsx` —
headline, placeholder, container width
- `autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts` —
quick action button text

Resolves: [SECRT-1876](https://linear.app/autogpt/issue/SECRT-1876)

---------

Co-authored-by: Lluis Agusti <hi@llu.lu>
2026-02-04 17:38:58 +07:00
Swifty
09cb313211 fix(frontend): Prevent reflected XSS in OAuth callback route (#11963)
## Summary

Fixes a reflected cross-site scripting (XSS) vulnerability in the OAuth
callback route.

**Security Issue:**
https://github.com/Significant-Gravitas/AutoGPT/security/code-scanning/202

### Vulnerability

The OAuth callback route at
`frontend/src/app/(platform)/auth/integrations/oauth_callback/route.ts`
was writing user-controlled data directly into an HTML response without
proper sanitization. This allowed potential attackers to inject
malicious scripts via OAuth callback parameters.

### Fix

Added a `safeJsonStringify()` function that escapes characters that
could break out of the script context:
- `<` → `\u003c`
- `>` → `\u003e`  
- `&` → `\u0026`

This prevents any user-provided values from being interpreted as
HTML/script content when embedded in the response.

### References

- [OWASP XSS Prevention Cheat
Sheet](https://cheatsheetseries.owasp.org/cheatsheets/Cross_Site_Scripting_Prevention_Cheat_Sheet.html)
- [CWE-79: Improper Neutralization of Input During Web Page
Generation](https://cwe.mitre.org/data/definitions/79.html)

## Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verified the OAuth callback still functions correctly
- [x] Confirmed special characters in OAuth responses are properly
escaped
2026-02-04 10:53:17 +01:00
Krzysztof Czerwinski
c026485023 feat(frontend): Disable auto-opening wallet (#11961)
<!-- Clearly explain the need for these changes: -->

### Changes 🏗️

- Disable auto-opening Wallet for first time user and on credit increase
- Remove no longer needed `lastSeenCredits` state and storage

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Wallet doesn't open automatically
2026-02-04 06:11:41 +00:00
44 changed files with 885 additions and 689 deletions

View File

@@ -19,3 +19,6 @@ load-tests/*.json
load-tests/*.log load-tests/*.log
load-tests/node_modules/* load-tests/node_modules/*
migrations/*/rollback*.sql migrations/*/rollback*.sql
# Workspace files
workspaces/

View File

@@ -1,15 +1,12 @@
import asyncio import asyncio
import logging import logging
import time import time
import uuid as uuid_module
from asyncio import CancelledError from asyncio import CancelledError
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
import openai import openai
from backend.util.prompt import compress_context
if TYPE_CHECKING: if TYPE_CHECKING:
from backend.util.prompt import CompressResult from backend.util.prompt import CompressResult
@@ -36,7 +33,7 @@ from backend.data.understanding import (
get_business_understanding, get_business_understanding,
) )
from backend.util.exceptions import NotFoundError from backend.util.exceptions import NotFoundError
from backend.util.settings import Settings from backend.util.settings import AppEnvironment, Settings
from . import db as chat_db from . import db as chat_db
from . import stream_registry from . import stream_registry
@@ -225,8 +222,18 @@ async def _get_system_prompt_template(context: str) -> str:
try: try:
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt # cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
# Use asyncio.to_thread to avoid blocking the event loop # Use asyncio.to_thread to avoid blocking the event loop
# In non-production environments, fetch the latest prompt version
# instead of the production-labeled version for easier testing
label = (
None
if settings.config.app_env == AppEnvironment.PRODUCTION
else "latest"
)
prompt = await asyncio.to_thread( prompt = await asyncio.to_thread(
langfuse.get_prompt, config.langfuse_prompt_name, cache_ttl_seconds=0 langfuse.get_prompt,
config.langfuse_prompt_name,
label=label,
cache_ttl_seconds=0,
) )
return prompt.compile(users_information=context) return prompt.compile(users_information=context)
except Exception as e: except Exception as e:
@@ -470,6 +477,8 @@ async def stream_chat_completion(
should_retry = False should_retry = False
# Generate unique IDs for AI SDK protocol # Generate unique IDs for AI SDK protocol
import uuid as uuid_module
message_id = str(uuid_module.uuid4()) message_id = str(uuid_module.uuid4())
text_block_id = str(uuid_module.uuid4()) text_block_id = str(uuid_module.uuid4())
@@ -619,6 +628,9 @@ async def stream_chat_completion(
total_tokens=chunk.totalTokens, total_tokens=chunk.totalTokens,
) )
) )
elif isinstance(chunk, StreamHeartbeat):
# Pass through heartbeat to keep SSE connection alive
yield chunk
else: else:
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True) logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
@@ -827,6 +839,10 @@ async def _manage_context_window(
Returns: Returns:
CompressResult with compacted messages and metadata CompressResult with compacted messages and metadata
""" """
import openai
from backend.util.prompt import compress_context
# Convert messages to dict format # Convert messages to dict format
messages_dict = [] messages_dict = []
for msg in messages: for msg in messages:
@@ -1137,6 +1153,8 @@ async def _yield_tool_call(
KeyError: If expected tool call fields are missing KeyError: If expected tool call fields are missing
TypeError: If tool call structure is invalid TypeError: If tool call structure is invalid
""" """
import uuid as uuid_module
tool_name = tool_calls[yield_idx]["function"]["name"] tool_name = tool_calls[yield_idx]["function"]["name"]
tool_call_id = tool_calls[yield_idx]["id"] tool_call_id = tool_calls[yield_idx]["id"]
@@ -1757,6 +1775,8 @@ async def _generate_llm_continuation_with_streaming(
after a tool result is saved. Chunks are published to the stream registry after a tool result is saved. Chunks are published to the stream registry
so reconnecting clients can receive them. so reconnecting clients can receive them.
""" """
import uuid as uuid_module
try: try:
# Load fresh session from DB (bypass cache to get the updated tool result) # Load fresh session from DB (bypass cache to get the updated tool result)
await invalidate_session_cache(session_id) await invalidate_session_cache(session_id)
@@ -1792,6 +1812,10 @@ async def _generate_llm_continuation_with_streaming(
extra_body["session_id"] = session_id[:128] extra_body["session_id"] = session_id[:128]
# Make streaming LLM call (no tools - just text response) # Make streaming LLM call (no tools - just text response)
from typing import cast
from openai.types.chat import ChatCompletionMessageParam
# Generate unique IDs for AI SDK protocol # Generate unique IDs for AI SDK protocol
message_id = str(uuid_module.uuid4()) message_id = str(uuid_module.uuid4())
text_block_id = str(uuid_module.uuid4()) text_block_id = str(uuid_module.uuid4())

View File

@@ -7,15 +7,7 @@ from typing import Any, NotRequired, TypedDict
from backend.api.features.library import db as library_db from backend.api.features.library import db as library_db
from backend.api.features.store import db as store_db from backend.api.features.store import db as store_db
from backend.data.graph import ( from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
Graph,
Link,
Node,
create_graph,
get_graph,
get_graph_all_versions,
get_store_listed_graphs,
)
from backend.util.exceptions import DatabaseError, NotFoundError from backend.util.exceptions import DatabaseError, NotFoundError
from .service import ( from .service import (
@@ -28,8 +20,6 @@ from .service import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
class ExecutionSummary(TypedDict): class ExecutionSummary(TypedDict):
"""Summary of a single execution for quality assessment.""" """Summary of a single execution for quality assessment."""
@@ -669,45 +659,6 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
) )
def _reassign_node_ids(graph: Graph) -> None:
"""Reassign all node and link IDs to new UUIDs.
This is needed when creating a new version to avoid unique constraint violations.
"""
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
for node in graph.nodes:
node.id = id_map[node.id]
for link in graph.links:
link.id = str(uuid.uuid4())
if link.source_id in id_map:
link.source_id = id_map[link.source_id]
if link.sink_id in id_map:
link.sink_id = id_map[link.sink_id]
def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None:
"""Populate user_id in AgentExecutorBlock nodes.
The external agent generator creates AgentExecutorBlock nodes with empty user_id.
This function fills in the actual user_id so sub-agents run with correct permissions.
Args:
agent_json: Agent JSON dict (modified in place)
user_id: User ID to set
"""
for node in agent_json.get("nodes", []):
if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID:
input_default = node.get("input_default") or {}
if not input_default.get("user_id"):
input_default["user_id"] = user_id
node["input_default"] = input_default
logger.debug(
f"Set user_id for AgentExecutorBlock node {node.get('id')}"
)
async def save_agent_to_library( async def save_agent_to_library(
agent_json: dict[str, Any], user_id: str, is_update: bool = False agent_json: dict[str, Any], user_id: str, is_update: bool = False
) -> tuple[Graph, Any]: ) -> tuple[Graph, Any]:
@@ -721,35 +672,10 @@ async def save_agent_to_library(
Returns: Returns:
Tuple of (created Graph, LibraryAgent) Tuple of (created Graph, LibraryAgent)
""" """
# Populate user_id in AgentExecutorBlock nodes before conversion
_populate_agent_executor_user_ids(agent_json, user_id)
graph = json_to_graph(agent_json) graph = json_to_graph(agent_json)
if is_update: if is_update:
if graph.id: return await library_db.update_graph_in_library(graph, user_id)
existing_versions = await get_graph_all_versions(graph.id, user_id) return await library_db.create_graph_in_library(graph, user_id)
if existing_versions:
latest_version = max(v.version for v in existing_versions)
graph.version = latest_version + 1
_reassign_node_ids(graph)
logger.info(f"Updating agent {graph.id} to version {graph.version}")
else:
graph.id = str(uuid.uuid4())
graph.version = 1
_reassign_node_ids(graph)
logger.info(f"Creating new agent with ID {graph.id}")
created_graph = await create_graph(graph, user_id)
library_agents = await library_db.create_library_agent(
graph=created_graph,
user_id=user_id,
sensitive_action_safe_mode=True,
create_library_agents_for_sub_graphs=False,
)
return created_graph, library_agents[0]
def graph_to_json(graph: Graph) -> dict[str, Any]: def graph_to_json(graph: Graph) -> dict[str, Any]:

View File

@@ -206,9 +206,9 @@ async def search_agents(
] ]
) )
no_results_msg = ( no_results_msg = (
f"No agents found matching '{query}'. Try different keywords or browse the marketplace." f"No agents found matching '{query}'. Let the user know they can try different keywords or browse the marketplace. Also let them know you can create a custom agent for them based on their needs."
if source == "marketplace" if source == "marketplace"
else f"No agents matching '{query}' found in your library." else f"No agents matching '{query}' found in your library. Let the user know you can create a custom agent for them based on their needs."
) )
return NoResultsResponse( return NoResultsResponse(
message=no_results_msg, session_id=session_id, suggestions=suggestions message=no_results_msg, session_id=session_id, suggestions=suggestions
@@ -224,10 +224,10 @@ async def search_agents(
message = ( message = (
"Now you have found some options for the user to choose from. " "Now you have found some options for the user to choose from. "
"You can add a link to a recommended agent at: /marketplace/agent/agent_id " "You can add a link to a recommended agent at: /marketplace/agent/agent_id "
"Please ask the user if they would like to use any of these agents." "Please ask the user if they would like to use any of these agents. Let the user know we can create a custom agent for them based on their needs."
if source == "marketplace" if source == "marketplace"
else "Found agents in the user's library. You can provide a link to view an agent at: " else "Found agents in the user's library. You can provide a link to view an agent at: "
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute." "/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute. Let the user know we can create a custom agent for them based on their needs."
) )
return AgentsFoundResponse( return AgentsFoundResponse(

View File

@@ -3,8 +3,6 @@
import logging import logging
from typing import Any from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from .agent_generator import ( from .agent_generator import (
@@ -30,26 +28,6 @@ from .models import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CreateAgentInput(BaseModel):
"""Input parameters for the create_agent tool."""
description: str = ""
context: str = ""
save: bool = True
# Internal async processing params (passed by long-running tool handler)
_operation_id: str | None = None
_task_id: str | None = None
@field_validator("description", "context", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str:
"""Strip whitespace from string fields."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class Config:
extra = "allow" # Allow _operation_id, _task_id from kwargs
class CreateAgentTool(BaseTool): class CreateAgentTool(BaseTool):
"""Tool for creating agents from natural language descriptions.""" """Tool for creating agents from natural language descriptions."""
@@ -107,7 +85,7 @@ class CreateAgentTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs: Any, **kwargs,
) -> ToolResponseBase: ) -> ToolResponseBase:
"""Execute the create_agent tool. """Execute the create_agent tool.
@@ -116,14 +94,16 @@ class CreateAgentTool(BaseTool):
2. Generate agent JSON (external service handles fixing and validation) 2. Generate agent JSON (external service handles fixing and validation)
3. Preview or save based on the save parameter 3. Preview or save based on the save parameter
""" """
params = CreateAgentInput(**kwargs) description = kwargs.get("description", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
session_id = session.session_id if session else None session_id = session.session_id if session else None
# Extract async processing params # Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id") operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id") task_id = kwargs.get("_task_id")
if not params.description: if not description:
return ErrorResponse( return ErrorResponse(
message="Please provide a description of what the agent should do.", message="Please provide a description of what the agent should do.",
error="Missing description parameter", error="Missing description parameter",
@@ -135,7 +115,7 @@ class CreateAgentTool(BaseTool):
try: try:
library_agents = await get_all_relevant_agents_for_generation( library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id, user_id=user_id,
search_query=params.description, search_query=description,
include_marketplace=True, include_marketplace=True,
) )
logger.debug( logger.debug(
@@ -146,7 +126,7 @@ class CreateAgentTool(BaseTool):
try: try:
decomposition_result = await decompose_goal( decomposition_result = await decompose_goal(
params.description, params.context, library_agents description, context, library_agents
) )
except AgentGeneratorNotConfiguredError: except AgentGeneratorNotConfiguredError:
return ErrorResponse( return ErrorResponse(
@@ -162,7 +142,7 @@ class CreateAgentTool(BaseTool):
return ErrorResponse( return ErrorResponse(
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.", message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
error="decomposition_failed", error="decomposition_failed",
details={"description": params.description[:100]}, details={"description": description[:100]},
session_id=session_id, session_id=session_id,
) )
@@ -178,7 +158,7 @@ class CreateAgentTool(BaseTool):
message=user_message, message=user_message,
error=f"decomposition_failed:{error_type}", error=f"decomposition_failed:{error_type}",
details={ details={
"description": params.description[:100], "description": description[:100],
"service_error": error_msg, "service_error": error_msg,
"error_type": error_type, "error_type": error_type,
}, },
@@ -264,7 +244,7 @@ class CreateAgentTool(BaseTool):
return ErrorResponse( return ErrorResponse(
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.", message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
error="generation_failed", error="generation_failed",
details={"description": params.description[:100]}, details={"description": description[:100]},
session_id=session_id, session_id=session_id,
) )
@@ -286,7 +266,7 @@ class CreateAgentTool(BaseTool):
message=user_message, message=user_message,
error=f"generation_failed:{error_type}", error=f"generation_failed:{error_type}",
details={ details={
"description": params.description[:100], "description": description[:100],
"service_error": error_msg, "service_error": error_msg,
"error_type": error_type, "error_type": error_type,
}, },
@@ -311,7 +291,7 @@ class CreateAgentTool(BaseTool):
node_count = len(agent_json.get("nodes", [])) node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", [])) link_count = len(agent_json.get("links", []))
if not params.save: if not save:
return AgentPreviewResponse( return AgentPreviewResponse(
message=( message=(
f"I've generated an agent called '{agent_name}' with {node_count} blocks. " f"I've generated an agent called '{agent_name}' with {node_count} blocks. "

View File

@@ -3,8 +3,6 @@
import logging import logging
from typing import Any from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db from backend.api.features.store import db as store_db
from backend.api.features.store.exceptions import AgentNotFoundError from backend.api.features.store.exceptions import AgentNotFoundError
@@ -29,23 +27,6 @@ from .models import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CustomizeAgentInput(BaseModel):
"""Input parameters for the customize_agent tool."""
agent_id: str = ""
modifications: str = ""
context: str = ""
save: bool = True
@field_validator("agent_id", "modifications", "context", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str:
"""Strip whitespace from string fields."""
if isinstance(v, str):
return v.strip()
return v if v is not None else ""
class CustomizeAgentTool(BaseTool): class CustomizeAgentTool(BaseTool):
"""Tool for customizing marketplace/template agents using natural language.""" """Tool for customizing marketplace/template agents using natural language."""
@@ -111,7 +92,7 @@ class CustomizeAgentTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs: Any, **kwargs,
) -> ToolResponseBase: ) -> ToolResponseBase:
"""Execute the customize_agent tool. """Execute the customize_agent tool.
@@ -121,17 +102,20 @@ class CustomizeAgentTool(BaseTool):
3. Call customize_template with the modification request 3. Call customize_template with the modification request
4. Preview or save based on the save parameter 4. Preview or save based on the save parameter
""" """
params = CustomizeAgentInput(**kwargs) agent_id = kwargs.get("agent_id", "").strip()
modifications = kwargs.get("modifications", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
session_id = session.session_id if session else None session_id = session.session_id if session else None
if not params.agent_id: if not agent_id:
return ErrorResponse( return ErrorResponse(
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').", message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
error="missing_agent_id", error="missing_agent_id",
session_id=session_id, session_id=session_id,
) )
if not params.modifications: if not modifications:
return ErrorResponse( return ErrorResponse(
message="Please describe how you want to customize this agent.", message="Please describe how you want to customize this agent.",
error="missing_modifications", error="missing_modifications",
@@ -139,11 +123,11 @@ class CustomizeAgentTool(BaseTool):
) )
# Parse agent_id in format "creator/slug" # Parse agent_id in format "creator/slug"
parts = [p.strip() for p in params.agent_id.split("/")] parts = [p.strip() for p in agent_id.split("/")]
if len(parts) != 2 or not parts[0] or not parts[1]: if len(parts) != 2 or not parts[0] or not parts[1]:
return ErrorResponse( return ErrorResponse(
message=( message=(
f"Invalid agent ID format: '{params.agent_id}'. " f"Invalid agent ID format: '{agent_id}'. "
"Expected format is 'creator/agent-name' " "Expected format is 'creator/agent-name' "
"(e.g., 'autogpt/newsletter-writer')." "(e.g., 'autogpt/newsletter-writer')."
), ),
@@ -161,14 +145,14 @@ class CustomizeAgentTool(BaseTool):
except AgentNotFoundError: except AgentNotFoundError:
return ErrorResponse( return ErrorResponse(
message=( message=(
f"Could not find marketplace agent '{params.agent_id}'. " f"Could not find marketplace agent '{agent_id}'. "
"Please check the agent ID and try again." "Please check the agent ID and try again."
), ),
error="agent_not_found", error="agent_not_found",
session_id=session_id, session_id=session_id,
) )
except Exception as e: except Exception as e:
logger.error(f"Error fetching marketplace agent {params.agent_id}: {e}") logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
return ErrorResponse( return ErrorResponse(
message="Failed to fetch the marketplace agent. Please try again.", message="Failed to fetch the marketplace agent. Please try again.",
error="fetch_error", error="fetch_error",
@@ -178,7 +162,7 @@ class CustomizeAgentTool(BaseTool):
if not agent_details.store_listing_version_id: if not agent_details.store_listing_version_id:
return ErrorResponse( return ErrorResponse(
message=( message=(
f"The agent '{params.agent_id}' does not have an available version. " f"The agent '{agent_id}' does not have an available version. "
"Please try a different agent." "Please try a different agent."
), ),
error="no_version_available", error="no_version_available",
@@ -190,7 +174,7 @@ class CustomizeAgentTool(BaseTool):
graph = await store_db.get_agent(agent_details.store_listing_version_id) graph = await store_db.get_agent(agent_details.store_listing_version_id)
template_agent = graph_to_json(graph) template_agent = graph_to_json(graph)
except Exception as e: except Exception as e:
logger.error(f"Error fetching agent graph for {params.agent_id}: {e}") logger.error(f"Error fetching agent graph for {agent_id}: {e}")
return ErrorResponse( return ErrorResponse(
message="Failed to fetch the agent configuration. Please try again.", message="Failed to fetch the agent configuration. Please try again.",
error="graph_fetch_error", error="graph_fetch_error",
@@ -201,8 +185,8 @@ class CustomizeAgentTool(BaseTool):
try: try:
result = await customize_template( result = await customize_template(
template_agent=template_agent, template_agent=template_agent,
modification_request=params.modifications, modification_request=modifications,
context=params.context, context=context,
) )
except AgentGeneratorNotConfiguredError: except AgentGeneratorNotConfiguredError:
return ErrorResponse( return ErrorResponse(
@@ -214,7 +198,7 @@ class CustomizeAgentTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
except Exception as e: except Exception as e:
logger.error(f"Error calling customize_template for {params.agent_id}: {e}") logger.error(f"Error calling customize_template for {agent_id}: {e}")
return ErrorResponse( return ErrorResponse(
message=( message=(
"Failed to customize the agent due to a service error. " "Failed to customize the agent due to a service error. "
@@ -235,25 +219,55 @@ class CustomizeAgentTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
# Handle response using match/case for cleaner pattern matching # Handle error response
return await self._handle_customization_result( if isinstance(result, dict) and result.get("type") == "error":
result=result, error_msg = result.get("error", "Unknown error")
params=params, error_type = result.get("error_type", "unknown")
agent_details=agent_details, user_message = get_user_message_for_error(
user_id=user_id, error_type,
session_id=session_id, operation="customize the agent",
) llm_parse_message=(
"The AI had trouble customizing the agent. "
"Please try again or simplify your request."
),
validation_message=(
"The customized agent failed validation. "
"Please try rephrasing your request."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"customization_failed:{error_type}",
session_id=session_id,
)
async def _handle_customization_result( # Handle clarifying questions
self, if isinstance(result, dict) and result.get("type") == "clarifying_questions":
result: dict[str, Any], questions = result.get("questions") or []
params: CustomizeAgentInput, if not isinstance(questions, list):
agent_details: Any, logger.error(
user_id: str | None, f"Unexpected clarifying questions format: {type(questions)}"
session_id: str | None, )
) -> ToolResponseBase: questions = []
"""Handle the result from customize_template using pattern matching.""" return ClarificationNeededResponse(
# Ensure result is a dict message=(
"I need some more information to customize this agent. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
if isinstance(q, dict)
],
session_id=session_id,
)
# Result should be the customized agent JSON
if not isinstance(result, dict): if not isinstance(result, dict):
logger.error(f"Unexpected customize_template response type: {type(result)}") logger.error(f"Unexpected customize_template response type: {type(result)}")
return ErrorResponse( return ErrorResponse(
@@ -262,77 +276,8 @@ class CustomizeAgentTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
result_type = result.get("type") customized_agent = result
match result_type:
case "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="customize the agent",
llm_parse_message=(
"The AI had trouble customizing the agent. "
"Please try again or simplify your request."
),
validation_message=(
"The customized agent failed validation. "
"Please try rephrasing your request."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"customization_failed:{error_type}",
session_id=session_id,
)
case "clarifying_questions":
questions_data = result.get("questions") or []
if not isinstance(questions_data, list):
logger.error(
f"Unexpected clarifying questions format: {type(questions_data)}"
)
questions_data = []
questions = [
ClarifyingQuestion(
question=q.get("question", "") if isinstance(q, dict) else "",
keyword=q.get("keyword", "") if isinstance(q, dict) else "",
example=q.get("example") if isinstance(q, dict) else None,
)
for q in questions_data
if isinstance(q, dict)
]
return ClarificationNeededResponse(
message=(
"I need some more information to customize this agent. "
"Please answer the following questions:"
),
questions=questions,
session_id=session_id,
)
case _:
# Default case: result is the customized agent JSON
return await self._save_or_preview_agent(
customized_agent=result,
params=params,
agent_details=agent_details,
user_id=user_id,
session_id=session_id,
)
async def _save_or_preview_agent(
self,
customized_agent: dict[str, Any],
params: CustomizeAgentInput,
agent_details: Any,
user_id: str | None,
session_id: str | None,
) -> ToolResponseBase:
"""Save or preview the customized agent based on params.save."""
agent_name = customized_agent.get( agent_name = customized_agent.get(
"name", f"Customized {agent_details.agent_name}" "name", f"Customized {agent_details.agent_name}"
) )
@@ -342,7 +287,7 @@ class CustomizeAgentTool(BaseTool):
node_count = len(nodes) if isinstance(nodes, list) else 0 node_count = len(nodes) if isinstance(nodes, list) else 0
link_count = len(links) if isinstance(links, list) else 0 link_count = len(links) if isinstance(links, list) else 0
if not params.save: if not save:
return AgentPreviewResponse( return AgentPreviewResponse(
message=( message=(
f"I've customized the agent '{agent_details.agent_name}'. " f"I've customized the agent '{agent_details.agent_name}'. "

View File

@@ -3,8 +3,6 @@
import logging import logging
from typing import Any from typing import Any
from pydantic import BaseModel, ConfigDict, field_validator
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from .agent_generator import ( from .agent_generator import (
@@ -29,20 +27,6 @@ from .models import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EditAgentInput(BaseModel):
model_config = ConfigDict(extra="allow")
agent_id: str = ""
changes: str = ""
context: str = ""
save: bool = True
@field_validator("agent_id", "changes", "context", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str:
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class EditAgentTool(BaseTool): class EditAgentTool(BaseTool):
"""Tool for editing existing agents using natural language.""" """Tool for editing existing agents using natural language."""
@@ -106,7 +90,7 @@ class EditAgentTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs: Any, **kwargs,
) -> ToolResponseBase: ) -> ToolResponseBase:
"""Execute the edit_agent tool. """Execute the edit_agent tool.
@@ -115,32 +99,35 @@ class EditAgentTool(BaseTool):
2. Generate updated agent (external service handles fixing and validation) 2. Generate updated agent (external service handles fixing and validation)
3. Preview or save based on the save parameter 3. Preview or save based on the save parameter
""" """
params = EditAgentInput(**kwargs) agent_id = kwargs.get("agent_id", "").strip()
changes = kwargs.get("changes", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
session_id = session.session_id if session else None session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler) # Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id") operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id") task_id = kwargs.get("_task_id")
if not params.agent_id: if not agent_id:
return ErrorResponse( return ErrorResponse(
message="Please provide the agent ID to edit.", message="Please provide the agent ID to edit.",
error="Missing agent_id parameter", error="Missing agent_id parameter",
session_id=session_id, session_id=session_id,
) )
if not params.changes: if not changes:
return ErrorResponse( return ErrorResponse(
message="Please describe what changes you want to make.", message="Please describe what changes you want to make.",
error="Missing changes parameter", error="Missing changes parameter",
session_id=session_id, session_id=session_id,
) )
current_agent = await get_agent_as_json(params.agent_id, user_id) current_agent = await get_agent_as_json(agent_id, user_id)
if current_agent is None: if current_agent is None:
return ErrorResponse( return ErrorResponse(
message=f"Could not find agent '{params.agent_id}' in your library.", message=f"Could not find agent with ID '{agent_id}' in your library.",
error="agent_not_found", error="agent_not_found",
session_id=session_id, session_id=session_id,
) )
@@ -151,7 +138,7 @@ class EditAgentTool(BaseTool):
graph_id = current_agent.get("id") graph_id = current_agent.get("id")
library_agents = await get_all_relevant_agents_for_generation( library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id, user_id=user_id,
search_query=params.changes, search_query=changes,
exclude_graph_id=graph_id, exclude_graph_id=graph_id,
include_marketplace=True, include_marketplace=True,
) )
@@ -161,11 +148,9 @@ class EditAgentTool(BaseTool):
except Exception as e: except Exception as e:
logger.warning(f"Failed to fetch library agents: {e}") logger.warning(f"Failed to fetch library agents: {e}")
update_request = params.changes update_request = changes
if params.context: if context:
update_request = ( update_request = f"{changes}\n\nAdditional context:\n{context}"
f"{params.changes}\n\nAdditional context:\n{params.context}"
)
try: try:
result = await generate_agent_patch( result = await generate_agent_patch(
@@ -189,7 +174,7 @@ class EditAgentTool(BaseTool):
return ErrorResponse( return ErrorResponse(
message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.", message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.",
error="update_generation_failed", error="update_generation_failed",
details={"agent_id": params.agent_id, "changes": params.changes[:100]}, details={"agent_id": agent_id, "changes": changes[:100]},
session_id=session_id, session_id=session_id,
) )
@@ -221,8 +206,8 @@ class EditAgentTool(BaseTool):
message=user_message, message=user_message,
error=f"update_generation_failed:{error_type}", error=f"update_generation_failed:{error_type}",
details={ details={
"agent_id": params.agent_id, "agent_id": agent_id,
"changes": params.changes[:100], "changes": changes[:100],
"service_error": error_msg, "service_error": error_msg,
"error_type": error_type, "error_type": error_type,
}, },
@@ -254,7 +239,7 @@ class EditAgentTool(BaseTool):
node_count = len(updated_agent.get("nodes", [])) node_count = len(updated_agent.get("nodes", []))
link_count = len(updated_agent.get("links", [])) link_count = len(updated_agent.get("links", []))
if not params.save: if not save:
return AgentPreviewResponse( return AgentPreviewResponse(
message=( message=(
f"I've updated the agent. " f"I've updated the agent. "

View File

@@ -2,8 +2,6 @@
from typing import Any from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents from .agent_search import search_agents
@@ -11,18 +9,6 @@ from .base import BaseTool
from .models import ToolResponseBase from .models import ToolResponseBase
class FindAgentInput(BaseModel):
"""Input parameters for the find_agent tool."""
query: str = ""
@field_validator("query", mode="before")
@classmethod
def strip_string(cls, v: Any) -> str:
"""Strip whitespace from query."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class FindAgentTool(BaseTool): class FindAgentTool(BaseTool):
"""Tool for discovering agents from the marketplace.""" """Tool for discovering agents from the marketplace."""
@@ -50,11 +36,10 @@ class FindAgentTool(BaseTool):
} }
async def _execute( async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs: Any self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase: ) -> ToolResponseBase:
params = FindAgentInput(**kwargs)
return await search_agents( return await search_agents(
query=params.query, query=kwargs.get("query", "").strip(),
source="marketplace", source="marketplace",
session_id=session.session_id, session_id=session.session_id,
user_id=user_id, user_id=user_id,

View File

@@ -2,7 +2,6 @@ import logging
from typing import Any from typing import Any
from prisma.enums import ContentType from prisma.enums import ContentType
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
@@ -19,18 +18,6 @@ from backend.data.block import get_block
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FindBlockInput(BaseModel):
"""Input parameters for the find_block tool."""
query: str = ""
@field_validator("query", mode="before")
@classmethod
def strip_string(cls, v: Any) -> str:
"""Strip whitespace from query."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class FindBlockTool(BaseTool): class FindBlockTool(BaseTool):
"""Tool for searching available blocks.""" """Tool for searching available blocks."""
@@ -72,24 +59,24 @@ class FindBlockTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs: Any, **kwargs,
) -> ToolResponseBase: ) -> ToolResponseBase:
"""Search for blocks matching the query. """Search for blocks matching the query.
Args: Args:
user_id: User ID (required) user_id: User ID (required)
session: Chat session session: Chat session
**kwargs: Tool parameters query: Search query
Returns: Returns:
BlockListResponse: List of matching blocks BlockListResponse: List of matching blocks
NoResultsResponse: No blocks found NoResultsResponse: No blocks found
ErrorResponse: Error message ErrorResponse: Error message
""" """
params = FindBlockInput(**kwargs) query = kwargs.get("query", "").strip()
session_id = session.session_id session_id = session.session_id
if not params.query: if not query:
return ErrorResponse( return ErrorResponse(
message="Please provide a search query", message="Please provide a search query",
session_id=session_id, session_id=session_id,
@@ -98,7 +85,7 @@ class FindBlockTool(BaseTool):
try: try:
# Search for blocks using hybrid search # Search for blocks using hybrid search
results, total = await unified_hybrid_search( results, total = await unified_hybrid_search(
query=params.query, query=query,
content_types=[ContentType.BLOCK], content_types=[ContentType.BLOCK],
page=1, page=1,
page_size=10, page_size=10,
@@ -106,7 +93,7 @@ class FindBlockTool(BaseTool):
if not results: if not results:
return NoResultsResponse( return NoResultsResponse(
message=f"No blocks found for '{params.query}'", message=f"No blocks found for '{query}'",
suggestions=[ suggestions=[
"Try broader keywords like 'email', 'http', 'text', 'ai'", "Try broader keywords like 'email', 'http', 'text', 'ai'",
"Check spelling of technical terms", "Check spelling of technical terms",
@@ -178,7 +165,7 @@ class FindBlockTool(BaseTool):
if not blocks: if not blocks:
return NoResultsResponse( return NoResultsResponse(
message=f"No blocks found for '{params.query}'", message=f"No blocks found for '{query}'",
suggestions=[ suggestions=[
"Try broader keywords like 'email', 'http', 'text', 'ai'", "Try broader keywords like 'email', 'http', 'text', 'ai'",
], ],
@@ -187,13 +174,13 @@ class FindBlockTool(BaseTool):
return BlockListResponse( return BlockListResponse(
message=( message=(
f"Found {len(blocks)} block(s) matching '{params.query}'. " f"Found {len(blocks)} block(s) matching '{query}'. "
"To execute a block, use run_block with the block's 'id' field " "To execute a block, use run_block with the block's 'id' field "
"and provide 'input_data' matching the block's input_schema." "and provide 'input_data' matching the block's input_schema."
), ),
blocks=blocks, blocks=blocks,
count=len(blocks), count=len(blocks),
query=params.query, query=query,
session_id=session_id, session_id=session_id,
) )

View File

@@ -2,8 +2,6 @@
from typing import Any from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents from .agent_search import search_agents
@@ -11,15 +9,6 @@ from .base import BaseTool
from .models import ToolResponseBase from .models import ToolResponseBase
class FindLibraryAgentInput(BaseModel):
query: str = ""
@field_validator("query", mode="before")
@classmethod
def strip_string(cls, v: Any) -> str:
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class FindLibraryAgentTool(BaseTool): class FindLibraryAgentTool(BaseTool):
"""Tool for searching agents in the user's library.""" """Tool for searching agents in the user's library."""
@@ -53,11 +42,10 @@ class FindLibraryAgentTool(BaseTool):
return True return True
async def _execute( async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs: Any self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase: ) -> ToolResponseBase:
params = FindLibraryAgentInput(**kwargs)
return await search_agents( return await search_agents(
query=params.query, query=kwargs.get("query", "").strip(),
source="library", source="library",
session_id=session.session_id, session_id=session.session_id,
user_id=user_id, user_id=user_id,

View File

@@ -4,8 +4,6 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import ( from backend.api.features.chat.tools.models import (
@@ -20,18 +18,6 @@ logger = logging.getLogger(__name__)
DOCS_BASE_URL = "https://docs.agpt.co" DOCS_BASE_URL = "https://docs.agpt.co"
class GetDocPageInput(BaseModel):
"""Input parameters for the get_doc_page tool."""
path: str = ""
@field_validator("path", mode="before")
@classmethod
def strip_string(cls, v: Any) -> str:
"""Strip whitespace from path."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class GetDocPageTool(BaseTool): class GetDocPageTool(BaseTool):
"""Tool for fetching full content of a documentation page.""" """Tool for fetching full content of a documentation page."""
@@ -89,23 +75,23 @@ class GetDocPageTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs: Any, **kwargs,
) -> ToolResponseBase: ) -> ToolResponseBase:
"""Fetch full content of a documentation page. """Fetch full content of a documentation page.
Args: Args:
user_id: User ID (not required for docs) user_id: User ID (not required for docs)
session: Chat session session: Chat session
**kwargs: Tool parameters path: Path to the documentation file
Returns: Returns:
DocPageResponse: Full document content DocPageResponse: Full document content
ErrorResponse: Error message ErrorResponse: Error message
""" """
params = GetDocPageInput(**kwargs) path = kwargs.get("path", "").strip()
session_id = session.session_id if session else None session_id = session.session_id if session else None
if not params.path: if not path:
return ErrorResponse( return ErrorResponse(
message="Please provide a documentation path.", message="Please provide a documentation path.",
error="Missing path parameter", error="Missing path parameter",
@@ -113,7 +99,7 @@ class GetDocPageTool(BaseTool):
) )
# Sanitize path to prevent directory traversal # Sanitize path to prevent directory traversal
if ".." in params.path or params.path.startswith("/"): if ".." in path or path.startswith("/"):
return ErrorResponse( return ErrorResponse(
message="Invalid documentation path.", message="Invalid documentation path.",
error="invalid_path", error="invalid_path",
@@ -121,11 +107,11 @@ class GetDocPageTool(BaseTool):
) )
docs_root = self._get_docs_root() docs_root = self._get_docs_root()
full_path = docs_root / params.path full_path = docs_root / path
if not full_path.exists(): if not full_path.exists():
return ErrorResponse( return ErrorResponse(
message=f"Documentation page not found: {params.path}", message=f"Documentation page not found: {path}",
error="not_found", error="not_found",
session_id=session_id, session_id=session_id,
) )
@@ -142,19 +128,19 @@ class GetDocPageTool(BaseTool):
try: try:
content = full_path.read_text(encoding="utf-8") content = full_path.read_text(encoding="utf-8")
title = self._extract_title(content, params.path) title = self._extract_title(content, path)
return DocPageResponse( return DocPageResponse(
message=f"Retrieved documentation page: {title}", message=f"Retrieved documentation page: {title}",
title=title, title=title,
path=params.path, path=path,
content=content, content=content,
doc_url=self._make_doc_url(params.path), doc_url=self._make_doc_url(path),
session_id=session_id, session_id=session_id,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to read documentation page {params.path}: {e}") logger.error(f"Failed to read documentation page {path}: {e}")
return ErrorResponse( return ErrorResponse(
message=f"Failed to read documentation page: {str(e)}", message=f"Failed to read documentation page: {str(e)}",
error="read_failed", error="read_failed",

View File

@@ -5,7 +5,6 @@ import uuid
from collections import defaultdict from collections import defaultdict
from typing import Any from typing import Any
from pydantic import BaseModel, field_validator
from pydantic_core import PydanticUndefined from pydantic_core import PydanticUndefined
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
@@ -30,25 +29,6 @@ from .utils import build_missing_credentials_from_field_info
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RunBlockInput(BaseModel):
"""Input parameters for the run_block tool."""
block_id: str = ""
input_data: dict[str, Any] = {}
@field_validator("block_id", mode="before")
@classmethod
def strip_block_id(cls, v: Any) -> str:
"""Strip whitespace from block_id."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
@field_validator("input_data", mode="before")
@classmethod
def ensure_dict(cls, v: Any) -> dict[str, Any]:
"""Ensure input_data is a dict."""
return v if isinstance(v, dict) else {}
class RunBlockTool(BaseTool): class RunBlockTool(BaseTool):
"""Tool for executing a block and returning its outputs.""" """Tool for executing a block and returning its outputs."""
@@ -182,29 +162,37 @@ class RunBlockTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs: Any, **kwargs,
) -> ToolResponseBase: ) -> ToolResponseBase:
"""Execute a block with the given input data. """Execute a block with the given input data.
Args: Args:
user_id: User ID (required) user_id: User ID (required)
session: Chat session session: Chat session
**kwargs: Tool parameters block_id: Block UUID to execute
input_data: Input values for the block
Returns: Returns:
BlockOutputResponse: Block execution outputs BlockOutputResponse: Block execution outputs
SetupRequirementsResponse: Missing credentials SetupRequirementsResponse: Missing credentials
ErrorResponse: Error message ErrorResponse: Error message
""" """
params = RunBlockInput(**kwargs) block_id = kwargs.get("block_id", "").strip()
input_data = kwargs.get("input_data", {})
session_id = session.session_id session_id = session.session_id
if not params.block_id: if not block_id:
return ErrorResponse( return ErrorResponse(
message="Please provide a block_id", message="Please provide a block_id",
session_id=session_id, session_id=session_id,
) )
if not isinstance(input_data, dict):
return ErrorResponse(
message="input_data must be an object",
session_id=session_id,
)
if not user_id: if not user_id:
return ErrorResponse( return ErrorResponse(
message="Authentication required", message="Authentication required",
@@ -212,25 +200,23 @@ class RunBlockTool(BaseTool):
) )
# Get the block # Get the block
block = get_block(params.block_id) block = get_block(block_id)
if not block: if not block:
return ErrorResponse( return ErrorResponse(
message=f"Block '{params.block_id}' not found", message=f"Block '{block_id}' not found",
session_id=session_id, session_id=session_id,
) )
if block.disabled: if block.disabled:
return ErrorResponse( return ErrorResponse(
message=f"Block '{params.block_id}' is disabled", message=f"Block '{block_id}' is disabled",
session_id=session_id, session_id=session_id,
) )
logger.info( logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
f"Executing block {block.name} ({params.block_id}) for user {user_id}"
)
creds_manager = IntegrationCredentialsManager() creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = await self._check_block_credentials( matched_credentials, missing_credentials = await self._check_block_credentials(
user_id, block, params.input_data user_id, block, input_data
) )
if missing_credentials: if missing_credentials:
@@ -248,7 +234,7 @@ class RunBlockTool(BaseTool):
), ),
session_id=session_id, session_id=session_id,
setup_info=SetupInfo( setup_info=SetupInfo(
agent_id=params.block_id, agent_id=block_id,
agent_name=block.name, agent_name=block.name,
user_readiness=UserReadiness( user_readiness=UserReadiness(
has_all_credentials=False, has_all_credentials=False,
@@ -277,7 +263,7 @@ class RunBlockTool(BaseTool):
# - node_exec_id = unique per block execution # - node_exec_id = unique per block execution
synthetic_graph_id = f"copilot-session-{session.session_id}" synthetic_graph_id = f"copilot-session-{session.session_id}"
synthetic_graph_exec_id = f"copilot-session-{session.session_id}" synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
synthetic_node_id = f"copilot-node-{params.block_id}" synthetic_node_id = f"copilot-node-{block_id}"
synthetic_node_exec_id = ( synthetic_node_exec_id = (
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}" f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
) )
@@ -312,8 +298,8 @@ class RunBlockTool(BaseTool):
for field_name, cred_meta in matched_credentials.items(): for field_name, cred_meta in matched_credentials.items():
# Inject metadata into input_data (for validation) # Inject metadata into input_data (for validation)
if field_name not in params.input_data: if field_name not in input_data:
params.input_data[field_name] = cred_meta.model_dump() input_data[field_name] = cred_meta.model_dump()
# Fetch actual credentials and pass as kwargs (for execution) # Fetch actual credentials and pass as kwargs (for execution)
actual_credentials = await creds_manager.get( actual_credentials = await creds_manager.get(
@@ -330,14 +316,14 @@ class RunBlockTool(BaseTool):
# Execute the block and collect outputs # Execute the block and collect outputs
outputs: dict[str, list[Any]] = defaultdict(list) outputs: dict[str, list[Any]] = defaultdict(list)
async for output_name, output_data in block.execute( async for output_name, output_data in block.execute(
params.input_data, input_data,
**exec_kwargs, **exec_kwargs,
): ):
outputs[output_name].append(output_data) outputs[output_name].append(output_data)
return BlockOutputResponse( return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully", message=f"Block '{block.name}' executed successfully",
block_id=params.block_id, block_id=block_id,
block_name=block.name, block_name=block.name,
outputs=dict(outputs), outputs=dict(outputs),
success=True, success=True,

View File

@@ -4,7 +4,6 @@ import logging
from typing import Any from typing import Any
from prisma.enums import ContentType from prisma.enums import ContentType
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool from backend.api.features.chat.tools.base import BaseTool
@@ -29,18 +28,6 @@ MAX_RESULTS = 5
SNIPPET_LENGTH = 200 SNIPPET_LENGTH = 200
class SearchDocsInput(BaseModel):
"""Input parameters for the search_docs tool."""
query: str = ""
@field_validator("query", mode="before")
@classmethod
def strip_string(cls, v: Any) -> str:
"""Strip whitespace from query."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class SearchDocsTool(BaseTool): class SearchDocsTool(BaseTool):
"""Tool for searching AutoGPT platform documentation.""" """Tool for searching AutoGPT platform documentation."""
@@ -104,24 +91,24 @@ class SearchDocsTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs: Any, **kwargs,
) -> ToolResponseBase: ) -> ToolResponseBase:
"""Search documentation and return relevant sections. """Search documentation and return relevant sections.
Args: Args:
user_id: User ID (not required for docs) user_id: User ID (not required for docs)
session: Chat session session: Chat session
**kwargs: Tool parameters query: Search query
Returns: Returns:
DocSearchResultsResponse: List of matching documentation sections DocSearchResultsResponse: List of matching documentation sections
NoResultsResponse: No results found NoResultsResponse: No results found
ErrorResponse: Error message ErrorResponse: Error message
""" """
params = SearchDocsInput(**kwargs) query = kwargs.get("query", "").strip()
session_id = session.session_id if session else None session_id = session.session_id if session else None
if not params.query: if not query:
return ErrorResponse( return ErrorResponse(
message="Please provide a search query.", message="Please provide a search query.",
error="Missing query parameter", error="Missing query parameter",
@@ -131,7 +118,7 @@ class SearchDocsTool(BaseTool):
try: try:
# Search using hybrid search for DOCUMENTATION content type only # Search using hybrid search for DOCUMENTATION content type only
results, total = await unified_hybrid_search( results, total = await unified_hybrid_search(
query=params.query, query=query,
content_types=[ContentType.DOCUMENTATION], content_types=[ContentType.DOCUMENTATION],
page=1, page=1,
page_size=MAX_RESULTS * 2, # Fetch extra for deduplication page_size=MAX_RESULTS * 2, # Fetch extra for deduplication
@@ -140,7 +127,7 @@ class SearchDocsTool(BaseTool):
if not results: if not results:
return NoResultsResponse( return NoResultsResponse(
message=f"No documentation found for '{params.query}'.", message=f"No documentation found for '{query}'.",
suggestions=[ suggestions=[
"Try different keywords", "Try different keywords",
"Use more general terms", "Use more general terms",
@@ -175,7 +162,7 @@ class SearchDocsTool(BaseTool):
if not deduplicated: if not deduplicated:
return NoResultsResponse( return NoResultsResponse(
message=f"No documentation found for '{params.query}'.", message=f"No documentation found for '{query}'.",
suggestions=[ suggestions=[
"Try different keywords", "Try different keywords",
"Use more general terms", "Use more general terms",
@@ -208,7 +195,7 @@ class SearchDocsTool(BaseTool):
message=f"Found {len(doc_results)} relevant documentation sections.", message=f"Found {len(doc_results)} relevant documentation sections.",
results=doc_results, results=doc_results,
count=len(doc_results), count=len(doc_results),
query=params.query, query=query,
session_id=session_id, session_id=session_id,
) )

View File

@@ -8,7 +8,12 @@ from backend.api.features.library import model as library_model
from backend.api.features.store import db as store_db from backend.api.features.store import db as store_db
from backend.data import graph as graph_db from backend.data import graph as graph_db
from backend.data.graph import GraphModel from backend.data.graph import GraphModel
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput from backend.data.model import (
CredentialsFieldInfo,
CredentialsMetaInput,
HostScopedCredentials,
OAuth2Credentials,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import NotFoundError from backend.util.exceptions import NotFoundError
@@ -273,7 +278,14 @@ async def match_user_credentials_to_graph(
for cred in available_creds for cred in available_creds
if cred.provider in credential_requirements.provider if cred.provider in credential_requirements.provider
and cred.type in credential_requirements.supported_types and cred.type in credential_requirements.supported_types
and _credential_has_required_scopes(cred, credential_requirements) and (
cred.type != "oauth2"
or _credential_has_required_scopes(cred, credential_requirements)
)
and (
cred.type != "host_scoped"
or _credential_is_for_host(cred, credential_requirements)
)
), ),
None, None,
) )
@@ -318,19 +330,10 @@ async def match_user_credentials_to_graph(
def _credential_has_required_scopes( def _credential_has_required_scopes(
credential: Credentials, credential: OAuth2Credentials,
requirements: CredentialsFieldInfo, requirements: CredentialsFieldInfo,
) -> bool: ) -> bool:
""" """Check if an OAuth2 credential has all the scopes required by the input."""
Check if a credential has all the scopes required by the block.
For OAuth2 credentials, verifies that the credential's scopes are a superset
of the required scopes. For other credential types, returns True (no scope check).
"""
# Only OAuth2 credentials have scopes to check
if credential.type != "oauth2":
return True
# If no scopes are required, any credential matches # If no scopes are required, any credential matches
if not requirements.required_scopes: if not requirements.required_scopes:
return True return True
@@ -339,6 +342,22 @@ def _credential_has_required_scopes(
return set(credential.scopes).issuperset(requirements.required_scopes) return set(credential.scopes).issuperset(requirements.required_scopes)
def _credential_is_for_host(
credential: HostScopedCredentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""Check if a host-scoped credential matches the host required by the input."""
# We need to know the host to match host-scoped credentials to.
# Graph.aggregate_credentials_inputs() adds the node's set URL value (if any)
# to discriminator_values. No discriminator_values -> no host to match against.
if not requirements.discriminator_values:
return True
# Check that credential host matches required host.
# Host-scoped credential inputs are grouped by host, so any item from the set works.
return credential.matches_url(list(requirements.discriminator_values)[0])
async def check_user_has_required_credentials( async def check_user_has_required_credentials(
user_id: str, user_id: str,
required_credentials: list[CredentialsMetaInput], required_credentials: list[CredentialsMetaInput],

View File

@@ -2,9 +2,9 @@
import base64 import base64
import logging import logging
from typing import Any from typing import Any, Optional
from pydantic import BaseModel, field_validator from pydantic import BaseModel
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.data.workspace import get_or_create_workspace from backend.data.workspace import get_or_create_workspace
@@ -78,65 +78,6 @@ class WorkspaceDeleteResponse(ToolResponseBase):
success: bool success: bool
# Input models for workspace tools
class ListWorkspaceFilesInput(BaseModel):
"""Input parameters for list_workspace_files tool."""
path_prefix: str | None = None
limit: int = 50
include_all_sessions: bool = False
@field_validator("path_prefix", mode="before")
@classmethod
def strip_path(cls, v: Any) -> str | None:
return v.strip() if isinstance(v, str) else None
class ReadWorkspaceFileInput(BaseModel):
"""Input parameters for read_workspace_file tool."""
file_id: str | None = None
path: str | None = None
force_download_url: bool = False
@field_validator("file_id", "path", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str | None:
return v.strip() if isinstance(v, str) else None
class WriteWorkspaceFileInput(BaseModel):
"""Input parameters for write_workspace_file tool."""
filename: str = ""
content_base64: str = ""
path: str | None = None
mime_type: str | None = None
overwrite: bool = False
@field_validator("filename", "content_base64", mode="before")
@classmethod
def strip_required(cls, v: Any) -> str:
return v.strip() if isinstance(v, str) else (v if v is not None else "")
@field_validator("path", "mime_type", mode="before")
@classmethod
def strip_optional(cls, v: Any) -> str | None:
return v.strip() if isinstance(v, str) else None
class DeleteWorkspaceFileInput(BaseModel):
"""Input parameters for delete_workspace_file tool."""
file_id: str | None = None
path: str | None = None
@field_validator("file_id", "path", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str | None:
return v.strip() if isinstance(v, str) else None
class ListWorkspaceFilesTool(BaseTool): class ListWorkspaceFilesTool(BaseTool):
"""Tool for listing files in user's workspace.""" """Tool for listing files in user's workspace."""
@@ -190,9 +131,8 @@ class ListWorkspaceFilesTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs: Any, **kwargs,
) -> ToolResponseBase: ) -> ToolResponseBase:
params = ListWorkspaceFilesInput(**kwargs)
session_id = session.session_id session_id = session.session_id
if not user_id: if not user_id:
@@ -201,7 +141,9 @@ class ListWorkspaceFilesTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
limit = min(params.limit, 100) path_prefix: Optional[str] = kwargs.get("path_prefix")
limit = min(kwargs.get("limit", 50), 100)
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
try: try:
workspace = await get_or_create_workspace(user_id) workspace = await get_or_create_workspace(user_id)
@@ -209,13 +151,13 @@ class ListWorkspaceFilesTool(BaseTool):
manager = WorkspaceManager(user_id, workspace.id, session_id) manager = WorkspaceManager(user_id, workspace.id, session_id)
files = await manager.list_files( files = await manager.list_files(
path=params.path_prefix, path=path_prefix,
limit=limit, limit=limit,
include_all_sessions=params.include_all_sessions, include_all_sessions=include_all_sessions,
) )
total = await manager.get_file_count( total = await manager.get_file_count(
path=params.path_prefix, path=path_prefix,
include_all_sessions=params.include_all_sessions, include_all_sessions=include_all_sessions,
) )
file_infos = [ file_infos = [
@@ -229,9 +171,7 @@ class ListWorkspaceFilesTool(BaseTool):
for f in files for f in files
] ]
scope_msg = ( scope_msg = "all sessions" if include_all_sessions else "current session"
"all sessions" if params.include_all_sessions else "current session"
)
return WorkspaceFileListResponse( return WorkspaceFileListResponse(
files=file_infos, files=file_infos,
total_count=total, total_count=total,
@@ -319,9 +259,8 @@ class ReadWorkspaceFileTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs: Any, **kwargs,
) -> ToolResponseBase: ) -> ToolResponseBase:
params = ReadWorkspaceFileInput(**kwargs)
session_id = session.session_id session_id = session.session_id
if not user_id: if not user_id:
@@ -330,7 +269,11 @@ class ReadWorkspaceFileTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
if not params.file_id and not params.path: file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
force_download_url: bool = kwargs.get("force_download_url", False)
if not file_id and not path:
return ErrorResponse( return ErrorResponse(
message="Please provide either file_id or path", message="Please provide either file_id or path",
session_id=session_id, session_id=session_id,
@@ -342,21 +285,21 @@ class ReadWorkspaceFileTool(BaseTool):
manager = WorkspaceManager(user_id, workspace.id, session_id) manager = WorkspaceManager(user_id, workspace.id, session_id)
# Get file info # Get file info
if params.file_id: if file_id:
file_info = await manager.get_file_info(params.file_id) file_info = await manager.get_file_info(file_id)
if file_info is None: if file_info is None:
return ErrorResponse( return ErrorResponse(
message=f"File not found: {params.file_id}", message=f"File not found: {file_id}",
session_id=session_id, session_id=session_id,
) )
target_file_id = params.file_id target_file_id = file_id
else: else:
# path is guaranteed to be non-None here due to the check above # path is guaranteed to be non-None here due to the check above
assert params.path is not None assert path is not None
file_info = await manager.get_file_info_by_path(params.path) file_info = await manager.get_file_info_by_path(path)
if file_info is None: if file_info is None:
return ErrorResponse( return ErrorResponse(
message=f"File not found at path: {params.path}", message=f"File not found at path: {path}",
session_id=session_id, session_id=session_id,
) )
target_file_id = file_info.id target_file_id = file_info.id
@@ -366,7 +309,7 @@ class ReadWorkspaceFileTool(BaseTool):
is_text_file = self._is_text_mime_type(file_info.mimeType) is_text_file = self._is_text_mime_type(file_info.mimeType)
# Return inline content for small text files (unless force_download_url) # Return inline content for small text files (unless force_download_url)
if is_small_file and is_text_file and not params.force_download_url: if is_small_file and is_text_file and not force_download_url:
content = await manager.read_file_by_id(target_file_id) content = await manager.read_file_by_id(target_file_id)
content_b64 = base64.b64encode(content).decode("utf-8") content_b64 = base64.b64encode(content).decode("utf-8")
@@ -486,9 +429,8 @@ class WriteWorkspaceFileTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs: Any, **kwargs,
) -> ToolResponseBase: ) -> ToolResponseBase:
params = WriteWorkspaceFileInput(**kwargs)
session_id = session.session_id session_id = session.session_id
if not user_id: if not user_id:
@@ -497,13 +439,19 @@ class WriteWorkspaceFileTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
if not params.filename: filename: str = kwargs.get("filename", "")
content_b64: str = kwargs.get("content_base64", "")
path: Optional[str] = kwargs.get("path")
mime_type: Optional[str] = kwargs.get("mime_type")
overwrite: bool = kwargs.get("overwrite", False)
if not filename:
return ErrorResponse( return ErrorResponse(
message="Please provide a filename", message="Please provide a filename",
session_id=session_id, session_id=session_id,
) )
if not params.content_base64: if not content_b64:
return ErrorResponse( return ErrorResponse(
message="Please provide content_base64", message="Please provide content_base64",
session_id=session_id, session_id=session_id,
@@ -511,7 +459,7 @@ class WriteWorkspaceFileTool(BaseTool):
# Decode content # Decode content
try: try:
content = base64.b64decode(params.content_base64) content = base64.b64decode(content_b64)
except Exception: except Exception:
return ErrorResponse( return ErrorResponse(
message="Invalid base64-encoded content", message="Invalid base64-encoded content",
@@ -528,7 +476,7 @@ class WriteWorkspaceFileTool(BaseTool):
try: try:
# Virus scan # Virus scan
await scan_content_safe(content, filename=params.filename) await scan_content_safe(content, filename=filename)
workspace = await get_or_create_workspace(user_id) workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access # Pass session_id for session-scoped file access
@@ -536,10 +484,10 @@ class WriteWorkspaceFileTool(BaseTool):
file_record = await manager.write_file( file_record = await manager.write_file(
content=content, content=content,
filename=params.filename, filename=filename,
path=params.path, path=path,
mime_type=params.mime_type, mime_type=mime_type,
overwrite=params.overwrite, overwrite=overwrite,
) )
return WorkspaceWriteResponse( return WorkspaceWriteResponse(
@@ -609,9 +557,8 @@ class DeleteWorkspaceFileTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs: Any, **kwargs,
) -> ToolResponseBase: ) -> ToolResponseBase:
params = DeleteWorkspaceFileInput(**kwargs)
session_id = session.session_id session_id = session.session_id
if not user_id: if not user_id:
@@ -620,7 +567,10 @@ class DeleteWorkspaceFileTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
if not params.file_id and not params.path: file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
if not file_id and not path:
return ErrorResponse( return ErrorResponse(
message="Please provide either file_id or path", message="Please provide either file_id or path",
session_id=session_id, session_id=session_id,
@@ -633,15 +583,15 @@ class DeleteWorkspaceFileTool(BaseTool):
# Determine the file_id to delete # Determine the file_id to delete
target_file_id: str target_file_id: str
if params.file_id: if file_id:
target_file_id = params.file_id target_file_id = file_id
else: else:
# path is guaranteed to be non-None here due to the check above # path is guaranteed to be non-None here due to the check above
assert params.path is not None assert path is not None
file_info = await manager.get_file_info_by_path(params.path) file_info = await manager.get_file_info_by_path(path)
if file_info is None: if file_info is None:
return ErrorResponse( return ErrorResponse(
message=f"File not found at path: {params.path}", message=f"File not found at path: {path}",
session_id=session_id, session_id=session_id,
) )
target_file_id = file_info.id target_file_id = file_info.id

View File

@@ -19,7 +19,10 @@ from backend.data.graph import GraphSettings
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
from backend.data.model import CredentialsMetaInput from backend.data.model import CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate from backend.integrations.webhooks.graph_lifecycle_hooks import (
on_graph_activate,
on_graph_deactivate,
)
from backend.util.clients import get_scheduler_client from backend.util.clients import get_scheduler_client
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
from backend.util.json import SafeJson from backend.util.json import SafeJson
@@ -537,6 +540,92 @@ async def update_agent_version_in_library(
return library_model.LibraryAgent.from_db(lib) return library_model.LibraryAgent.from_db(lib)
async def create_graph_in_library(
graph: graph_db.Graph,
user_id: str,
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
"""Create a new graph and add it to the user's library."""
graph.version = 1
graph_model = graph_db.make_graph_model(graph, user_id)
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=True)
created_graph = await graph_db.create_graph(graph_model, user_id)
library_agents = await create_library_agent(
graph=created_graph,
user_id=user_id,
sensitive_action_safe_mode=True,
create_library_agents_for_sub_graphs=False,
)
if created_graph.is_active:
created_graph = await on_graph_activate(created_graph, user_id=user_id)
return created_graph, library_agents[0]
async def update_graph_in_library(
graph: graph_db.Graph,
user_id: str,
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
"""Create a new version of an existing graph and update the library entry."""
existing_versions = await graph_db.get_graph_all_versions(graph.id, user_id)
current_active_version = (
next((v for v in existing_versions if v.is_active), None)
if existing_versions
else None
)
graph.version = (
max(v.version for v in existing_versions) + 1 if existing_versions else 1
)
graph_model = graph_db.make_graph_model(graph, user_id)
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=False)
created_graph = await graph_db.create_graph(graph_model, user_id)
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
if not library_agent:
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
library_agent = await update_library_agent_version_and_settings(
user_id, created_graph
)
if created_graph.is_active:
created_graph = await on_graph_activate(created_graph, user_id=user_id)
await graph_db.set_graph_active_version(
graph_id=created_graph.id,
version=created_graph.version,
user_id=user_id,
)
if current_active_version:
await on_graph_deactivate(current_active_version, user_id=user_id)
return created_graph, library_agent
async def update_library_agent_version_and_settings(
user_id: str, agent_graph: graph_db.GraphModel
) -> library_model.LibraryAgent:
"""Update library agent to point to new graph version and sync settings."""
library = await update_agent_version_in_library(
user_id, agent_graph.id, agent_graph.version
)
updated_settings = GraphSettings.from_graph(
graph=agent_graph,
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
)
if updated_settings != library.settings:
library = await update_library_agent(
library_agent_id=library.id,
user_id=user_id,
settings=updated_settings,
)
return library
async def update_library_agent( async def update_library_agent(
library_agent_id: str, library_agent_id: str,
user_id: str, user_id: str,

View File

@@ -101,7 +101,6 @@ from backend.util.timezone_utils import (
from backend.util.virus_scanner import scan_content_safe from backend.util.virus_scanner import scan_content_safe
from .library import db as library_db from .library import db as library_db
from .library import model as library_model
from .store.model import StoreAgentDetails from .store.model import StoreAgentDetails
@@ -823,18 +822,16 @@ async def update_graph(
graph: graph_db.Graph, graph: graph_db.Graph,
user_id: Annotated[str, Security(get_user_id)], user_id: Annotated[str, Security(get_user_id)],
) -> graph_db.GraphModel: ) -> graph_db.GraphModel:
# Sanity check
if graph.id and graph.id != graph_id: if graph.id and graph.id != graph_id:
raise HTTPException(400, detail="Graph ID does not match ID in URI") raise HTTPException(400, detail="Graph ID does not match ID in URI")
# Determine new version
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id) existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
if not existing_versions: if not existing_versions:
raise HTTPException(404, detail=f"Graph #{graph_id} not found") raise HTTPException(404, detail=f"Graph #{graph_id} not found")
latest_version_number = max(g.version for g in existing_versions)
graph.version = latest_version_number + 1
graph.version = max(g.version for g in existing_versions) + 1
current_active_version = next((v for v in existing_versions if v.is_active), None) current_active_version = next((v for v in existing_versions if v.is_active), None)
graph = graph_db.make_graph_model(graph, user_id) graph = graph_db.make_graph_model(graph, user_id)
graph.reassign_ids(user_id=user_id, reassign_graph_id=False) graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
graph.validate_graph(for_run=False) graph.validate_graph(for_run=False)
@@ -842,27 +839,23 @@ async def update_graph(
new_graph_version = await graph_db.create_graph(graph, user_id=user_id) new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
if new_graph_version.is_active: if new_graph_version.is_active:
# Keep the library agent up to date with the new active version await library_db.update_library_agent_version_and_settings(
await _update_library_agent_version_and_settings(user_id, new_graph_version) user_id, new_graph_version
)
# Handle activation of the new graph first to ensure continuity
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id) new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
# Ensure new version is the only active version
await graph_db.set_graph_active_version( await graph_db.set_graph_active_version(
graph_id=graph_id, version=new_graph_version.version, user_id=user_id graph_id=graph_id, version=new_graph_version.version, user_id=user_id
) )
if current_active_version: if current_active_version:
# Handle deactivation of the previously active version
await on_graph_deactivate(current_active_version, user_id=user_id) await on_graph_deactivate(current_active_version, user_id=user_id)
# Fetch new graph version *with sub-graphs* (needed for credentials input schema)
new_graph_version_with_subgraphs = await graph_db.get_graph( new_graph_version_with_subgraphs = await graph_db.get_graph(
graph_id, graph_id,
new_graph_version.version, new_graph_version.version,
user_id=user_id, user_id=user_id,
include_subgraphs=True, include_subgraphs=True,
) )
assert new_graph_version_with_subgraphs # make type checker happy assert new_graph_version_with_subgraphs
return new_graph_version_with_subgraphs return new_graph_version_with_subgraphs
@@ -900,33 +893,15 @@ async def set_graph_active_version(
) )
# Keep the library agent up to date with the new active version # Keep the library agent up to date with the new active version
await _update_library_agent_version_and_settings(user_id, new_active_graph) await library_db.update_library_agent_version_and_settings(
user_id, new_active_graph
)
if current_active_graph and current_active_graph.version != new_active_version: if current_active_graph and current_active_graph.version != new_active_version:
# Handle deactivation of the previously active version # Handle deactivation of the previously active version
await on_graph_deactivate(current_active_graph, user_id=user_id) await on_graph_deactivate(current_active_graph, user_id=user_id)
async def _update_library_agent_version_and_settings(
user_id: str, agent_graph: graph_db.GraphModel
) -> library_model.LibraryAgent:
library = await library_db.update_agent_version_in_library(
user_id, agent_graph.id, agent_graph.version
)
updated_settings = GraphSettings.from_graph(
graph=agent_graph,
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
)
if updated_settings != library.settings:
library = await library_db.update_library_agent(
library_agent_id=library.id,
user_id=user_id,
settings=updated_settings,
)
return library
@v1_router.patch( @v1_router.patch(
path="/graphs/{graph_id}/settings", path="/graphs/{graph_id}/settings",
summary="Update graph settings", summary="Update graph settings",

View File

@@ -162,8 +162,16 @@ class LinearClient:
"searchTerm": team_name, "searchTerm": team_name,
} }
team_id = await self.query(query, variables) result = await self.query(query, variables)
return team_id["teams"]["nodes"][0]["id"] nodes = result["teams"]["nodes"]
if not nodes:
raise LinearAPIException(
f"Team '{team_name}' not found. Check the team name or key and try again.",
status_code=404,
)
return nodes[0]["id"]
except LinearAPIException as e: except LinearAPIException as e:
raise e raise e
@@ -240,17 +248,44 @@ class LinearClient:
except LinearAPIException as e: except LinearAPIException as e:
raise e raise e
async def try_search_issues(self, term: str) -> list[Issue]: async def try_search_issues(
self,
term: str,
max_results: int = 10,
team_id: str | None = None,
) -> list[Issue]:
try: try:
query = """ query = """
query SearchIssues($term: String!, $includeComments: Boolean!) { query SearchIssues(
searchIssues(term: $term, includeComments: $includeComments) { $term: String!,
$first: Int,
$teamId: String
) {
searchIssues(
term: $term,
first: $first,
teamId: $teamId
) {
nodes { nodes {
id id
identifier identifier
title title
description description
priority priority
createdAt
state {
id
name
type
}
project {
id
name
}
assignee {
id
name
}
} }
} }
} }
@@ -258,7 +293,8 @@ class LinearClient:
variables: dict[str, Any] = { variables: dict[str, Any] = {
"term": term, "term": term,
"includeComments": True, "first": max_results,
"teamId": team_id,
} }
issues = await self.query(query, variables) issues = await self.query(query, variables)

View File

@@ -17,7 +17,7 @@ from ._config import (
LinearScope, LinearScope,
linear, linear,
) )
from .models import CreateIssueResponse, Issue from .models import CreateIssueResponse, Issue, State
class LinearCreateIssueBlock(Block): class LinearCreateIssueBlock(Block):
@@ -135,9 +135,20 @@ class LinearSearchIssuesBlock(Block):
description="Linear credentials with read permissions", description="Linear credentials with read permissions",
required_scopes={LinearScope.READ}, required_scopes={LinearScope.READ},
) )
max_results: int = SchemaField(
description="Maximum number of results to return",
default=10,
ge=1,
le=100,
)
team_name: str | None = SchemaField(
description="Optional team name to filter results (e.g., 'Internal', 'Open Source')",
default=None,
)
class Output(BlockSchemaOutput): class Output(BlockSchemaOutput):
issues: list[Issue] = SchemaField(description="List of issues") issues: list[Issue] = SchemaField(description="List of issues")
error: str = SchemaField(description="Error message if the search failed")
def __init__(self): def __init__(self):
super().__init__( super().__init__(
@@ -145,8 +156,11 @@ class LinearSearchIssuesBlock(Block):
description="Searches for issues on Linear", description="Searches for issues on Linear",
input_schema=self.Input, input_schema=self.Input,
output_schema=self.Output, output_schema=self.Output,
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
test_input={ test_input={
"term": "Test issue", "term": "Test issue",
"max_results": 10,
"team_name": None,
"credentials": TEST_CREDENTIALS_INPUT_OAUTH, "credentials": TEST_CREDENTIALS_INPUT_OAUTH,
}, },
test_credentials=TEST_CREDENTIALS_OAUTH, test_credentials=TEST_CREDENTIALS_OAUTH,
@@ -156,10 +170,14 @@ class LinearSearchIssuesBlock(Block):
[ [
Issue( Issue(
id="abc123", id="abc123",
identifier="abc123", identifier="TST-123",
title="Test issue", title="Test issue",
description="Test description", description="Test description",
priority=1, priority=1,
state=State(
id="state1", name="In Progress", type="started"
),
createdAt="2026-01-15T10:00:00.000Z",
) )
], ],
) )
@@ -168,10 +186,12 @@ class LinearSearchIssuesBlock(Block):
"search_issues": lambda *args, **kwargs: [ "search_issues": lambda *args, **kwargs: [
Issue( Issue(
id="abc123", id="abc123",
identifier="abc123", identifier="TST-123",
title="Test issue", title="Test issue",
description="Test description", description="Test description",
priority=1, priority=1,
state=State(id="state1", name="In Progress", type="started"),
createdAt="2026-01-15T10:00:00.000Z",
) )
] ]
}, },
@@ -181,10 +201,22 @@ class LinearSearchIssuesBlock(Block):
async def search_issues( async def search_issues(
credentials: OAuth2Credentials | APIKeyCredentials, credentials: OAuth2Credentials | APIKeyCredentials,
term: str, term: str,
max_results: int = 10,
team_name: str | None = None,
) -> list[Issue]: ) -> list[Issue]:
client = LinearClient(credentials=credentials) client = LinearClient(credentials=credentials)
response: list[Issue] = await client.try_search_issues(term=term)
return response # Resolve team name to ID if provided
# Raises LinearAPIException with descriptive message if team not found
team_id: str | None = None
if team_name:
team_id = await client.try_get_team_by_name(team_name=team_name)
return await client.try_search_issues(
term=term,
max_results=max_results,
team_id=team_id,
)
async def run( async def run(
self, self,
@@ -196,7 +228,10 @@ class LinearSearchIssuesBlock(Block):
"""Execute the issue search""" """Execute the issue search"""
try: try:
issues = await self.search_issues( issues = await self.search_issues(
credentials=credentials, term=input_data.term credentials=credentials,
term=input_data.term,
max_results=input_data.max_results,
team_name=input_data.team_name,
) )
yield "issues", issues yield "issues", issues
except LinearAPIException as e: except LinearAPIException as e:

View File

@@ -36,12 +36,21 @@ class Project(BaseModel):
content: str | None = None content: str | None = None
class State(BaseModel):
id: str
name: str
type: str | None = (
None # Workflow state type (e.g., "triage", "backlog", "started", "completed", "canceled")
)
class Issue(BaseModel): class Issue(BaseModel):
id: str id: str
identifier: str identifier: str
title: str title: str
description: str | None description: str | None
priority: int priority: int
state: State | None = None
project: Project | None = None project: Project | None = None
createdAt: str | None = None createdAt: str | None = None
comments: list[Comment] | None = None comments: list[Comment] | None = None

View File

@@ -165,10 +165,13 @@ class TranscribeYoutubeVideoBlock(Block):
credentials: WebshareProxyCredentials, credentials: WebshareProxyCredentials,
**kwargs, **kwargs,
) -> BlockOutput: ) -> BlockOutput:
video_id = self.extract_video_id(input_data.youtube_url) try:
yield "video_id", video_id video_id = self.extract_video_id(input_data.youtube_url)
transcript = self.get_transcript(video_id, credentials)
transcript_text = self.format_transcript(transcript=transcript)
transcript = self.get_transcript(video_id, credentials) # Only yield after all operations succeed
transcript_text = self.format_transcript(transcript=transcript) yield "video_id", video_id
yield "transcript", transcript_text
yield "transcript", transcript_text except Exception as e:
yield "error", str(e)

View File

@@ -134,6 +134,16 @@ async def test_block_credit_reset(server: SpinTestServer):
month1 = datetime.now(timezone.utc).replace(month=1, day=1) month1 = datetime.now(timezone.utc).replace(month=1, day=1)
user_credit.time_now = lambda: month1 user_credit.time_now = lambda: month1
# IMPORTANT: Set updatedAt to December of previous year to ensure it's
# in a different month than month1 (January). This fixes a timing bug
# where if the test runs in early February, 35 days ago would be January,
# matching the mocked month1 and preventing the refill from triggering.
dec_previous_year = month1.replace(year=month1.year - 1, month=12, day=15)
await UserBalance.prisma().update(
where={"userId": DEFAULT_USER_ID},
data={"updatedAt": dec_previous_year},
)
# First call in month 1 should trigger refill # First call in month 1 should trigger refill
balance = await user_credit.get_credits(DEFAULT_USER_ID) balance = await user_credit.get_credits(DEFAULT_USER_ID)
assert balance == REFILL_VALUE # Should get 1000 credits assert balance == REFILL_VALUE # Should get 1000 credits

View File

@@ -19,7 +19,6 @@ from typing import (
cast, cast,
get_args, get_args,
) )
from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
from prisma.enums import CreditTransactionType, OnboardingStep from prisma.enums import CreditTransactionType, OnboardingStep
@@ -42,6 +41,7 @@ from typing_extensions import TypedDict
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
from backend.util.json import loads as json_loads from backend.util.json import loads as json_loads
from backend.util.request import parse_url
from backend.util.settings import Secrets from backend.util.settings import Secrets
# Type alias for any provider name (including custom ones) # Type alias for any provider name (including custom ones)
@@ -397,19 +397,25 @@ class HostScopedCredentials(_BaseCredentials):
def matches_url(self, url: str) -> bool: def matches_url(self, url: str) -> bool:
"""Check if this credential should be applied to the given URL.""" """Check if this credential should be applied to the given URL."""
parsed_url = urlparse(url) request_host, request_port = _extract_host_from_url(url)
# Extract hostname without port cred_scope_host, cred_scope_port = _extract_host_from_url(self.host)
request_host = parsed_url.hostname
if not request_host: if not request_host:
return False return False
# Simple host matching - exact match or wildcard subdomain match # If a port is specified in credential host, the request host port must match
if self.host == request_host: if cred_scope_port is not None and request_port != cred_scope_port:
return False
# Non-standard ports are only allowed if explicitly specified in credential host
elif cred_scope_port is None and request_port not in (80, 443, None):
return False
# Simple host matching
if cred_scope_host == request_host:
return True return True
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com") # Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
if self.host.startswith("*."): if cred_scope_host.startswith("*."):
domain = self.host[2:] # Remove "*." domain = cred_scope_host[2:] # Remove "*."
return request_host.endswith(f".{domain}") or request_host == domain return request_host.endswith(f".{domain}") or request_host == domain
return False return False
@@ -551,13 +557,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
) )
def _extract_host_from_url(url: str) -> str: def _extract_host_from_url(url: str) -> tuple[str, int | None]:
"""Extract host from URL for grouping host-scoped credentials.""" """Extract host and port from URL for grouping host-scoped credentials."""
try: try:
parsed = urlparse(url) parsed = parse_url(url)
return parsed.hostname or url return parsed.hostname or url, parsed.port
except Exception: except Exception:
return "" return "", None
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]): class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
@@ -606,7 +612,7 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
providers = frozenset( providers = frozenset(
[cast(CP, "http")] [cast(CP, "http")]
+ [ + [
cast(CP, _extract_host_from_url(str(value))) cast(CP, parse_url(str(value)).netloc)
for value in field.discriminator_values for value in field.discriminator_values
] ]
) )

View File

@@ -79,10 +79,23 @@ class TestHostScopedCredentials:
headers={"Authorization": SecretStr("Bearer token")}, headers={"Authorization": SecretStr("Bearer token")},
) )
assert creds.matches_url("http://localhost:8080/api/v1") # Non-standard ports require explicit port in credential host
assert not creds.matches_url("http://localhost:8080/api/v1")
assert creds.matches_url("https://localhost:443/secure/endpoint") assert creds.matches_url("https://localhost:443/secure/endpoint")
assert creds.matches_url("http://localhost/simple") assert creds.matches_url("http://localhost/simple")
def test_matches_url_with_explicit_port(self):
"""Test URL matching with explicit port in credential host."""
creds = HostScopedCredentials(
provider="custom",
host="localhost:8080",
headers={"Authorization": SecretStr("Bearer token")},
)
assert creds.matches_url("http://localhost:8080/api/v1")
assert not creds.matches_url("http://localhost:3000/api/v1")
assert not creds.matches_url("http://localhost/simple")
def test_empty_headers_dict(self): def test_empty_headers_dict(self):
"""Test HostScopedCredentials with empty headers.""" """Test HostScopedCredentials with empty headers."""
creds = HostScopedCredentials( creds = HostScopedCredentials(
@@ -128,8 +141,20 @@ class TestHostScopedCredentials:
("*.example.com", "https://sub.api.example.com/test", True), ("*.example.com", "https://sub.api.example.com/test", True),
("*.example.com", "https://example.com/test", True), ("*.example.com", "https://example.com/test", True),
("*.example.com", "https://example.org/test", False), ("*.example.com", "https://example.org/test", False),
("localhost", "http://localhost:3000/test", True), # Non-standard ports require explicit port in credential host
("localhost", "http://localhost:3000/test", False),
("localhost:3000", "http://localhost:3000/test", True),
("localhost", "http://127.0.0.1:3000/test", False), ("localhost", "http://127.0.0.1:3000/test", False),
# IPv6 addresses (frontend stores with brackets via URL.hostname)
("[::1]", "http://[::1]/test", True),
("[::1]", "http://[::1]:80/test", True),
("[::1]", "https://[::1]:443/test", True),
("[::1]", "http://[::1]:8080/test", False), # Non-standard port
("[::1]:8080", "http://[::1]:8080/test", True),
("[::1]:8080", "http://[::1]:9090/test", False),
("[2001:db8::1]", "http://[2001:db8::1]/path", True),
("[2001:db8::1]", "https://[2001:db8::1]:443/path", True),
("[2001:db8::1]", "http://[2001:db8::ff]/path", False),
], ],
) )
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool): def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):

View File

@@ -157,12 +157,7 @@ async def validate_url(
is_trusted: Boolean indicating if the hostname is in trusted_origins is_trusted: Boolean indicating if the hostname is in trusted_origins
ip_addresses: List of IP addresses for the host; empty if the host is trusted ip_addresses: List of IP addresses for the host; empty if the host is trusted
""" """
# Canonicalize URL parsed = parse_url(url)
url = url.strip("/ ").replace("\\", "/")
parsed = urlparse(url)
if not parsed.scheme:
url = f"http://{url}"
parsed = urlparse(url)
# Check scheme # Check scheme
if parsed.scheme not in ALLOWED_SCHEMES: if parsed.scheme not in ALLOWED_SCHEMES:
@@ -220,6 +215,17 @@ async def validate_url(
) )
def parse_url(url: str) -> URL:
"""Canonicalizes and parses a URL string."""
url = url.strip("/ ").replace("\\", "/")
# Ensure scheme is present for proper parsing
if not re.match(r"[a-z0-9+.\-]+://", url):
url = f"http://{url}"
return urlparse(url)
def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL: def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL:
""" """
Pins a URL to a specific IP address to prevent DNS rebinding attacks. Pins a URL to a specific IP address to prevent DNS rebinding attacks.

View File

@@ -1,6 +1,17 @@
import { OAuthPopupResultMessage } from "./types"; import { OAuthPopupResultMessage } from "./types";
import { NextResponse } from "next/server"; import { NextResponse } from "next/server";
/**
* Safely encode a value as JSON for embedding in a script tag.
* Escapes characters that could break out of the script context to prevent XSS.
*/
function safeJsonStringify(value: unknown): string {
return JSON.stringify(value)
.replace(/</g, "\\u003c")
.replace(/>/g, "\\u003e")
.replace(/&/g, "\\u0026");
}
// This route is intended to be used as the callback for integration OAuth flows, // This route is intended to be used as the callback for integration OAuth flows,
// controlled by the CredentialsInput component. The CredentialsInput opens the login // controlled by the CredentialsInput component. The CredentialsInput opens the login
// page in a pop-up window, which then redirects to this route to close the loop. // page in a pop-up window, which then redirects to this route to close the loop.
@@ -23,12 +34,13 @@ export async function GET(request: Request) {
console.debug("Sending message to opener:", message); console.debug("Sending message to opener:", message);
// Return a response with the message as JSON and a script to close the window // Return a response with the message as JSON and a script to close the window
// Use safeJsonStringify to prevent XSS by escaping <, >, and & characters
return new NextResponse( return new NextResponse(
` `
<html> <html>
<body> <body>
<script> <script>
window.opener.postMessage(${JSON.stringify(message)}); window.opener.postMessage(${safeJsonStringify(message)});
window.close(); window.close();
</script> </script>
</body> </body>

View File

@@ -26,8 +26,20 @@ export function buildCopilotChatUrl(prompt: string): string {
export function getQuickActions(): string[] { export function getQuickActions(): string[] {
return [ return [
"Show me what I can automate", "I don't know where to start, just ask me stuff",
"Design a custom workflow", "I do the same thing every week and it's killing me",
"Help me with content creation", "Help me find where I'm wasting my time",
]; ];
} }
export function getInputPlaceholder(width?: number) {
if (!width) return "What's your role and what eats up most of your day?";
if (width < 500) {
return "I'm a chef and I hate...";
}
if (width <= 1080) {
return "What's your role and what eats up most of your day?";
}
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
}

View File

@@ -6,7 +6,9 @@ import { Text } from "@/components/atoms/Text/Text";
import { Chat } from "@/components/contextual/Chat/Chat"; import { Chat } from "@/components/contextual/Chat/Chat";
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput"; import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
import { Dialog } from "@/components/molecules/Dialog/Dialog"; import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { useEffect, useState } from "react";
import { useCopilotStore } from "./copilot-page-store"; import { useCopilotStore } from "./copilot-page-store";
import { getInputPlaceholder } from "./helpers";
import { useCopilotPage } from "./useCopilotPage"; import { useCopilotPage } from "./useCopilotPage";
export default function CopilotPage() { export default function CopilotPage() {
@@ -14,8 +16,25 @@ export default function CopilotPage() {
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen); const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt); const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt); const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
const [inputPlaceholder, setInputPlaceholder] = useState(
getInputPlaceholder(),
);
useEffect(() => {
const handleResize = () => {
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
};
handleResize();
window.addEventListener("resize", handleResize);
return () => window.removeEventListener("resize", handleResize);
}, []);
const { greetingName, quickActions, isLoading, hasSession, initialPrompt } = const { greetingName, quickActions, isLoading, hasSession, initialPrompt } =
state; state;
const { const {
handleQuickAction, handleQuickAction,
startChatWithPrompt, startChatWithPrompt,
@@ -73,7 +92,7 @@ export default function CopilotPage() {
} }
return ( return (
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10"> <div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-3 py-5 md:px-6 md:py-10">
<div className="w-full text-center"> <div className="w-full text-center">
{isLoading ? ( {isLoading ? (
<div className="mx-auto max-w-2xl"> <div className="mx-auto max-w-2xl">
@@ -90,25 +109,25 @@ export default function CopilotPage() {
</div> </div>
) : ( ) : (
<> <>
<div className="mx-auto max-w-2xl"> <div className="mx-auto max-w-3xl">
<Text <Text
variant="h3" variant="h3"
className="mb-3 !text-[1.375rem] text-zinc-700" className="mb-1 !text-[1.375rem] text-zinc-700"
> >
Hey, <span className="text-violet-600">{greetingName}</span> Hey, <span className="text-violet-600">{greetingName}</span>
</Text> </Text>
<Text variant="h3" className="mb-8 !font-normal"> <Text variant="h3" className="mb-8 !font-normal">
What do you want to automate? Tell me about your work I&apos;ll find what to automate.
</Text> </Text>
<div className="mb-6"> <div className="mb-6">
<ChatInput <ChatInput
onSend={startChatWithPrompt} onSend={startChatWithPrompt}
placeholder='You can search or just ask - e.g. "create a blog post outline"' placeholder={inputPlaceholder}
/> />
</div> </div>
</div> </div>
<div className="flex flex-nowrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden"> <div className="flex flex-wrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
{quickActions.map((action) => ( {quickActions.map((action) => (
<Button <Button
key={action} key={action}
@@ -116,7 +135,7 @@ export default function CopilotPage() {
variant="outline" variant="outline"
size="small" size="small"
onClick={() => handleQuickAction(action)} onClick={() => handleQuickAction(action)}
className="h-auto shrink-0 border-zinc-600 !px-4 !py-2 text-[1rem] text-zinc-600" className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
> >
{action} {action}
</Button> </Button>

View File

@@ -2,7 +2,6 @@ import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessi
import { Button } from "@/components/atoms/Button/Button"; import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text"; import { Text } from "@/components/atoms/Text/Text";
import { Dialog } from "@/components/molecules/Dialog/Dialog"; import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import { GlobeHemisphereEastIcon } from "@phosphor-icons/react"; import { GlobeHemisphereEastIcon } from "@phosphor-icons/react";
import { useEffect } from "react"; import { useEffect } from "react";
@@ -56,10 +55,6 @@ export function ChatContainer({
onStreamingChange?.(isStreaming); onStreamingChange?.(isStreaming);
}, [isStreaming, onStreamingChange]); }, [isStreaming, onStreamingChange]);
const breakpoint = useBreakpoint();
const isMobile =
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
return ( return (
<div <div
className={cn( className={cn(
@@ -127,11 +122,7 @@ export function ChatContainer({
disabled={isStreaming || !sessionId} disabled={isStreaming || !sessionId}
isStreaming={isStreaming} isStreaming={isStreaming}
onStop={stopStreaming} onStop={stopStreaming}
placeholder={ placeholder="What else can I help with?"
isMobile
? "You can search or just ask"
: 'You can search or just ask — e.g. "create a blog post outline"'
}
/> />
</div> </div>
</div> </div>

View File

@@ -74,19 +74,20 @@ export function ChatInput({
hasMultipleLines ? "rounded-xlarge" : "rounded-full", hasMultipleLines ? "rounded-xlarge" : "rounded-full",
)} )}
> >
{!value && !isRecording && (
<div
className="pointer-events-none absolute inset-0 top-0.5 flex items-center justify-start pl-14 text-[1rem] text-zinc-400"
aria-hidden="true"
>
{isTranscribing ? "Transcribing..." : placeholder}
</div>
)}
<textarea <textarea
id={inputId} id={inputId}
aria-label="Chat message input" aria-label="Chat message input"
value={value} value={value}
onChange={handleChange} onChange={handleChange}
onKeyDown={handleKeyDown} onKeyDown={handleKeyDown}
placeholder={
isTranscribing
? "Transcribing..."
: isRecording
? ""
: placeholder
}
disabled={isInputDisabled} disabled={isInputDisabled}
rows={1} rows={1}
className={cn( className={cn(
@@ -122,13 +123,14 @@ export function ChatInput({
size="icon" size="icon"
aria-label={isRecording ? "Stop recording" : "Start recording"} aria-label={isRecording ? "Stop recording" : "Start recording"}
onClick={toggleRecording} onClick={toggleRecording}
disabled={disabled || isTranscribing} disabled={disabled || isTranscribing || isStreaming}
className={cn( className={cn(
isRecording isRecording
? "animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600" ? "animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600"
: isTranscribing : isTranscribing
? "border-zinc-300 bg-zinc-100 text-zinc-400" ? "border-zinc-300 bg-zinc-100 text-zinc-400"
: "border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700", : "border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
isStreaming && "opacity-40",
)} )}
> >
{isTranscribing ? ( {isTranscribing ? (

View File

@@ -38,8 +38,8 @@ export function AudioWaveform({
// Create audio context and analyser // Create audio context and analyser
const audioContext = new AudioContext(); const audioContext = new AudioContext();
const analyser = audioContext.createAnalyser(); const analyser = audioContext.createAnalyser();
analyser.fftSize = 512; analyser.fftSize = 256;
analyser.smoothingTimeConstant = 0.8; analyser.smoothingTimeConstant = 0.3;
// Connect the stream to the analyser // Connect the stream to the analyser
const source = audioContext.createMediaStreamSource(stream); const source = audioContext.createMediaStreamSource(stream);
@@ -73,10 +73,11 @@ export function AudioWaveform({
maxAmplitude = Math.max(maxAmplitude, amplitude); maxAmplitude = Math.max(maxAmplitude, amplitude);
} }
// Map amplitude (0-128) to bar height // Normalize amplitude (0-128 range) to 0-1
const normalized = (maxAmplitude / 128) * 255; const normalized = maxAmplitude / 128;
const height = // Apply sensitivity boost (multiply by 4) and use sqrt curve to amplify quiet sounds
minBarHeight + (normalized / 255) * (maxBarHeight - minBarHeight); const boosted = Math.min(1, Math.sqrt(normalized) * 4);
const height = minBarHeight + boosted * (maxBarHeight - minBarHeight);
newBars.push(height); newBars.push(height);
} }

View File

@@ -224,7 +224,7 @@ export function useVoiceRecording({
[value, isTranscribing, toggleRecording, baseHandleKeyDown], [value, isTranscribing, toggleRecording, baseHandleKeyDown],
); );
const showMicButton = isSupported && !isStreaming; const showMicButton = isSupported;
const isInputDisabled = disabled || isStreaming || isTranscribing; const isInputDisabled = disabled || isStreaming || isTranscribing;
// Cleanup on unmount // Cleanup on unmount

View File

@@ -346,6 +346,7 @@ export function ChatMessage({
toolId={message.toolId} toolId={message.toolId}
toolName={message.toolName} toolName={message.toolName}
result={message.result} result={message.result}
onSendMessage={onSendMessage}
/> />
</div> </div>
); );

View File

@@ -73,6 +73,7 @@ export function MessageList({
key={index} key={index}
message={message} message={message}
prevMessage={messages[index - 1]} prevMessage={messages[index - 1]}
onSendMessage={onSendMessage}
/> />
); );
} }

View File

@@ -5,11 +5,13 @@ import { shouldSkipAgentOutput } from "../../helpers";
export interface LastToolResponseProps { export interface LastToolResponseProps {
message: ChatMessageData; message: ChatMessageData;
prevMessage: ChatMessageData | undefined; prevMessage: ChatMessageData | undefined;
onSendMessage?: (content: string) => void;
} }
export function LastToolResponse({ export function LastToolResponse({
message, message,
prevMessage, prevMessage,
onSendMessage,
}: LastToolResponseProps) { }: LastToolResponseProps) {
if (message.type !== "tool_response") return null; if (message.type !== "tool_response") return null;
@@ -21,6 +23,7 @@ export function LastToolResponse({
toolId={message.toolId} toolId={message.toolId}
toolName={message.toolName} toolName={message.toolName}
result={message.result} result={message.result}
onSendMessage={onSendMessage}
/> />
</div> </div>
); );

View File

@@ -1,6 +1,8 @@
import { Progress } from "@/components/atoms/Progress/Progress";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import { useEffect, useRef, useState } from "react"; import { useEffect, useRef, useState } from "react";
import { AIChatBubble } from "../AIChatBubble/AIChatBubble"; import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
import { useAsymptoticProgress } from "../ToolCallMessage/useAsymptoticProgress";
export interface ThinkingMessageProps { export interface ThinkingMessageProps {
className?: string; className?: string;
@@ -11,18 +13,19 @@ export function ThinkingMessage({ className }: ThinkingMessageProps) {
const [showCoffeeMessage, setShowCoffeeMessage] = useState(false); const [showCoffeeMessage, setShowCoffeeMessage] = useState(false);
const timerRef = useRef<NodeJS.Timeout | null>(null); const timerRef = useRef<NodeJS.Timeout | null>(null);
const coffeeTimerRef = useRef<NodeJS.Timeout | null>(null); const coffeeTimerRef = useRef<NodeJS.Timeout | null>(null);
const progress = useAsymptoticProgress(showCoffeeMessage);
useEffect(() => { useEffect(() => {
if (timerRef.current === null) { if (timerRef.current === null) {
timerRef.current = setTimeout(() => { timerRef.current = setTimeout(() => {
setShowSlowLoader(true); setShowSlowLoader(true);
}, 8000); }, 3000);
} }
if (coffeeTimerRef.current === null) { if (coffeeTimerRef.current === null) {
coffeeTimerRef.current = setTimeout(() => { coffeeTimerRef.current = setTimeout(() => {
setShowCoffeeMessage(true); setShowCoffeeMessage(true);
}, 10000); }, 8000);
} }
return () => { return () => {
@@ -49,9 +52,18 @@ export function ThinkingMessage({ className }: ThinkingMessageProps) {
<AIChatBubble> <AIChatBubble>
<div className="transition-all duration-500 ease-in-out"> <div className="transition-all duration-500 ease-in-out">
{showCoffeeMessage ? ( {showCoffeeMessage ? (
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent"> <div className="flex flex-col items-center gap-3">
This could take a few minutes, grab a coffee <div className="flex w-full max-w-[280px] flex-col gap-1.5">
</span> <div className="flex items-center justify-between text-xs text-neutral-500">
<span>Working on it...</span>
<span>{Math.round(progress)}%</span>
</div>
<Progress value={progress} className="h-2 w-full" />
</div>
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
This could take a few minutes, grab a coffee
</span>
</div>
) : showSlowLoader ? ( ) : showSlowLoader ? (
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent"> <span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
Taking a bit more time... Taking a bit more time...

View File

@@ -0,0 +1,50 @@
import { useEffect, useRef, useState } from "react";
/**
* Hook that returns a progress value that starts fast and slows down,
* asymptotically approaching but never reaching the max value.
*
* Uses a half-life formula: progress = max * (1 - 0.5^(time/halfLife))
* This creates the "game loading bar" effect where:
* - 50% is reached at halfLifeSeconds
* - 75% is reached at 2 * halfLifeSeconds
* - 87.5% is reached at 3 * halfLifeSeconds
* - and so on...
*
* @param isActive - Whether the progress should be animating
* @param halfLifeSeconds - Time in seconds to reach 50% progress (default: 30)
* @param maxProgress - Maximum progress value to approach (default: 100)
* @param intervalMs - Update interval in milliseconds (default: 100)
* @returns Current progress value (0-maxProgress)
*/
export function useAsymptoticProgress(
isActive: boolean,
halfLifeSeconds = 30,
maxProgress = 100,
intervalMs = 100,
) {
const [progress, setProgress] = useState(0);
const elapsedTimeRef = useRef(0);
useEffect(() => {
if (!isActive) {
setProgress(0);
elapsedTimeRef.current = 0;
return;
}
const interval = setInterval(() => {
elapsedTimeRef.current += intervalMs / 1000;
// Half-life approach: progress = max * (1 - 0.5^(time/halfLife))
// At t=halfLife: 50%, at t=2*halfLife: 75%, at t=3*halfLife: 87.5%, etc.
const newProgress =
maxProgress *
(1 - Math.pow(0.5, elapsedTimeRef.current / halfLifeSeconds));
setProgress(newProgress);
}, intervalMs);
return () => clearInterval(interval);
}, [isActive, halfLifeSeconds, maxProgress, intervalMs]);
return progress;
}

View File

@@ -0,0 +1,128 @@
"use client";
import { useGetV2GetLibraryAgent } from "@/app/api/__generated__/endpoints/library/library";
import { GraphExecutionJobInfo } from "@/app/api/__generated__/models/graphExecutionJobInfo";
import { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
import { RunAgentModal } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/RunAgentModal";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import {
CheckCircleIcon,
PencilLineIcon,
PlayIcon,
} from "@phosphor-icons/react";
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
interface Props {
agentName: string;
libraryAgentId: string;
onSendMessage?: (content: string) => void;
}
export function AgentCreatedPrompt({
agentName,
libraryAgentId,
onSendMessage,
}: Props) {
// Fetch library agent eagerly so modal is ready when user clicks
const { data: libraryAgentResponse, isLoading } = useGetV2GetLibraryAgent(
libraryAgentId,
{
query: {
enabled: !!libraryAgentId,
},
},
);
const libraryAgent =
libraryAgentResponse?.status === 200 ? libraryAgentResponse.data : null;
function handleRunWithPlaceholders() {
onSendMessage?.(
`Run the agent "${agentName}" with placeholder/example values so I can test it.`,
);
}
function handleRunCreated(execution: GraphExecutionMeta) {
onSendMessage?.(
`I've started the agent "${agentName}". The execution ID is ${execution.id}. Please monitor its progress and let me know when it completes.`,
);
}
function handleScheduleCreated(schedule: GraphExecutionJobInfo) {
const scheduleInfo = schedule.cron
? `with cron schedule "${schedule.cron}"`
: "to run on the specified schedule";
onSendMessage?.(
`I've scheduled the agent "${agentName}" ${scheduleInfo}. The schedule ID is ${schedule.id}.`,
);
}
return (
<AIChatBubble>
<div className="flex flex-col gap-4">
<div className="flex items-center gap-2">
<div className="flex h-8 w-8 items-center justify-center rounded-full bg-green-100">
<CheckCircleIcon
size={18}
weight="fill"
className="text-green-600"
/>
</div>
<div>
<Text variant="body-medium" className="text-neutral-900">
Agent Created Successfully
</Text>
<Text variant="small" className="text-neutral-500">
&quot;{agentName}&quot; is ready to test
</Text>
</div>
</div>
<div className="flex flex-col gap-2">
<Text variant="small-medium" className="text-neutral-700">
Ready to test?
</Text>
<div className="flex flex-wrap gap-2">
<Button
variant="outline"
size="small"
onClick={handleRunWithPlaceholders}
className="gap-2"
>
<PlayIcon size={16} />
Run with example values
</Button>
{libraryAgent ? (
<RunAgentModal
triggerSlot={
<Button variant="outline" size="small" className="gap-2">
<PencilLineIcon size={16} />
Run with my inputs
</Button>
}
agent={libraryAgent}
onRunCreated={handleRunCreated}
onScheduleCreated={handleScheduleCreated}
/>
) : (
<Button
variant="outline"
size="small"
loading={isLoading}
disabled
className="gap-2"
>
<PencilLineIcon size={16} />
Run with my inputs
</Button>
)}
</div>
<Text variant="small" className="text-neutral-500">
or just ask me
</Text>
</div>
</div>
</AIChatBubble>
);
}

View File

@@ -2,11 +2,13 @@ import { Text } from "@/components/atoms/Text/Text";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import type { ToolResult } from "@/types/chat"; import type { ToolResult } from "@/types/chat";
import { WarningCircleIcon } from "@phosphor-icons/react"; import { WarningCircleIcon } from "@phosphor-icons/react";
import { AgentCreatedPrompt } from "./AgentCreatedPrompt";
import { AIChatBubble } from "../AIChatBubble/AIChatBubble"; import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
import { MarkdownContent } from "../MarkdownContent/MarkdownContent"; import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
import { import {
formatToolResponse, formatToolResponse,
getErrorMessage, getErrorMessage,
isAgentSavedResponse,
isErrorResponse, isErrorResponse,
} from "./helpers"; } from "./helpers";
@@ -16,6 +18,7 @@ export interface ToolResponseMessageProps {
result?: ToolResult; result?: ToolResult;
success?: boolean; success?: boolean;
className?: string; className?: string;
onSendMessage?: (content: string) => void;
} }
export function ToolResponseMessage({ export function ToolResponseMessage({
@@ -24,6 +27,7 @@ export function ToolResponseMessage({
result, result,
success: _success, success: _success,
className, className,
onSendMessage,
}: ToolResponseMessageProps) { }: ToolResponseMessageProps) {
if (isErrorResponse(result)) { if (isErrorResponse(result)) {
const errorMessage = getErrorMessage(result); const errorMessage = getErrorMessage(result);
@@ -43,6 +47,18 @@ export function ToolResponseMessage({
); );
} }
// Check for agent_saved response - show special prompt
const agentSavedData = isAgentSavedResponse(result);
if (agentSavedData.isSaved) {
return (
<AgentCreatedPrompt
agentName={agentSavedData.agentName}
libraryAgentId={agentSavedData.libraryAgentId}
onSendMessage={onSendMessage}
/>
);
}
const formattedText = formatToolResponse(result, toolName); const formattedText = formatToolResponse(result, toolName);
return ( return (

View File

@@ -6,6 +6,43 @@ function stripInternalReasoning(content: string): string {
.trim(); .trim();
} }
export interface AgentSavedData {
isSaved: boolean;
agentName: string;
agentId: string;
libraryAgentId: string;
libraryAgentLink: string;
}
export function isAgentSavedResponse(result: unknown): AgentSavedData {
if (typeof result !== "object" || result === null) {
return {
isSaved: false,
agentName: "",
agentId: "",
libraryAgentId: "",
libraryAgentLink: "",
};
}
const response = result as Record<string, unknown>;
if (response.type === "agent_saved") {
return {
isSaved: true,
agentName: (response.agent_name as string) || "Agent",
agentId: (response.agent_id as string) || "",
libraryAgentId: (response.library_agent_id as string) || "",
libraryAgentLink: (response.library_agent_link as string) || "",
};
}
return {
isSaved: false,
agentName: "",
agentId: "",
libraryAgentId: "",
libraryAgentLink: "",
};
}
export function isErrorResponse(result: unknown): boolean { export function isErrorResponse(result: unknown): boolean {
if (typeof result === "string") { if (typeof result === "string") {
const lower = result.toLowerCase(); const lower = result.toLowerCase();

View File

@@ -41,7 +41,17 @@ export function HostScopedCredentialsModal({
const currentHost = currentUrl ? getHostFromUrl(currentUrl) : ""; const currentHost = currentUrl ? getHostFromUrl(currentUrl) : "";
const formSchema = z.object({ const formSchema = z.object({
host: z.string().min(1, "Host is required"), host: z
.string()
.min(1, "Host is required")
.refine((val) => !/^[a-zA-Z][a-zA-Z\d+\-.]*:\/\//.test(val), {
message: "Enter only the host (e.g. api.example.com), not a full URL",
})
.refine((val) => !val.includes("/"), {
message:
"Enter only the host (e.g. api.example.com), without a trailing path. " +
"You may specify a port (e.g. api.example.com:8080) if needed.",
}),
title: z.string().optional(), title: z.string().optional(),
headers: z.record(z.string()).optional(), headers: z.record(z.string()).optional(),
}); });

View File

@@ -15,7 +15,6 @@ import {
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import { useOnboarding } from "@/providers/onboarding/onboarding-provider"; import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag"; import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { storage, Key as StorageKey } from "@/services/storage/local-storage";
import { WalletIcon } from "@phosphor-icons/react"; import { WalletIcon } from "@phosphor-icons/react";
import { PopoverClose } from "@radix-ui/react-popover"; import { PopoverClose } from "@radix-ui/react-popover";
import { X } from "lucide-react"; import { X } from "lucide-react";
@@ -175,7 +174,6 @@ export function Wallet() {
const [prevCredits, setPrevCredits] = useState<number | null>(credits); const [prevCredits, setPrevCredits] = useState<number | null>(credits);
const [flash, setFlash] = useState(false); const [flash, setFlash] = useState(false);
const [walletOpen, setWalletOpen] = useState(false); const [walletOpen, setWalletOpen] = useState(false);
const [lastSeenCredits, setLastSeenCredits] = useState<number | null>(null);
const totalCount = useMemo(() => { const totalCount = useMemo(() => {
return groups.reduce((acc, group) => acc + group.tasks.length, 0); return groups.reduce((acc, group) => acc + group.tasks.length, 0);
@@ -200,38 +198,6 @@ export function Wallet() {
setCompletedCount(completed); setCompletedCount(completed);
}, [groups, state?.completedSteps]); }, [groups, state?.completedSteps]);
// Load last seen credits from localStorage once on mount
useEffect(() => {
const stored = storage.get(StorageKey.WALLET_LAST_SEEN_CREDITS);
if (stored !== undefined && stored !== null) {
const parsed = parseFloat(stored);
if (!Number.isNaN(parsed)) setLastSeenCredits(parsed);
else setLastSeenCredits(0);
} else {
setLastSeenCredits(0);
}
}, []);
// Auto-open once if never shown, otherwise open only when credits increase beyond last seen
useEffect(() => {
if (typeof credits !== "number") return;
// Open once for first-time users
if (state && state.walletShown === false) {
requestAnimationFrame(() => setWalletOpen(true));
// Mark as shown so it won't reopen on every reload
updateState({ walletShown: true });
return;
}
// Open if user gained more credits than last acknowledged
if (
lastSeenCredits !== null &&
credits > lastSeenCredits &&
walletOpen === false
) {
requestAnimationFrame(() => setWalletOpen(true));
}
}, [credits, lastSeenCredits, state?.walletShown, updateState, walletOpen]);
const onWalletOpen = useCallback(async () => { const onWalletOpen = useCallback(async () => {
if (!state?.walletShown) { if (!state?.walletShown) {
updateState({ walletShown: true }); updateState({ walletShown: true });
@@ -324,19 +290,7 @@ export function Wallet() {
if (credits === null || !state) return null; if (credits === null || !state) return null;
return ( return (
<Popover <Popover open={walletOpen} onOpenChange={(open) => setWalletOpen(open)}>
open={walletOpen}
onOpenChange={(open) => {
setWalletOpen(open);
if (!open) {
// Persist the latest acknowledged credits so we only auto-open on future gains
if (typeof credits === "number") {
storage.set(StorageKey.WALLET_LAST_SEEN_CREDITS, String(credits));
setLastSeenCredits(credits);
}
}
}}
>
<PopoverTrigger asChild> <PopoverTrigger asChild>
<div className="relative inline-block"> <div className="relative inline-block">
<button <button

View File

@@ -62,7 +62,6 @@ Below is a comprehensive list of all available blocks, categorized by their prim
| [Get Store Agent Details](block-integrations/system/store_operations.md#get-store-agent-details) | Get detailed information about an agent from the store | | [Get Store Agent Details](block-integrations/system/store_operations.md#get-store-agent-details) | Get detailed information about an agent from the store |
| [Get Weather Information](block-integrations/basic.md#get-weather-information) | Retrieves weather information for a specified location using OpenWeatherMap API | | [Get Weather Information](block-integrations/basic.md#get-weather-information) | Retrieves weather information for a specified location using OpenWeatherMap API |
| [Human In The Loop](block-integrations/basic.md#human-in-the-loop) | Pause execution and wait for human approval or modification of data | | [Human In The Loop](block-integrations/basic.md#human-in-the-loop) | Pause execution and wait for human approval or modification of data |
| [Linear Search Issues](block-integrations/linear/issues.md#linear-search-issues) | Searches for issues on Linear |
| [List Is Empty](block-integrations/basic.md#list-is-empty) | Checks if a list is empty | | [List Is Empty](block-integrations/basic.md#list-is-empty) | Checks if a list is empty |
| [List Library Agents](block-integrations/system/library_operations.md#list-library-agents) | List all agents in your personal library | | [List Library Agents](block-integrations/system/library_operations.md#list-library-agents) | List all agents in your personal library |
| [Note](block-integrations/basic.md#note) | A visual annotation block that displays a sticky note in the workflow editor for documentation and organization purposes | | [Note](block-integrations/basic.md#note) | A visual annotation block that displays a sticky note in the workflow editor for documentation and organization purposes |
@@ -571,6 +570,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim
| [Linear Create Comment](block-integrations/linear/comment.md#linear-create-comment) | Creates a new comment on a Linear issue | | [Linear Create Comment](block-integrations/linear/comment.md#linear-create-comment) | Creates a new comment on a Linear issue |
| [Linear Create Issue](block-integrations/linear/issues.md#linear-create-issue) | Creates a new issue on Linear | | [Linear Create Issue](block-integrations/linear/issues.md#linear-create-issue) | Creates a new issue on Linear |
| [Linear Get Project Issues](block-integrations/linear/issues.md#linear-get-project-issues) | Gets issues from a Linear project filtered by status and assignee | | [Linear Get Project Issues](block-integrations/linear/issues.md#linear-get-project-issues) | Gets issues from a Linear project filtered by status and assignee |
| [Linear Search Issues](block-integrations/linear/issues.md#linear-search-issues) | Searches for issues on Linear |
| [Linear Search Projects](block-integrations/linear/projects.md#linear-search-projects) | Searches for projects on Linear | | [Linear Search Projects](block-integrations/linear/projects.md#linear-search-projects) | Searches for projects on Linear |
## Hardware ## Hardware

View File

@@ -90,9 +90,9 @@ Searches for issues on Linear
### How it works ### How it works
<!-- MANUAL: how_it_works --> <!-- MANUAL: how_it_works -->
This block searches for issues in Linear using a text query. It searches across issue titles, descriptions, and other fields to find matching issues. This block searches for issues in Linear using a text query. It searches across issue titles, descriptions, and other fields to find matching issues. You can limit the number of results returned using the `max_results` parameter (default: 10, max: 100) to control token consumption and response size.
Returns a list of issues matching the search term. Optionally filter results by team name to narrow searches to specific workspaces. If a team name is provided, the block resolves it to a team ID before searching. Returns matching issues with their state, creation date, project, and assignee information. If the search or team resolution fails, an error message is returned.
<!-- END MANUAL --> <!-- END MANUAL -->
### Inputs ### Inputs
@@ -100,12 +100,14 @@ Returns a list of issues matching the search term.
| Input | Description | Type | Required | | Input | Description | Type | Required |
|-------|-------------|------|----------| |-------|-------------|------|----------|
| term | Term to search for issues | str | Yes | | term | Term to search for issues | str | Yes |
| max_results | Maximum number of results to return | int | No |
| team_name | Optional team name to filter results (e.g., 'Internal', 'Open Source') | str | No |
### Outputs ### Outputs
| Output | Description | Type | | Output | Description | Type |
|--------|-------------|------| |--------|-------------|------|
| error | Error message if the operation failed | str | | error | Error message if the search failed | str |
| issues | List of issues | List[Issue] | | issues | List of issues | List[Issue] |
### Possible use case ### Possible use case