mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Compare commits
12 Commits
fix/cookie
...
zamilmajdy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
23204b9314 | ||
|
|
1c6b829925 | ||
|
|
efa4b6d2a0 | ||
|
|
94aed94113 | ||
|
|
e701f41e66 | ||
|
|
a2d54c5fb4 | ||
|
|
568f5a449e | ||
|
|
3df6dcd26b | ||
|
|
aab40fe225 | ||
|
|
91ea322887 | ||
|
|
e183be08bd | ||
|
|
a541a3edd7 |
3
.github/workflows/platform-frontend-ci.yml
vendored
3
.github/workflows/platform-frontend-ci.yml
vendored
@@ -55,6 +55,9 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Generate API client
|
||||
run: pnpm generate:api-client
|
||||
|
||||
- name: Run tsc check
|
||||
run: pnpm type-check
|
||||
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -165,7 +165,7 @@ package-lock.json
|
||||
|
||||
# Allow for locally private items
|
||||
# private
|
||||
pri*
|
||||
pri*
|
||||
# ignore
|
||||
ig*
|
||||
.github_access_token
|
||||
@@ -177,3 +177,6 @@ autogpt_platform/backend/settings.py
|
||||
*.ign.*
|
||||
.test-contents
|
||||
.claude/settings.local.json
|
||||
|
||||
# Auto generated client
|
||||
autogpt_platform/frontend/src/api/__generated__
|
||||
|
||||
@@ -32,6 +32,7 @@ poetry run test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
@@ -77,6 +78,7 @@ npm run type-check
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
### Frontend Architecture
|
||||
- **Framework**: Next.js App Router with React Server Components
|
||||
@@ -129,4 +131,15 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
1. Components go in `/frontend/src/components/`
|
||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
||||
3. Add Storybook stories for new components
|
||||
4. Test with Playwright if user-facing
|
||||
4. Test with Playwright if user-facing
|
||||
|
||||
### Security Implementation
|
||||
|
||||
**Cache Protection Middleware:**
|
||||
- Located in `/backend/backend/server/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`/static/*`, `/_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
@@ -62,6 +62,12 @@ To run the AutoGPT Platform, follow these steps:
|
||||
pnpm i
|
||||
```
|
||||
|
||||
Generate the API client (this step is required before running the frontend):
|
||||
|
||||
```
|
||||
pnpm generate:api-client
|
||||
```
|
||||
|
||||
Then start the frontend application in development mode:
|
||||
|
||||
```
|
||||
@@ -164,3 +170,27 @@ To persist data for PostgreSQL and Redis, you can modify the `docker-compose.yml
|
||||
3. Save the file and run `docker compose up -d` to apply the changes.
|
||||
|
||||
This configuration will create named volumes for PostgreSQL and Redis, ensuring that your data persists across container restarts.
|
||||
|
||||
### API Client Generation
|
||||
|
||||
The platform includes scripts for generating and managing the API client:
|
||||
|
||||
- `pnpm fetch:openapi`: Fetches the OpenAPI specification from the backend service (requires backend to be running on port 8006)
|
||||
- `pnpm generate:api-client`: Generates the TypeScript API client from the OpenAPI specification using Orval
|
||||
- `pnpm generate:api-all`: Runs both fetch and generate commands in sequence
|
||||
|
||||
#### Manual API Client Updates
|
||||
|
||||
If you need to update the API client after making changes to the backend API:
|
||||
|
||||
1. Ensure the backend services are running:
|
||||
```
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
2. Generate the updated API client:
|
||||
```
|
||||
pnpm generate:api-all
|
||||
```
|
||||
|
||||
This will fetch the latest OpenAPI specification and regenerate the TypeScript client code.
|
||||
|
||||
@@ -2,6 +2,8 @@ import asyncio
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -12,7 +14,7 @@ from backend.data.block import (
|
||||
get_block,
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import CredentialsMetaInput, SchemaField
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -28,9 +30,9 @@ class AgentExecutorBlock(Block):
|
||||
input_schema: dict = SchemaField(description="Input schema for the graph")
|
||||
output_schema: dict = SchemaField(description="Output schema for the graph")
|
||||
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = SchemaField(default=None, hidden=True)
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = SchemaField(
|
||||
default=None, hidden=True
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
|
||||
@@ -71,7 +73,7 @@ class AgentExecutorBlock(Block):
|
||||
graph_version=input_data.graph_version,
|
||||
user_id=input_data.user_id,
|
||||
inputs=input_data.inputs,
|
||||
node_credentials_input_map=input_data.node_credentials_input_map,
|
||||
nodes_input_masks=input_data.nodes_input_masks,
|
||||
use_db_query=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -29,10 +29,10 @@ class ApolloClient:
|
||||
|
||||
async def search_people(self, query: SearchPeopleRequest) -> List[Contact]:
|
||||
"""Search for people in Apollo"""
|
||||
response = await self.requests.get(
|
||||
response = await self.requests.post(
|
||||
f"{self.API_URL}/mixed_people/search",
|
||||
headers=self._get_headers(),
|
||||
params=query.model_dump(exclude={"credentials", "max_results"}),
|
||||
json=query.model_dump(exclude={"max_results"}),
|
||||
)
|
||||
data = response.json()
|
||||
parsed_response = SearchPeopleResponse(**data)
|
||||
@@ -53,10 +53,10 @@ class ApolloClient:
|
||||
and len(parsed_response.people) > 0
|
||||
):
|
||||
query.page += 1
|
||||
response = await self.requests.get(
|
||||
response = await self.requests.post(
|
||||
f"{self.API_URL}/mixed_people/search",
|
||||
headers=self._get_headers(),
|
||||
params=query.model_dump(exclude={"credentials", "max_results"}),
|
||||
json=query.model_dump(exclude={"max_results"}),
|
||||
)
|
||||
data = response.json()
|
||||
parsed_response = SearchPeopleResponse(**data)
|
||||
@@ -69,10 +69,10 @@ class ApolloClient:
|
||||
self, query: SearchOrganizationsRequest
|
||||
) -> List[Organization]:
|
||||
"""Search for organizations in Apollo"""
|
||||
response = await self.requests.get(
|
||||
response = await self.requests.post(
|
||||
f"{self.API_URL}/mixed_companies/search",
|
||||
headers=self._get_headers(),
|
||||
params=query.model_dump(exclude={"credentials", "max_results"}),
|
||||
json=query.model_dump(exclude={"max_results"}),
|
||||
)
|
||||
data = response.json()
|
||||
parsed_response = SearchOrganizationsResponse(**data)
|
||||
@@ -93,10 +93,10 @@ class ApolloClient:
|
||||
and len(parsed_response.organizations) > 0
|
||||
):
|
||||
query.page += 1
|
||||
response = await self.requests.get(
|
||||
response = await self.requests.post(
|
||||
f"{self.API_URL}/mixed_companies/search",
|
||||
headers=self._get_headers(),
|
||||
params=query.model_dump(exclude={"credentials", "max_results"}),
|
||||
json=query.model_dump(exclude={"max_results"}),
|
||||
)
|
||||
data = response.json()
|
||||
parsed_response = SearchOrganizationsResponse(**data)
|
||||
|
||||
@@ -1,17 +1,31 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel as OriginalBaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class BaseModel(OriginalBaseModel):
|
||||
def model_dump(self, *args, exclude: set[str] | None = None, **kwargs):
|
||||
if exclude is None:
|
||||
exclude = set("credentials")
|
||||
else:
|
||||
exclude.add("credentials")
|
||||
|
||||
kwargs.setdefault("exclude_none", True)
|
||||
kwargs.setdefault("exclude_unset", True)
|
||||
kwargs.setdefault("exclude_defaults", True)
|
||||
return super().model_dump(*args, exclude=exclude, **kwargs)
|
||||
|
||||
|
||||
class PrimaryPhone(BaseModel):
|
||||
"""A primary phone in Apollo"""
|
||||
|
||||
number: str
|
||||
source: str
|
||||
sanitized_number: str
|
||||
number: str = ""
|
||||
source: str = ""
|
||||
sanitized_number: str = ""
|
||||
|
||||
|
||||
class SenorityLevels(str, Enum):
|
||||
@@ -42,88 +56,88 @@ class ContactEmailStatuses(str, Enum):
|
||||
class RuleConfigStatus(BaseModel):
|
||||
"""A rule config status in Apollo"""
|
||||
|
||||
_id: str
|
||||
created_at: str
|
||||
rule_action_config_id: str
|
||||
rule_config_id: str
|
||||
status_cd: str
|
||||
updated_at: str
|
||||
id: str
|
||||
key: str
|
||||
_id: str = ""
|
||||
created_at: str = ""
|
||||
rule_action_config_id: str = ""
|
||||
rule_config_id: str = ""
|
||||
status_cd: str = ""
|
||||
updated_at: str = ""
|
||||
id: str = ""
|
||||
key: str = ""
|
||||
|
||||
|
||||
class ContactCampaignStatus(BaseModel):
|
||||
"""A contact campaign status in Apollo"""
|
||||
|
||||
id: str
|
||||
emailer_campaign_id: str
|
||||
send_email_from_user_id: str
|
||||
inactive_reason: str
|
||||
status: str
|
||||
added_at: str
|
||||
added_by_user_id: str
|
||||
finished_at: str
|
||||
paused_at: str
|
||||
auto_unpause_at: str
|
||||
send_email_from_email_address: str
|
||||
send_email_from_email_account_id: str
|
||||
manually_set_unpause: str
|
||||
failure_reason: str
|
||||
current_step_id: str
|
||||
in_response_to_emailer_message_id: str
|
||||
cc_emails: str
|
||||
bcc_emails: str
|
||||
to_emails: str
|
||||
id: str = ""
|
||||
emailer_campaign_id: str = ""
|
||||
send_email_from_user_id: str = ""
|
||||
inactive_reason: str = ""
|
||||
status: str = ""
|
||||
added_at: str = ""
|
||||
added_by_user_id: str = ""
|
||||
finished_at: str = ""
|
||||
paused_at: str = ""
|
||||
auto_unpause_at: str = ""
|
||||
send_email_from_email_address: str = ""
|
||||
send_email_from_email_account_id: str = ""
|
||||
manually_set_unpause: str = ""
|
||||
failure_reason: str = ""
|
||||
current_step_id: str = ""
|
||||
in_response_to_emailer_message_id: str = ""
|
||||
cc_emails: str = ""
|
||||
bcc_emails: str = ""
|
||||
to_emails: str = ""
|
||||
|
||||
|
||||
class Account(BaseModel):
|
||||
"""An account in Apollo"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
website_url: str
|
||||
blog_url: str
|
||||
angellist_url: str
|
||||
linkedin_url: str
|
||||
twitter_url: str
|
||||
facebook_url: str
|
||||
primary_phone: PrimaryPhone
|
||||
id: str = ""
|
||||
name: str = ""
|
||||
website_url: str = ""
|
||||
blog_url: str = ""
|
||||
angellist_url: str = ""
|
||||
linkedin_url: str = ""
|
||||
twitter_url: str = ""
|
||||
facebook_url: str = ""
|
||||
primary_phone: PrimaryPhone = PrimaryPhone()
|
||||
languages: list[str]
|
||||
alexa_ranking: int
|
||||
phone: str
|
||||
linkedin_uid: str
|
||||
founded_year: int
|
||||
publicly_traded_symbol: str
|
||||
publicly_traded_exchange: str
|
||||
logo_url: str
|
||||
chrunchbase_url: str
|
||||
primary_domain: str
|
||||
domain: str
|
||||
team_id: str
|
||||
organization_id: str
|
||||
account_stage_id: str
|
||||
source: str
|
||||
original_source: str
|
||||
creator_id: str
|
||||
owner_id: str
|
||||
created_at: str
|
||||
phone_status: str
|
||||
hubspot_id: str
|
||||
salesforce_id: str
|
||||
crm_owner_id: str
|
||||
parent_account_id: str
|
||||
sanitized_phone: str
|
||||
alexa_ranking: int = 0
|
||||
phone: str = ""
|
||||
linkedin_uid: str = ""
|
||||
founded_year: int = 0
|
||||
publicly_traded_symbol: str = ""
|
||||
publicly_traded_exchange: str = ""
|
||||
logo_url: str = ""
|
||||
chrunchbase_url: str = ""
|
||||
primary_domain: str = ""
|
||||
domain: str = ""
|
||||
team_id: str = ""
|
||||
organization_id: str = ""
|
||||
account_stage_id: str = ""
|
||||
source: str = ""
|
||||
original_source: str = ""
|
||||
creator_id: str = ""
|
||||
owner_id: str = ""
|
||||
created_at: str = ""
|
||||
phone_status: str = ""
|
||||
hubspot_id: str = ""
|
||||
salesforce_id: str = ""
|
||||
crm_owner_id: str = ""
|
||||
parent_account_id: str = ""
|
||||
sanitized_phone: str = ""
|
||||
# no listed type on the API docs
|
||||
account_playbook_statues: list[Any]
|
||||
account_rule_config_statuses: list[RuleConfigStatus]
|
||||
existence_level: str
|
||||
label_ids: list[str]
|
||||
account_playbook_statues: list[Any] = []
|
||||
account_rule_config_statuses: list[RuleConfigStatus] = []
|
||||
existence_level: str = ""
|
||||
label_ids: list[str] = []
|
||||
typed_custom_fields: Any
|
||||
custom_field_errors: Any
|
||||
modality: str
|
||||
source_display_name: str
|
||||
salesforce_record_id: str
|
||||
crm_record_url: str
|
||||
modality: str = ""
|
||||
source_display_name: str = ""
|
||||
salesforce_record_id: str = ""
|
||||
crm_record_url: str = ""
|
||||
|
||||
|
||||
class ContactEmail(BaseModel):
|
||||
@@ -205,7 +219,7 @@ class Pagination(BaseModel):
|
||||
class DialerFlags(BaseModel):
|
||||
"""A dialer flags in Apollo"""
|
||||
|
||||
country_name: str
|
||||
country_name: str = ""
|
||||
country_enabled: bool
|
||||
high_risk_calling_enabled: bool
|
||||
potential_high_risk_number: bool
|
||||
|
||||
@@ -210,9 +210,7 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: ApolloCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
query = SearchOrganizationsRequest(
|
||||
**input_data.model_dump(exclude={"credentials"})
|
||||
)
|
||||
query = SearchOrganizationsRequest(**input_data.model_dump())
|
||||
organizations = await self.search_organizations(query, credentials)
|
||||
for organization in organizations:
|
||||
yield "organization", organization
|
||||
|
||||
@@ -107,6 +107,7 @@ class SearchPeopleBlock(Block):
|
||||
default_factory=list,
|
||||
)
|
||||
person: Contact = SchemaField(
|
||||
title="Person",
|
||||
description="Each found person, one at a time",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
@@ -387,7 +388,7 @@ class SearchPeopleBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
query = SearchPeopleRequest(**input_data.model_dump(exclude={"credentials"}))
|
||||
query = SearchPeopleRequest(**input_data.model_dump())
|
||||
people = await self.search_people(query, credentials)
|
||||
for person in people:
|
||||
yield "person", person
|
||||
|
||||
@@ -14,6 +14,12 @@ class FileStoreBlock(Block):
|
||||
file_in: MediaFileType = SchemaField(
|
||||
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
|
||||
)
|
||||
base_64: bool = SchemaField(
|
||||
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
|
||||
default=False,
|
||||
advanced=True,
|
||||
title="Produce Base64 Output",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
file_out: MediaFileType = SchemaField(
|
||||
@@ -37,12 +43,11 @@ class FileStoreBlock(Block):
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
file_path = await store_media_file(
|
||||
yield "file_out", await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.file_in,
|
||||
return_content=False,
|
||||
return_content=input_data.base_64,
|
||||
)
|
||||
yield "file_out", file_path
|
||||
|
||||
|
||||
class StoreValueBlock(Block):
|
||||
@@ -456,6 +461,11 @@ class CreateListBlock(Block):
|
||||
description="A list of values to be combined into a new list.",
|
||||
placeholder="e.g., ['Alice', 25, True]",
|
||||
)
|
||||
max_size: int | None = SchemaField(
|
||||
default=None,
|
||||
description="Maximum size of the list. If provided, the list will be yielded in chunks of this size.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
@@ -492,8 +502,9 @@ class CreateListBlock(Block):
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# The values are already validated by Pydantic schema
|
||||
yield "list", input_data.values
|
||||
max_size = input_data.max_size or len(input_data.values)
|
||||
for i in range(0, len(input_data.values), max_size):
|
||||
yield "list", input_data.values[i : i + max_size]
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create list: {str(e)}"
|
||||
|
||||
|
||||
@@ -163,7 +163,7 @@ class IfInputMatchesBlock(Block):
|
||||
},
|
||||
{
|
||||
"input": 10,
|
||||
"value": None,
|
||||
"value": "None",
|
||||
"yes_value": "Yes",
|
||||
"no_value": "No",
|
||||
},
|
||||
|
||||
@@ -13,7 +13,7 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.file import MediaFileType
|
||||
from backend.util.file import MediaFileType, store_media_file
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
@@ -108,7 +108,7 @@ class AIImageEditorBlock(Block):
|
||||
output_schema=AIImageEditorBlock.Output,
|
||||
test_input={
|
||||
"prompt": "Add a hat to the cat",
|
||||
"input_image": "https://example.com/cat.png",
|
||||
"input_image": "data:image/png;base64,MQ==",
|
||||
"aspect_ratio": AspectRatio.MATCH_INPUT_IMAGE,
|
||||
"seed": None,
|
||||
"model": FluxKontextModelName.PRO,
|
||||
@@ -128,13 +128,22 @@ class AIImageEditorBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
result = await self.run_model(
|
||||
api_key=credentials.api_key,
|
||||
model_name=input_data.model.api_name,
|
||||
prompt=input_data.prompt,
|
||||
input_image=input_data.input_image,
|
||||
input_image_b64=(
|
||||
await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.input_image,
|
||||
return_content=True,
|
||||
)
|
||||
if input_data.input_image
|
||||
else None
|
||||
),
|
||||
aspect_ratio=input_data.aspect_ratio.value,
|
||||
seed=input_data.seed,
|
||||
)
|
||||
@@ -145,14 +154,14 @@ class AIImageEditorBlock(Block):
|
||||
api_key: SecretStr,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
input_image: Optional[MediaFileType],
|
||||
input_image_b64: Optional[str],
|
||||
aspect_ratio: str,
|
||||
seed: Optional[int],
|
||||
) -> MediaFileType:
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"input_image": input_image,
|
||||
"input_image": input_image_b64,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
**({"seed": seed} if seed is not None else {}),
|
||||
}
|
||||
|
||||
@@ -413,6 +413,12 @@ class AgentFileInputBlock(AgentInputBlock):
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
base_64: bool = SchemaField(
|
||||
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
|
||||
default=False,
|
||||
advanced=True,
|
||||
title="Produce Base64 Output",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: str = SchemaField(description="File reference/path result.")
|
||||
@@ -446,12 +452,11 @@ class AgentFileInputBlock(AgentInputBlock):
|
||||
if not input_data.value:
|
||||
return
|
||||
|
||||
file_path = await store_media_file(
|
||||
yield "result", await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.value,
|
||||
return_content=False,
|
||||
return_content=input_data.base_64,
|
||||
)
|
||||
yield "result", file_path
|
||||
|
||||
|
||||
class AgentDropdownInputBlock(AgentInputBlock):
|
||||
|
||||
@@ -348,10 +348,10 @@ async def llm_call(
|
||||
# Calculate available tokens based on context window and input length
|
||||
estimated_input_tokens = estimate_token_count(prompt)
|
||||
context_window = llm_model.context_window
|
||||
model_max_output = llm_model.max_output_tokens or 4096
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
user_max = max_tokens or model_max_output
|
||||
available_tokens = max(context_window - estimated_input_tokens, 0)
|
||||
max_tokens = max(min(available_tokens, model_max_output, user_max), 0)
|
||||
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
|
||||
|
||||
if provider == "openai":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
@@ -663,6 +663,11 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
description="Expected format of the response. If provided, the response will be validated against this format. "
|
||||
"The keys should be the expected fields in the response, and the values should be the description of the field.",
|
||||
)
|
||||
list_result: bool = SchemaField(
|
||||
title="List Result",
|
||||
default=False,
|
||||
description="Whether the response should be a list of objects in the expected format.",
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
@@ -702,7 +707,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: dict[str, Any] = SchemaField(
|
||||
response: dict[str, Any] | list[dict[str, Any]] = SchemaField(
|
||||
description="The response object generated by the language model."
|
||||
)
|
||||
prompt: list = SchemaField(description="The prompt sent to the language model.")
|
||||
@@ -793,13 +798,22 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
expected_format = [
|
||||
f'"{k}": "{v}"' for k, v in input_data.expected_format.items()
|
||||
]
|
||||
format_prompt = ",\n ".join(expected_format)
|
||||
if input_data.list_result:
|
||||
format_prompt = (
|
||||
f'"results": [\n {{\n {", ".join(expected_format)}\n }}\n]'
|
||||
)
|
||||
else:
|
||||
format_prompt = "\n ".join(expected_format)
|
||||
|
||||
sys_prompt = trim_prompt(
|
||||
f"""
|
||||
|Reply strictly only in the following JSON format:
|
||||
|{{
|
||||
| {format_prompt}
|
||||
|}}
|
||||
|
|
||||
|Ensure the response is valid JSON. Do not include any additional text outside of the JSON.
|
||||
|If you cannot provide all the keys, provide an empty string for the values you cannot answer.
|
||||
"""
|
||||
)
|
||||
prompt.append({"role": "system", "content": sys_prompt})
|
||||
@@ -807,17 +821,16 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
if input_data.prompt:
|
||||
prompt.append({"role": "user", "content": input_data.prompt})
|
||||
|
||||
def parse_response(resp: str) -> tuple[dict[str, Any], str | None]:
|
||||
def validate_response(parsed: object) -> str | None:
|
||||
try:
|
||||
parsed = json.loads(resp)
|
||||
if not isinstance(parsed, dict):
|
||||
return {}, f"Expected a dictionary, but got {type(parsed)}"
|
||||
return f"Expected a dictionary, but got {type(parsed)}"
|
||||
miss_keys = set(input_data.expected_format.keys()) - set(parsed.keys())
|
||||
if miss_keys:
|
||||
return parsed, f"Missing keys: {miss_keys}"
|
||||
return parsed, None
|
||||
return f"Missing keys: {miss_keys}"
|
||||
return None
|
||||
except JSONDecodeError as e:
|
||||
return {}, f"JSON decode error: {e}"
|
||||
return f"JSON decode error: {e}"
|
||||
|
||||
logger.info(f"LLM request: {prompt}")
|
||||
retry_prompt = ""
|
||||
@@ -843,18 +856,29 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
logger.info(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
if input_data.expected_format:
|
||||
parsed_dict, parsed_error = parse_response(response_text)
|
||||
if not parsed_error:
|
||||
yield "response", {
|
||||
k: (
|
||||
json.loads(v)
|
||||
if isinstance(v, str)
|
||||
and v.startswith("[")
|
||||
and v.endswith("]")
|
||||
else (", ".join(v) if isinstance(v, list) else v)
|
||||
|
||||
response_obj = json.loads(response_text)
|
||||
|
||||
if input_data.list_result and isinstance(response_obj, dict):
|
||||
if "results" in response_obj:
|
||||
response_obj = response_obj.get("results", [])
|
||||
elif len(response_obj) == 1:
|
||||
response_obj = list(response_obj.values())
|
||||
|
||||
response_error = "\n".join(
|
||||
[
|
||||
validation_error
|
||||
for response_item in (
|
||||
response_obj
|
||||
if isinstance(response_obj, list)
|
||||
else [response_obj]
|
||||
)
|
||||
for k, v in parsed_dict.items()
|
||||
}
|
||||
if (validation_error := validate_response(response_item))
|
||||
]
|
||||
)
|
||||
|
||||
if not response_error:
|
||||
yield "response", response_obj
|
||||
yield "prompt", self.prompt
|
||||
return
|
||||
else:
|
||||
@@ -871,7 +895,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
|
||||
|And this is the error:
|
||||
|--
|
||||
|{parsed_error}
|
||||
|{response_error}
|
||||
|--
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -142,6 +142,12 @@ class SmartDecisionMakerBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
credentials: llm.AICredentials = llm.AICredentialsField()
|
||||
multiple_tool_calls: bool = SchemaField(
|
||||
title="Multiple Tool Calls",
|
||||
default=False,
|
||||
description="Whether to allow multiple tool calls in a single response.",
|
||||
advanced=True,
|
||||
)
|
||||
sys_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
default="Thinking carefully step by step decide which function to call. "
|
||||
@@ -150,7 +156,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
"matching the required jsonschema signature, no missing argument is allowed. "
|
||||
"If you have already completed the task objective, you can end the task "
|
||||
"by providing the end result of your work as a finish message. "
|
||||
"Only provide EXACTLY one function call, multiple tool calls is strictly prohibited.",
|
||||
"Function parameters that has no default value and not optional typed has to be provided. ",
|
||||
description="The system prompt to provide additional context to the model.",
|
||||
)
|
||||
conversation_history: list[dict] = SchemaField(
|
||||
@@ -273,29 +279,18 @@ class SmartDecisionMakerBlock(Block):
|
||||
"name": SmartDecisionMakerBlock.cleanup(block.name),
|
||||
"description": block.description,
|
||||
}
|
||||
|
||||
sink_block_input_schema = block.input_schema
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for link in links:
|
||||
sink_block_input_schema = block.input_schema
|
||||
description = (
|
||||
sink_block_input_schema.model_fields[link.sink_name].description
|
||||
if link.sink_name in sink_block_input_schema.model_fields
|
||||
and sink_block_input_schema.model_fields[link.sink_name].description
|
||||
else f"The {link.sink_name} of the tool"
|
||||
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
|
||||
properties[sink_name] = sink_block_input_schema.get_field_schema(
|
||||
link.sink_name
|
||||
)
|
||||
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
}
|
||||
|
||||
tool_function["parameters"] = {
|
||||
"type": "object",
|
||||
**block.input_schema.jsonschema(),
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
"additionalProperties": False,
|
||||
"strict": True,
|
||||
}
|
||||
|
||||
return {"type": "function", "function": tool_function}
|
||||
@@ -335,25 +330,27 @@ class SmartDecisionMakerBlock(Block):
|
||||
}
|
||||
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for link in links:
|
||||
sink_block_input_schema = sink_node.input_default["input_schema"]
|
||||
sink_block_properties = sink_block_input_schema.get("properties", {}).get(
|
||||
link.sink_name, {}
|
||||
)
|
||||
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
|
||||
description = (
|
||||
sink_block_input_schema["properties"][link.sink_name]["description"]
|
||||
if "description"
|
||||
in sink_block_input_schema["properties"][link.sink_name]
|
||||
sink_block_properties["description"]
|
||||
if "description" in sink_block_properties
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
|
||||
properties[sink_name] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
"default": json.dumps(sink_block_properties.get("default", None)),
|
||||
}
|
||||
|
||||
tool_function["parameters"] = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
"additionalProperties": False,
|
||||
"strict": True,
|
||||
}
|
||||
@@ -430,6 +427,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
tool_functions = self._create_function_signature(node_id)
|
||||
yield "tool_functions", json.dumps(tool_functions)
|
||||
|
||||
input_data.conversation_history = input_data.conversation_history or []
|
||||
prompt = [json.to_dict(p) for p in input_data.conversation_history if p]
|
||||
@@ -469,6 +467,10 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
|
||||
prompt.extend(tool_output)
|
||||
if input_data.multiple_tool_calls:
|
||||
input_data.sys_prompt += "\nYou can call a tool (different tools) multiple times in a single response."
|
||||
else:
|
||||
input_data.sys_prompt += "\nOnly provide EXACTLY one function call, multiple tool calls is strictly prohibited."
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
@@ -495,7 +497,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
parallel_tool_calls=False,
|
||||
parallel_tool_calls=True if input_data.multiple_tool_calls else None,
|
||||
)
|
||||
|
||||
if not response.tool_calls:
|
||||
@@ -506,8 +508,31 @@ class SmartDecisionMakerBlock(Block):
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
for arg_name, arg_value in tool_args.items():
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", arg_value
|
||||
# Find the tool definition to get the expected arguments
|
||||
tool_def = next(
|
||||
(
|
||||
tool
|
||||
for tool in tool_functions
|
||||
if tool["function"]["name"] == tool_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if (
|
||||
tool_def
|
||||
and "function" in tool_def
|
||||
and "parameters" in tool_def["function"]
|
||||
):
|
||||
expected_args = tool_def["function"]["parameters"].get("properties", {})
|
||||
else:
|
||||
expected_args = tool_args.keys()
|
||||
|
||||
# Yield provided arguments and None for missing ones
|
||||
for arg_name in expected_args:
|
||||
if arg_name in tool_args:
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", tool_args[arg_name]
|
||||
else:
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", None
|
||||
|
||||
response.prompt.append(response.raw_response)
|
||||
yield "conversations", response.prompt
|
||||
|
||||
@@ -118,7 +118,10 @@ class BlockSchema(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def validate_data(cls, data: BlockInput) -> str | None:
|
||||
return json.validate_with_jsonschema(schema=cls.jsonschema(), data=data)
|
||||
return json.validate_with_jsonschema(
|
||||
schema=cls.jsonschema(),
|
||||
data={k: v for k, v in data.items() if v is not None},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
@@ -471,7 +474,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
)
|
||||
|
||||
async for output_name, output_data in self.run(
|
||||
self.input_schema(**input_data), **kwargs
|
||||
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
|
||||
**kwargs,
|
||||
):
|
||||
if output_name == "error":
|
||||
raise RuntimeError(output_data)
|
||||
@@ -481,6 +485,22 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
raise ValueError(f"Block produced an invalid output data: {error}")
|
||||
yield output_name, output_data
|
||||
|
||||
def is_triggered_by_event_type(
|
||||
self, trigger_config: dict[str, Any], event_type: str
|
||||
) -> bool:
|
||||
if not self.webhook_config:
|
||||
raise TypeError("This method can't be used on non-trigger blocks")
|
||||
if not self.webhook_config.event_filter_input:
|
||||
return True
|
||||
event_filter = trigger_config.get(self.webhook_config.event_filter_input)
|
||||
if not event_filter:
|
||||
raise ValueError("Event filter is not configured on trigger")
|
||||
return event_type in [
|
||||
self.webhook_config.event_format.format(event=k)
|
||||
for k in event_filter
|
||||
if event_filter[k] is True
|
||||
]
|
||||
|
||||
|
||||
# ======================= Block Helper Functions ======================= #
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@ from typing import Type
|
||||
|
||||
from backend.blocks.ai_music_generator import AIMusicGeneratorBlock
|
||||
from backend.blocks.ai_shortform_video_block import AIShortformVideoCreatorBlock
|
||||
from backend.blocks.apollo.organization import SearchOrganizationsBlock
|
||||
from backend.blocks.apollo.people import SearchPeopleBlock
|
||||
from backend.blocks.flux_kontext import AIImageEditorBlock, FluxKontextModelName
|
||||
from backend.blocks.ideogram import IdeogramModelBlock
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
@@ -24,6 +26,7 @@ from backend.data.cost import BlockCost, BlockCostType
|
||||
from backend.integrations.credentials_store import (
|
||||
aiml_api_credentials,
|
||||
anthropic_credentials,
|
||||
apollo_credentials,
|
||||
did_credentials,
|
||||
groq_credentials,
|
||||
ideogram_credentials,
|
||||
@@ -345,4 +348,28 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
)
|
||||
],
|
||||
SmartDecisionMakerBlock: LLM_COST,
|
||||
SearchOrganizationsBlock: [
|
||||
BlockCost(
|
||||
cost_amount=2,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": apollo_credentials.id,
|
||||
"provider": apollo_credentials.provider,
|
||||
"type": apollo_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
SearchPeopleBlock: [
|
||||
BlockCost(
|
||||
cost_amount=2,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": apollo_credentials.id,
|
||||
"provider": apollo_credentials.provider,
|
||||
"type": apollo_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ from prisma.types import (
|
||||
AgentNodeExecutionUpdateInput,
|
||||
AgentNodeExecutionWhereInput,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, JsonValue
|
||||
from pydantic.fields import Field
|
||||
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
@@ -54,7 +54,7 @@ from .includes import (
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
graph_execution_include,
|
||||
)
|
||||
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
|
||||
from .model import GraphExecutionStats, NodeExecutionStats
|
||||
from .queue import AsyncRedisEventBus, RedisEventBus
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -271,7 +271,7 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
graph_id=self.graph_id,
|
||||
graph_version=self.graph_version or 0,
|
||||
graph_exec_id=self.id,
|
||||
node_credentials_input_map={}, # FIXME
|
||||
nodes_input_masks={}, # FIXME: store credentials on AgentGraphExecution
|
||||
)
|
||||
|
||||
|
||||
@@ -556,18 +556,18 @@ async def upsert_execution_input(
|
||||
async def upsert_execution_output(
|
||||
node_exec_id: str,
|
||||
output_name: str,
|
||||
output_data: Any,
|
||||
output_data: Any | None,
|
||||
) -> None:
|
||||
"""
|
||||
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
|
||||
"""
|
||||
await AgentNodeExecutionInputOutput.prisma().create(
|
||||
data=AgentNodeExecutionInputOutputCreateInput(
|
||||
name=output_name,
|
||||
data=Json(output_data),
|
||||
referencedByOutputExecId=node_exec_id,
|
||||
)
|
||||
data = AgentNodeExecutionInputOutputCreateInput(
|
||||
name=output_name,
|
||||
referencedByOutputExecId=node_exec_id,
|
||||
)
|
||||
if output_data is not None:
|
||||
data["data"] = Json(output_data)
|
||||
await AgentNodeExecutionInputOutput.prisma().create(data=data)
|
||||
|
||||
|
||||
async def update_graph_execution_start_time(
|
||||
@@ -783,7 +783,7 @@ class GraphExecutionEntry(BaseModel):
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]]
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]]
|
||||
|
||||
|
||||
class NodeExecutionEntry(BaseModel):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any, Literal, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
import prisma
|
||||
from prisma import Json
|
||||
@@ -14,7 +14,7 @@ from prisma.types import (
|
||||
AgentNodeLinkCreateInput,
|
||||
StoreListingVersionWhereInput,
|
||||
)
|
||||
from pydantic import create_model
|
||||
from pydantic import JsonValue, create_model
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
@@ -32,7 +32,9 @@ from backend.util import type as type_utils
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
from .db import BaseDbModel, transaction
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
||||
from .integrations import Webhook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .integrations import Webhook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -81,10 +83,12 @@ class NodeModel(Node):
|
||||
graph_version: int
|
||||
|
||||
webhook_id: Optional[str] = None
|
||||
webhook: Optional[Webhook] = None
|
||||
webhook: Optional["Webhook"] = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(node: AgentNode, for_export: bool = False) -> "NodeModel":
|
||||
from .integrations import Webhook
|
||||
|
||||
obj = NodeModel(
|
||||
id=node.id,
|
||||
block_id=node.agentBlockId,
|
||||
@@ -102,19 +106,7 @@ class NodeModel(Node):
|
||||
return obj
|
||||
|
||||
def is_triggered_by_event_type(self, event_type: str) -> bool:
|
||||
block = self.block
|
||||
if not block.webhook_config:
|
||||
raise TypeError("This method can't be used on non-webhook blocks")
|
||||
if not block.webhook_config.event_filter_input:
|
||||
return True
|
||||
event_filter = self.input_default.get(block.webhook_config.event_filter_input)
|
||||
if not event_filter:
|
||||
raise ValueError(f"Event filter is not configured on node #{self.id}")
|
||||
return event_type in [
|
||||
block.webhook_config.event_format.format(event=k)
|
||||
for k in event_filter
|
||||
if event_filter[k] is True
|
||||
]
|
||||
return self.block.is_triggered_by_event_type(self.input_default, event_type)
|
||||
|
||||
def stripped_for_export(self) -> "NodeModel":
|
||||
"""
|
||||
@@ -162,10 +154,6 @@ class NodeModel(Node):
|
||||
return result
|
||||
|
||||
|
||||
# Fix 2-way reference Node <-> Webhook
|
||||
Webhook.model_rebuild()
|
||||
|
||||
|
||||
class BaseGraph(BaseDbModel):
|
||||
version: int = 1
|
||||
is_active: bool = True
|
||||
@@ -406,13 +394,21 @@ class GraphModel(Graph):
|
||||
if (graph_id := node.input_default.get("graph_id")) in graph_id_map:
|
||||
node.input_default["graph_id"] = graph_id_map[graph_id]
|
||||
|
||||
def validate_graph(self, for_run: bool = False):
|
||||
self._validate_graph(self, for_run)
|
||||
def validate_graph(
|
||||
self,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
):
|
||||
self._validate_graph(self, for_run, nodes_input_masks)
|
||||
for sub_graph in self.sub_graphs:
|
||||
self._validate_graph(sub_graph, for_run)
|
||||
self._validate_graph(sub_graph, for_run, nodes_input_masks)
|
||||
|
||||
@staticmethod
|
||||
def _validate_graph(graph: BaseGraph, for_run: bool = False):
|
||||
def _validate_graph(
|
||||
graph: BaseGraph,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
):
|
||||
def is_tool_pin(name: str) -> bool:
|
||||
return name.startswith("tools_^_")
|
||||
|
||||
@@ -439,20 +435,18 @@ class GraphModel(Graph):
|
||||
if (block := nodes_block.get(node.id)) is None:
|
||||
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
||||
|
||||
node_input_mask = (
|
||||
nodes_input_masks.get(node.id, {}) if nodes_input_masks else {}
|
||||
)
|
||||
provided_inputs = set(
|
||||
[sanitize(name) for name in node.input_default]
|
||||
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
|
||||
+ ([name for name in node_input_mask] if node_input_mask else [])
|
||||
)
|
||||
InputSchema = block.input_schema
|
||||
for name in (required_fields := InputSchema.get_required_fields()):
|
||||
if (
|
||||
name not in provided_inputs
|
||||
# Webhook payload is passed in by ExecutionManager
|
||||
and not (
|
||||
name == "payload"
|
||||
and block.block_type
|
||||
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
)
|
||||
# Checking availability of credentials is done by ExecutionManager
|
||||
and name not in InputSchema.get_credentials_fields()
|
||||
# Validate only I/O nodes, or validate everything when executing
|
||||
@@ -485,10 +479,18 @@ class GraphModel(Graph):
|
||||
|
||||
def has_value(node: Node, name: str):
|
||||
return (
|
||||
name in node.input_default
|
||||
and node.input_default[name] is not None
|
||||
and str(node.input_default[name]).strip() != ""
|
||||
) or (name in input_fields and input_fields[name].default is not None)
|
||||
(
|
||||
name in node.input_default
|
||||
and node.input_default[name] is not None
|
||||
and str(node.input_default[name]).strip() != ""
|
||||
)
|
||||
or (name in input_fields and input_fields[name].default is not None)
|
||||
or (
|
||||
name in node_input_mask
|
||||
and node_input_mask[name] is not None
|
||||
and str(node_input_mask[name]).strip() != ""
|
||||
)
|
||||
)
|
||||
|
||||
# Validate dependencies between fields
|
||||
for field_name in input_fields.keys():
|
||||
@@ -574,7 +576,7 @@ class GraphModel(Graph):
|
||||
graph: AgentGraph,
|
||||
for_export: bool = False,
|
||||
sub_graphs: list[AgentGraph] | None = None,
|
||||
):
|
||||
) -> "GraphModel":
|
||||
return GraphModel(
|
||||
id=graph.id,
|
||||
user_id=graph.userId if not for_export else "",
|
||||
@@ -603,6 +605,7 @@ class GraphModel(Graph):
|
||||
|
||||
|
||||
async def get_node(node_id: str) -> NodeModel:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
node = await AgentNode.prisma().find_unique_or_raise(
|
||||
where={"id": node_id},
|
||||
include=AGENT_NODE_INCLUDE,
|
||||
@@ -611,6 +614,7 @@ async def get_node(node_id: str) -> NodeModel:
|
||||
|
||||
|
||||
async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
node = await AgentNode.prisma().update(
|
||||
where={"id": node_id},
|
||||
data=(
|
||||
|
||||
@@ -60,7 +60,8 @@ def graph_execution_include(
|
||||
|
||||
|
||||
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE}
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE},
|
||||
"AgentPresets": {"include": {"InputPresets": True}},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,21 +1,25 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Optional
|
||||
from typing import AsyncGenerator, Literal, Optional, overload
|
||||
|
||||
from prisma import Json
|
||||
from prisma.models import IntegrationWebhook
|
||||
from prisma.types import IntegrationWebhookCreateInput
|
||||
from prisma.types import (
|
||||
IntegrationWebhookCreateInput,
|
||||
IntegrationWebhookUpdateInput,
|
||||
IntegrationWebhookWhereInput,
|
||||
Serializable,
|
||||
)
|
||||
from pydantic import Field, computed_field
|
||||
|
||||
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
|
||||
from backend.data.queue import AsyncRedisEventBus
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.utils import webhook_ingress_url
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .db import BaseDbModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .graph import NodeModel
|
||||
from .graph import NodeModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -32,8 +36,6 @@ class Webhook(BaseDbModel):
|
||||
|
||||
provider_webhook_id: str
|
||||
|
||||
attached_nodes: Optional[list["NodeModel"]] = None
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
@@ -41,8 +43,6 @@ class Webhook(BaseDbModel):
|
||||
|
||||
@staticmethod
|
||||
def from_db(webhook: IntegrationWebhook):
|
||||
from .graph import NodeModel
|
||||
|
||||
return Webhook(
|
||||
id=webhook.id,
|
||||
user_id=webhook.userId,
|
||||
@@ -54,14 +54,33 @@ class Webhook(BaseDbModel):
|
||||
config=dict(webhook.config),
|
||||
secret=webhook.secret,
|
||||
provider_webhook_id=webhook.providerWebhookId,
|
||||
attached_nodes=(
|
||||
[NodeModel.from_db(node) for node in webhook.AgentNodes]
|
||||
if webhook.AgentNodes is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class WebhookWithRelations(Webhook):
|
||||
triggered_nodes: list[NodeModel]
|
||||
triggered_presets: list[LibraryAgentPreset]
|
||||
|
||||
@staticmethod
|
||||
def from_db(webhook: IntegrationWebhook):
|
||||
if webhook.AgentNodes is None or webhook.AgentPresets is None:
|
||||
raise ValueError(
|
||||
"AgentNodes and AgentPresets must be included in "
|
||||
"IntegrationWebhook query with relations"
|
||||
)
|
||||
return WebhookWithRelations(
|
||||
**Webhook.from_db(webhook).model_dump(),
|
||||
triggered_nodes=[NodeModel.from_db(node) for node in webhook.AgentNodes],
|
||||
triggered_presets=[
|
||||
LibraryAgentPreset.from_db(preset) for preset in webhook.AgentPresets
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# Fix Webhook <- NodeModel relations
|
||||
NodeModel.model_rebuild()
|
||||
|
||||
|
||||
# --------------------- CRUD functions --------------------- #
|
||||
|
||||
|
||||
@@ -83,7 +102,19 @@ async def create_webhook(webhook: Webhook) -> Webhook:
|
||||
return Webhook.from_db(created_webhook)
|
||||
|
||||
|
||||
async def get_webhook(webhook_id: str) -> Webhook:
|
||||
@overload
|
||||
async def get_webhook(
|
||||
webhook_id: str, *, include_relations: Literal[True]
|
||||
) -> WebhookWithRelations: ...
|
||||
@overload
|
||||
async def get_webhook(
|
||||
webhook_id: str, *, include_relations: Literal[False] = False
|
||||
) -> Webhook: ...
|
||||
|
||||
|
||||
async def get_webhook(
|
||||
webhook_id: str, *, include_relations: bool = False
|
||||
) -> Webhook | WebhookWithRelations:
|
||||
"""
|
||||
⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints.
|
||||
|
||||
@@ -92,73 +123,113 @@ async def get_webhook(webhook_id: str) -> Webhook:
|
||||
"""
|
||||
webhook = await IntegrationWebhook.prisma().find_unique(
|
||||
where={"id": webhook_id},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE,
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE if include_relations else None,
|
||||
)
|
||||
if not webhook:
|
||||
raise NotFoundError(f"Webhook #{webhook_id} not found")
|
||||
return Webhook.from_db(webhook)
|
||||
return (WebhookWithRelations if include_relations else Webhook).from_db(webhook)
|
||||
|
||||
|
||||
async def get_all_webhooks_by_creds(credentials_id: str) -> list[Webhook]:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
@overload
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str, credentials_id: str, *, include_relations: Literal[True]
|
||||
) -> list[WebhookWithRelations]: ...
|
||||
@overload
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str, credentials_id: str, *, include_relations: Literal[False] = False
|
||||
) -> list[Webhook]: ...
|
||||
|
||||
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str, credentials_id: str, *, include_relations: bool = False
|
||||
) -> list[Webhook] | list[WebhookWithRelations]:
|
||||
if not credentials_id:
|
||||
raise ValueError("credentials_id must not be empty")
|
||||
webhooks = await IntegrationWebhook.prisma().find_many(
|
||||
where={"credentialsId": credentials_id},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE,
|
||||
where={"userId": user_id, "credentialsId": credentials_id},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE if include_relations else None,
|
||||
)
|
||||
return [Webhook.from_db(webhook) for webhook in webhooks]
|
||||
return [
|
||||
(WebhookWithRelations if include_relations else Webhook).from_db(webhook)
|
||||
for webhook in webhooks
|
||||
]
|
||||
|
||||
|
||||
async def find_webhook_by_credentials_and_props(
|
||||
credentials_id: str, webhook_type: str, resource: str, events: list[str]
|
||||
user_id: str,
|
||||
credentials_id: str,
|
||||
webhook_type: str,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
) -> Webhook | None:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
webhook = await IntegrationWebhook.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"credentialsId": credentials_id,
|
||||
"webhookType": webhook_type,
|
||||
"resource": resource,
|
||||
"events": {"has_every": events},
|
||||
},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE,
|
||||
)
|
||||
return Webhook.from_db(webhook) if webhook else None
|
||||
|
||||
|
||||
async def find_webhook_by_graph_and_props(
|
||||
graph_id: str, provider: str, webhook_type: str, events: list[str]
|
||||
user_id: str,
|
||||
provider: str,
|
||||
webhook_type: str,
|
||||
graph_id: Optional[str] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
) -> Webhook | None:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
"""Either `graph_id` or `preset_id` must be provided."""
|
||||
where_clause: IntegrationWebhookWhereInput = {
|
||||
"userId": user_id,
|
||||
"provider": provider,
|
||||
"webhookType": webhook_type,
|
||||
}
|
||||
|
||||
if preset_id:
|
||||
where_clause["AgentPresets"] = {"some": {"id": preset_id}}
|
||||
elif graph_id:
|
||||
where_clause["AgentNodes"] = {"some": {"agentGraphId": graph_id}}
|
||||
else:
|
||||
raise ValueError("Either graph_id or preset_id must be provided")
|
||||
|
||||
webhook = await IntegrationWebhook.prisma().find_first(
|
||||
where={
|
||||
"provider": provider,
|
||||
"webhookType": webhook_type,
|
||||
"events": {"has_every": events},
|
||||
"AgentNodes": {"some": {"agentGraphId": graph_id}},
|
||||
},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE,
|
||||
where=where_clause,
|
||||
)
|
||||
return Webhook.from_db(webhook) if webhook else None
|
||||
|
||||
|
||||
async def update_webhook_config(webhook_id: str, updated_config: dict) -> Webhook:
|
||||
async def update_webhook(
|
||||
webhook_id: str,
|
||||
config: Optional[dict[str, Serializable]] = None,
|
||||
events: Optional[list[str]] = None,
|
||||
) -> Webhook:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
data: IntegrationWebhookUpdateInput = {}
|
||||
if config is not None:
|
||||
data["config"] = Json(config)
|
||||
if events is not None:
|
||||
data["events"] = events
|
||||
if not data:
|
||||
raise ValueError("Empty update query")
|
||||
|
||||
_updated_webhook = await IntegrationWebhook.prisma().update(
|
||||
where={"id": webhook_id},
|
||||
data={"config": Json(updated_config)},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE,
|
||||
data=data,
|
||||
)
|
||||
if _updated_webhook is None:
|
||||
raise ValueError(f"Webhook #{webhook_id} not found")
|
||||
raise NotFoundError(f"Webhook #{webhook_id} not found")
|
||||
return Webhook.from_db(_updated_webhook)
|
||||
|
||||
|
||||
async def delete_webhook(webhook_id: str) -> None:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
deleted = await IntegrationWebhook.prisma().delete(where={"id": webhook_id})
|
||||
if not deleted:
|
||||
raise ValueError(f"Webhook #{webhook_id} not found")
|
||||
async def delete_webhook(user_id: str, webhook_id: str) -> None:
|
||||
deleted = await IntegrationWebhook.prisma().delete_many(
|
||||
where={"id": webhook_id, "userId": user_id}
|
||||
)
|
||||
if deleted < 1:
|
||||
raise NotFoundError(f"Webhook #{webhook_id} not found")
|
||||
|
||||
|
||||
# --------------------- WEBHOOK EVENTS --------------------- #
|
||||
|
||||
@@ -12,14 +12,11 @@ from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from pydantic import JsonValue
|
||||
from redis.asyncio.lock import Lock as RedisLock
|
||||
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
GraphExecutionStats,
|
||||
NodeExecutionStats,
|
||||
)
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
LowBalanceData,
|
||||
@@ -138,9 +135,7 @@ async def execute_node(
|
||||
creds_manager: IntegrationCredentialsManager,
|
||||
data: NodeExecutionEntry,
|
||||
execution_stats: NodeExecutionStats | None = None,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Execute a node in the graph. This will trigger a block execution on a node,
|
||||
@@ -183,8 +178,8 @@ async def execute_node(
|
||||
if isinstance(node_block, AgentExecutorBlock):
|
||||
_input_data = AgentExecutorBlock.Input(**node.input_default)
|
||||
_input_data.inputs = input_data
|
||||
if node_credentials_input_map:
|
||||
_input_data.node_credentials_input_map = node_credentials_input_map
|
||||
if nodes_input_masks:
|
||||
_input_data.nodes_input_masks = nodes_input_masks
|
||||
input_data = _input_data.model_dump()
|
||||
data.inputs = input_data
|
||||
|
||||
@@ -255,7 +250,7 @@ async def _enqueue_next_nodes(
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
log_metadata: LogMetadata,
|
||||
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]],
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||
) -> list[NodeExecutionEntry]:
|
||||
async def add_enqueued_execution(
|
||||
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
|
||||
@@ -289,8 +284,9 @@ async def _enqueue_next_nodes(
|
||||
next_input_name = node_link.sink_name
|
||||
next_node_id = node_link.sink_id
|
||||
|
||||
output_name, _ = output
|
||||
next_data = parse_execution_output(output, next_output_name)
|
||||
if next_data is None:
|
||||
if next_data is None and output_name != next_output_name:
|
||||
return enqueued_executions
|
||||
next_node = await db_client.get_node(next_node_id)
|
||||
|
||||
@@ -325,14 +321,12 @@ async def _enqueue_next_nodes(
|
||||
for name in static_link_names:
|
||||
next_node_input[name] = latest_execution.input_data.get(name)
|
||||
|
||||
# Apply node credentials overrides
|
||||
node_credentials = None
|
||||
if node_credentials_input_map and (
|
||||
node_credentials := node_credentials_input_map.get(next_node.id)
|
||||
# Apply node input overrides
|
||||
node_input_mask = None
|
||||
if nodes_input_masks and (
|
||||
node_input_mask := nodes_input_masks.get(next_node.id)
|
||||
):
|
||||
next_node_input.update(
|
||||
{k: v.model_dump() for k, v in node_credentials.items()}
|
||||
)
|
||||
next_node_input.update(node_input_mask)
|
||||
|
||||
# Validate the input data for the next node.
|
||||
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
|
||||
@@ -376,11 +370,9 @@ async def _enqueue_next_nodes(
|
||||
for input_name in static_link_names:
|
||||
idata[input_name] = next_node_input[input_name]
|
||||
|
||||
# Apply node credentials overrides
|
||||
if node_credentials:
|
||||
idata.update(
|
||||
{k: v.model_dump() for k, v in node_credentials.items()}
|
||||
)
|
||||
# Apply node input overrides
|
||||
if node_input_mask:
|
||||
idata.update(node_input_mask)
|
||||
|
||||
idata, msg = validate_exec(next_node, idata)
|
||||
suffix = f"{next_output_name}>{next_input_name}~{ineid}:{msg}"
|
||||
@@ -434,9 +426,7 @@ class Executor:
|
||||
cls,
|
||||
node_exec: NodeExecutionEntry,
|
||||
node_exec_progress: NodeExecutionProgress,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> NodeExecutionStats:
|
||||
log_metadata = LogMetadata(
|
||||
user_id=node_exec.user_id,
|
||||
@@ -457,7 +447,7 @@ class Executor:
|
||||
db_client=db_client,
|
||||
log_metadata=log_metadata,
|
||||
stats=execution_stats,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
execution_stats.walltime = timing_info.wall_time
|
||||
execution_stats.cputime = timing_info.cpu_time
|
||||
@@ -480,9 +470,7 @@ class Executor:
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
log_metadata: LogMetadata,
|
||||
stats: NodeExecutionStats | None = None,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
):
|
||||
try:
|
||||
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
|
||||
@@ -497,7 +485,7 @@ class Executor:
|
||||
creds_manager=cls.creds_manager,
|
||||
data=node_exec,
|
||||
execution_stats=stats,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
):
|
||||
node_exec_progress.add_output(
|
||||
ExecutionOutputEntry(
|
||||
@@ -778,24 +766,19 @@ class Executor:
|
||||
)
|
||||
raise
|
||||
|
||||
# Add credential overrides -----------------------------
|
||||
# Add input overrides -----------------------------
|
||||
node_id = queued_node_exec.node_id
|
||||
if (node_creds_map := graph_exec.node_credentials_input_map) and (
|
||||
node_field_creds_map := node_creds_map.get(node_id)
|
||||
if (nodes_input_masks := graph_exec.nodes_input_masks) and (
|
||||
node_input_mask := nodes_input_masks.get(node_id)
|
||||
):
|
||||
queued_node_exec.inputs.update(
|
||||
{
|
||||
field_name: creds_meta.model_dump()
|
||||
for field_name, creds_meta in node_field_creds_map.items()
|
||||
}
|
||||
)
|
||||
queued_node_exec.inputs.update(node_input_mask)
|
||||
|
||||
# Kick off async node execution -------------------------
|
||||
node_execution_task = asyncio.run_coroutine_threadsafe(
|
||||
cls.on_node_execution(
|
||||
node_exec=queued_node_exec,
|
||||
node_exec_progress=running_node_execution[node_id],
|
||||
node_credentials_input_map=node_creds_map,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
),
|
||||
cls.node_execution_loop,
|
||||
)
|
||||
@@ -839,7 +822,7 @@ class Executor:
|
||||
node_id=node_id,
|
||||
graph_exec=graph_exec,
|
||||
log_metadata=log_metadata,
|
||||
node_creds_map=node_creds_map,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
execution_queue=execution_queue,
|
||||
),
|
||||
cls.node_evaluation_loop,
|
||||
@@ -909,7 +892,7 @@ class Executor:
|
||||
node_id: str,
|
||||
graph_exec: GraphExecutionEntry,
|
||||
log_metadata: LogMetadata,
|
||||
node_creds_map: Optional[dict[str, dict[str, CredentialsMetaInput]]],
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||
execution_queue: ExecutionQueue[NodeExecutionEntry],
|
||||
) -> None:
|
||||
"""Process a node's output, update its status, and enqueue next nodes.
|
||||
@@ -919,7 +902,7 @@ class Executor:
|
||||
node_id: The ID of the node that produced the output
|
||||
graph_exec: The graph execution entry
|
||||
log_metadata: Logger metadata for consistent logging
|
||||
node_creds_map: Optional map of node credentials
|
||||
nodes_input_masks: Optional map of node input overrides
|
||||
execution_queue: Queue to add next executions to
|
||||
"""
|
||||
db_client = get_db_async_client()
|
||||
@@ -943,7 +926,7 @@ class Executor:
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
graph_id=graph_exec.graph_id,
|
||||
log_metadata=log_metadata,
|
||||
node_credentials_input_map=node_creds_map,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
):
|
||||
execution_queue.add(next_execution)
|
||||
except Exception as e:
|
||||
|
||||
@@ -5,7 +5,7 @@ from concurrent.futures import Future
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, JsonValue
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
@@ -402,12 +402,6 @@ def validate_exec(
|
||||
return None, f"Block for {node.block_id} not found."
|
||||
schema = node_block.input_schema
|
||||
|
||||
# Convert non-matching data types to the expected input schema.
|
||||
for name, data_type in schema.__annotations__.items():
|
||||
value = data.get(name)
|
||||
if (value is not None) and (type(value) is not data_type):
|
||||
data[name] = convert(value, data_type)
|
||||
|
||||
# Input data (without default values) should contain all required fields.
|
||||
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
|
||||
if missing_links := schema.get_missing_links(data, node.input_links):
|
||||
@@ -419,6 +413,12 @@ def validate_exec(
|
||||
if resolve_input:
|
||||
data = merge_execution_input(data)
|
||||
|
||||
# Convert non-matching data types to the expected input schema.
|
||||
for name, data_type in schema.__annotations__.items():
|
||||
value = data.get(name)
|
||||
if (value is not None) and (type(value) is not data_type):
|
||||
data[name] = convert(value, data_type)
|
||||
|
||||
# Input data post-merge should contain all required fields from the schema.
|
||||
if missing_input := schema.get_missing_input(data):
|
||||
return None, f"{error_prefix} missing input {missing_input}"
|
||||
@@ -435,9 +435,7 @@ def validate_exec(
|
||||
async def _validate_node_input_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
):
|
||||
"""Checks all credentials for all nodes of the graph"""
|
||||
|
||||
@@ -453,11 +451,13 @@ async def _validate_node_input_credentials(
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
if (
|
||||
node_credentials_input_map
|
||||
and (node_credentials_inputs := node_credentials_input_map.get(node.id))
|
||||
and field_name in node_credentials_inputs
|
||||
nodes_input_masks
|
||||
and (node_input_mask := nodes_input_masks.get(node.id))
|
||||
and field_name in node_input_mask
|
||||
):
|
||||
credentials_meta = node_credentials_input_map[node.id][field_name]
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node_input_mask[field_name]
|
||||
)
|
||||
elif field_name in node.input_default:
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
@@ -496,7 +496,7 @@ async def _validate_node_input_credentials(
|
||||
def make_node_credentials_input_map(
|
||||
graph: GraphModel,
|
||||
graph_credentials_input: dict[str, CredentialsMetaInput],
|
||||
) -> dict[str, dict[str, CredentialsMetaInput]]:
|
||||
) -> dict[str, dict[str, JsonValue]]:
|
||||
"""
|
||||
Maps credentials for an execution to the correct nodes.
|
||||
|
||||
@@ -505,9 +505,9 @@ def make_node_credentials_input_map(
|
||||
graph_credentials_input: A (graph_input_name, credentials_meta) map.
|
||||
|
||||
Returns:
|
||||
dict[node_id, dict[field_name, CredentialsMetaInput]]: Node credentials input map.
|
||||
dict[node_id, dict[field_name, CredentialsMetaRaw]]: Node credentials input map.
|
||||
"""
|
||||
result: dict[str, dict[str, CredentialsMetaInput]] = {}
|
||||
result: dict[str, dict[str, JsonValue]] = {}
|
||||
|
||||
# Get aggregated credentials fields for the graph
|
||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
||||
@@ -521,7 +521,9 @@ def make_node_credentials_input_map(
|
||||
for node_id, node_field_name in compatible_node_fields:
|
||||
if node_id not in result:
|
||||
result[node_id] = {}
|
||||
result[node_id][node_field_name] = graph_credentials_input[graph_input_name]
|
||||
result[node_id][node_field_name] = graph_credentials_input[
|
||||
graph_input_name
|
||||
].model_dump(exclude_none=True)
|
||||
|
||||
return result
|
||||
|
||||
@@ -530,9 +532,7 @@ async def construct_node_execution_input(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> list[tuple[str, BlockInput]]:
|
||||
"""
|
||||
Validates and prepares the input data for executing a graph.
|
||||
@@ -550,8 +550,8 @@ async def construct_node_execution_input(
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
|
||||
the corresponding input data for that node.
|
||||
"""
|
||||
graph.validate_graph(for_run=True)
|
||||
await _validate_node_input_credentials(graph, user_id, node_credentials_input_map)
|
||||
graph.validate_graph(for_run=True, nodes_input_masks=nodes_input_masks)
|
||||
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
@@ -568,23 +568,9 @@ async def construct_node_execution_input(
|
||||
if input_name and input_name in graph_inputs:
|
||||
input_data = {"value": graph_inputs[input_name]}
|
||||
|
||||
# Extract webhook payload, and assign it to the input pin
|
||||
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
|
||||
if (
|
||||
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
and node.webhook_id
|
||||
):
|
||||
if webhook_payload_key not in graph_inputs:
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} webhook payload is missing"
|
||||
)
|
||||
input_data = {"payload": graph_inputs[webhook_payload_key]}
|
||||
|
||||
# Apply node credentials overrides
|
||||
if node_credentials_input_map and (
|
||||
node_credentials := node_credentials_input_map.get(node.id)
|
||||
):
|
||||
input_data.update({k: v.model_dump() for k, v in node_credentials.items()})
|
||||
# Apply node input overrides
|
||||
if nodes_input_masks and (node_input_mask := nodes_input_masks.get(node.id)):
|
||||
input_data.update(node_input_mask)
|
||||
|
||||
input_data, error = validate_exec(node, input_data)
|
||||
if input_data is None:
|
||||
@@ -600,6 +586,20 @@ async def construct_node_execution_input(
|
||||
return nodes_input
|
||||
|
||||
|
||||
def _merge_nodes_input_masks(
|
||||
overrides_map_1: dict[str, dict[str, JsonValue]],
|
||||
overrides_map_2: dict[str, dict[str, JsonValue]],
|
||||
) -> dict[str, dict[str, JsonValue]]:
|
||||
"""Perform a per-node merge of input overrides"""
|
||||
result = overrides_map_1.copy()
|
||||
for node_id, overrides2 in overrides_map_2.items():
|
||||
if node_id in result:
|
||||
result[node_id] = {**result[node_id], **overrides2}
|
||||
else:
|
||||
result[node_id] = overrides2
|
||||
return result
|
||||
|
||||
|
||||
# ============ Execution Queue Helpers ============ #
|
||||
|
||||
|
||||
@@ -730,13 +730,11 @@ async def stop_graph_execution(
|
||||
async def add_graph_execution(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
inputs: BlockInput,
|
||||
inputs: Optional[BlockInput] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
use_db_query: bool = True,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
@@ -750,7 +748,7 @@ async def add_graph_execution(
|
||||
graph_version: The version of the graph to execute.
|
||||
graph_credentials_inputs: Credentials inputs to use in the execution.
|
||||
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
|
||||
node_credentials_input_map: Credentials inputs to use in the execution, mapped to specific nodes.
|
||||
nodes_input_masks: Node inputs to use in the execution.
|
||||
Returns:
|
||||
GraphExecutionEntry: The entry for the graph execution.
|
||||
Raises:
|
||||
@@ -774,10 +772,19 @@ async def add_graph_execution(
|
||||
if not graph:
|
||||
raise NotFoundError(f"Graph #{graph_id} not found.")
|
||||
|
||||
node_credentials_input_map = node_credentials_input_map or (
|
||||
make_node_credentials_input_map(graph, graph_credentials_inputs)
|
||||
if graph_credentials_inputs
|
||||
else None
|
||||
nodes_input_masks = _merge_nodes_input_masks(
|
||||
(
|
||||
make_node_credentials_input_map(graph, graph_credentials_inputs)
|
||||
if graph_credentials_inputs
|
||||
else {}
|
||||
),
|
||||
nodes_input_masks or {},
|
||||
)
|
||||
starting_nodes_input = await construct_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=inputs or {},
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
|
||||
if use_db_query:
|
||||
@@ -785,12 +792,7 @@ async def add_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
starting_nodes_input=await construct_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=inputs,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
),
|
||||
starting_nodes_input=starting_nodes_input,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
else:
|
||||
@@ -798,20 +800,15 @@ async def add_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
starting_nodes_input=await construct_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=inputs,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
),
|
||||
starting_nodes_input=starting_nodes_input,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
|
||||
try:
|
||||
queue = await get_async_execution_queue()
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry()
|
||||
if node_credentials_input_map:
|
||||
graph_exec_entry.node_credentials_input_map = node_credentials_input_map
|
||||
if nodes_input_masks:
|
||||
graph_exec_entry.nodes_input_masks = nodes_input_masks
|
||||
await queue.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec_entry.model_dump_json(),
|
||||
|
||||
@@ -1,23 +1,22 @@
|
||||
import functools
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..providers import ProviderName
|
||||
from ._base import BaseWebhooksManager
|
||||
|
||||
_WEBHOOK_MANAGERS: dict["ProviderName", type["BaseWebhooksManager"]] = {}
|
||||
|
||||
|
||||
# --8<-- [start:load_webhook_managers]
|
||||
@functools.cache
|
||||
def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]:
|
||||
if _WEBHOOK_MANAGERS:
|
||||
return _WEBHOOK_MANAGERS
|
||||
webhook_managers = {}
|
||||
|
||||
from .compass import CompassWebhookManager
|
||||
from .generic import GenericWebhooksManager
|
||||
from .github import GithubWebhooksManager
|
||||
from .slant3d import Slant3DWebhooksManager
|
||||
|
||||
_WEBHOOK_MANAGERS.update(
|
||||
webhook_managers.update(
|
||||
{
|
||||
handler.PROVIDER_NAME: handler
|
||||
for handler in [
|
||||
@@ -28,7 +27,7 @@ def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]
|
||||
]
|
||||
}
|
||||
)
|
||||
return _WEBHOOK_MANAGERS
|
||||
return webhook_managers
|
||||
|
||||
|
||||
# --8<-- [end:load_webhook_managers]
|
||||
|
||||
@@ -7,13 +7,14 @@ from uuid import uuid4
|
||||
from fastapi import Request
|
||||
from strenum import StrEnum
|
||||
|
||||
from backend.data import integrations
|
||||
import backend.data.integrations as integrations
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.utils import webhook_ingress_url
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
from backend.util.settings import Config
|
||||
|
||||
from .utils import webhook_ingress_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
app_config = Config()
|
||||
|
||||
@@ -41,44 +42,74 @@ class BaseWebhooksManager(ABC, Generic[WT]):
|
||||
)
|
||||
|
||||
if webhook := await integrations.find_webhook_by_credentials_and_props(
|
||||
credentials.id, webhook_type, resource, events
|
||||
user_id=user_id,
|
||||
credentials_id=credentials.id,
|
||||
webhook_type=webhook_type,
|
||||
resource=resource,
|
||||
events=events,
|
||||
):
|
||||
return webhook
|
||||
|
||||
return await self._create_webhook(
|
||||
user_id, webhook_type, events, resource, credentials
|
||||
user_id=user_id,
|
||||
webhook_type=webhook_type,
|
||||
events=events,
|
||||
resource=resource,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
async def get_manual_webhook(
|
||||
self,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
webhook_type: WT,
|
||||
events: list[str],
|
||||
):
|
||||
if current_webhook := await integrations.find_webhook_by_graph_and_props(
|
||||
graph_id, self.PROVIDER_NAME, webhook_type, events
|
||||
graph_id: Optional[str] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
) -> integrations.Webhook:
|
||||
"""
|
||||
Tries to find an existing webhook tied to `graph_id`/`preset_id`,
|
||||
or creates a new webhook if none exists.
|
||||
|
||||
Existing webhooks are matched by `user_id`, `webhook_type`,
|
||||
and `graph_id`/`preset_id`.
|
||||
|
||||
If an existing webhook is found, we check if the events match and update them
|
||||
if necessary. We do this rather than creating a new webhook
|
||||
to avoid changing the webhook URL for existing manual webhooks.
|
||||
"""
|
||||
if (graph_id or preset_id) and (
|
||||
current_webhook := await integrations.find_webhook_by_graph_and_props(
|
||||
user_id=user_id,
|
||||
provider=self.PROVIDER_NAME.value,
|
||||
webhook_type=webhook_type.value,
|
||||
graph_id=graph_id,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
):
|
||||
if set(current_webhook.events) != set(events):
|
||||
current_webhook = await integrations.update_webhook(
|
||||
current_webhook.id, events=events
|
||||
)
|
||||
return current_webhook
|
||||
|
||||
return await self._create_webhook(
|
||||
user_id,
|
||||
webhook_type,
|
||||
events,
|
||||
user_id=user_id,
|
||||
webhook_type=webhook_type,
|
||||
events=events,
|
||||
register=False,
|
||||
)
|
||||
|
||||
async def prune_webhook_if_dangling(
|
||||
self, webhook_id: str, credentials: Optional[Credentials]
|
||||
self, user_id: str, webhook_id: str, credentials: Optional[Credentials]
|
||||
) -> bool:
|
||||
webhook = await integrations.get_webhook(webhook_id)
|
||||
if webhook.attached_nodes is None:
|
||||
raise ValueError("Error retrieving webhook including attached nodes")
|
||||
if webhook.attached_nodes:
|
||||
webhook = await integrations.get_webhook(webhook_id, include_relations=True)
|
||||
if webhook.triggered_nodes or webhook.triggered_presets:
|
||||
# Don't prune webhook if in use
|
||||
return False
|
||||
|
||||
if credentials:
|
||||
await self._deregister_webhook(webhook, credentials)
|
||||
await integrations.delete_webhook(webhook.id)
|
||||
await integrations.delete_webhook(user_id, webhook.id)
|
||||
return True
|
||||
|
||||
# --8<-- [start:BaseWebhooksManager3]
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
|
||||
from backend.data.block import BlockSchema, BlockWebhookConfig
|
||||
from backend.data.block import BlockSchema
|
||||
from backend.data.graph import set_node_webhook
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks import get_webhook_manager, supports_webhooks
|
||||
|
||||
from . import get_webhook_manager, supports_webhooks
|
||||
from .utils import setup_webhook_for_block
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.graph import GraphModel, NodeModel
|
||||
@@ -81,7 +83,9 @@ async def on_graph_deactivate(graph: "GraphModel", user_id: str):
|
||||
f"credentials #{creds_meta['id']}"
|
||||
)
|
||||
|
||||
updated_node = await on_node_deactivate(node, credentials=node_credentials)
|
||||
updated_node = await on_node_deactivate(
|
||||
user_id, node, credentials=node_credentials
|
||||
)
|
||||
updated_nodes.append(updated_node)
|
||||
|
||||
graph.nodes = updated_nodes
|
||||
@@ -96,105 +100,25 @@ async def on_node_activate(
|
||||
) -> "NodeModel":
|
||||
"""Hook to be called when the node is activated/created"""
|
||||
|
||||
block = node.block
|
||||
|
||||
if not block.webhook_config:
|
||||
return node
|
||||
|
||||
provider = block.webhook_config.provider
|
||||
if not supports_webhooks(provider):
|
||||
raise ValueError(
|
||||
f"Block #{block.id} has webhook_config for provider {provider} "
|
||||
"which does not support webhooks"
|
||||
if node.block.webhook_config:
|
||||
new_webhook, feedback = await setup_webhook_for_block(
|
||||
user_id=user_id,
|
||||
trigger_block=node.block,
|
||||
trigger_config=node.input_default,
|
||||
for_graph_id=node.graph_id,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Activating webhook node #{node.id} with config {block.webhook_config}"
|
||||
)
|
||||
|
||||
webhooks_manager = get_webhook_manager(provider)
|
||||
|
||||
if auto_setup_webhook := isinstance(block.webhook_config, BlockWebhookConfig):
|
||||
try:
|
||||
resource = block.webhook_config.resource_format.format(**node.input_default)
|
||||
except KeyError:
|
||||
resource = None
|
||||
logger.debug(
|
||||
f"Constructed resource string {resource} from input {node.input_default}"
|
||||
)
|
||||
else:
|
||||
resource = "" # not relevant for manual webhooks
|
||||
|
||||
block_input_schema = cast(BlockSchema, block.input_schema)
|
||||
credentials_field_name = next(iter(block_input_schema.get_credentials_fields()), "")
|
||||
credentials_meta = (
|
||||
node.input_default.get(credentials_field_name)
|
||||
if credentials_field_name
|
||||
else None
|
||||
)
|
||||
event_filter_input_name = block.webhook_config.event_filter_input
|
||||
has_everything_for_webhook = (
|
||||
resource is not None
|
||||
and (credentials_meta or not credentials_field_name)
|
||||
and (
|
||||
not event_filter_input_name
|
||||
or (
|
||||
event_filter_input_name in node.input_default
|
||||
and any(
|
||||
is_on
|
||||
for is_on in node.input_default[event_filter_input_name].values()
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if has_everything_for_webhook and resource is not None:
|
||||
logger.debug(f"Node #{node} has everything for a webhook!")
|
||||
if credentials_meta and not credentials:
|
||||
raise ValueError(
|
||||
f"Cannot set up webhook for node #{node.id}: "
|
||||
f"credentials #{credentials_meta['id']} not available"
|
||||
)
|
||||
|
||||
if event_filter_input_name:
|
||||
# Shape of the event filter is enforced in Block.__init__
|
||||
event_filter = cast(dict, node.input_default[event_filter_input_name])
|
||||
events = [
|
||||
block.webhook_config.event_format.format(event=event)
|
||||
for event, enabled in event_filter.items()
|
||||
if enabled is True
|
||||
]
|
||||
logger.debug(f"Webhook events to subscribe to: {', '.join(events)}")
|
||||
if new_webhook:
|
||||
node = await set_node_webhook(node.id, new_webhook.id)
|
||||
else:
|
||||
events = []
|
||||
|
||||
# Find/make and attach a suitable webhook to the node
|
||||
if auto_setup_webhook:
|
||||
assert credentials is not None
|
||||
new_webhook = await webhooks_manager.get_suitable_auto_webhook(
|
||||
user_id,
|
||||
credentials,
|
||||
block.webhook_config.webhook_type,
|
||||
resource,
|
||||
events,
|
||||
logger.debug(
|
||||
f"Node #{node.id} does not have everything for a webhook: {feedback}"
|
||||
)
|
||||
else:
|
||||
# Manual webhook -> no credentials -> don't register but do create
|
||||
new_webhook = await webhooks_manager.get_manual_webhook(
|
||||
user_id,
|
||||
node.graph_id,
|
||||
block.webhook_config.webhook_type,
|
||||
events,
|
||||
)
|
||||
logger.debug(f"Acquired webhook: {new_webhook}")
|
||||
return await set_node_webhook(node.id, new_webhook.id)
|
||||
else:
|
||||
logger.debug(f"Node #{node.id} does not have everything for a webhook")
|
||||
|
||||
return node
|
||||
|
||||
|
||||
async def on_node_deactivate(
|
||||
user_id: str,
|
||||
node: "NodeModel",
|
||||
*,
|
||||
credentials: Optional["Credentials"] = None,
|
||||
@@ -233,7 +157,9 @@ async def on_node_deactivate(
|
||||
f"Pruning{' and deregistering' if credentials else ''} "
|
||||
f"webhook #{webhook.id}"
|
||||
)
|
||||
await webhooks_manager.prune_webhook_if_dangling(webhook.id, credentials)
|
||||
await webhooks_manager.prune_webhook_if_dangling(
|
||||
user_id, webhook.id, credentials
|
||||
)
|
||||
if (
|
||||
cast(BlockSchema, block.input_schema).get_credentials_fields()
|
||||
and not credentials
|
||||
|
||||
@@ -1,7 +1,22 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Config
|
||||
|
||||
from . import get_webhook_manager, supports_webhooks
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.block import Block, BlockSchema
|
||||
from backend.data.integrations import Webhook
|
||||
from backend.data.model import Credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
app_config = Config()
|
||||
credentials_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
# TODO: add test to assert this matches the actual API route
|
||||
@@ -10,3 +25,122 @@ def webhook_ingress_url(provider_name: ProviderName, webhook_id: str) -> str:
|
||||
f"{app_config.platform_base_url}/api/integrations/{provider_name.value}"
|
||||
f"/webhooks/{webhook_id}/ingress"
|
||||
)
|
||||
|
||||
|
||||
async def setup_webhook_for_block(
|
||||
user_id: str,
|
||||
trigger_block: "Block[BlockSchema, BlockSchema]",
|
||||
trigger_config: dict[str, JsonValue], # = Trigger block inputs
|
||||
for_graph_id: Optional[str] = None,
|
||||
for_preset_id: Optional[str] = None,
|
||||
credentials: Optional["Credentials"] = None,
|
||||
) -> tuple["Webhook", None] | tuple[None, str]:
|
||||
"""
|
||||
Utility function to create (and auto-setup if possible) a webhook for a given provider.
|
||||
|
||||
Returns:
|
||||
Webhook: The created or found webhook object, if successful.
|
||||
str: A feedback message, if any required inputs are missing.
|
||||
"""
|
||||
from backend.data.block import BlockWebhookConfig
|
||||
|
||||
if not (trigger_base_config := trigger_block.webhook_config):
|
||||
raise ValueError(f"Block #{trigger_block.id} does not have a webhook_config")
|
||||
|
||||
provider = trigger_base_config.provider
|
||||
if not supports_webhooks(provider):
|
||||
raise NotImplementedError(
|
||||
f"Block #{trigger_block.id} has webhook_config for provider {provider} "
|
||||
"for which we do not have a WebhooksManager"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Setting up webhook for block #{trigger_block.id} with config {trigger_config}"
|
||||
)
|
||||
|
||||
# Check & parse the event filter input, if any
|
||||
events: list[str] = []
|
||||
if event_filter_input_name := trigger_base_config.event_filter_input:
|
||||
if not (event_filter := trigger_config.get(event_filter_input_name)):
|
||||
return None, (
|
||||
f"Cannot set up {provider.value} webhook without event filter input: "
|
||||
f"missing input for '{event_filter_input_name}'"
|
||||
)
|
||||
elif not (
|
||||
# Shape of the event filter is enforced in Block.__init__
|
||||
any((event_filter := cast(dict[str, bool], event_filter)).values())
|
||||
):
|
||||
return None, (
|
||||
f"Cannot set up {provider.value} webhook without any enabled events "
|
||||
f"in event filter input '{event_filter_input_name}'"
|
||||
)
|
||||
|
||||
events = [
|
||||
trigger_base_config.event_format.format(event=event)
|
||||
for event, enabled in event_filter.items()
|
||||
if enabled is True
|
||||
]
|
||||
logger.debug(f"Webhook events to subscribe to: {', '.join(events)}")
|
||||
|
||||
# Check & process prerequisites for auto-setup webhooks
|
||||
if auto_setup_webhook := isinstance(trigger_base_config, BlockWebhookConfig):
|
||||
try:
|
||||
resource = trigger_base_config.resource_format.format(**trigger_config)
|
||||
except KeyError as missing_key:
|
||||
return None, (
|
||||
f"Cannot auto-setup {provider.value} webhook without resource: "
|
||||
f"missing input for '{missing_key}'"
|
||||
)
|
||||
logger.debug(
|
||||
f"Constructed resource string {resource} from input {trigger_config}"
|
||||
)
|
||||
|
||||
creds_field_name = next(
|
||||
# presence of this field is enforced in Block.__init__
|
||||
iter(trigger_block.input_schema.get_credentials_fields())
|
||||
)
|
||||
|
||||
if not (
|
||||
credentials_meta := cast(dict, trigger_config.get(creds_field_name, None))
|
||||
):
|
||||
return None, f"Cannot set up {provider.value} webhook without credentials"
|
||||
elif not (
|
||||
credentials := credentials
|
||||
or await credentials_manager.get(user_id, credentials_meta["id"])
|
||||
):
|
||||
raise ValueError(
|
||||
f"Cannot set up {provider.value} webhook without credentials: "
|
||||
f"credentials #{credentials_meta['id']} not found for user #{user_id}"
|
||||
)
|
||||
elif credentials.provider != provider:
|
||||
raise ValueError(
|
||||
f"Credentials #{credentials.id} do not match provider {provider.value}"
|
||||
)
|
||||
else:
|
||||
# not relevant for manual webhooks:
|
||||
resource = ""
|
||||
credentials = None
|
||||
|
||||
webhooks_manager = get_webhook_manager(provider)
|
||||
|
||||
# Find/make and attach a suitable webhook to the node
|
||||
if auto_setup_webhook:
|
||||
assert credentials is not None
|
||||
webhook = await webhooks_manager.get_suitable_auto_webhook(
|
||||
user_id=user_id,
|
||||
credentials=credentials,
|
||||
webhook_type=trigger_base_config.webhook_type,
|
||||
resource=resource,
|
||||
events=events,
|
||||
)
|
||||
else:
|
||||
# Manual webhook -> no credentials -> don't register but do create
|
||||
webhook = await webhooks_manager.get_manual_webhook(
|
||||
user_id=user_id,
|
||||
webhook_type=trigger_base_config.webhook_type,
|
||||
events=events,
|
||||
graph_id=for_graph_id,
|
||||
preset_id=for_preset_id,
|
||||
)
|
||||
logger.debug(f"Acquired webhook: {webhook}")
|
||||
return webhook, None
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
|
||||
from .routes.v1 import v1_router
|
||||
|
||||
external_app = FastAPI(
|
||||
@@ -8,4 +10,6 @@ external_app = FastAPI(
|
||||
docs_url="/docs",
|
||||
version="1.0",
|
||||
)
|
||||
|
||||
external_app.add_middleware(SecurityHeadersMiddleware)
|
||||
external_app.include_router(v1_router, prefix="/v1")
|
||||
|
||||
@@ -2,11 +2,19 @@ import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Annotated, Awaitable, Literal
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Path,
|
||||
Query,
|
||||
Request,
|
||||
status,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.status import HTTP_404_NOT_FOUND
|
||||
|
||||
from backend.data.graph import set_node_webhook
|
||||
from backend.data.graph import get_graph, set_node_webhook
|
||||
from backend.data.integrations import (
|
||||
WebhookEvent,
|
||||
get_all_webhooks_by_creds,
|
||||
@@ -20,6 +28,7 @@ from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.server.v2.library.db import set_preset_webhook, update_preset
|
||||
from backend.util.exceptions import NeedConfirmation, NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -95,7 +104,10 @@ async def callback(
|
||||
|
||||
if not valid_state:
|
||||
logger.warning(f"Invalid or expired state token for user {user_id}")
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired state token")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid or expired state token",
|
||||
)
|
||||
try:
|
||||
scopes = valid_state.scopes
|
||||
logger.debug(f"Retrieved scopes from state token: {scopes}")
|
||||
@@ -122,17 +134,12 @@ async def callback(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"OAuth callback for provider %s failed during code exchange: %s. Confirm provider credentials.",
|
||||
provider.value,
|
||||
e,
|
||||
logger.error(
|
||||
f"OAuth2 Code->Token exchange failed for provider {provider.value}: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"message": str(e),
|
||||
"hint": "Verify OAuth configuration and try again.",
|
||||
},
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"OAuth2 callback failed to exchange code for tokens: {str(e)}",
|
||||
)
|
||||
|
||||
# TODO: Allow specifying `title` to set on `credentials`
|
||||
@@ -201,10 +208,13 @@ async def get_credential(
|
||||
) -> Credentials:
|
||||
credential = await creds_manager.get(user_id, cred_id)
|
||||
if not credential:
|
||||
raise HTTPException(status_code=404, detail="Credentials not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if credential.provider != provider:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Credentials do not match the specified provider"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
)
|
||||
return credential
|
||||
|
||||
@@ -222,7 +232,8 @@ async def create_credentials(
|
||||
await creds_manager.create(user_id, credentials)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to store credentials: {str(e)}"
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to store credentials: {str(e)}",
|
||||
)
|
||||
return credentials
|
||||
|
||||
@@ -256,14 +267,17 @@ async def delete_credentials(
|
||||
) -> CredentialsDeletionResponse | CredentialsDeletionNeedsConfirmationResponse:
|
||||
creds = await creds_manager.store.get_creds_by_id(user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(status_code=404, detail="Credentials not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if creds.provider != provider:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Credentials do not match the specified provider"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
)
|
||||
|
||||
try:
|
||||
await remove_all_webhooks_for_credentials(creds, force)
|
||||
await remove_all_webhooks_for_credentials(user_id, creds, force)
|
||||
except NeedConfirmation as e:
|
||||
return CredentialsDeletionNeedsConfirmationResponse(message=str(e))
|
||||
|
||||
@@ -294,16 +308,10 @@ async def webhook_ingress_generic(
|
||||
logger.debug(f"Received {provider.value} webhook ingress for ID {webhook_id}")
|
||||
webhook_manager = get_webhook_manager(provider)
|
||||
try:
|
||||
webhook = await get_webhook(webhook_id)
|
||||
webhook = await get_webhook(webhook_id, include_relations=True)
|
||||
except NotFoundError as e:
|
||||
logger.warning(
|
||||
"Webhook payload received for unknown webhook %s. Confirm the webhook ID.",
|
||||
webhook_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_404_NOT_FOUND,
|
||||
detail={"message": str(e), "hint": "Check if the webhook ID is correct."},
|
||||
) from e
|
||||
logger.warning(f"Webhook payload received for unknown webhook #{webhook_id}")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
logger.debug(f"Webhook #{webhook_id}: {webhook}")
|
||||
payload, event_type = await webhook_manager.validate_payload(webhook, request)
|
||||
logger.debug(
|
||||
@@ -320,11 +328,11 @@ async def webhook_ingress_generic(
|
||||
await publish_webhook_event(webhook_event)
|
||||
logger.debug(f"Webhook event published: {webhook_event}")
|
||||
|
||||
if not webhook.attached_nodes:
|
||||
if not (webhook.triggered_nodes or webhook.triggered_presets):
|
||||
return
|
||||
|
||||
executions: list[Awaitable] = []
|
||||
for node in webhook.attached_nodes:
|
||||
for node in webhook.triggered_nodes:
|
||||
logger.debug(f"Webhook-attached node: {node}")
|
||||
if not node.is_triggered_by_event_type(event_type):
|
||||
logger.debug(f"Node #{node.id} doesn't trigger on event {event_type}")
|
||||
@@ -335,7 +343,48 @@ async def webhook_ingress_generic(
|
||||
user_id=webhook.user_id,
|
||||
graph_id=node.graph_id,
|
||||
graph_version=node.graph_version,
|
||||
inputs={f"webhook_{webhook_id}_payload": payload},
|
||||
nodes_input_masks={node.id: {"payload": payload}},
|
||||
)
|
||||
)
|
||||
for preset in webhook.triggered_presets:
|
||||
logger.debug(f"Webhook-attached preset: {preset}")
|
||||
if not preset.is_active:
|
||||
logger.debug(f"Preset #{preset.id} is inactive")
|
||||
continue
|
||||
|
||||
graph = await get_graph(preset.graph_id, preset.graph_version, webhook.user_id)
|
||||
if not graph:
|
||||
logger.error(
|
||||
f"User #{webhook.user_id} has preset #{preset.id} for graph "
|
||||
f"#{preset.graph_id} v{preset.graph_version}, "
|
||||
"but no access to the graph itself."
|
||||
)
|
||||
logger.info(f"Automatically deactivating broken preset #{preset.id}")
|
||||
await update_preset(preset.user_id, preset.id, is_active=False)
|
||||
continue
|
||||
if not (trigger_node := graph.webhook_input_node):
|
||||
# NOTE: this should NEVER happen, but we log and handle it gracefully
|
||||
logger.error(
|
||||
f"Preset #{preset.id} is triggered by webhook #{webhook.id}, but graph "
|
||||
f"#{preset.graph_id} v{preset.graph_version} has no webhook input node"
|
||||
)
|
||||
await set_preset_webhook(preset.user_id, preset.id, None)
|
||||
continue
|
||||
if not trigger_node.block.is_triggered_by_event_type(preset.inputs, event_type):
|
||||
logger.debug(f"Preset #{preset.id} doesn't trigger on event {event_type}")
|
||||
continue
|
||||
logger.debug(f"Executing preset #{preset.id} for webhook #{webhook.id}")
|
||||
|
||||
executions.append(
|
||||
add_graph_execution(
|
||||
user_id=webhook.user_id,
|
||||
graph_id=preset.graph_id,
|
||||
preset_id=preset.id,
|
||||
graph_version=preset.graph_version,
|
||||
graph_credentials_inputs=preset.credentials,
|
||||
nodes_input_masks={
|
||||
trigger_node.id: {**preset.inputs, "payload": payload}
|
||||
},
|
||||
)
|
||||
)
|
||||
asyncio.gather(*executions)
|
||||
@@ -360,7 +409,9 @@ async def webhook_ping(
|
||||
return False
|
||||
|
||||
if not await wait_for_webhook_event(webhook_id, event_type="ping", timeout=10):
|
||||
raise HTTPException(status_code=504, detail="Webhook ping timed out")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="Webhook ping timed out"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@@ -369,32 +420,37 @@ async def webhook_ping(
|
||||
|
||||
|
||||
async def remove_all_webhooks_for_credentials(
|
||||
credentials: Credentials, force: bool = False
|
||||
user_id: str, credentials: Credentials, force: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Remove and deregister all webhooks that were registered using the given credentials.
|
||||
|
||||
Params:
|
||||
user_id: The ID of the user who owns the credentials and webhooks.
|
||||
credentials: The credentials for which to remove the associated webhooks.
|
||||
force: Whether to proceed if any of the webhooks are still in use.
|
||||
|
||||
Raises:
|
||||
NeedConfirmation: If any of the webhooks are still in use and `force` is `False`
|
||||
"""
|
||||
webhooks = await get_all_webhooks_by_creds(credentials.id)
|
||||
if any(w.attached_nodes for w in webhooks) and not force:
|
||||
webhooks = await get_all_webhooks_by_creds(
|
||||
user_id, credentials.id, include_relations=True
|
||||
)
|
||||
if any(w.triggered_nodes or w.triggered_presets for w in webhooks) and not force:
|
||||
raise NeedConfirmation(
|
||||
"Some webhooks linked to these credentials are still in use by an agent"
|
||||
)
|
||||
for webhook in webhooks:
|
||||
# Unlink all nodes
|
||||
for node in webhook.attached_nodes or []:
|
||||
# Unlink all nodes & presets
|
||||
for node in webhook.triggered_nodes:
|
||||
await set_node_webhook(node.id, None)
|
||||
for preset in webhook.triggered_presets:
|
||||
await set_preset_webhook(user_id, preset.id, None)
|
||||
|
||||
# Prune the webhook
|
||||
webhook_manager = get_webhook_manager(ProviderName(credentials.provider))
|
||||
success = await webhook_manager.prune_webhook_if_dangling(
|
||||
webhook.id, credentials
|
||||
user_id, webhook.id, credentials
|
||||
)
|
||||
if not success:
|
||||
logger.warning(f"Webhook #{webhook.id} failed to prune")
|
||||
@@ -405,7 +461,7 @@ def _get_provider_oauth_handler(
|
||||
) -> "BaseOAuthHandler":
|
||||
if provider_name not in HANDLERS_BY_NAME:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Provider '{provider_name.value}' does not support OAuth",
|
||||
)
|
||||
|
||||
@@ -413,14 +469,13 @@ def _get_provider_oauth_handler(
|
||||
client_secret = getattr(settings.secrets, f"{provider_name.value}_client_secret")
|
||||
if not (client_id and client_secret):
|
||||
logger.error(
|
||||
"OAuth credentials for provider %s are missing. Check environment configuration.",
|
||||
provider_name.value,
|
||||
f"Attempt to use unconfigured {provider_name.value} OAuth integration"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail={
|
||||
"message": f"Integration with provider '{provider_name.value}' is not configured",
|
||||
"hint": "Set client ID and secret in the environment.",
|
||||
"message": f"Integration with provider '{provider_name.value}' is not configured.",
|
||||
"hint": "Set client ID and secret in the application's deployment environment",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
import re
|
||||
from typing import Set
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to add security headers to responses, with cache control
|
||||
disabled by default for all endpoints except those explicitly allowed.
|
||||
"""
|
||||
|
||||
CACHEABLE_PATHS: Set[str] = {
|
||||
# Static assets
|
||||
"/static",
|
||||
"/_next/static",
|
||||
"/assets",
|
||||
"/images",
|
||||
"/css",
|
||||
"/js",
|
||||
"/fonts",
|
||||
# Public API endpoints
|
||||
"/api/health",
|
||||
"/api/v1/health",
|
||||
"/api/status",
|
||||
# Public store/marketplace pages (read-only)
|
||||
"/api/store/agents",
|
||||
"/api/v1/store/agents",
|
||||
"/api/store/categories",
|
||||
"/api/v1/store/categories",
|
||||
"/api/store/featured",
|
||||
"/api/v1/store/featured",
|
||||
# Public graph templates (read-only, no user data)
|
||||
"/api/graphs/templates",
|
||||
"/api/v1/graphs/templates",
|
||||
# Documentation endpoints
|
||||
"/api/docs",
|
||||
"/api/v1/docs",
|
||||
"/docs",
|
||||
"/swagger",
|
||||
"/openapi.json",
|
||||
# Favicon and manifest
|
||||
"/favicon.ico",
|
||||
"/manifest.json",
|
||||
"/robots.txt",
|
||||
"/sitemap.xml",
|
||||
}
|
||||
|
||||
def __init__(self, app: ASGIApp):
|
||||
super().__init__(app)
|
||||
# Compile regex patterns for wildcard matching
|
||||
self.cacheable_patterns = [
|
||||
re.compile(pattern.replace("*", "[^/]+"))
|
||||
for pattern in self.CACHEABLE_PATHS
|
||||
if "*" in pattern
|
||||
]
|
||||
self.exact_paths = {path for path in self.CACHEABLE_PATHS if "*" not in path}
|
||||
|
||||
def is_cacheable_path(self, path: str) -> bool:
|
||||
"""Check if the given path is allowed to be cached."""
|
||||
# Check exact matches first
|
||||
for cacheable_path in self.exact_paths:
|
||||
if path.startswith(cacheable_path):
|
||||
return True
|
||||
|
||||
# Check pattern matches
|
||||
for pattern in self.cacheable_patterns:
|
||||
if pattern.match(path):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
response: Response = await call_next(request)
|
||||
|
||||
# Add general security headers
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
|
||||
# Default: Disable caching for all endpoints
|
||||
# Only allow caching for explicitly permitted paths
|
||||
if not self.is_cacheable_path(request.url.path):
|
||||
response.headers["Cache-Control"] = (
|
||||
"no-store, no-cache, must-revalidate, private"
|
||||
)
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
response.headers["Expires"] = "0"
|
||||
|
||||
return response
|
||||
@@ -0,0 +1,143 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.applications import Starlette
|
||||
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create a test FastAPI app with security middleware."""
|
||||
app = FastAPI()
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
@app.get("/api/auth/user")
|
||||
def get_user():
|
||||
return {"user": "test"}
|
||||
|
||||
@app.get("/api/v1/integrations/oauth/google")
|
||||
def oauth_endpoint():
|
||||
return {"oauth": "data"}
|
||||
|
||||
@app.get("/api/graphs/123/execute")
|
||||
def execute_graph():
|
||||
return {"execution": "data"}
|
||||
|
||||
@app.get("/api/integrations/credentials")
|
||||
def get_credentials():
|
||||
return {"credentials": "sensitive"}
|
||||
|
||||
@app.get("/api/store/agents")
|
||||
def store_agents():
|
||||
return {"agents": "public list"}
|
||||
|
||||
@app.get("/api/health")
|
||||
def health_check():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/static/logo.png")
|
||||
def static_file():
|
||||
return {"static": "content"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create a test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_non_cacheable_endpoints_have_cache_control_headers(client):
|
||||
"""Test that non-cacheable endpoints (most endpoints) have proper cache control headers."""
|
||||
non_cacheable_endpoints = [
|
||||
"/api/auth/user",
|
||||
"/api/v1/integrations/oauth/google",
|
||||
"/api/graphs/123/execute",
|
||||
"/api/integrations/credentials",
|
||||
]
|
||||
|
||||
for endpoint in non_cacheable_endpoints:
|
||||
response = client.get(endpoint)
|
||||
|
||||
# Check cache control headers are present (default behavior)
|
||||
assert (
|
||||
response.headers["Cache-Control"]
|
||||
== "no-store, no-cache, must-revalidate, private"
|
||||
)
|
||||
assert response.headers["Pragma"] == "no-cache"
|
||||
assert response.headers["Expires"] == "0"
|
||||
|
||||
# Check general security headers
|
||||
assert response.headers["X-Content-Type-Options"] == "nosniff"
|
||||
assert response.headers["X-Frame-Options"] == "DENY"
|
||||
assert response.headers["X-XSS-Protection"] == "1; mode=block"
|
||||
assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin"
|
||||
|
||||
|
||||
def test_cacheable_endpoints_dont_have_cache_control_headers(client):
|
||||
"""Test that explicitly cacheable endpoints don't have restrictive cache control headers."""
|
||||
cacheable_endpoints = [
|
||||
"/api/store/agents",
|
||||
"/api/health",
|
||||
"/static/logo.png",
|
||||
]
|
||||
|
||||
for endpoint in cacheable_endpoints:
|
||||
response = client.get(endpoint)
|
||||
|
||||
# Should NOT have restrictive cache control headers
|
||||
assert (
|
||||
"Cache-Control" not in response.headers
|
||||
or "no-store" not in response.headers.get("Cache-Control", "")
|
||||
)
|
||||
assert (
|
||||
"Pragma" not in response.headers
|
||||
or response.headers.get("Pragma") != "no-cache"
|
||||
)
|
||||
assert (
|
||||
"Expires" not in response.headers or response.headers.get("Expires") != "0"
|
||||
)
|
||||
|
||||
# Should still have general security headers
|
||||
assert response.headers["X-Content-Type-Options"] == "nosniff"
|
||||
assert response.headers["X-Frame-Options"] == "DENY"
|
||||
assert response.headers["X-XSS-Protection"] == "1; mode=block"
|
||||
assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin"
|
||||
|
||||
|
||||
def test_is_cacheable_path_detection():
|
||||
"""Test the path detection logic."""
|
||||
middleware = SecurityHeadersMiddleware(Starlette())
|
||||
|
||||
# Test cacheable paths (allow list)
|
||||
assert middleware.is_cacheable_path("/api/health")
|
||||
assert middleware.is_cacheable_path("/api/v1/health")
|
||||
assert middleware.is_cacheable_path("/static/image.png")
|
||||
assert middleware.is_cacheable_path("/api/store/agents")
|
||||
assert middleware.is_cacheable_path("/docs")
|
||||
assert middleware.is_cacheable_path("/favicon.ico")
|
||||
|
||||
# Test non-cacheable paths (everything else)
|
||||
assert not middleware.is_cacheable_path("/api/auth/user")
|
||||
assert not middleware.is_cacheable_path("/api/v1/integrations/oauth/callback")
|
||||
assert not middleware.is_cacheable_path("/api/integrations/credentials/123")
|
||||
assert not middleware.is_cacheable_path("/api/graphs/abc123/execute")
|
||||
assert not middleware.is_cacheable_path("/api/store/xyz/submissions")
|
||||
|
||||
|
||||
def test_path_prefix_matching():
|
||||
"""Test that path prefix matching works correctly."""
|
||||
middleware = SecurityHeadersMiddleware(Starlette())
|
||||
|
||||
# Test that paths starting with cacheable prefixes are cacheable
|
||||
assert middleware.is_cacheable_path("/static/css/style.css")
|
||||
assert middleware.is_cacheable_path("/static/js/app.js")
|
||||
assert middleware.is_cacheable_path("/assets/images/logo.png")
|
||||
assert middleware.is_cacheable_path("/_next/static/chunks/main.js")
|
||||
|
||||
# Test that other API paths are not cacheable by default
|
||||
assert not middleware.is_cacheable_path("/api/users/profile")
|
||||
assert not middleware.is_cacheable_path("/api/v1/private/data")
|
||||
assert not middleware.is_cacheable_path("/api/billing/subscription")
|
||||
@@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
import autogpt_libs.auth.models
|
||||
@@ -14,6 +15,7 @@ from autogpt_libs.feature_flag.client import (
|
||||
)
|
||||
from autogpt_libs.logging.utils import generate_uvicorn_config
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
@@ -36,6 +38,7 @@ from backend.blocks.llm import LlmModel
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.external.api import external_app
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
|
||||
settings = backend.util.settings.Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -67,6 +70,33 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.db.disconnect()
|
||||
|
||||
|
||||
def custom_generate_unique_id(route: APIRoute):
|
||||
"""Generate clean operation IDs for OpenAPI spec following the format:
|
||||
{method}{tag}{summary}
|
||||
"""
|
||||
if not route.tags or not route.methods:
|
||||
return f"{route.name}"
|
||||
|
||||
method = list(route.methods)[0].lower()
|
||||
first_tag = route.tags[0]
|
||||
if isinstance(first_tag, Enum):
|
||||
tag_str = first_tag.name
|
||||
else:
|
||||
tag_str = str(first_tag)
|
||||
|
||||
tag = "".join(word.capitalize() for word in tag_str.split("_")) # v1/v2
|
||||
|
||||
summary = (
|
||||
route.summary if route.summary else route.name
|
||||
) # need to be unique, a different version could have the same summary
|
||||
summary = "".join(word.capitalize() for word in str(summary).split("_"))
|
||||
|
||||
if tag:
|
||||
return f"{method}{tag}{summary}"
|
||||
else:
|
||||
return f"{method}{summary}"
|
||||
|
||||
|
||||
docs_url = (
|
||||
"/docs"
|
||||
if settings.config.app_env == backend.util.settings.AppEnvironment.LOCAL
|
||||
@@ -82,8 +112,11 @@ app = fastapi.FastAPI(
|
||||
version="0.1",
|
||||
lifespan=lifespan_context,
|
||||
docs_url=docs_url,
|
||||
generate_unique_id_function=custom_generate_unique_id,
|
||||
)
|
||||
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
|
||||
def handle_internal_http_error(status_code: int = 500, log_error: bool = True):
|
||||
def handler(request: fastapi.Request, exc: Exception):
|
||||
@@ -158,10 +191,12 @@ app.include_router(
|
||||
backend.server.v2.library.routes.router, tags=["v2"], prefix="/api/library"
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.otto.routes.router, tags=["v2"], prefix="/api/otto"
|
||||
backend.server.v2.otto.routes.router, tags=["v2", "otto"], prefix="/api/otto"
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.turnstile.routes.router, tags=["v2"], prefix="/api/turnstile"
|
||||
backend.server.v2.turnstile.routes.router,
|
||||
tags=["v2", "turnstile"],
|
||||
prefix="/api/turnstile",
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
@@ -288,18 +323,14 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
|
||||
@staticmethod
|
||||
async def test_execute_preset(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
preset_id: str,
|
||||
user_id: str,
|
||||
node_input: Optional[dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
return await backend.server.v2.library.routes.presets.execute_preset(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
preset_id=preset_id,
|
||||
node_input=node_input or {},
|
||||
user_id=user_id,
|
||||
inputs=inputs or {},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -34,7 +34,7 @@ router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/unsubscribe")
|
||||
@router.post("/unsubscribe", summary="One Click Email Unsubscribe")
|
||||
async def unsubscribe_via_one_click(token: Annotated[str, Query()]):
|
||||
logger.info("Received unsubscribe request from One Click Unsubscribe")
|
||||
try:
|
||||
@@ -48,7 +48,11 @@ async def unsubscribe_via_one_click(token: Annotated[str, Query()]):
|
||||
return JSONResponse(status_code=200, content={"status": "ok"})
|
||||
|
||||
|
||||
@router.post("/", dependencies=[Depends(postmark_validator.get_dependency())])
|
||||
@router.post(
|
||||
"/",
|
||||
dependencies=[Depends(postmark_validator.get_dependency())],
|
||||
summary="Handle Postmark Email Webhooks",
|
||||
)
|
||||
async def postmark_webhook_handler(
|
||||
webhook: Annotated[
|
||||
PostmarkWebhook,
|
||||
|
||||
@@ -113,14 +113,22 @@ v1_router.include_router(
|
||||
########################################################
|
||||
|
||||
|
||||
@v1_router.post("/auth/user", tags=["auth"], dependencies=[Depends(auth_middleware)])
|
||||
@v1_router.post(
|
||||
"/auth/user",
|
||||
summary="Get or create user",
|
||||
tags=["auth"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_or_create_user_route(user_data: dict = Depends(auth_middleware)):
|
||||
user = await get_or_create_user(user_data)
|
||||
return user.model_dump()
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/auth/user/email", tags=["auth"], dependencies=[Depends(auth_middleware)]
|
||||
"/auth/user/email",
|
||||
summary="Update user email",
|
||||
tags=["auth"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def update_user_email_route(
|
||||
user_id: Annotated[str, Depends(get_user_id)], email: str = Body(...)
|
||||
@@ -132,6 +140,7 @@ async def update_user_email_route(
|
||||
|
||||
@v1_router.get(
|
||||
"/auth/user/preferences",
|
||||
summary="Get notification preferences",
|
||||
tags=["auth"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -144,6 +153,7 @@ async def get_preferences(
|
||||
|
||||
@v1_router.post(
|
||||
"/auth/user/preferences",
|
||||
summary="Update notification preferences",
|
||||
tags=["auth"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -161,14 +171,20 @@ async def update_preferences(
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/onboarding", tags=["onboarding"], dependencies=[Depends(auth_middleware)]
|
||||
"/onboarding",
|
||||
summary="Get onboarding status",
|
||||
tags=["onboarding"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_onboarding(user_id: Annotated[str, Depends(get_user_id)]):
|
||||
return await get_user_onboarding(user_id)
|
||||
|
||||
|
||||
@v1_router.patch(
|
||||
"/onboarding", tags=["onboarding"], dependencies=[Depends(auth_middleware)]
|
||||
"/onboarding",
|
||||
summary="Update onboarding progress",
|
||||
tags=["onboarding"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def update_onboarding(
|
||||
user_id: Annotated[str, Depends(get_user_id)], data: UserOnboardingUpdate
|
||||
@@ -178,6 +194,7 @@ async def update_onboarding(
|
||||
|
||||
@v1_router.get(
|
||||
"/onboarding/agents",
|
||||
summary="Get recommended agents",
|
||||
tags=["onboarding"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -189,6 +206,7 @@ async def get_onboarding_agents(
|
||||
|
||||
@v1_router.get(
|
||||
"/onboarding/enabled",
|
||||
summary="Check onboarding enabled",
|
||||
tags=["onboarding", "public"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -201,7 +219,12 @@ async def is_onboarding_enabled():
|
||||
########################################################
|
||||
|
||||
|
||||
@v1_router.get(path="/blocks", tags=["blocks"], dependencies=[Depends(auth_middleware)])
|
||||
@v1_router.get(
|
||||
path="/blocks",
|
||||
summary="List available blocks",
|
||||
tags=["blocks"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
blocks = [block() for block in get_blocks().values()]
|
||||
costs = get_block_costs()
|
||||
@@ -212,6 +235,7 @@ def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
|
||||
@v1_router.post(
|
||||
path="/blocks/{block_id}/execute",
|
||||
summary="Execute graph block",
|
||||
tags=["blocks"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -231,7 +255,12 @@ async def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlock
|
||||
########################################################
|
||||
|
||||
|
||||
@v1_router.get(path="/credits", dependencies=[Depends(auth_middleware)])
|
||||
@v1_router.get(
|
||||
path="/credits",
|
||||
tags=["credits"],
|
||||
summary="Get user credits",
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_user_credits(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[str, int]:
|
||||
@@ -239,7 +268,10 @@ async def get_user_credits(
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/credits", tags=["credits"], dependencies=[Depends(auth_middleware)]
|
||||
path="/credits",
|
||||
summary="Request credit top up",
|
||||
tags=["credits"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def request_top_up(
|
||||
request: RequestTopUp, user_id: Annotated[str, Depends(get_user_id)]
|
||||
@@ -252,6 +284,7 @@ async def request_top_up(
|
||||
|
||||
@v1_router.post(
|
||||
path="/credits/{transaction_key}/refund",
|
||||
summary="Refund credit transaction",
|
||||
tags=["credits"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -264,7 +297,10 @@ async def refund_top_up(
|
||||
|
||||
|
||||
@v1_router.patch(
|
||||
path="/credits", tags=["credits"], dependencies=[Depends(auth_middleware)]
|
||||
path="/credits",
|
||||
summary="Fulfill checkout session",
|
||||
tags=["credits"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def fulfill_checkout(user_id: Annotated[str, Depends(get_user_id)]):
|
||||
await _user_credit_model.fulfill_checkout(user_id=user_id)
|
||||
@@ -273,6 +309,7 @@ async def fulfill_checkout(user_id: Annotated[str, Depends(get_user_id)]):
|
||||
|
||||
@v1_router.post(
|
||||
path="/credits/auto-top-up",
|
||||
summary="Configure auto top up",
|
||||
tags=["credits"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -301,6 +338,7 @@ async def configure_user_auto_top_up(
|
||||
|
||||
@v1_router.get(
|
||||
path="/credits/auto-top-up",
|
||||
summary="Get auto top up",
|
||||
tags=["credits"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -310,7 +348,9 @@ async def get_user_auto_top_up(
|
||||
return await get_auto_top_up(user_id)
|
||||
|
||||
|
||||
@v1_router.post(path="/credits/stripe_webhook", tags=["credits"])
|
||||
@v1_router.post(
|
||||
path="/credits/stripe_webhook", summary="Handle Stripe webhooks", tags=["credits"]
|
||||
)
|
||||
async def stripe_webhook(request: Request):
|
||||
# Get the raw request body
|
||||
payload = await request.body()
|
||||
@@ -345,14 +385,24 @@ async def stripe_webhook(request: Request):
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@v1_router.get(path="/credits/manage", dependencies=[Depends(auth_middleware)])
|
||||
@v1_router.get(
|
||||
path="/credits/manage",
|
||||
tags=["credits"],
|
||||
summary="Manage payment methods",
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def manage_payment_method(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[str, str]:
|
||||
return {"url": await _user_credit_model.create_billing_portal_session(user_id)}
|
||||
|
||||
|
||||
@v1_router.get(path="/credits/transactions", dependencies=[Depends(auth_middleware)])
|
||||
@v1_router.get(
|
||||
path="/credits/transactions",
|
||||
tags=["credits"],
|
||||
summary="Get credit history",
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_credit_history(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
transaction_time: datetime | None = None,
|
||||
@@ -370,7 +420,12 @@ async def get_credit_history(
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get(path="/credits/refunds", dependencies=[Depends(auth_middleware)])
|
||||
@v1_router.get(
|
||||
path="/credits/refunds",
|
||||
tags=["credits"],
|
||||
summary="Get refund requests",
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_refund_requests(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[RefundRequest]:
|
||||
@@ -386,7 +441,12 @@ class DeleteGraphResponse(TypedDict):
|
||||
version_counts: int
|
||||
|
||||
|
||||
@v1_router.get(path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)])
|
||||
@v1_router.get(
|
||||
path="/graphs",
|
||||
summary="List user graphs",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_graphs(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
@@ -394,10 +454,14 @@ async def get_graphs(
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
|
||||
path="/graphs/{graph_id}",
|
||||
summary="Get specific graph",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/versions/{version}",
|
||||
summary="Get graph version",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -421,6 +485,7 @@ async def get_graph(
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/versions",
|
||||
summary="Get all graph versions",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -434,7 +499,10 @@ async def get_graph_all_versions(
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)]
|
||||
path="/graphs",
|
||||
summary="Create new graph",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def create_new_graph(
|
||||
create_graph: CreateGraph,
|
||||
@@ -457,7 +525,10 @@ async def create_new_graph(
|
||||
|
||||
|
||||
@v1_router.delete(
|
||||
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
|
||||
path="/graphs/{graph_id}",
|
||||
summary="Delete graph permanently",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def delete_graph(
|
||||
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
@@ -469,7 +540,10 @@ async def delete_graph(
|
||||
|
||||
|
||||
@v1_router.put(
|
||||
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
|
||||
path="/graphs/{graph_id}",
|
||||
summary="Update graph version",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def update_graph(
|
||||
graph_id: str,
|
||||
@@ -515,6 +589,7 @@ async def update_graph(
|
||||
|
||||
@v1_router.put(
|
||||
path="/graphs/{graph_id}/versions/active",
|
||||
summary="Set active graph version",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -553,6 +628,7 @@ async def set_graph_active_version(
|
||||
|
||||
@v1_router.post(
|
||||
path="/graphs/{graph_id}/execute/{graph_version}",
|
||||
summary="Execute graph agent",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -586,6 +662,7 @@ async def execute_graph(
|
||||
|
||||
@v1_router.post(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}/stop",
|
||||
summary="Stop graph execution",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -613,6 +690,7 @@ async def stop_graph_run(
|
||||
|
||||
@v1_router.get(
|
||||
path="/executions",
|
||||
summary="Get all executions",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -624,6 +702,7 @@ async def get_graphs_executions(
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/executions",
|
||||
summary="Get graph executions",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -636,6 +715,7 @@ async def get_graph_executions(
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}",
|
||||
summary="Get execution details",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -665,6 +745,7 @@ async def get_graph_execution(
|
||||
|
||||
@v1_router.delete(
|
||||
path="/executions/{graph_exec_id}",
|
||||
summary="Delete graph execution",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
status_code=HTTP_204_NO_CONTENT,
|
||||
@@ -692,6 +773,7 @@ class ScheduleCreationRequest(pydantic.BaseModel):
|
||||
|
||||
@v1_router.post(
|
||||
path="/schedules",
|
||||
summary="Create execution schedule",
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -719,6 +801,7 @@ async def create_schedule(
|
||||
|
||||
@v1_router.delete(
|
||||
path="/schedules/{schedule_id}",
|
||||
summary="Delete execution schedule",
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -732,6 +815,7 @@ async def delete_schedule(
|
||||
|
||||
@v1_router.get(
|
||||
path="/schedules",
|
||||
summary="List execution schedules",
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -752,6 +836,7 @@ async def get_execution_schedules(
|
||||
|
||||
@v1_router.post(
|
||||
"/api-keys",
|
||||
summary="Create new API key",
|
||||
response_model=CreateAPIKeyResponse,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
@@ -782,6 +867,7 @@ async def create_api_key(
|
||||
|
||||
@v1_router.get(
|
||||
"/api-keys",
|
||||
summary="List user API keys",
|
||||
response_model=list[APIKeyWithoutHash] | dict[str, str],
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
@@ -802,6 +888,7 @@ async def get_api_keys(
|
||||
|
||||
@v1_router.get(
|
||||
"/api-keys/{key_id}",
|
||||
summary="Get specific API key",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
@@ -825,6 +912,7 @@ async def get_api_key(
|
||||
|
||||
@v1_router.delete(
|
||||
"/api-keys/{key_id}",
|
||||
summary="Revoke API key",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
@@ -853,6 +941,7 @@ async def delete_api_key(
|
||||
|
||||
@v1_router.post(
|
||||
"/api-keys/{key_id}/suspend",
|
||||
summary="Suspend API key",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
@@ -878,6 +967,7 @@ async def suspend_key(
|
||||
|
||||
@v1_router.put(
|
||||
"/api-keys/{key_id}/permissions",
|
||||
summary="Update key permissions",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
|
||||
@@ -22,7 +22,9 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/add_credits", response_model=AddUserCreditsResponse)
|
||||
@router.post(
|
||||
"/add_credits", response_model=AddUserCreditsResponse, summary="Add Credits to User"
|
||||
)
|
||||
async def add_user_credits(
|
||||
user_id: typing.Annotated[str, Body()],
|
||||
amount: typing.Annotated[int, Body()],
|
||||
@@ -49,6 +51,7 @@ async def add_user_credits(
|
||||
@router.get(
|
||||
"/users_history",
|
||||
response_model=UserHistoryResponse,
|
||||
summary="Get All Users History",
|
||||
)
|
||||
async def admin_get_all_user_history(
|
||||
admin_user: typing.Annotated[
|
||||
|
||||
@@ -19,6 +19,7 @@ router = fastapi.APIRouter(prefix="/admin", tags=["store", "admin"])
|
||||
|
||||
@router.get(
|
||||
"/listings",
|
||||
summary="Get Admin Listings History",
|
||||
response_model=backend.server.v2.store.model.StoreListingsWithVersionsResponse,
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
|
||||
)
|
||||
@@ -63,6 +64,7 @@ async def get_admin_listings_with_versions(
|
||||
|
||||
@router.post(
|
||||
"/submissions/{store_listing_version_id}/review",
|
||||
summary="Review Store Submission",
|
||||
response_model=backend.server.v2.store.model.StoreSubmission,
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
|
||||
)
|
||||
@@ -104,6 +106,7 @@ async def review_submission(
|
||||
|
||||
@router.get(
|
||||
"/submissions/download/{store_listing_version_id}",
|
||||
summary="Admin Download Agent File",
|
||||
tags=["store", "admin"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
|
||||
)
|
||||
|
||||
@@ -7,17 +7,17 @@ import prisma.fields
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
import backend.data.graph
|
||||
import backend.data.graph as graph_db
|
||||
import backend.server.model
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
import backend.server.v2.store.image_gen as store_image_gen
|
||||
import backend.server.v2.store.media as store_media
|
||||
from backend.data import db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.db import locked_transaction
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.db import locked_transaction, transaction
|
||||
from backend.data.execution import get_graph_execution
|
||||
from backend.data.includes import library_agent_include
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||
from backend.util.exceptions import NotFoundError
|
||||
@@ -216,7 +216,7 @@ async def get_library_agent_by_store_version_id(
|
||||
|
||||
|
||||
async def add_generated_agent_image(
|
||||
graph: backend.data.graph.GraphModel,
|
||||
graph: graph_db.GraphModel,
|
||||
library_agent_id: str,
|
||||
) -> Optional[prisma.models.LibraryAgent]:
|
||||
"""
|
||||
@@ -249,7 +249,7 @@ async def add_generated_agent_image(
|
||||
|
||||
|
||||
async def create_library_agent(
|
||||
graph: backend.data.graph.GraphModel,
|
||||
graph: graph_db.GraphModel,
|
||||
user_id: str,
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
@@ -525,7 +525,10 @@ async def list_presets(
|
||||
)
|
||||
raise store_exceptions.DatabaseError("Invalid pagination parameters")
|
||||
|
||||
query_filter: prisma.types.AgentPresetWhereInput = {"userId": user_id}
|
||||
query_filter: prisma.types.AgentPresetWhereInput = {
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
if graph_id:
|
||||
query_filter["agentGraphId"] = graph_id
|
||||
|
||||
@@ -581,7 +584,7 @@ async def get_preset(
|
||||
where={"id": preset_id},
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not preset or preset.userId != user_id:
|
||||
if not preset or preset.userId != user_id or preset.isDeleted:
|
||||
return None
|
||||
return library_model.LibraryAgentPreset.from_db(preset)
|
||||
except prisma.errors.PrismaError as e:
|
||||
@@ -618,12 +621,19 @@ async def create_preset(
|
||||
agentGraphId=preset.graph_id,
|
||||
agentGraphVersion=preset.graph_version,
|
||||
isActive=preset.is_active,
|
||||
webhookId=preset.webhook_id,
|
||||
InputPresets={
|
||||
"create": [
|
||||
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput( # noqa
|
||||
name=name, data=prisma.fields.Json(data)
|
||||
)
|
||||
for name, data in preset.inputs.items()
|
||||
for name, data in {
|
||||
**preset.inputs,
|
||||
**{
|
||||
key: creds_meta.model_dump(exclude_none=True)
|
||||
for key, creds_meta in preset.credentials.items()
|
||||
},
|
||||
}.items()
|
||||
]
|
||||
},
|
||||
),
|
||||
@@ -664,6 +674,7 @@ async def create_preset_from_graph_execution(
|
||||
user_id=user_id,
|
||||
preset=library_model.LibraryAgentPresetCreatable(
|
||||
inputs=graph_execution.inputs,
|
||||
credentials={}, # FIXME
|
||||
graph_id=graph_execution.graph_id,
|
||||
graph_version=graph_execution.graph_version,
|
||||
name=create_request.name,
|
||||
@@ -676,7 +687,11 @@ async def create_preset_from_graph_execution(
|
||||
async def update_preset(
|
||||
user_id: str,
|
||||
preset_id: str,
|
||||
preset: library_model.LibraryAgentPresetUpdatable,
|
||||
inputs: Optional[BlockInput] = None,
|
||||
credentials: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
) -> library_model.LibraryAgentPreset:
|
||||
"""
|
||||
Updates an existing AgentPreset for a user.
|
||||
@@ -684,49 +699,95 @@ async def update_preset(
|
||||
Args:
|
||||
user_id: The ID of the user updating the preset.
|
||||
preset_id: The ID of the preset to update.
|
||||
preset: The preset data used for the update.
|
||||
inputs: New inputs object to set on the preset.
|
||||
credentials: New credentials to set on the preset.
|
||||
name: New name for the preset.
|
||||
description: New description for the preset.
|
||||
is_active: New active status for the preset.
|
||||
|
||||
Returns:
|
||||
The updated LibraryAgentPreset.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there's a database error in updating the preset.
|
||||
ValueError: If attempting to update a non-existent preset.
|
||||
NotFoundError: If attempting to update a non-existent preset.
|
||||
"""
|
||||
current = await get_preset(user_id, preset_id) # assert ownership
|
||||
if not current:
|
||||
raise NotFoundError(f"Preset #{preset_id} not found for user #{user_id}")
|
||||
logger.debug(
|
||||
f"Updating preset #{preset_id} ({repr(preset.name)}) for user #{user_id}",
|
||||
f"Updating preset #{preset_id} ({repr(current.name)}) for user #{user_id}",
|
||||
)
|
||||
try:
|
||||
update_data: prisma.types.AgentPresetUpdateInput = {}
|
||||
if preset.name:
|
||||
update_data["name"] = preset.name
|
||||
if preset.description:
|
||||
update_data["description"] = preset.description
|
||||
if preset.inputs:
|
||||
update_data["InputPresets"] = {
|
||||
"create": [
|
||||
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput( # noqa
|
||||
name=name, data=prisma.fields.Json(data)
|
||||
async with transaction() as tx:
|
||||
update_data: prisma.types.AgentPresetUpdateInput = {}
|
||||
if name:
|
||||
update_data["name"] = name
|
||||
if description:
|
||||
update_data["description"] = description
|
||||
if is_active is not None:
|
||||
update_data["isActive"] = is_active
|
||||
if inputs or credentials:
|
||||
if not (inputs and credentials):
|
||||
raise ValueError(
|
||||
"Preset inputs and credentials must be provided together"
|
||||
)
|
||||
for name, data in preset.inputs.items()
|
||||
]
|
||||
}
|
||||
if preset.is_active:
|
||||
update_data["isActive"] = preset.is_active
|
||||
update_data["InputPresets"] = {
|
||||
"create": [
|
||||
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput( # noqa
|
||||
name=name, data=prisma.fields.Json(data)
|
||||
)
|
||||
for name, data in {
|
||||
**inputs,
|
||||
**{
|
||||
key: creds_meta.model_dump(exclude_none=True)
|
||||
for key, creds_meta in credentials.items()
|
||||
},
|
||||
}.items()
|
||||
],
|
||||
}
|
||||
# Existing InputPresets must be deleted, in a separate query
|
||||
await prisma.models.AgentNodeExecutionInputOutput.prisma(
|
||||
tx
|
||||
).delete_many(where={"agentPresetId": preset_id})
|
||||
|
||||
updated = await prisma.models.AgentPreset.prisma().update(
|
||||
where={"id": preset_id},
|
||||
data=update_data,
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
updated = await prisma.models.AgentPreset.prisma(tx).update(
|
||||
where={"id": preset_id},
|
||||
data=update_data,
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not updated:
|
||||
raise ValueError(f"AgentPreset #{preset_id} not found")
|
||||
raise RuntimeError(f"AgentPreset #{preset_id} vanished while updating")
|
||||
return library_model.LibraryAgentPreset.from_db(updated)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating preset: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to update preset") from e
|
||||
|
||||
|
||||
async def set_preset_webhook(
|
||||
user_id: str, preset_id: str, webhook_id: str | None
|
||||
) -> library_model.LibraryAgentPreset:
|
||||
current = await prisma.models.AgentPreset.prisma().find_unique(
|
||||
where={"id": preset_id},
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not current or current.userId != user_id:
|
||||
raise NotFoundError(f"Preset #{preset_id} not found")
|
||||
|
||||
updated = await prisma.models.AgentPreset.prisma().update(
|
||||
where={"id": preset_id},
|
||||
data=(
|
||||
{"Webhook": {"connect": {"id": webhook_id}}}
|
||||
if webhook_id
|
||||
else {"Webhook": {"disconnect": True}}
|
||||
),
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not updated:
|
||||
raise RuntimeError(f"AgentPreset #{preset_id} vanished while updating")
|
||||
return library_model.LibraryAgentPreset.from_db(updated)
|
||||
|
||||
|
||||
async def delete_preset(user_id: str, preset_id: str) -> None:
|
||||
"""
|
||||
Soft-deletes a preset by marking it as isDeleted = True.
|
||||
@@ -738,7 +799,7 @@ async def delete_preset(user_id: str, preset_id: str) -> None:
|
||||
Raises:
|
||||
DatabaseError: If there's a database error during deletion.
|
||||
"""
|
||||
logger.info(f"Deleting preset {preset_id} for user {user_id}")
|
||||
logger.debug(f"Setting preset #{preset_id} for user #{user_id} to deleted")
|
||||
try:
|
||||
await prisma.models.AgentPreset.prisma().update_many(
|
||||
where={"id": preset_id, "userId": user_id},
|
||||
@@ -765,7 +826,7 @@ async def fork_library_agent(library_agent_id: str, user_id: str):
|
||||
"""
|
||||
logger.debug(f"Forking library agent {library_agent_id} for user {user_id}")
|
||||
try:
|
||||
async with db.locked_transaction(f"usr_trx_{user_id}-fork_agent"):
|
||||
async with locked_transaction(f"usr_trx_{user_id}-fork_agent"):
|
||||
# Fetch the original agent
|
||||
original_agent = await get_library_agent(library_agent_id, user_id)
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ import pydantic
|
||||
import backend.data.block as block_model
|
||||
import backend.data.graph as graph_model
|
||||
import backend.server.model as server_model
|
||||
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
|
||||
class LibraryAgentStatus(str, Enum):
|
||||
@@ -18,6 +20,14 @@ class LibraryAgentStatus(str, Enum):
|
||||
ERROR = "ERROR" # Agent is in an error state
|
||||
|
||||
|
||||
class LibraryAgentTriggerInfo(pydantic.BaseModel):
|
||||
provider: ProviderName
|
||||
config_schema: dict[str, Any] = pydantic.Field(
|
||||
description="Input schema for the trigger block"
|
||||
)
|
||||
credentials_input_name: Optional[str]
|
||||
|
||||
|
||||
class LibraryAgent(pydantic.BaseModel):
|
||||
"""
|
||||
Represents an agent in the library, including metadata for display and
|
||||
@@ -40,8 +50,15 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
# Made input_schema and output_schema match GraphMeta's type
|
||||
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
|
||||
credentials_input_schema: dict[str, Any] = pydantic.Field(
|
||||
description="Input schema for credentials required by the agent",
|
||||
)
|
||||
|
||||
has_external_trigger: bool = pydantic.Field(
|
||||
description="Whether the agent has an external trigger (e.g. webhook) node"
|
||||
)
|
||||
trigger_setup_info: Optional[LibraryAgentTriggerInfo] = None
|
||||
|
||||
# Indicates whether there's a new output (based on recent runs)
|
||||
new_output: bool
|
||||
@@ -106,6 +123,32 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
input_schema=graph.input_schema,
|
||||
credentials_input_schema=graph.credentials_input_schema,
|
||||
has_external_trigger=graph.has_webhook_trigger,
|
||||
trigger_setup_info=(
|
||||
LibraryAgentTriggerInfo(
|
||||
provider=trigger_block.webhook_config.provider,
|
||||
config_schema={
|
||||
**(json_schema := trigger_block.input_schema.jsonschema()),
|
||||
"properties": {
|
||||
pn: sub_schema
|
||||
for pn, sub_schema in json_schema["properties"].items()
|
||||
if not is_credentials_field_name(pn)
|
||||
},
|
||||
"required": [
|
||||
pn
|
||||
for pn in json_schema["required"] or []
|
||||
if not is_credentials_field_name(pn)
|
||||
],
|
||||
},
|
||||
credentials_input_name=next(
|
||||
iter(trigger_block.input_schema.get_credentials_fields()), None
|
||||
),
|
||||
)
|
||||
if graph.webhook_input_node
|
||||
and (trigger_block := graph.webhook_input_node.block).webhook_config
|
||||
else None
|
||||
),
|
||||
new_output=new_output,
|
||||
can_access_graph=can_access_graph,
|
||||
is_latest_version=is_latest_version,
|
||||
@@ -177,12 +220,15 @@ class LibraryAgentPresetCreatable(pydantic.BaseModel):
|
||||
graph_version: int
|
||||
|
||||
inputs: block_model.BlockInput
|
||||
credentials: dict[str, CredentialsMetaInput]
|
||||
|
||||
name: str
|
||||
description: str
|
||||
|
||||
is_active: bool = True
|
||||
|
||||
webhook_id: Optional[str] = None
|
||||
|
||||
|
||||
class LibraryAgentPresetCreatableFromGraphExecution(pydantic.BaseModel):
|
||||
"""
|
||||
@@ -203,6 +249,7 @@ class LibraryAgentPresetUpdatable(pydantic.BaseModel):
|
||||
"""
|
||||
|
||||
inputs: Optional[block_model.BlockInput] = None
|
||||
credentials: Optional[dict[str, CredentialsMetaInput]] = None
|
||||
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
@@ -214,20 +261,28 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
||||
"""Represents a preset configuration for a library agent."""
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
updated_at: datetime.datetime
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, preset: prisma.models.AgentPreset) -> "LibraryAgentPreset":
|
||||
if preset.InputPresets is None:
|
||||
raise ValueError("Input values must be included in object")
|
||||
raise ValueError("InputPresets must be included in AgentPreset query")
|
||||
|
||||
input_data: block_model.BlockInput = {}
|
||||
input_credentials: dict[str, CredentialsMetaInput] = {}
|
||||
|
||||
for preset_input in preset.InputPresets:
|
||||
input_data[preset_input.name] = preset_input.data
|
||||
if not is_credentials_field_name(preset_input.name):
|
||||
input_data[preset_input.name] = preset_input.data
|
||||
else:
|
||||
input_credentials[preset_input.name] = (
|
||||
CredentialsMetaInput.model_validate(preset_input.data)
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=preset.id,
|
||||
user_id=preset.userId,
|
||||
updated_at=preset.updatedAt,
|
||||
graph_id=preset.agentGraphId,
|
||||
graph_version=preset.agentGraphVersion,
|
||||
@@ -235,6 +290,8 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
||||
description=preset.description,
|
||||
is_active=preset.isActive,
|
||||
inputs=input_data,
|
||||
credentials=input_credentials,
|
||||
webhook_id=preset.webhookId,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import backend.server.v2.library.db as library_db
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
from backend.data.graph import get_graph
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.executor.utils import make_node_credentials_input_map
|
||||
from backend.integrations.webhooks.utils import setup_webhook_for_block
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -20,6 +26,7 @@ router = APIRouter(
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
summary="List Library Agents",
|
||||
responses={
|
||||
500: {"description": "Server error", "content": {"application/json": {}}},
|
||||
},
|
||||
@@ -70,14 +77,14 @@ async def list_library_agents(
|
||||
page_size=page_size,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Listing library agents failed for user %s: %s", user_id, e)
|
||||
logger.error(f"Could not list library agents for user #{user_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Inspect database connectivity."},
|
||||
detail=str(e),
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/{library_agent_id}")
|
||||
@router.get("/{library_agent_id}", summary="Get Library Agent")
|
||||
async def get_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
@@ -87,6 +94,7 @@ async def get_library_agent(
|
||||
|
||||
@router.get(
|
||||
"/marketplace/{store_listing_version_id}",
|
||||
summary="Get Agent By Store ID",
|
||||
tags=["store, library"],
|
||||
response_model=library_model.LibraryAgent | None,
|
||||
)
|
||||
@@ -101,23 +109,22 @@ async def get_library_agent_by_store_listing_version_id(
|
||||
return await library_db.get_library_agent_by_store_version_id(
|
||||
store_listing_version_id, user_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Retrieving library agent by store version failed for user %s: %s",
|
||||
user_id,
|
||||
e,
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not fetch library agent from store version ID: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"message": str(e),
|
||||
"hint": "Check if the store listing ID is valid.",
|
||||
},
|
||||
detail=str(e),
|
||||
) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
summary="Add Marketplace Agent",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
responses={
|
||||
201: {"description": "Agent added successfully"},
|
||||
@@ -149,26 +156,20 @@ async def add_marketplace_agent_to_library(
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
except store_exceptions.AgentNotFoundError:
|
||||
except store_exceptions.AgentNotFoundError as e:
|
||||
logger.warning(
|
||||
"Store listing version %s not found when adding to library",
|
||||
store_listing_version_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"message": f"Store listing version {store_listing_version_id} not found",
|
||||
"hint": "Confirm the ID provided.",
|
||||
},
|
||||
f"Could not find store listing version {store_listing_version_id} "
|
||||
"to add to library"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
except store_exceptions.DatabaseError as e:
|
||||
logger.exception("Database error whilst adding agent to library: %s", e)
|
||||
logger.error(f"Database error while adding agent to library: {e}", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Inspect DB logs for details."},
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.exception("Unexpected error while adding agent to library: %s", e)
|
||||
logger.error(f"Unexpected error while adding agent to library: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
@@ -180,6 +181,7 @@ async def add_marketplace_agent_to_library(
|
||||
|
||||
@router.put(
|
||||
"/{library_agent_id}",
|
||||
summary="Update Library Agent",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
responses={
|
||||
204: {"description": "Agent updated successfully"},
|
||||
@@ -219,20 +221,20 @@ async def update_library_agent(
|
||||
content={"message": "Agent updated successfully"},
|
||||
)
|
||||
except store_exceptions.DatabaseError as e:
|
||||
logger.exception("Database error while updating library agent: %s", e)
|
||||
logger.error(f"Database error while updating library agent: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Verify DB connection."},
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.exception("Unexpected error while updating library agent: %s", e)
|
||||
logger.error(f"Unexpected error while updating library agent: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Check server logs."},
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/{library_agent_id}/fork")
|
||||
@router.post("/{library_agent_id}/fork", summary="Fork Library Agent")
|
||||
async def fork_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
@@ -241,3 +243,81 @@ async def fork_library_agent(
|
||||
library_agent_id=library_agent_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
class TriggeredPresetSetupParams(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
|
||||
trigger_config: dict[str, Any]
|
||||
agent_credentials: dict[str, CredentialsMetaInput] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@router.post("/{library_agent_id}/setup-trigger")
|
||||
async def setup_trigger(
|
||||
library_agent_id: str = Path(..., description="ID of the library agent"),
|
||||
params: TriggeredPresetSetupParams = Body(),
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
) -> library_model.LibraryAgentPreset:
|
||||
"""
|
||||
Sets up a webhook-triggered `LibraryAgentPreset` for a `LibraryAgent`.
|
||||
Returns the correspondingly created `LibraryAgentPreset` with `webhook_id` set.
|
||||
"""
|
||||
library_agent = await library_db.get_library_agent(
|
||||
id=library_agent_id, user_id=user_id
|
||||
)
|
||||
if not library_agent:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Library agent #{library_agent_id} not found",
|
||||
)
|
||||
|
||||
graph = await get_graph(
|
||||
library_agent.graph_id, version=library_agent.graph_version, user_id=user_id
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(
|
||||
status.HTTP_410_GONE,
|
||||
f"Graph #{library_agent.graph_id} not accessible (anymore)",
|
||||
)
|
||||
if not (trigger_node := graph.webhook_input_node):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Graph #{library_agent.graph_id} does not have a webhook node",
|
||||
)
|
||||
|
||||
trigger_config_with_credentials = {
|
||||
**params.trigger_config,
|
||||
**(
|
||||
make_node_credentials_input_map(graph, params.agent_credentials).get(
|
||||
trigger_node.id
|
||||
)
|
||||
or {}
|
||||
),
|
||||
}
|
||||
|
||||
new_webhook, feedback = await setup_webhook_for_block(
|
||||
user_id=user_id,
|
||||
trigger_block=trigger_node.block,
|
||||
trigger_config=trigger_config_with_credentials,
|
||||
)
|
||||
if not new_webhook:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Could not set up webhook: {feedback}",
|
||||
)
|
||||
|
||||
new_preset = await library_db.create_preset(
|
||||
user_id=user_id,
|
||||
preset=library_model.LibraryAgentPresetCreatable(
|
||||
graph_id=library_agent.graph_id,
|
||||
graph_version=library_agent.graph_version,
|
||||
name=params.name,
|
||||
description=params.description,
|
||||
inputs=trigger_config_with_credentials,
|
||||
credentials=params.agent_credentials,
|
||||
webhook_id=new_webhook.id,
|
||||
is_active=True,
|
||||
),
|
||||
)
|
||||
return new_preset
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
import logging
|
||||
from typing import Annotated, Any, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
|
||||
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.library.model as models
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.data.graph import get_graph
|
||||
from backend.data.integrations import get_webhook
|
||||
from backend.executor.utils import add_graph_execution, make_node_credentials_input_map
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.integrations.webhooks.utils import setup_webhook_for_block
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
credentials_manager = IntegrationCredentialsManager()
|
||||
router = APIRouter(tags=["presets"])
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -49,11 +55,7 @@ async def list_presets(
|
||||
except Exception as e:
|
||||
logger.exception("Failed to list presets for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"message": str(e),
|
||||
"hint": "Ensure the presets DB table is accessible.",
|
||||
},
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||||
)
|
||||
|
||||
|
||||
@@ -81,21 +83,21 @@ async def get_preset(
|
||||
"""
|
||||
try:
|
||||
preset = await db.get_preset(user_id, preset_id)
|
||||
if not preset:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Preset {preset_id} not found",
|
||||
)
|
||||
return preset
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Error retrieving preset %s for user %s: %s", preset_id, user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Validate preset ID and retry."},
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||||
)
|
||||
|
||||
if not preset:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Preset #{preset_id} not found",
|
||||
)
|
||||
return preset
|
||||
|
||||
|
||||
@router.post(
|
||||
"/presets",
|
||||
@@ -132,8 +134,7 @@ async def create_preset(
|
||||
except Exception as e:
|
||||
logger.exception("Preset creation failed for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Check preset payload format."},
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||||
)
|
||||
|
||||
|
||||
@@ -161,17 +162,85 @@ async def update_preset(
|
||||
Raises:
|
||||
HTTPException: If an error occurs while updating the preset.
|
||||
"""
|
||||
current = await get_preset(preset_id, user_id=user_id)
|
||||
if not current:
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Preset #{preset_id} not found")
|
||||
|
||||
graph = await get_graph(
|
||||
current.graph_id,
|
||||
current.graph_version,
|
||||
user_id=user_id,
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(
|
||||
status.HTTP_410_GONE,
|
||||
f"Graph #{current.graph_id} not accessible (anymore)",
|
||||
)
|
||||
|
||||
trigger_inputs_updated, new_webhook, feedback = False, None, None
|
||||
if (trigger_node := graph.webhook_input_node) and (
|
||||
preset.inputs is not None and preset.credentials is not None
|
||||
):
|
||||
trigger_config_with_credentials = {
|
||||
**preset.inputs,
|
||||
**(
|
||||
make_node_credentials_input_map(graph, preset.credentials).get(
|
||||
trigger_node.id
|
||||
)
|
||||
or {}
|
||||
),
|
||||
}
|
||||
new_webhook, feedback = await setup_webhook_for_block(
|
||||
user_id=user_id,
|
||||
trigger_block=graph.webhook_input_node.block,
|
||||
trigger_config=trigger_config_with_credentials,
|
||||
for_preset_id=preset_id,
|
||||
)
|
||||
trigger_inputs_updated = True
|
||||
if not new_webhook:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Could not update trigger configuration: {feedback}",
|
||||
)
|
||||
|
||||
try:
|
||||
return await db.update_preset(
|
||||
user_id=user_id, preset_id=preset_id, preset=preset
|
||||
updated = await db.update_preset(
|
||||
user_id=user_id,
|
||||
preset_id=preset_id,
|
||||
inputs=preset.inputs,
|
||||
credentials=preset.credentials,
|
||||
name=preset.name,
|
||||
description=preset.description,
|
||||
is_active=preset.is_active,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Preset update failed for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Check preset data and try again."},
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||||
)
|
||||
|
||||
# Update the webhook as well, if necessary
|
||||
if trigger_inputs_updated:
|
||||
updated = await db.set_preset_webhook(
|
||||
user_id, preset_id, new_webhook.id if new_webhook else None
|
||||
)
|
||||
|
||||
# Clean up webhook if it is now unused
|
||||
if current.webhook_id and (
|
||||
current.webhook_id != (new_webhook.id if new_webhook else None)
|
||||
):
|
||||
current_webhook = await get_webhook(current.webhook_id)
|
||||
credentials = (
|
||||
await credentials_manager.get(user_id, current_webhook.credentials_id)
|
||||
if current_webhook.credentials_id
|
||||
else None
|
||||
)
|
||||
await get_webhook_manager(
|
||||
current_webhook.provider
|
||||
).prune_webhook_if_dangling(user_id, current_webhook.id, credentials)
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/presets/{preset_id}",
|
||||
@@ -193,6 +262,28 @@ async def delete_preset(
|
||||
Raises:
|
||||
HTTPException: If an error occurs while deleting the preset.
|
||||
"""
|
||||
preset = await db.get_preset(user_id, preset_id)
|
||||
if not preset:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Preset #{preset_id} not found for user #{user_id}",
|
||||
)
|
||||
|
||||
# Detach and clean up the attached webhook, if any
|
||||
if preset.webhook_id:
|
||||
webhook = await get_webhook(preset.webhook_id)
|
||||
await db.set_preset_webhook(user_id, preset_id, None)
|
||||
|
||||
# Clean up webhook if it is now unused
|
||||
credentials = (
|
||||
await credentials_manager.get(user_id, webhook.credentials_id)
|
||||
if webhook.credentials_id
|
||||
else None
|
||||
)
|
||||
await get_webhook_manager(webhook.provider).prune_webhook_if_dangling(
|
||||
user_id, webhook.id, credentials
|
||||
)
|
||||
|
||||
try:
|
||||
await db.delete_preset(user_id, preset_id)
|
||||
except Exception as e:
|
||||
@@ -201,7 +292,7 @@ async def delete_preset(
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Ensure preset exists before deleting."},
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
|
||||
@@ -212,24 +303,20 @@ async def delete_preset(
|
||||
description="Execute a preset with the given graph and node input for the current user.",
|
||||
)
|
||||
async def execute_preset(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
preset_id: str,
|
||||
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
inputs: dict[str, Any] = Body(..., embed=True, default_factory=dict),
|
||||
) -> dict[str, Any]: # FIXME: add proper return type
|
||||
"""
|
||||
Execute a preset given graph parameters, returning the execution ID on success.
|
||||
|
||||
Args:
|
||||
graph_id (str): ID of the graph to execute.
|
||||
graph_version (int): Version of the graph to execute.
|
||||
preset_id (str): ID of the preset to execute.
|
||||
node_input (Dict[Any, Any]): Input data for the node.
|
||||
user_id (str): ID of the authenticated user.
|
||||
inputs (dict[str, Any]): Optionally, additional input data for the graph execution.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A response containing the execution ID.
|
||||
{id: graph_exec_id}: A response containing the execution ID.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the preset is not found or an error occurs while executing the preset.
|
||||
@@ -239,18 +326,18 @@ async def execute_preset(
|
||||
if not preset:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Preset not found",
|
||||
detail=f"Preset #{preset_id} not found",
|
||||
)
|
||||
|
||||
# Merge input overrides with preset inputs
|
||||
merged_node_input = preset.inputs | node_input
|
||||
merged_node_input = preset.inputs | inputs
|
||||
|
||||
execution = await add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
inputs=merged_node_input,
|
||||
graph_id=preset.graph_id,
|
||||
graph_version=preset.graph_version,
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
inputs=merged_node_input,
|
||||
)
|
||||
|
||||
logger.debug(f"Execution added: {execution} with input: {merged_node_input}")
|
||||
@@ -261,9 +348,6 @@ async def execute_preset(
|
||||
except Exception as e:
|
||||
logger.exception("Preset execution failed for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"message": str(e),
|
||||
"hint": "Review preset configuration and graph ID.",
|
||||
},
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
@@ -50,6 +50,8 @@ async def test_get_library_agents_success(
|
||||
creator_name="Test Creator",
|
||||
creator_image_url="",
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
new_output=False,
|
||||
can_access_graph=True,
|
||||
@@ -66,6 +68,8 @@ async def test_get_library_agents_success(
|
||||
creator_name="Test Creator",
|
||||
creator_image_url="",
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
new_output=False,
|
||||
can_access_graph=False,
|
||||
|
||||
@@ -14,7 +14,10 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/ask", response_model=ApiResponse, dependencies=[Depends(auth_middleware)]
|
||||
"/ask",
|
||||
response_model=ApiResponse,
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
summary="Proxy Otto Chat Request",
|
||||
)
|
||||
async def proxy_otto_request(
|
||||
request: ChatRequest, user_id: str = Depends(get_user_id)
|
||||
|
||||
@@ -29,6 +29,7 @@ router = fastapi.APIRouter()
|
||||
|
||||
@router.get(
|
||||
"/profile",
|
||||
summary="Get user profile",
|
||||
tags=["store", "private"],
|
||||
response_model=backend.server.v2.store.model.ProfileDetails,
|
||||
)
|
||||
@@ -61,6 +62,7 @@ async def get_profile(
|
||||
|
||||
@router.post(
|
||||
"/profile",
|
||||
summary="Update user profile",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
response_model=backend.server.v2.store.model.CreatorDetails,
|
||||
@@ -107,6 +109,7 @@ async def update_or_create_profile(
|
||||
|
||||
@router.get(
|
||||
"/agents",
|
||||
summary="List store agents",
|
||||
tags=["store", "public"],
|
||||
response_model=backend.server.v2.store.model.StoreAgentsResponse,
|
||||
)
|
||||
@@ -179,6 +182,7 @@ async def get_agents(
|
||||
|
||||
@router.get(
|
||||
"/agents/{username}/{agent_name}",
|
||||
summary="Get specific agent",
|
||||
tags=["store", "public"],
|
||||
response_model=backend.server.v2.store.model.StoreAgentDetails,
|
||||
)
|
||||
@@ -208,6 +212,7 @@ async def get_agent(username: str, agent_name: str):
|
||||
|
||||
@router.get(
|
||||
"/graph/{store_listing_version_id}",
|
||||
summary="Get agent graph",
|
||||
tags=["store"],
|
||||
)
|
||||
async def get_graph_meta_by_store_listing_version_id(
|
||||
@@ -232,6 +237,7 @@ async def get_graph_meta_by_store_listing_version_id(
|
||||
|
||||
@router.get(
|
||||
"/agents/{store_listing_version_id}",
|
||||
summary="Get agent by version",
|
||||
tags=["store"],
|
||||
response_model=backend.server.v2.store.model.StoreAgentDetails,
|
||||
)
|
||||
@@ -257,6 +263,7 @@ async def get_store_agent(
|
||||
|
||||
@router.post(
|
||||
"/agents/{username}/{agent_name}/review",
|
||||
summary="Create agent review",
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
response_model=backend.server.v2.store.model.StoreReview,
|
||||
@@ -308,6 +315,7 @@ async def create_review(
|
||||
|
||||
@router.get(
|
||||
"/creators",
|
||||
summary="List store creators",
|
||||
tags=["store", "public"],
|
||||
response_model=backend.server.v2.store.model.CreatorsResponse,
|
||||
)
|
||||
@@ -359,6 +367,7 @@ async def get_creators(
|
||||
|
||||
@router.get(
|
||||
"/creator/{username}",
|
||||
summary="Get creator details",
|
||||
tags=["store", "public"],
|
||||
response_model=backend.server.v2.store.model.CreatorDetails,
|
||||
)
|
||||
@@ -390,6 +399,7 @@ async def get_creator(
|
||||
############################################
|
||||
@router.get(
|
||||
"/myagents",
|
||||
summary="Get my agents",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
response_model=backend.server.v2.store.model.MyAgentsResponse,
|
||||
@@ -412,6 +422,7 @@ async def get_my_agents(
|
||||
|
||||
@router.delete(
|
||||
"/submissions/{submission_id}",
|
||||
summary="Delete store submission",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
response_model=bool,
|
||||
@@ -448,6 +459,7 @@ async def delete_submission(
|
||||
|
||||
@router.get(
|
||||
"/submissions",
|
||||
summary="List my submissions",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
response_model=backend.server.v2.store.model.StoreSubmissionsResponse,
|
||||
@@ -501,6 +513,7 @@ async def get_submissions(
|
||||
|
||||
@router.post(
|
||||
"/submissions",
|
||||
summary="Create store submission",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
response_model=backend.server.v2.store.model.StoreSubmission,
|
||||
@@ -548,6 +561,7 @@ async def create_submission(
|
||||
|
||||
@router.post(
|
||||
"/submissions/media",
|
||||
summary="Upload submission media",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
@@ -585,6 +599,7 @@ async def upload_submission_media(
|
||||
|
||||
@router.post(
|
||||
"/submissions/generate_image",
|
||||
summary="Generate submission image",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
@@ -646,6 +661,7 @@ async def generate_image(
|
||||
|
||||
@router.get(
|
||||
"/download/agents/{store_listing_version_id}",
|
||||
summary="Download agent file",
|
||||
tags=["store", "public"],
|
||||
)
|
||||
async def download_agent_file(
|
||||
|
||||
@@ -13,7 +13,9 @@ router = APIRouter()
|
||||
settings = Settings()
|
||||
|
||||
|
||||
@router.post("/verify", response_model=TurnstileVerifyResponse)
|
||||
@router.post(
|
||||
"/verify", response_model=TurnstileVerifyResponse, summary="Verify Turnstile Token"
|
||||
)
|
||||
async def verify_turnstile_token(
|
||||
request: TurnstileVerifyRequest,
|
||||
) -> TurnstileVerifyResponse:
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentNodeExecutionInputOutput" ALTER COLUMN "data" DROP NOT NULL;
|
||||
@@ -0,0 +1,5 @@
|
||||
-- Add webhookId column
|
||||
ALTER TABLE "AgentPreset" ADD COLUMN "webhookId" TEXT;
|
||||
|
||||
-- Add AgentPreset<->IntegrationWebhook relation
|
||||
ALTER TABLE "AgentPreset" ADD CONSTRAINT "AgentPreset_webhookId_fkey" FOREIGN KEY ("webhookId") REFERENCES "IntegrationWebhook"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
@@ -80,8 +80,8 @@ enum OnboardingStep {
|
||||
}
|
||||
|
||||
model UserOnboarding {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime? @updatedAt
|
||||
|
||||
completedSteps OnboardingStep[] @default([])
|
||||
@@ -122,7 +122,7 @@ model AgentGraph {
|
||||
|
||||
forkedFromId String?
|
||||
forkedFromVersion Int?
|
||||
forkedFrom AgentGraph? @relation("AgentGraphForks", fields: [forkedFromId, forkedFromVersion], references: [id, version])
|
||||
forkedFrom AgentGraph? @relation("AgentGraphForks", fields: [forkedFromId, forkedFromVersion], references: [id, version])
|
||||
forks AgentGraph[] @relation("AgentGraphForks")
|
||||
|
||||
Nodes AgentNode[]
|
||||
@@ -169,6 +169,10 @@ model AgentPreset {
|
||||
InputPresets AgentNodeExecutionInputOutput[] @relation("AgentPresetsInputData")
|
||||
Executions AgentGraphExecution[]
|
||||
|
||||
// For webhook-triggered agents: reference to the webhook that triggers the agent
|
||||
webhookId String?
|
||||
Webhook IntegrationWebhook? @relation(fields: [webhookId], references: [id])
|
||||
|
||||
isDeleted Boolean @default(false)
|
||||
|
||||
@@index([userId])
|
||||
@@ -390,7 +394,7 @@ model AgentNodeExecutionInputOutput {
|
||||
id String @id @default(uuid())
|
||||
|
||||
name String
|
||||
data Json
|
||||
data Json?
|
||||
time DateTime @default(now())
|
||||
|
||||
// Prisma requires explicit back-references.
|
||||
@@ -428,7 +432,8 @@ model IntegrationWebhook {
|
||||
|
||||
providerWebhookId String // Webhook ID assigned by the provider
|
||||
|
||||
AgentNodes AgentNode[]
|
||||
AgentNodes AgentNode[]
|
||||
AgentPresets AgentPreset[]
|
||||
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
@@ -366,14 +366,13 @@ async def test_execute_preset(server: SpinTestServer):
|
||||
"dictionary": {"key1": "Hello", "key2": "World"},
|
||||
"selected_value": "key2",
|
||||
},
|
||||
credentials={},
|
||||
is_active=True,
|
||||
)
|
||||
created_preset = await server.agent_server.test_create_preset(preset, test_user.id)
|
||||
|
||||
# Execute preset with overriding values
|
||||
result = await server.agent_server.test_execute_preset(
|
||||
graph_id=test_graph.id,
|
||||
graph_version=test_graph.version,
|
||||
preset_id=created_preset.id,
|
||||
user_id=test_user.id,
|
||||
)
|
||||
@@ -455,16 +454,15 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
|
||||
"dictionary": {"key1": "Hello", "key2": "World"},
|
||||
"selected_value": "key2",
|
||||
},
|
||||
credentials={},
|
||||
is_active=True,
|
||||
)
|
||||
created_preset = await server.agent_server.test_create_preset(preset, test_user.id)
|
||||
|
||||
# Execute preset with overriding values
|
||||
result = await server.agent_server.test_execute_preset(
|
||||
graph_id=test_graph.id,
|
||||
graph_version=test_graph.version,
|
||||
preset_id=created_preset.id,
|
||||
node_input={"selected_value": "key1"},
|
||||
inputs={"selected_value": "key1"},
|
||||
user_id=test_user.id,
|
||||
)
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ NEXT_PUBLIC_LAUNCHDARKLY_ENABLED=false
|
||||
NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID=
|
||||
NEXT_PUBLIC_APP_ENV=local
|
||||
|
||||
NEXT_PUBLIC_AGPT_SERVER_BASE_URL=http://localhost:8006
|
||||
|
||||
## Locale settings
|
||||
|
||||
NEXT_PUBLIC_DEFAULT_LOCALE=en
|
||||
@@ -33,3 +35,6 @@ NEXT_PUBLIC_SHOW_BILLING_PAGE=false
|
||||
## This is the frontend site key
|
||||
NEXT_PUBLIC_CLOUDFLARE_TURNSTILE_SITE_KEY=
|
||||
NEXT_PUBLIC_TURNSTILE=disabled
|
||||
|
||||
# Devtools
|
||||
NEXT_PUBLIC_REACT_QUERY_DEVTOOL=true
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
"extends": [
|
||||
"next/core-web-vitals",
|
||||
"next/typescript",
|
||||
"plugin:storybook/recommended"
|
||||
"plugin:storybook/recommended",
|
||||
"plugin:@tanstack/query/recommended"
|
||||
],
|
||||
"rules": {
|
||||
// Disabling exhaustive-deps to avoid forcing unnecessary dependencies and useCallback proliferation.
|
||||
|
||||
@@ -156,3 +156,9 @@ By integrating Storybook into our development workflow, we can streamline UI dev
|
||||
- [**Zod**](https://zod.dev/) - TypeScript-first schema validation
|
||||
- [**React Table**](https://tanstack.com/table) - Headless table library
|
||||
- [**React Flow**](https://reactflow.dev/) - Interactive node-based diagrams
|
||||
- [**React Query**](https://tanstack.com/query/latest/docs/framework/react/overview) - Data fetching and caching
|
||||
- [**React Query DevTools**](https://tanstack.com/query/latest/docs/framework/react/devtools) - Debugging tool for React Query
|
||||
|
||||
### Development Tools
|
||||
|
||||
- `NEXT_PUBLIC_REACT_QUERY_DEVTOOL` - Enable React Query DevTools. Set to `true` to enable.
|
||||
|
||||
59
autogpt_platform/frontend/orval.config.ts
Normal file
59
autogpt_platform/frontend/orval.config.ts
Normal file
@@ -0,0 +1,59 @@
|
||||
import { defineConfig } from "orval";
|
||||
|
||||
export default defineConfig({
|
||||
autogpt_api_client: {
|
||||
input: {
|
||||
target: `./src/api/openapi.json`,
|
||||
override: {
|
||||
transformer: "./src/api/transformers/fix-tags.mjs",
|
||||
},
|
||||
},
|
||||
output: {
|
||||
workspace: "./src/api",
|
||||
target: `./__generated__/endpoints`,
|
||||
schemas: "./__generated__/models",
|
||||
mode: "tags-split",
|
||||
client: "react-query",
|
||||
httpClient: "fetch",
|
||||
indexFiles: false,
|
||||
mock: {
|
||||
type: "msw",
|
||||
delay: 1000, // artifical latency
|
||||
generateEachHttpStatus: true, // helps us test error-handling scenarios and generate mocks for all HTTP statuses
|
||||
},
|
||||
override: {
|
||||
mutator: {
|
||||
path: "./mutators/custom-mutator.ts",
|
||||
name: "customMutator",
|
||||
},
|
||||
query: {
|
||||
useQuery: true,
|
||||
useMutation: true,
|
||||
// Will add more as their use cases arise
|
||||
},
|
||||
},
|
||||
},
|
||||
hooks: {
|
||||
afterAllFilesWrite: "prettier --write",
|
||||
},
|
||||
},
|
||||
autogpt_zod_schema: {
|
||||
input: {
|
||||
target: `./src/api/openapi.json`,
|
||||
override: {
|
||||
transformer: "./src/api/transformers/fix-tags.mjs",
|
||||
},
|
||||
},
|
||||
output: {
|
||||
workspace: "./src/api",
|
||||
target: `./__generated__/zod-schema`,
|
||||
schemas: "./__generated__/models",
|
||||
mode: "tags-split",
|
||||
client: "zod",
|
||||
indexFiles: false,
|
||||
},
|
||||
hooks: {
|
||||
afterAllFilesWrite: "prettier --write",
|
||||
},
|
||||
},
|
||||
});
|
||||
@@ -5,7 +5,7 @@
|
||||
"scripts": {
|
||||
"dev": "next dev --turbo",
|
||||
"dev:test": "NODE_ENV=test && next dev --turbo",
|
||||
"build": "SKIP_STORYBOOK_TESTS=true next build",
|
||||
"build": "pnpm run generate:api-client && SKIP_STORYBOOK_TESTS=true next build",
|
||||
"start": "next start",
|
||||
"start:standalone": "cd .next/standalone && node server.js",
|
||||
"lint": "next lint && prettier --check .",
|
||||
@@ -18,7 +18,10 @@
|
||||
"storybook": "storybook dev -p 6006",
|
||||
"build-storybook": "storybook build",
|
||||
"test-storybook": "test-storybook",
|
||||
"test-storybook:ci": "concurrently -k -s first -n \"SB,TEST\" -c \"magenta,blue\" \"pnpm run build-storybook -- --quiet && npx http-server storybook-static --port 6006 --silent\" \"wait-on tcp:6006 && pnpm run test-storybook\""
|
||||
"test-storybook:ci": "concurrently -k -s first -n \"SB,TEST\" -c \"magenta,blue\" \"pnpm run build-storybook -- --quiet && npx http-server storybook-static --port 6006 --silent\" \"wait-on tcp:6006 && pnpm run test-storybook\"",
|
||||
"fetch:openapi": "curl http://localhost:8006/openapi.json > ./src/api/openapi.json && prettier --write ./src/api/openapi.json",
|
||||
"generate:api-client": "orval --config ./orval.config.ts",
|
||||
"generate:api-all": "pnpm run fetch:openapi && pnpm run generate:api-client"
|
||||
},
|
||||
"browserslist": [
|
||||
"defaults"
|
||||
@@ -98,6 +101,8 @@
|
||||
"@storybook/addon-links": "9.0.12",
|
||||
"@storybook/addon-onboarding": "9.0.12",
|
||||
"@storybook/nextjs": "9.0.12",
|
||||
"@tanstack/eslint-plugin-query": "5.78.0",
|
||||
"@tanstack/react-query-devtools": "5.80.10",
|
||||
"@types/canvas-confetti": "1.9.0",
|
||||
"@types/lodash": "4.17.18",
|
||||
"@types/negotiator": "0.6.4",
|
||||
@@ -114,6 +119,7 @@
|
||||
"import-in-the-middle": "1.14.2",
|
||||
"msw": "2.10.2",
|
||||
"msw-storybook-addon": "2.0.5",
|
||||
"orval": "7.10.0",
|
||||
"postcss": "8.5.6",
|
||||
"prettier": "3.5.3",
|
||||
"prettier-plugin-tailwindcss": "0.6.12",
|
||||
|
||||
1256
autogpt_platform/frontend/pnpm-lock.yaml
generated
1256
autogpt_platform/frontend/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
81
autogpt_platform/frontend/src/api/mutators/custom-mutator.ts
Normal file
81
autogpt_platform/frontend/src/api/mutators/custom-mutator.ts
Normal file
@@ -0,0 +1,81 @@
|
||||
import { getSupabaseClient } from "@/lib/supabase/getSupabaseClient";
|
||||
|
||||
const BASE_URL =
|
||||
process.env.NEXT_PUBLIC_AGPT_SERVER_BASE_URL || "http://localhost:8006";
|
||||
|
||||
const getBody = <T>(c: Response | Request): Promise<T> => {
|
||||
const contentType = c.headers.get("content-type");
|
||||
|
||||
if (contentType && contentType.includes("application/json")) {
|
||||
return c.json();
|
||||
}
|
||||
|
||||
if (contentType && contentType.includes("application/pdf")) {
|
||||
return c.blob() as Promise<T>;
|
||||
}
|
||||
|
||||
return c.text() as Promise<T>;
|
||||
};
|
||||
|
||||
const getSupabaseToken = async () => {
|
||||
const supabase = await getSupabaseClient();
|
||||
|
||||
const {
|
||||
data: { session },
|
||||
} = (await supabase?.auth.getSession()) || {
|
||||
data: { session: null },
|
||||
};
|
||||
|
||||
return session?.access_token;
|
||||
};
|
||||
|
||||
export const customMutator = async <T = any>(
|
||||
url: string,
|
||||
options: RequestInit & {
|
||||
params?: any;
|
||||
} = {},
|
||||
): Promise<T> => {
|
||||
const { params, ...requestOptions } = options;
|
||||
const method = (requestOptions.method || "GET") as
|
||||
| "GET"
|
||||
| "POST"
|
||||
| "PUT"
|
||||
| "DELETE"
|
||||
| "PATCH";
|
||||
const data = requestOptions.body;
|
||||
const headers: Record<string, string> = {
|
||||
...((requestOptions.headers as Record<string, string>) || {}),
|
||||
};
|
||||
|
||||
const token = await getSupabaseToken();
|
||||
|
||||
if (token) {
|
||||
headers["Authorization"] = `Bearer ${token}`;
|
||||
}
|
||||
|
||||
const isFormData = data instanceof FormData;
|
||||
|
||||
// Currently, only two content types are handled here: application/json and multipart/form-data
|
||||
if (!isFormData && data && !headers["Content-Type"]) {
|
||||
headers["Content-Type"] = "application/json";
|
||||
}
|
||||
|
||||
const queryString = params
|
||||
? "?" + new URLSearchParams(params).toString()
|
||||
: "";
|
||||
|
||||
const response = await fetch(`${BASE_URL}${url}${queryString}`, {
|
||||
...requestOptions,
|
||||
method,
|
||||
headers,
|
||||
body: data,
|
||||
});
|
||||
|
||||
const response_data = await getBody<T>(response);
|
||||
|
||||
return {
|
||||
status: response.status,
|
||||
response_data,
|
||||
headers: response.headers,
|
||||
} as T;
|
||||
};
|
||||
6093
autogpt_platform/frontend/src/api/openapi.json
Normal file
6093
autogpt_platform/frontend/src/api/openapi.json
Normal file
File diff suppressed because it is too large
Load Diff
57
autogpt_platform/frontend/src/api/transformers/fix-tags.mjs
Normal file
57
autogpt_platform/frontend/src/api/transformers/fix-tags.mjs
Normal file
@@ -0,0 +1,57 @@
|
||||
/**
|
||||
* Transformer function for orval that fixes tags in OpenAPI spec.
|
||||
* 1. Create a set of tags so we have unique values
|
||||
* 2. Then remove public, private, v1, and v2 tags from tags array
|
||||
* 3. Then arrange remaining tags alphabetically and only keep the first one
|
||||
*
|
||||
* @param {OpenAPIObject} inputSchema
|
||||
* @return {OpenAPIObject}
|
||||
*/
|
||||
|
||||
export const tagTransformer = (inputSchema) => {
|
||||
const processedPaths = Object.entries(inputSchema.paths || {}).reduce(
|
||||
(acc, [path, pathItem]) => ({
|
||||
...acc,
|
||||
[path]: Object.entries(pathItem || {}).reduce(
|
||||
(pathItemAcc, [verb, operation]) => {
|
||||
if (typeof operation === "object" && operation !== null) {
|
||||
// 1. Create a set of tags so we have unique values
|
||||
const uniqueTags = Array.from(new Set(operation.tags || []));
|
||||
|
||||
// 2. Remove public, private, v1, and v2 tags from tags array
|
||||
const filteredTags = uniqueTags.filter(
|
||||
(tag) =>
|
||||
!["public", "private"].includes(tag.toLowerCase()) &&
|
||||
!/^v[12]$/i.test(tag),
|
||||
);
|
||||
|
||||
// 3. Arrange tags alphabetically and only keep the first one
|
||||
const sortedTags = filteredTags.sort((a, b) => a.localeCompare(b));
|
||||
const firstTag = sortedTags.length > 0 ? [sortedTags[0]] : [];
|
||||
|
||||
return {
|
||||
...pathItemAcc,
|
||||
[verb]: {
|
||||
...operation,
|
||||
tags: firstTag,
|
||||
},
|
||||
};
|
||||
}
|
||||
return {
|
||||
...pathItemAcc,
|
||||
[verb]: operation,
|
||||
};
|
||||
},
|
||||
{},
|
||||
),
|
||||
}),
|
||||
{},
|
||||
);
|
||||
|
||||
return {
|
||||
...inputSchema,
|
||||
paths: processedPaths,
|
||||
};
|
||||
};
|
||||
|
||||
export default tagTransformer;
|
||||
@@ -20,6 +20,8 @@ import {
|
||||
LibraryAgentID,
|
||||
Schedule,
|
||||
ScheduleID,
|
||||
LibraryAgentPreset,
|
||||
LibraryAgentPresetID,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
|
||||
import type { ButtonAction } from "@/components/agptui/types";
|
||||
@@ -52,9 +54,11 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
const [graph, setGraph] = useState<Graph | null>(null); // Graph version corresponding to LibraryAgent
|
||||
const [agent, setAgent] = useState<LibraryAgent | null>(null);
|
||||
const [agentRuns, setAgentRuns] = useState<GraphExecutionMeta[]>([]);
|
||||
const [agentPresets, setAgentPresets] = useState<LibraryAgentPreset[]>([]);
|
||||
const [schedules, setSchedules] = useState<Schedule[]>([]);
|
||||
const [selectedView, selectView] = useState<
|
||||
| { type: "run"; id?: GraphExecutionID }
|
||||
| { type: "preset"; id: LibraryAgentPresetID }
|
||||
| { type: "schedule"; id: ScheduleID }
|
||||
>({ type: "run" });
|
||||
const [selectedRun, setSelectedRun] = useState<
|
||||
@@ -68,6 +72,8 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
useState<boolean>(false);
|
||||
const [confirmingDeleteAgentRun, setConfirmingDeleteAgentRun] =
|
||||
useState<GraphExecutionMeta | null>(null);
|
||||
const [confirmingDeleteAgentPreset, setConfirmingDeleteAgentPreset] =
|
||||
useState<LibraryAgentPresetID | null>(null);
|
||||
const {
|
||||
state: onboardingState,
|
||||
updateState: updateOnboardingState,
|
||||
@@ -90,6 +96,10 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
selectView({ type: "run", id });
|
||||
}, []);
|
||||
|
||||
const selectPreset = useCallback((id: LibraryAgentPresetID) => {
|
||||
selectView({ type: "preset", id });
|
||||
}, []);
|
||||
|
||||
const selectSchedule = useCallback((schedule: Schedule) => {
|
||||
selectView({ type: "schedule", id: schedule.id });
|
||||
setSelectedSchedule(schedule);
|
||||
@@ -143,12 +153,19 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
(_graph) =>
|
||||
(graph && graph.version == _graph.version) || setGraph(_graph),
|
||||
);
|
||||
api.getGraphExecutions(agent.graph_id).then((agentRuns) => {
|
||||
setAgentRuns(agentRuns);
|
||||
Promise.all([
|
||||
api.getGraphExecutions(agent.graph_id),
|
||||
api.listLibraryAgentPresets({
|
||||
graph_id: agent.graph_id,
|
||||
page_size: 100,
|
||||
}),
|
||||
]).then(([runs, presets]) => {
|
||||
setAgentRuns(runs);
|
||||
setAgentPresets(presets.presets);
|
||||
|
||||
// Preload the corresponding graph versions
|
||||
new Set(agentRuns.map((run) => run.graph_version)).forEach((version) =>
|
||||
getGraphVersion(agent.graph_id, version),
|
||||
// Preload the corresponding graph versions for the latest 10 runs
|
||||
new Set(runs.slice(0, 10).map((run) => run.graph_version)).forEach(
|
||||
(version) => getGraphVersion(agent.graph_id, version),
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -157,16 +174,33 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
// On first load: select the latest run
|
||||
useEffect(() => {
|
||||
// Only for first load or first execution
|
||||
if (selectedView.id || !isFirstLoad || agentRuns.length == 0) return;
|
||||
setIsFirstLoad(false);
|
||||
if (selectedView.id || !isFirstLoad) return;
|
||||
if (agentRuns.length == 0 && agentPresets.length == 0) return;
|
||||
|
||||
const latestRun = agentRuns.reduce((latest, current) => {
|
||||
if (latest.started_at && !current.started_at) return current;
|
||||
else if (!latest.started_at) return latest;
|
||||
return latest.started_at > current.started_at ? latest : current;
|
||||
}, agentRuns[0]);
|
||||
selectView({ type: "run", id: latestRun.id });
|
||||
}, [agentRuns, isFirstLoad, selectedView.id, selectView]);
|
||||
setIsFirstLoad(false);
|
||||
if (agentRuns.length > 0) {
|
||||
// select latest run
|
||||
const latestRun = agentRuns.reduce((latest, current) => {
|
||||
if (latest.started_at && !current.started_at) return current;
|
||||
else if (!latest.started_at) return latest;
|
||||
return latest.started_at > current.started_at ? latest : current;
|
||||
}, agentRuns[0]);
|
||||
selectRun(latestRun.id);
|
||||
} else {
|
||||
// select top preset
|
||||
const latestPreset = agentPresets.toSorted(
|
||||
(a, b) => b.updated_at.getTime() - a.updated_at.getTime(),
|
||||
)[0];
|
||||
selectPreset(latestPreset.id);
|
||||
}
|
||||
}, [
|
||||
isFirstLoad,
|
||||
selectedView.id,
|
||||
agentRuns,
|
||||
agentPresets,
|
||||
selectRun,
|
||||
selectPreset,
|
||||
]);
|
||||
|
||||
// Initial load
|
||||
useEffect(() => {
|
||||
@@ -304,9 +338,22 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
if (selectedView.type == "run" && selectedView.id == run.id) {
|
||||
openRunDraftView();
|
||||
}
|
||||
setAgentRuns(agentRuns.filter((r) => r.id !== run.id));
|
||||
setAgentRuns((runs) => runs.filter((r) => r.id !== run.id));
|
||||
},
|
||||
[agentRuns, api, selectedView, openRunDraftView],
|
||||
[api, selectedView, openRunDraftView],
|
||||
);
|
||||
|
||||
const deletePreset = useCallback(
|
||||
async (presetID: LibraryAgentPresetID) => {
|
||||
await api.deleteLibraryAgentPreset(presetID);
|
||||
|
||||
setConfirmingDeleteAgentPreset(null);
|
||||
if (selectedView.type == "preset" && selectedView.id == presetID) {
|
||||
openRunDraftView();
|
||||
}
|
||||
setAgentPresets((presets) => presets.filter((p) => p.id !== presetID));
|
||||
},
|
||||
[api, selectedView, openRunDraftView],
|
||||
);
|
||||
|
||||
const deleteSchedule = useCallback(
|
||||
@@ -370,11 +417,22 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
[agent, downloadGraph],
|
||||
);
|
||||
|
||||
const onRun = useCallback(
|
||||
(runID: GraphExecutionID) => {
|
||||
selectRun(runID);
|
||||
const onCreatePreset = useCallback(
|
||||
(preset: LibraryAgentPreset) => {
|
||||
setAgentPresets((prev) => [...prev, preset]);
|
||||
selectPreset(preset.id);
|
||||
},
|
||||
[selectRun],
|
||||
[selectPreset],
|
||||
);
|
||||
|
||||
const onUpdatePreset = useCallback(
|
||||
(updated: LibraryAgentPreset) => {
|
||||
setAgentPresets((prev) =>
|
||||
prev.map((p) => (p.id === updated.id ? updated : p)),
|
||||
);
|
||||
selectPreset(updated.id);
|
||||
},
|
||||
[selectPreset],
|
||||
);
|
||||
|
||||
if (!agent || !graph) {
|
||||
@@ -389,14 +447,16 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
className="agpt-div w-full border-b lg:w-auto lg:border-b-0 lg:border-r"
|
||||
agent={agent}
|
||||
agentRuns={agentRuns}
|
||||
agentPresets={agentPresets}
|
||||
schedules={schedules}
|
||||
selectedView={selectedView}
|
||||
allowDraftNewRun={!graph.has_webhook_trigger}
|
||||
onSelectRun={selectRun}
|
||||
onSelectPreset={selectPreset}
|
||||
onSelectSchedule={selectSchedule}
|
||||
onSelectDraftNewRun={openRunDraftView}
|
||||
onDeleteRun={setConfirmingDeleteAgentRun}
|
||||
onDeleteSchedule={(id) => deleteSchedule(id)}
|
||||
onDeletePreset={setConfirmingDeleteAgentPreset}
|
||||
onDeleteSchedule={deleteSchedule}
|
||||
/>
|
||||
|
||||
<div className="flex-1">
|
||||
@@ -417,14 +477,28 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
graph={graphVersions.current[selectedRun.graph_version] ?? graph}
|
||||
run={selectedRun}
|
||||
agentActions={agentActions}
|
||||
onRun={onRun}
|
||||
onRun={selectRun}
|
||||
deleteRun={() => setConfirmingDeleteAgentRun(selectedRun)}
|
||||
/>
|
||||
)
|
||||
) : selectedView.type == "run" ? (
|
||||
/* Draft new runs / Create new presets */
|
||||
<AgentRunDraftView
|
||||
graph={graph}
|
||||
onRun={onRun}
|
||||
agent={agent}
|
||||
onRun={selectRun}
|
||||
onCreatePreset={onCreatePreset}
|
||||
agentActions={agentActions}
|
||||
/>
|
||||
) : selectedView.type == "preset" ? (
|
||||
/* Edit & update presets */
|
||||
<AgentRunDraftView
|
||||
agent={agent}
|
||||
agentPreset={
|
||||
agentPresets.find((preset) => preset.id == selectedView.id)!
|
||||
}
|
||||
onRun={selectRun}
|
||||
onUpdatePreset={onUpdatePreset}
|
||||
doDeletePreset={setConfirmingDeleteAgentPreset}
|
||||
agentActions={agentActions}
|
||||
/>
|
||||
) : selectedView.type == "schedule" ? (
|
||||
@@ -432,7 +506,7 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
<AgentScheduleDetailsView
|
||||
graph={graph}
|
||||
schedule={selectedSchedule}
|
||||
onForcedRun={onRun}
|
||||
onForcedRun={selectRun}
|
||||
agentActions={agentActions}
|
||||
/>
|
||||
)
|
||||
@@ -458,6 +532,15 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
confirmingDeleteAgentRun && deleteRun(confirmingDeleteAgentRun)
|
||||
}
|
||||
/>
|
||||
<DeleteConfirmDialog
|
||||
entityType={agent.has_external_trigger ? "trigger" : "agent preset"}
|
||||
open={!!confirmingDeleteAgentPreset}
|
||||
onOpenChange={(open) => !open && setConfirmingDeleteAgentPreset(null)}
|
||||
onDoDelete={() =>
|
||||
confirmingDeleteAgentPreset &&
|
||||
deletePreset(confirmingDeleteAgentPreset)
|
||||
}
|
||||
/>
|
||||
{/* Copy agent confirmation dialog */}
|
||||
<Dialog
|
||||
onOpenChange={setCopyAgentDialogOpen}
|
||||
|
||||
@@ -8,6 +8,7 @@ import { Toaster } from "@/components/ui/toaster";
|
||||
import { Providers } from "@/app/providers";
|
||||
import TallyPopupSimple from "@/components/TallyPopup";
|
||||
import { GoogleAnalytics } from "@/components/analytics/google-analytics";
|
||||
import { ReactQueryDevtools } from "@tanstack/react-query-devtools";
|
||||
|
||||
export const metadata: Metadata = {
|
||||
title: "AutoGPT Platform",
|
||||
@@ -41,6 +42,14 @@ export default async function RootLayout({
|
||||
<div className="flex min-h-screen flex-col items-stretch justify-items-stretch">
|
||||
{children}
|
||||
<TallyPopupSimple />
|
||||
|
||||
{/* React Query DevTools is only available in development */}
|
||||
{process.env.NEXT_PUBLIC_REACT_QUERY_DEVTOOL && (
|
||||
<ReactQueryDevtools
|
||||
initialIsOpen={false}
|
||||
buttonPosition={"bottom-left"}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<Toaster />
|
||||
</Providers>
|
||||
|
||||
@@ -268,7 +268,7 @@ export const CustomNode = React.memo(
|
||||
|
||||
default:
|
||||
const getInputPropKey = (key: string) =>
|
||||
nodeType == BlockUIType.AGENT ? `data.${key}` : key;
|
||||
nodeType == BlockUIType.AGENT ? `inputs.${key}` : key;
|
||||
|
||||
return keys.map(([propKey, propSchema]) => {
|
||||
const isRequired = data.inputSchema.required?.includes(propKey);
|
||||
|
||||
@@ -26,15 +26,23 @@ export default function NodeOutputs({
|
||||
<div className="mt-2">
|
||||
<strong className="mr-2">Data:</strong>
|
||||
<div className="mt-1">
|
||||
{dataArray.map((item, index) => (
|
||||
{dataArray.slice(0, 10).map((item, index) => (
|
||||
<React.Fragment key={index}>
|
||||
<ContentRenderer
|
||||
value={item}
|
||||
truncateLongData={truncateLongData}
|
||||
/>
|
||||
{index < dataArray.length - 1 && ", "}
|
||||
{index < Math.min(dataArray.length, 10) - 1 && ", "}
|
||||
</React.Fragment>
|
||||
))}
|
||||
{dataArray.length > 10 && (
|
||||
<span style={{ color: "#888" }}>
|
||||
<br />
|
||||
<b>⋮</b>
|
||||
<br />
|
||||
<span>and {dataArray.length - 10} more</span>
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<Separator.Root className="my-4 h-[1px] bg-gray-300" />
|
||||
</div>
|
||||
|
||||
@@ -1,73 +1,434 @@
|
||||
"use client";
|
||||
import React, { useCallback, useMemo, useState } from "react";
|
||||
import React, { useCallback, useEffect, useMemo, useState } from "react";
|
||||
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import { GraphExecutionID, GraphMeta } from "@/lib/autogpt-server-api";
|
||||
import {
|
||||
CredentialsMetaInput,
|
||||
GraphExecutionID,
|
||||
LibraryAgent,
|
||||
LibraryAgentPreset,
|
||||
LibraryAgentPresetID,
|
||||
LibraryAgentPresetUpdatable,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
|
||||
import type { ButtonAction } from "@/components/agptui/types";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { IconCross, IconPlay, IconSave } from "@/components/ui/icons";
|
||||
import { CredentialsInput } from "@/components/integrations/credentials-input";
|
||||
import { TypeBasedInput } from "@/components/type-based-input";
|
||||
import { useToastOnFail } from "@/components/ui/use-toast";
|
||||
import ActionButtonGroup from "@/components/agptui/action-button-group";
|
||||
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
|
||||
import { Trash2Icon } from "lucide-react";
|
||||
import SchemaTooltip from "@/components/SchemaTooltip";
|
||||
import { IconPlay } from "@/components/ui/icons";
|
||||
import { useOnboarding } from "../onboarding/onboarding-provider";
|
||||
import { useToast } from "@/components/ui/use-toast";
|
||||
import { isEmpty } from "lodash";
|
||||
import { Input } from "@/components/ui/input";
|
||||
|
||||
export default function AgentRunDraftView({
|
||||
graph,
|
||||
agent,
|
||||
agentPreset,
|
||||
onRun,
|
||||
onCreatePreset,
|
||||
onUpdatePreset,
|
||||
doDeletePreset,
|
||||
agentActions,
|
||||
}: {
|
||||
graph: GraphMeta;
|
||||
onRun: (runID: GraphExecutionID) => void;
|
||||
agent: LibraryAgent;
|
||||
agentActions: ButtonAction[];
|
||||
}): React.ReactNode {
|
||||
onRun: (runID: GraphExecutionID) => void;
|
||||
} & (
|
||||
| {
|
||||
onCreatePreset: (preset: LibraryAgentPreset) => void;
|
||||
agentPreset?: never;
|
||||
onUpdatePreset?: never;
|
||||
doDeletePreset?: never;
|
||||
}
|
||||
| {
|
||||
onCreatePreset?: never;
|
||||
agentPreset: LibraryAgentPreset;
|
||||
onUpdatePreset: (preset: LibraryAgentPreset) => void;
|
||||
doDeletePreset: (presetID: LibraryAgentPresetID) => void;
|
||||
}
|
||||
)): React.ReactNode {
|
||||
const api = useBackendAPI();
|
||||
const { toast } = useToast();
|
||||
const toastOnFail = useToastOnFail();
|
||||
|
||||
const agentInputs = graph.input_schema.properties;
|
||||
const agentCredentialsInputs = graph.credentials_input_schema.properties;
|
||||
const [inputValues, setInputValues] = useState<Record<string, any>>({});
|
||||
const [inputCredentials, setInputCredentials] = useState<Record<string, any>>(
|
||||
{},
|
||||
const [inputCredentials, setInputCredentials] = useState<
|
||||
Record<string, CredentialsMetaInput>
|
||||
>({});
|
||||
const [presetName, setPresetName] = useState<string>("");
|
||||
const [presetDescription, setPresetDescription] = useState<string>("");
|
||||
const [changedPresetAttributes, setChangedPresetAttributes] = useState<
|
||||
Set<keyof LibraryAgentPresetUpdatable>
|
||||
>(new Set());
|
||||
const { state: onboardingState, completeStep: completeOnboardingStep } =
|
||||
useOnboarding();
|
||||
|
||||
// Update values if agentPreset parameter is changed
|
||||
useEffect(() => {
|
||||
setInputValues(agentPreset?.inputs ?? {});
|
||||
setInputCredentials(agentPreset?.credentials ?? {});
|
||||
setPresetName(agentPreset?.name ?? "");
|
||||
setPresetDescription(agentPreset?.description ?? "");
|
||||
setChangedPresetAttributes(new Set());
|
||||
}, [agentPreset]);
|
||||
|
||||
const agentInputSchema = useMemo(
|
||||
() =>
|
||||
agent.has_external_trigger
|
||||
? agent.trigger_setup_info.config_schema
|
||||
: agent.input_schema,
|
||||
[agent],
|
||||
);
|
||||
const agentInputFields = useMemo(
|
||||
() =>
|
||||
Object.fromEntries(
|
||||
Object.entries(agentInputSchema.properties).filter(
|
||||
([_, subSchema]) => !subSchema.hidden,
|
||||
),
|
||||
),
|
||||
[agentInputSchema],
|
||||
);
|
||||
const agentCredentialsInputFields = useMemo(
|
||||
() => agent.credentials_input_schema.properties,
|
||||
[agent],
|
||||
);
|
||||
|
||||
const [allRequiredInputsAreSet, missingInputs] = useMemo(() => {
|
||||
const nonEmptyInputs = new Set(
|
||||
Object.keys(inputValues).filter((k) => !isEmpty(inputValues[k])),
|
||||
);
|
||||
const requiredInputs = new Set(
|
||||
agentInputSchema.required as string[] | undefined,
|
||||
);
|
||||
return [
|
||||
nonEmptyInputs.isSupersetOf(requiredInputs),
|
||||
[...requiredInputs.difference(nonEmptyInputs)],
|
||||
];
|
||||
}, [agentInputSchema.required, inputValues]);
|
||||
const [allCredentialsAreSet, missingCredentials] = useMemo(() => {
|
||||
const availableCredentials = new Set(Object.keys(inputCredentials));
|
||||
const allCredentials = new Set(Object.keys(agentCredentialsInputFields));
|
||||
return [
|
||||
availableCredentials.isSupersetOf(allCredentials),
|
||||
[...allCredentials.difference(availableCredentials)],
|
||||
];
|
||||
}, [agentCredentialsInputFields, inputCredentials]);
|
||||
const notifyMissingInputs = useCallback(
|
||||
(needPresetName: boolean = true) => {
|
||||
const allMissingFields = (
|
||||
needPresetName && !presetName
|
||||
? [agent.has_external_trigger ? "trigger_name" : "preset_name"]
|
||||
: []
|
||||
)
|
||||
.concat(missingInputs)
|
||||
.concat(missingCredentials);
|
||||
toast({
|
||||
title: "⚠️ Not all required inputs are set",
|
||||
description: `Please set ${allMissingFields.map((k) => `\`${k}\``).join(", ")}`,
|
||||
});
|
||||
},
|
||||
[missingInputs, missingCredentials],
|
||||
);
|
||||
const { state, completeStep } = useOnboarding();
|
||||
|
||||
const doRun = useCallback(() => {
|
||||
api
|
||||
.executeGraph(graph.id, graph.version, inputValues, inputCredentials)
|
||||
.then((newRun) => onRun(newRun.graph_exec_id))
|
||||
.catch(toastOnFail("execute agent"));
|
||||
// Manually running webhook-triggered agents is not supported
|
||||
if (agent.has_external_trigger) return;
|
||||
|
||||
if (!agentPreset || changedPresetAttributes.size > 0) {
|
||||
if (!allRequiredInputsAreSet || !allCredentialsAreSet) {
|
||||
notifyMissingInputs(false);
|
||||
return;
|
||||
}
|
||||
// TODO: on executing preset with changes, ask for confirmation and offer save+run
|
||||
api
|
||||
.executeGraph(
|
||||
agent.graph_id,
|
||||
agent.graph_version,
|
||||
inputValues,
|
||||
inputCredentials,
|
||||
)
|
||||
.then((newRun) => onRun(newRun.graph_exec_id))
|
||||
.catch(toastOnFail("execute agent"));
|
||||
} else {
|
||||
api
|
||||
.executeLibraryAgentPreset(agentPreset.id)
|
||||
.then((newRun) => onRun(newRun.id))
|
||||
.catch(toastOnFail("execute agent preset"));
|
||||
}
|
||||
// Mark run agent onboarding step as completed
|
||||
if (state?.completedSteps.includes("MARKETPLACE_ADD_AGENT")) {
|
||||
completeStep("MARKETPLACE_RUN_AGENT");
|
||||
if (onboardingState?.completedSteps.includes("MARKETPLACE_ADD_AGENT")) {
|
||||
completeOnboardingStep("MARKETPLACE_RUN_AGENT");
|
||||
}
|
||||
}, [
|
||||
api,
|
||||
graph,
|
||||
agent,
|
||||
inputValues,
|
||||
inputCredentials,
|
||||
onRun,
|
||||
toastOnFail,
|
||||
state,
|
||||
completeStep,
|
||||
onboardingState,
|
||||
completeOnboardingStep,
|
||||
]);
|
||||
|
||||
const doCreatePreset = useCallback(() => {
|
||||
if (!onCreatePreset) return;
|
||||
|
||||
if (!presetName || !allRequiredInputsAreSet || !allCredentialsAreSet) {
|
||||
notifyMissingInputs();
|
||||
return;
|
||||
}
|
||||
|
||||
api
|
||||
.createLibraryAgentPreset({
|
||||
name: presetName,
|
||||
description: presetDescription,
|
||||
graph_id: agent.graph_id,
|
||||
graph_version: agent.graph_version,
|
||||
inputs: inputValues,
|
||||
credentials: inputCredentials,
|
||||
})
|
||||
.then((newPreset) => {
|
||||
onCreatePreset(newPreset);
|
||||
setChangedPresetAttributes(new Set()); // reset change tracker
|
||||
})
|
||||
.catch(toastOnFail("save agent preset"));
|
||||
}, [
|
||||
api,
|
||||
agent,
|
||||
presetName,
|
||||
presetDescription,
|
||||
inputValues,
|
||||
inputCredentials,
|
||||
onCreatePreset,
|
||||
toast,
|
||||
toastOnFail,
|
||||
onboardingState,
|
||||
completeOnboardingStep,
|
||||
]);
|
||||
|
||||
const doUpdatePreset = useCallback(() => {
|
||||
if (!agentPreset || changedPresetAttributes.size == 0) return;
|
||||
|
||||
if (!presetName || !allRequiredInputsAreSet || !allCredentialsAreSet) {
|
||||
notifyMissingInputs();
|
||||
return;
|
||||
}
|
||||
|
||||
const updatePreset: LibraryAgentPresetUpdatable = {};
|
||||
if (changedPresetAttributes.has("name")) updatePreset["name"] = presetName;
|
||||
if (changedPresetAttributes.has("description"))
|
||||
updatePreset["description"] = presetDescription;
|
||||
if (
|
||||
changedPresetAttributes.has("inputs") ||
|
||||
changedPresetAttributes.has("credentials")
|
||||
) {
|
||||
updatePreset["inputs"] = inputValues;
|
||||
updatePreset["credentials"] = inputCredentials;
|
||||
}
|
||||
api
|
||||
.updateLibraryAgentPreset(agentPreset.id, updatePreset)
|
||||
.then((updatedPreset) => {
|
||||
onUpdatePreset(updatedPreset);
|
||||
setChangedPresetAttributes(new Set()); // reset change tracker
|
||||
})
|
||||
.catch(toastOnFail("update agent preset"));
|
||||
}, [
|
||||
api,
|
||||
agent,
|
||||
presetName,
|
||||
presetDescription,
|
||||
inputValues,
|
||||
inputCredentials,
|
||||
onUpdatePreset,
|
||||
toast,
|
||||
toastOnFail,
|
||||
onboardingState,
|
||||
completeOnboardingStep,
|
||||
]);
|
||||
|
||||
const doSetPresetActive = useCallback(
|
||||
async (active: boolean) => {
|
||||
if (!agentPreset) return;
|
||||
const updatedPreset = await api.updateLibraryAgentPreset(agentPreset.id, {
|
||||
is_active: active,
|
||||
});
|
||||
onUpdatePreset(updatedPreset);
|
||||
},
|
||||
[agentPreset, api, onUpdatePreset],
|
||||
);
|
||||
|
||||
const doSetupTrigger = useCallback(() => {
|
||||
// Setting up a trigger for non-webhook-triggered agents is not supported
|
||||
if (!agent.has_external_trigger || !onCreatePreset) return;
|
||||
|
||||
if (!presetName || !allRequiredInputsAreSet || !allCredentialsAreSet) {
|
||||
notifyMissingInputs();
|
||||
return;
|
||||
}
|
||||
|
||||
const credentialsInputName =
|
||||
agent.trigger_setup_info.credentials_input_name;
|
||||
|
||||
if (!credentialsInputName) {
|
||||
// FIXME: implement support for manual-setup webhooks
|
||||
toast({
|
||||
variant: "destructive",
|
||||
title: "🚧 Feature under construction",
|
||||
description: "Setting up non-auto-setup triggers is not yet supported.",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
api
|
||||
.setupAgentTrigger(agent.id, {
|
||||
name: presetName,
|
||||
description: presetDescription,
|
||||
trigger_config: inputValues,
|
||||
agent_credentials: inputCredentials,
|
||||
})
|
||||
.then((newPreset) => {
|
||||
onCreatePreset(newPreset);
|
||||
setChangedPresetAttributes(new Set()); // reset change tracker
|
||||
})
|
||||
.catch(toastOnFail("set up agent trigger"));
|
||||
|
||||
// Mark run agent onboarding step as completed(?)
|
||||
if (onboardingState?.completedSteps.includes("MARKETPLACE_ADD_AGENT")) {
|
||||
completeOnboardingStep("MARKETPLACE_RUN_AGENT");
|
||||
}
|
||||
}, [
|
||||
api,
|
||||
agent,
|
||||
presetName,
|
||||
presetDescription,
|
||||
inputValues,
|
||||
inputCredentials,
|
||||
onCreatePreset,
|
||||
toast,
|
||||
toastOnFail,
|
||||
onboardingState,
|
||||
completeOnboardingStep,
|
||||
]);
|
||||
|
||||
const runActions: ButtonAction[] = useMemo(
|
||||
() => [
|
||||
{
|
||||
label: (
|
||||
<>
|
||||
<IconPlay className="mr-2 size-5" />
|
||||
Run
|
||||
</>
|
||||
),
|
||||
variant: "accent",
|
||||
callback: doRun,
|
||||
},
|
||||
// "Regular" agent: [run] + [save as preset] buttons
|
||||
...(!agent.has_external_trigger
|
||||
? ([
|
||||
{
|
||||
label: (
|
||||
<>
|
||||
<IconPlay className="mr-2 size-4" /> Run
|
||||
</>
|
||||
),
|
||||
variant: "accent",
|
||||
callback: doRun,
|
||||
},
|
||||
// {
|
||||
// label: (
|
||||
// <>
|
||||
// <IconSave className="mr-2 size-4" /> Save as a preset
|
||||
// </>
|
||||
// ),
|
||||
// callback: doCreatePreset,
|
||||
// disabled: !(
|
||||
// presetName &&
|
||||
// allRequiredInputsAreSet &&
|
||||
// allCredentialsAreSet
|
||||
// ),
|
||||
// },
|
||||
] satisfies ButtonAction[])
|
||||
: []),
|
||||
// Triggered agent: [setup] button
|
||||
...(agent.has_external_trigger && !agentPreset?.webhook_id
|
||||
? ([
|
||||
{
|
||||
label: (
|
||||
<>
|
||||
<IconPlay className="mr-2 size-4" /> Set up trigger
|
||||
</>
|
||||
),
|
||||
variant: "accent",
|
||||
callback: doSetupTrigger,
|
||||
disabled: !(
|
||||
presetName &&
|
||||
allRequiredInputsAreSet &&
|
||||
allCredentialsAreSet
|
||||
),
|
||||
},
|
||||
] satisfies ButtonAction[])
|
||||
: []),
|
||||
// Existing agent trigger: [enable]/[disable] button
|
||||
...(agentPreset?.webhook_id
|
||||
? ([
|
||||
agentPreset.is_active
|
||||
? {
|
||||
label: (
|
||||
<>
|
||||
<IconCross className="mr-2.5 size-3.5" /> Disable trigger
|
||||
</>
|
||||
),
|
||||
variant: "destructive",
|
||||
callback: () => doSetPresetActive(false),
|
||||
}
|
||||
: {
|
||||
label: (
|
||||
<>
|
||||
<IconPlay className="mr-2 size-4" /> Enable trigger
|
||||
</>
|
||||
),
|
||||
variant: "accent",
|
||||
callback: () => doSetPresetActive(true),
|
||||
},
|
||||
] satisfies ButtonAction[])
|
||||
: []),
|
||||
// Existing agent preset/trigger: [save] and [delete] buttons
|
||||
...(agentPreset
|
||||
? ([
|
||||
{
|
||||
label: (
|
||||
<>
|
||||
<IconSave className="mr-2 size-4" /> Save changes
|
||||
</>
|
||||
),
|
||||
callback: doUpdatePreset,
|
||||
disabled: !(
|
||||
changedPresetAttributes.size > 0 &&
|
||||
presetName &&
|
||||
allRequiredInputsAreSet &&
|
||||
allCredentialsAreSet
|
||||
),
|
||||
},
|
||||
{
|
||||
label: (
|
||||
<>
|
||||
<Trash2Icon className="mr-2 size-4" />
|
||||
Delete {agent.has_external_trigger ? "trigger" : "preset"}
|
||||
</>
|
||||
),
|
||||
callback: () => doDeletePreset(agentPreset.id),
|
||||
},
|
||||
] satisfies ButtonAction[])
|
||||
: []),
|
||||
],
|
||||
[
|
||||
agent.has_external_trigger,
|
||||
agentPreset,
|
||||
api,
|
||||
doRun,
|
||||
doSetupTrigger,
|
||||
doCreatePreset,
|
||||
doUpdatePreset,
|
||||
doDeletePreset,
|
||||
changedPresetAttributes,
|
||||
presetName,
|
||||
allRequiredInputsAreSet,
|
||||
allCredentialsAreSet,
|
||||
],
|
||||
[doRun],
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -78,8 +439,49 @@ export default function AgentRunDraftView({
|
||||
<CardTitle className="font-poppins text-lg">Input</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent className="flex flex-col gap-4">
|
||||
{(agentPreset || agent.has_external_trigger) && (
|
||||
<>
|
||||
{/* Preset name and description */}
|
||||
<div className="flex flex-col space-y-2">
|
||||
<label className="flex items-center gap-1 text-sm font-medium">
|
||||
{agent.has_external_trigger ? "Trigger" : "Preset"} Name
|
||||
<SchemaTooltip
|
||||
description={`Name of the ${agent.has_external_trigger ? "trigger" : "preset"} you are setting up`}
|
||||
/>
|
||||
</label>
|
||||
<Input
|
||||
value={presetName}
|
||||
placeholder={`Enter ${agent.has_external_trigger ? "trigger" : "preset"} name`}
|
||||
onChange={(e) => {
|
||||
setPresetName(e.target.value);
|
||||
setChangedPresetAttributes((prev) => prev.add("name"));
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col space-y-2">
|
||||
<label className="flex items-center gap-1 text-sm font-medium">
|
||||
{agent.has_external_trigger ? "Trigger" : "Preset"}{" "}
|
||||
Description
|
||||
<SchemaTooltip
|
||||
description={`Description of the ${agent.has_external_trigger ? "trigger" : "preset"} you are setting up`}
|
||||
/>
|
||||
</label>
|
||||
<Input
|
||||
value={presetDescription}
|
||||
placeholder={`Enter ${agent.has_external_trigger ? "trigger" : "preset"} description`}
|
||||
onChange={(e) => {
|
||||
setPresetDescription(e.target.value);
|
||||
setChangedPresetAttributes((prev) =>
|
||||
prev.add("description"),
|
||||
);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Credentials inputs */}
|
||||
{Object.entries(agentCredentialsInputs).map(
|
||||
{Object.entries(agentCredentialsInputFields).map(
|
||||
([key, inputSubSchema]) => (
|
||||
<CredentialsInput
|
||||
key={key}
|
||||
@@ -87,18 +489,31 @@ export default function AgentRunDraftView({
|
||||
selectedCredentials={
|
||||
inputCredentials[key] ?? inputSubSchema.default
|
||||
}
|
||||
onSelectCredentials={(value) =>
|
||||
setInputCredentials((obj) => ({
|
||||
...obj,
|
||||
[key]: value,
|
||||
}))
|
||||
onSelectCredentials={(value) => {
|
||||
setInputCredentials((obj) => {
|
||||
const newObj = { ...obj };
|
||||
if (value === undefined) {
|
||||
delete newObj[key];
|
||||
return newObj;
|
||||
}
|
||||
return {
|
||||
...obj,
|
||||
[key]: value,
|
||||
};
|
||||
});
|
||||
setChangedPresetAttributes((prev) =>
|
||||
prev.add("credentials"),
|
||||
);
|
||||
}}
|
||||
hideIfSingleCredentialAvailable={
|
||||
!agentPreset && !agent.has_external_trigger
|
||||
}
|
||||
/>
|
||||
),
|
||||
)}
|
||||
|
||||
{/* Regular inputs */}
|
||||
{Object.entries(agentInputs).map(([key, inputSubSchema]) => (
|
||||
{Object.entries(agentInputFields).map(([key, inputSubSchema]) => (
|
||||
<div key={key} className="flex flex-col space-y-2">
|
||||
<label className="flex items-center gap-1 text-sm font-medium">
|
||||
{inputSubSchema.title || key}
|
||||
@@ -109,12 +524,13 @@ export default function AgentRunDraftView({
|
||||
schema={inputSubSchema}
|
||||
value={inputValues[key] ?? inputSubSchema.default}
|
||||
placeholder={inputSubSchema.description}
|
||||
onChange={(value) =>
|
||||
onChange={(value) => {
|
||||
setInputValues((obj) => ({
|
||||
...obj,
|
||||
[key]: value,
|
||||
}))
|
||||
}
|
||||
}));
|
||||
setChangedPresetAttributes((prev) => prev.add("inputs"));
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
@@ -125,7 +541,10 @@ export default function AgentRunDraftView({
|
||||
{/* Actions */}
|
||||
<aside className="w-48 xl:w-56">
|
||||
<div className="flex flex-col gap-8">
|
||||
<ActionButtonGroup title="Run actions" actions={runActions} />
|
||||
<ActionButtonGroup
|
||||
title={`${agent.has_external_trigger ? "Trigger" : agentPreset ? "Preset" : "Run"} actions`}
|
||||
actions={runActions}
|
||||
/>
|
||||
|
||||
<ActionButtonGroup title="Agent actions" actions={agentActions} />
|
||||
</div>
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import React from "react";
|
||||
import moment from "moment";
|
||||
import { MoreVertical } from "lucide-react";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Link2Icon, Link2OffIcon, MoreVertical } from "lucide-react";
|
||||
import { Card, CardContent } from "@/components/ui/card";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
@@ -16,11 +16,26 @@ import {
|
||||
import AgentRunStatusChip, {
|
||||
AgentRunStatus,
|
||||
} from "@/components/agents/agent-run-status-chip";
|
||||
import AgentStatusChip, {
|
||||
AgentStatus,
|
||||
} from "@/components/agents/agent-status-chip";
|
||||
|
||||
export type AgentRunSummaryProps = {
|
||||
status: AgentRunStatus;
|
||||
export type AgentRunSummaryProps = (
|
||||
| {
|
||||
type: "run";
|
||||
status: AgentRunStatus;
|
||||
}
|
||||
| {
|
||||
type: "preset";
|
||||
status: AgentStatus;
|
||||
}
|
||||
| {
|
||||
type: "schedule";
|
||||
status: "scheduled";
|
||||
}
|
||||
) & {
|
||||
title: string;
|
||||
timestamp: number | Date;
|
||||
timestamp?: number | Date;
|
||||
selected?: boolean;
|
||||
onClick?: () => void;
|
||||
// onRename: () => void;
|
||||
@@ -29,6 +44,7 @@ export type AgentRunSummaryProps = {
|
||||
};
|
||||
|
||||
export default function AgentRunSummaryCard({
|
||||
type,
|
||||
status,
|
||||
title,
|
||||
timestamp,
|
||||
@@ -48,7 +64,23 @@ export default function AgentRunSummaryCard({
|
||||
onClick={onClick}
|
||||
>
|
||||
<CardContent className="relative p-2.5 lg:p-4">
|
||||
<AgentRunStatusChip status={status} />
|
||||
{(type == "run" || type == "schedule") && (
|
||||
<AgentRunStatusChip status={status} />
|
||||
)}
|
||||
{type == "preset" && (
|
||||
<div className="flex items-center justify-between">
|
||||
<AgentStatusChip status={status} />
|
||||
|
||||
<div className="flex items-center text-sm text-zinc-400">
|
||||
{status == "inactive" ? (
|
||||
<Link2OffIcon className="mr-1 size-4" />
|
||||
) : (
|
||||
<Link2Icon className="mr-1 size-4" />
|
||||
)}{" "}
|
||||
Trigger
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="mt-5 flex items-center justify-between">
|
||||
<h3 className="truncate pr-2 text-base font-medium text-neutral-900">
|
||||
@@ -75,12 +107,15 @@ export default function AgentRunSummaryCard({
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
|
||||
<p
|
||||
className="mt-1 text-sm font-normal text-neutral-500"
|
||||
title={moment(timestamp).toString()}
|
||||
>
|
||||
Ran {moment(timestamp).fromNow()}
|
||||
</p>
|
||||
{timestamp && (
|
||||
<p
|
||||
className="mt-1 text-sm font-normal text-neutral-500"
|
||||
title={moment(timestamp).toString()}
|
||||
>
|
||||
{moment(timestamp).isBefore() ? "Ran" : "Runs in"}{" "}
|
||||
{moment(timestamp).fromNow()}
|
||||
</p>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
|
||||
@@ -7,11 +7,14 @@ import {
|
||||
GraphExecutionID,
|
||||
GraphExecutionMeta,
|
||||
LibraryAgent,
|
||||
LibraryAgentPreset,
|
||||
LibraryAgentPresetID,
|
||||
Schedule,
|
||||
ScheduleID,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
|
||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { Button } from "@/components/agptui/Button";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
|
||||
@@ -21,13 +24,16 @@ import AgentRunSummaryCard from "@/components/agents/agent-run-summary-card";
|
||||
interface AgentRunsSelectorListProps {
|
||||
agent: LibraryAgent;
|
||||
agentRuns: GraphExecutionMeta[];
|
||||
agentPresets: LibraryAgentPreset[];
|
||||
schedules: Schedule[];
|
||||
selectedView: { type: "run" | "schedule"; id?: string };
|
||||
selectedView: { type: "run" | "preset" | "schedule"; id?: string };
|
||||
allowDraftNewRun?: boolean;
|
||||
onSelectRun: (id: GraphExecutionID) => void;
|
||||
onSelectPreset: (preset: LibraryAgentPresetID) => void;
|
||||
onSelectSchedule: (schedule: Schedule) => void;
|
||||
onSelectDraftNewRun: () => void;
|
||||
onDeleteRun: (id: GraphExecutionMeta) => void;
|
||||
onDeletePreset: (id: LibraryAgentPresetID) => void;
|
||||
onDeleteSchedule: (id: ScheduleID) => void;
|
||||
className?: string;
|
||||
}
|
||||
@@ -35,13 +41,16 @@ interface AgentRunsSelectorListProps {
|
||||
export default function AgentRunsSelectorList({
|
||||
agent,
|
||||
agentRuns,
|
||||
agentPresets,
|
||||
schedules,
|
||||
selectedView,
|
||||
allowDraftNewRun = true,
|
||||
onSelectRun,
|
||||
onSelectPreset,
|
||||
onSelectSchedule,
|
||||
onSelectDraftNewRun,
|
||||
onDeleteRun,
|
||||
onDeletePreset,
|
||||
onDeleteSchedule,
|
||||
className,
|
||||
}: AgentRunsSelectorListProps): React.ReactElement {
|
||||
@@ -49,6 +58,8 @@ export default function AgentRunsSelectorList({
|
||||
"runs",
|
||||
);
|
||||
|
||||
const listItemClasses = "h-28 w-72 lg:h-32 xl:w-80";
|
||||
|
||||
return (
|
||||
<aside className={cn("flex flex-col gap-4", className)}>
|
||||
{allowDraftNewRun && (
|
||||
@@ -63,7 +74,7 @@ export default function AgentRunsSelectorList({
|
||||
onClick={onSelectDraftNewRun}
|
||||
>
|
||||
<Plus className="h-6 w-6" />
|
||||
<span>New run</span>
|
||||
<span>New {agent.has_external_trigger ? "trigger" : "run"}</span>
|
||||
</Button>
|
||||
)}
|
||||
|
||||
@@ -105,41 +116,69 @@ export default function AgentRunsSelectorList({
|
||||
onClick={onSelectDraftNewRun}
|
||||
>
|
||||
<Plus className="h-6 w-6" />
|
||||
<span>New run</span>
|
||||
<span>New {agent.has_external_trigger ? "trigger" : "run"}</span>
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{activeListTab === "runs"
|
||||
? agentRuns
|
||||
{activeListTab === "runs" ? (
|
||||
<>
|
||||
{agentPresets
|
||||
.toSorted(
|
||||
(a, b) => b.updated_at.getTime() - a.updated_at.getTime(),
|
||||
)
|
||||
.map((preset) => (
|
||||
<AgentRunSummaryCard
|
||||
className={cn(listItemClasses, "lg:h-auto")}
|
||||
key={preset.id}
|
||||
type="preset"
|
||||
status={preset.is_active ? "active" : "inactive"}
|
||||
title={preset.name}
|
||||
// timestamp={preset.last_run_time} // TODO: implement this
|
||||
selected={selectedView.id === preset.id}
|
||||
onClick={() => onSelectPreset(preset.id)}
|
||||
onDelete={() => onDeletePreset(preset.id)}
|
||||
/>
|
||||
))}
|
||||
{agentPresets.length > 0 && <Separator className="my-1" />}
|
||||
{agentRuns
|
||||
.toSorted(
|
||||
(a, b) => b.started_at.getTime() - a.started_at.getTime(),
|
||||
)
|
||||
.map((run) => (
|
||||
<AgentRunSummaryCard
|
||||
className="h-28 w-72 lg:h-32 xl:w-80"
|
||||
className={listItemClasses}
|
||||
key={run.id}
|
||||
type="run"
|
||||
status={agentRunStatusMap[run.status]}
|
||||
title={agent.name}
|
||||
title={
|
||||
(run.preset_id
|
||||
? agentPresets.find((p) => p.id == run.preset_id)?.name
|
||||
: null) ?? agent.name
|
||||
}
|
||||
timestamp={run.started_at}
|
||||
selected={selectedView.id === run.id}
|
||||
onClick={() => onSelectRun(run.id)}
|
||||
onDelete={() => onDeleteRun(run)}
|
||||
/>
|
||||
))
|
||||
: schedules
|
||||
.filter((schedule) => schedule.graph_id === agent.graph_id)
|
||||
.map((schedule) => (
|
||||
<AgentRunSummaryCard
|
||||
className="h-28 w-72 lg:h-32 xl:w-80"
|
||||
key={schedule.id}
|
||||
status="scheduled"
|
||||
title={schedule.name}
|
||||
timestamp={schedule.next_run_time}
|
||||
selected={selectedView.id === schedule.id}
|
||||
onClick={() => onSelectSchedule(schedule)}
|
||||
onDelete={() => onDeleteSchedule(schedule.id)}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
) : (
|
||||
schedules
|
||||
.filter((schedule) => schedule.graph_id === agent.graph_id)
|
||||
.map((schedule) => (
|
||||
<AgentRunSummaryCard
|
||||
className={listItemClasses}
|
||||
key={schedule.id}
|
||||
type="schedule"
|
||||
status="scheduled" // TODO: implement active/inactive status for schedules
|
||||
title={schedule.name}
|
||||
timestamp={schedule.next_run_time}
|
||||
selected={selectedView.id === schedule.id}
|
||||
onClick={() => onSelectSchedule(schedule)}
|
||||
onDelete={() => onDeleteSchedule(schedule.id)}
|
||||
/>
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</aside>
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
import React from "react";
|
||||
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
|
||||
export type AgentStatus = "active" | "inactive" | "error";
|
||||
|
||||
const statusData: Record<
|
||||
AgentStatus,
|
||||
{ label: string; variant: keyof typeof statusStyles }
|
||||
> = {
|
||||
active: { label: "Active", variant: "success" },
|
||||
error: { label: "Error", variant: "destructive" },
|
||||
inactive: { label: "Inactive", variant: "secondary" },
|
||||
};
|
||||
|
||||
const statusStyles = {
|
||||
success:
|
||||
"bg-green-100 text-green-800 hover:bg-green-100 hover:text-green-800",
|
||||
destructive: "bg-red-100 text-red-800 hover:bg-red-100 hover:text-red-800",
|
||||
warning:
|
||||
"bg-yellow-100 text-yellow-800 hover:bg-yellow-100 hover:text-yellow-800",
|
||||
info: "bg-blue-100 text-blue-800 hover:bg-blue-100 hover:text-blue-800",
|
||||
secondary:
|
||||
"bg-slate-100 text-slate-800 hover:bg-slate-100 hover:text-slate-800",
|
||||
};
|
||||
|
||||
export default function AgentStatusChip({
|
||||
status,
|
||||
}: {
|
||||
status: AgentStatus;
|
||||
}): React.ReactElement {
|
||||
return (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
className={`text-xs font-medium ${statusStyles[statusData[status].variant]} rounded-[45px] px-[9px] py-[3px]`}
|
||||
>
|
||||
{statusData[status].label}
|
||||
</Badge>
|
||||
);
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import * as React from "react";
|
||||
import { IconClose } from "../ui/icons";
|
||||
import { IconCross } from "../ui/icons";
|
||||
import Image from "next/image";
|
||||
import { Button } from "../agptui/Button";
|
||||
|
||||
@@ -50,7 +50,7 @@ export const PublishAgentAwaitingReview: React.FC<
|
||||
className="absolute right-4 top-4 flex h-[38px] w-[38px] items-center justify-center rounded-full bg-gray-100 transition-colors hover:bg-gray-200 dark:bg-neutral-700 dark:hover:bg-neutral-600"
|
||||
aria-label="Close dialog"
|
||||
>
|
||||
<IconClose
|
||||
<IconCross
|
||||
size="default"
|
||||
className="text-neutral-600 dark:text-neutral-300"
|
||||
/>
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import * as React from "react";
|
||||
import Image from "next/image";
|
||||
import { Button } from "../agptui/Button";
|
||||
import { IconClose } from "../ui/icons";
|
||||
import { IconCross } from "../ui/icons";
|
||||
|
||||
export interface Agent {
|
||||
name: string;
|
||||
@@ -56,7 +56,7 @@ export const PublishAgentSelect: React.FC<PublishAgentSelectProps> = ({
|
||||
className="flex h-8 w-8 items-center justify-center rounded-full bg-gray-100 transition-colors hover:bg-gray-200 dark:bg-gray-700 dark:hover:bg-gray-600"
|
||||
aria-label="Close"
|
||||
>
|
||||
<IconClose
|
||||
<IconCross
|
||||
size="default"
|
||||
className="text-neutral-600 dark:text-neutral-400"
|
||||
/>
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import * as React from "react";
|
||||
import Image from "next/image";
|
||||
import { Button } from "../agptui/Button";
|
||||
import { IconClose, IconPlus } from "../ui/icons";
|
||||
import { IconCross, IconPlus } from "../ui/icons";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { toast } from "../ui/use-toast";
|
||||
|
||||
@@ -180,7 +180,7 @@ export const PublishAgentInfo: React.FC<PublishAgentInfoProps> = ({
|
||||
className="flex h-[38px] w-[38px] items-center justify-center rounded-full bg-gray-100 transition-colors hover:bg-gray-200 dark:bg-gray-700 dark:hover:bg-gray-600"
|
||||
aria-label="Close"
|
||||
>
|
||||
<IconClose
|
||||
<IconCross
|
||||
size="default"
|
||||
className="text-neutral-600 dark:text-neutral-300"
|
||||
/>
|
||||
@@ -313,7 +313,7 @@ export const PublishAgentInfo: React.FC<PublishAgentInfoProps> = ({
|
||||
className="absolute right-1 top-1 flex h-5 w-5 items-center justify-center rounded-full bg-white bg-opacity-70 transition-opacity hover:bg-opacity-100 dark:bg-gray-800 dark:bg-opacity-70 dark:hover:bg-opacity-100"
|
||||
aria-label="Remove image"
|
||||
>
|
||||
<IconClose
|
||||
<IconCross
|
||||
size="sm"
|
||||
className="text-neutral-600 dark:text-neutral-300"
|
||||
/>
|
||||
|
||||
@@ -0,0 +1,154 @@
|
||||
import type { Meta, StoryObj } from "@storybook/nextjs";
|
||||
import { Input } from "./Input";
|
||||
|
||||
const meta: Meta<typeof Input> = {
|
||||
title: "Atoms/Input",
|
||||
tags: ["autodocs"],
|
||||
component: Input,
|
||||
parameters: {
|
||||
layout: "centered",
|
||||
docs: {
|
||||
description: {
|
||||
component:
|
||||
"Input component based on our design system. Built on top of shadcn/ui input with custom styling matching Figma designs.",
|
||||
},
|
||||
},
|
||||
},
|
||||
argTypes: {
|
||||
type: {
|
||||
control: "select",
|
||||
options: ["text", "email", "password", "number", "amount", "tel", "url"],
|
||||
description: "Input type",
|
||||
},
|
||||
placeholder: {
|
||||
control: "text",
|
||||
description: "Placeholder text",
|
||||
},
|
||||
value: {
|
||||
control: "text",
|
||||
description: "The value of the input",
|
||||
},
|
||||
label: {
|
||||
control: "text",
|
||||
description:
|
||||
"Label text (used as placeholder if no placeholder provided)",
|
||||
},
|
||||
disabled: {
|
||||
control: "boolean",
|
||||
description: "Disable the input",
|
||||
},
|
||||
hideLabel: {
|
||||
control: "boolean",
|
||||
description: "Hide the label",
|
||||
},
|
||||
decimalCount: {
|
||||
control: "number",
|
||||
description:
|
||||
"Number of decimal places allowed (only for amount type). Default is 4.",
|
||||
},
|
||||
error: {
|
||||
control: "text",
|
||||
description: "Error message to display below the input",
|
||||
},
|
||||
},
|
||||
args: {
|
||||
placeholder: "Enter text...",
|
||||
type: "text",
|
||||
value: "",
|
||||
disabled: false,
|
||||
hideLabel: false,
|
||||
decimalCount: 4,
|
||||
},
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof meta>;
|
||||
|
||||
// Basic variants
|
||||
export const Default: Story = {
|
||||
args: {
|
||||
placeholder: "Enter your text",
|
||||
label: "Full Name",
|
||||
},
|
||||
};
|
||||
|
||||
export const WithoutLabel: Story = {
|
||||
args: {
|
||||
label: "Full Name",
|
||||
hideLabel: true,
|
||||
},
|
||||
};
|
||||
|
||||
export const Disabled: Story = {
|
||||
args: {
|
||||
placeholder: "This field is disabled",
|
||||
label: "Full Name",
|
||||
disabled: true,
|
||||
},
|
||||
};
|
||||
|
||||
export const WithError: Story = {
|
||||
args: {
|
||||
label: "Email",
|
||||
type: "email",
|
||||
placeholder: "Enter your email",
|
||||
error: "Please enter a valid email address",
|
||||
},
|
||||
};
|
||||
|
||||
export const InputTypes: Story = {
|
||||
render: renderInputTypes,
|
||||
parameters: {
|
||||
controls: {
|
||||
disable: true,
|
||||
},
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"Complete showcase of all input types with their specific behaviors. Test each input type to verify filtering and formatting works correctly.",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// Render functions as function declarations
|
||||
function renderInputTypes() {
|
||||
return (
|
||||
<div className="w-full max-w-md space-y-8">
|
||||
<Input label="Full Name" type="text" placeholder="Enter your full name" />
|
||||
<Input label="Email" type="email" placeholder="your.email@example.com" />
|
||||
<Input
|
||||
label="Password"
|
||||
type="password"
|
||||
placeholder="Enter your password"
|
||||
/>
|
||||
<div className="flex flex-col gap-4">
|
||||
<p className="font-mono text-sm">
|
||||
If type="number" prop is provided, the input will allow only
|
||||
positive or negative numbers. No decimal limiting.
|
||||
</p>
|
||||
<Input label="Age" type="number" placeholder="Enter your age" />
|
||||
</div>
|
||||
<div className="flex flex-col gap-4">
|
||||
<p className="font-mono text-sm">
|
||||
If type="amount" prop is provided, it formats numbers with
|
||||
commas (1000 → 1,000) and limits decimals via decimalCount prop.
|
||||
</p>
|
||||
<Input
|
||||
label="Price"
|
||||
type="amount"
|
||||
placeholder="Enter amount"
|
||||
decimalCount={2}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col gap-4">
|
||||
<p className="font-mono text-sm">
|
||||
If type="tel" prop is provided, the input will allow only
|
||||
numbers, spaces, parentheses (), plus +, and brackets [].
|
||||
</p>
|
||||
<Input label="Phone" type="tel" placeholder="+1 (555) 123-4567" />
|
||||
</div>
|
||||
<Input label="Website" type="url" placeholder="https://example.com" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
import { Input as BaseInput, type InputProps } from "@/components/ui/input";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Text } from "../Text/Text";
|
||||
import { useInput } from "./useInput";
|
||||
|
||||
export interface TextFieldProps extends InputProps {
|
||||
label: string;
|
||||
hideLabel?: boolean;
|
||||
decimalCount?: number; // Only used for type="amount"
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export function Input({
|
||||
className,
|
||||
label,
|
||||
placeholder,
|
||||
hideLabel = false,
|
||||
decimalCount,
|
||||
error,
|
||||
...props
|
||||
}: TextFieldProps) {
|
||||
const { handleInputChange } = useInput({ ...props, decimalCount });
|
||||
|
||||
const input = (
|
||||
<BaseInput
|
||||
className={cn(
|
||||
// Override the default input styles with Figma design
|
||||
"h-[2.875rem] rounded-3xl border border-zinc-200 bg-white px-4 py-2.5 shadow-none",
|
||||
"font-normal leading-6 text-black",
|
||||
"placeholder:font-normal placeholder:text-zinc-400",
|
||||
// Focus and hover states
|
||||
"focus:border-zinc-400 focus:shadow-none focus:outline-none focus:ring-1 focus:ring-zinc-400 focus:ring-offset-0",
|
||||
// Error state
|
||||
error &&
|
||||
"border-2 border-red-500 focus:border-red-500 focus:ring-red-500",
|
||||
className,
|
||||
)}
|
||||
type={props.type}
|
||||
placeholder={placeholder || label}
|
||||
onChange={handleInputChange}
|
||||
{...(hideLabel ? { "aria-label": label } : {})}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
|
||||
const inputWithError = (
|
||||
<div className="flex flex-col gap-1">
|
||||
{input}
|
||||
{error && (
|
||||
<Text variant="small-medium" as="span" className="!text-red-500">
|
||||
{error}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
return hideLabel ? (
|
||||
inputWithError
|
||||
) : (
|
||||
<label className="flex flex-col gap-2">
|
||||
<Text variant="body-medium" as="span" className="text-black">
|
||||
{label}
|
||||
</Text>
|
||||
{inputWithError}
|
||||
</label>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
export const NUMBER_REGEX = /[^0-9.-]/g;
|
||||
export const PHONE_REGEX = /[^0-9\s()\+\[\]]/g;
|
||||
|
||||
export function formatAmountWithCommas(value: string): string {
|
||||
if (!value) return value;
|
||||
|
||||
const parts = value.split(".");
|
||||
const integerPart = parts[0];
|
||||
const decimalPart = parts[1];
|
||||
|
||||
// Add commas to integer part
|
||||
const formattedInteger = integerPart.replace(/\B(?=(\d{3})+(?!\d))/g, ",");
|
||||
|
||||
// Check if there was a decimal point in the original value
|
||||
if (value.includes(".")) {
|
||||
return decimalPart
|
||||
? `${formattedInteger}.${decimalPart}`
|
||||
: `${formattedInteger}.`;
|
||||
}
|
||||
|
||||
return formattedInteger;
|
||||
}
|
||||
|
||||
export function filterNumberInput(value: string): string {
|
||||
let filteredValue = value;
|
||||
|
||||
// Remove all non-numeric characters except . and -
|
||||
filteredValue = value.replace(NUMBER_REGEX, "");
|
||||
|
||||
// Handle multiple decimal points - keep only the first one
|
||||
const parts = filteredValue.split(".");
|
||||
if (parts.length > 2) {
|
||||
filteredValue = parts[0] + "." + parts.slice(1).join("");
|
||||
}
|
||||
|
||||
// Handle minus signs - only allow at the beginning
|
||||
if (filteredValue.indexOf("-") > 0) {
|
||||
const hadMinusAtStart = value.startsWith("-");
|
||||
filteredValue = filteredValue.replace(/-/g, "");
|
||||
if (hadMinusAtStart) {
|
||||
filteredValue = "-" + filteredValue;
|
||||
}
|
||||
}
|
||||
|
||||
return filteredValue;
|
||||
}
|
||||
|
||||
export function limitDecimalPlaces(
|
||||
value: string,
|
||||
decimalCount: number,
|
||||
): string {
|
||||
const [integerPart, decimalPart] = value.split(".");
|
||||
if (decimalPart && decimalPart.length > decimalCount) {
|
||||
return `${integerPart}.${decimalPart.substring(0, decimalCount)}`;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
export function filterPhoneInput(value: string): string {
|
||||
return value.replace(PHONE_REGEX, "");
|
||||
}
|
||||
|
||||
export function removeCommas(value: string): string {
|
||||
return value.replace(/,/g, "");
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
import { InputProps } from "@/components/ui/input";
|
||||
import {
|
||||
filterNumberInput,
|
||||
filterPhoneInput,
|
||||
formatAmountWithCommas,
|
||||
limitDecimalPlaces,
|
||||
removeCommas,
|
||||
} from "./helpers";
|
||||
|
||||
interface ExtendedInputProps extends InputProps {
|
||||
decimalCount?: number;
|
||||
}
|
||||
|
||||
export function useInput(args: ExtendedInputProps) {
|
||||
function handleInputChange(e: React.ChangeEvent<HTMLInputElement>) {
|
||||
const { value } = e.target;
|
||||
const decimalCount = args.decimalCount ?? 4;
|
||||
|
||||
let processedValue = value;
|
||||
|
||||
if (args.type === "number") {
|
||||
// Basic number filtering - no decimal limiting
|
||||
const filteredValue = filterNumberInput(value);
|
||||
processedValue = filteredValue;
|
||||
} else if (args.type === "amount") {
|
||||
// Amount type with decimal limiting and comma formatting
|
||||
const cleanValue = removeCommas(value);
|
||||
let filteredValue = filterNumberInput(cleanValue);
|
||||
filteredValue = limitDecimalPlaces(filteredValue, decimalCount);
|
||||
|
||||
const displayValue = formatAmountWithCommas(filteredValue);
|
||||
e.target.value = displayValue;
|
||||
processedValue = filteredValue; // Pass clean value to parent
|
||||
} else if (args.type === "tel") {
|
||||
processedValue = filterPhoneInput(value);
|
||||
}
|
||||
|
||||
// Call onChange with processed value
|
||||
if (args.onChange) {
|
||||
// Only create synthetic event if we need to change the value
|
||||
if (processedValue !== value || args.type === "amount") {
|
||||
const syntheticEvent = {
|
||||
...e,
|
||||
target: {
|
||||
...e.target,
|
||||
value: processedValue,
|
||||
},
|
||||
} as React.ChangeEvent<HTMLInputElement>;
|
||||
|
||||
args.onChange(syntheticEvent);
|
||||
} else {
|
||||
args.onChange(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { handleInputChange };
|
||||
}
|
||||
@@ -114,12 +114,14 @@ export const CredentialsInput: FC<{
|
||||
selectedCredentials?: CredentialsMetaInput;
|
||||
onSelectCredentials: (newValue?: CredentialsMetaInput) => void;
|
||||
siblingInputs?: Record<string, any>;
|
||||
hideIfSingleCredentialAvailable?: boolean;
|
||||
}> = ({
|
||||
schema,
|
||||
className,
|
||||
selectedCredentials,
|
||||
onSelectCredentials,
|
||||
siblingInputs,
|
||||
hideIfSingleCredentialAvailable = true,
|
||||
}) => {
|
||||
const [isAPICredentialsModalOpen, setAPICredentialsModalOpen] =
|
||||
useState(false);
|
||||
@@ -162,7 +164,11 @@ export const CredentialsInput: FC<{
|
||||
}
|
||||
}, [singleCredential, selectedCredentials, onSelectCredentials]);
|
||||
|
||||
if (!credentials || credentials.isLoading || singleCredential) {
|
||||
if (
|
||||
!credentials ||
|
||||
credentials.isLoading ||
|
||||
(singleCredential && hideIfSingleCredentialAvailable)
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@@ -21,8 +21,20 @@ import {
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
} from "@/components/ui/select";
|
||||
import { determineDataType, DataType } from "@/lib/autogpt-server-api/types";
|
||||
import { BlockIOSubSchema } from "@/lib/autogpt-server-api/types";
|
||||
import {
|
||||
MultiSelector,
|
||||
MultiSelectorContent,
|
||||
MultiSelectorInput,
|
||||
MultiSelectorItem,
|
||||
MultiSelectorList,
|
||||
MultiSelectorTrigger,
|
||||
} from "@/components/ui/multiselect";
|
||||
import {
|
||||
BlockIOObjectSubSchema,
|
||||
BlockIOSubSchema,
|
||||
DataType,
|
||||
determineDataType,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
|
||||
/**
|
||||
* A generic prop structure for the TypeBasedInput.
|
||||
@@ -37,7 +49,7 @@ export interface TypeBasedInputProps {
|
||||
onChange: (value: any) => void;
|
||||
}
|
||||
|
||||
const inputClasses = "min-h-11 rounded-full border px-4 py-2.5";
|
||||
const inputClasses = "min-h-11 rounded-[1.375rem] border px-4 py-2.5 bg-text";
|
||||
|
||||
function Input({
|
||||
className,
|
||||
@@ -171,6 +183,46 @@ export const TypeBasedInput: FC<
|
||||
break;
|
||||
}
|
||||
|
||||
case DataType.MULTI_SELECT:
|
||||
const _schema = schema as BlockIOObjectSubSchema;
|
||||
|
||||
innerInputElement = (
|
||||
<MultiSelector
|
||||
className="nodrag"
|
||||
values={Object.entries(value || {})
|
||||
.filter(([_, v]) => v)
|
||||
.map(([k, _]) => k)}
|
||||
onValuesChange={(values: string[]) => {
|
||||
const allKeys = Object.keys(_schema.properties);
|
||||
onChange(
|
||||
Object.fromEntries(
|
||||
allKeys.map((opt) => [opt, values.includes(opt)]),
|
||||
),
|
||||
);
|
||||
}}
|
||||
>
|
||||
<MultiSelectorTrigger className={inputClasses}>
|
||||
<MultiSelectorInput
|
||||
placeholder={schema.placeholder ?? `Select ${schema.title}...`}
|
||||
/>
|
||||
</MultiSelectorTrigger>
|
||||
<MultiSelectorContent className="nowheel">
|
||||
<MultiSelectorList
|
||||
className={cn(inputClasses, "agpt-border-input bg-white")}
|
||||
>
|
||||
{Object.keys(_schema.properties)
|
||||
.map((key) => ({ ..._schema.properties[key], key }))
|
||||
.map(({ key, title, description }) => (
|
||||
<MultiSelectorItem key={key} value={key} title={description}>
|
||||
{title ?? key}
|
||||
</MultiSelectorItem>
|
||||
))}
|
||||
</MultiSelectorList>
|
||||
</MultiSelectorContent>
|
||||
</MultiSelector>
|
||||
);
|
||||
break;
|
||||
|
||||
case DataType.SHORT_TEXT:
|
||||
default:
|
||||
innerInputElement = (
|
||||
|
||||
@@ -1308,21 +1308,21 @@ export const IconTiktok = createIcon((props) => (
|
||||
));
|
||||
|
||||
/**
|
||||
* Close (X) icon component.
|
||||
* Cross (X) icon component.
|
||||
*
|
||||
* @component IconClose
|
||||
* @component IconCross
|
||||
* @param {IconProps} props - The props object containing additional attributes and event handlers for the icon.
|
||||
* @returns {JSX.Element} - The close icon.
|
||||
* @returns {JSX.Element} - The cross icon.
|
||||
*
|
||||
* @example
|
||||
* // Default usage
|
||||
* <IconClose />
|
||||
* <IconCross />
|
||||
*
|
||||
* @example
|
||||
* // With custom color and size
|
||||
* <IconClose className="text-primary" size="lg" />
|
||||
* <IconCross className="text-primary" size="lg" />
|
||||
*/
|
||||
export const IconClose = createIcon((props) => (
|
||||
export const IconCross = createIcon((props) => (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 14 14"
|
||||
|
||||
@@ -9,16 +9,24 @@ const getYouTubeVideoId = (url: string) => {
|
||||
return match && match[7].length === 11 ? match[7] : null;
|
||||
};
|
||||
|
||||
const isValidMediaUri = (url: string): boolean => {
|
||||
if (url.startsWith("data:")) {
|
||||
return true;
|
||||
}
|
||||
const validUrl = /^(https?:\/\/)(www\.)?/i;
|
||||
const cleanedUrl = url.split("?")[0];
|
||||
return validUrl.test(cleanedUrl);
|
||||
};
|
||||
|
||||
const isValidVideoUrl = (url: string): boolean => {
|
||||
if (url.startsWith("data:video")) {
|
||||
return true;
|
||||
}
|
||||
const validUrl = /^(https?:\/\/)(www\.)?/i;
|
||||
const videoExtensions = /\.(mp4|webm|ogg)$/i;
|
||||
const youtubeRegex = /^(https?:\/\/)?(www\.)?(youtube\.com|youtu\.?be)\/.+$/;
|
||||
const cleanedUrl = url.split("?")[0];
|
||||
return (
|
||||
(validUrl.test(cleanedUrl) && videoExtensions.test(cleanedUrl)) ||
|
||||
(isValidMediaUri(url) && videoExtensions.test(cleanedUrl)) ||
|
||||
youtubeRegex.test(cleanedUrl)
|
||||
);
|
||||
};
|
||||
@@ -29,17 +37,16 @@ const isValidImageUrl = (url: string): boolean => {
|
||||
}
|
||||
const imageExtensions = /\.(jpeg|jpg|gif|png|svg|webp)$/i;
|
||||
const cleanedUrl = url.split("?")[0];
|
||||
return imageExtensions.test(cleanedUrl);
|
||||
return isValidMediaUri(url) && imageExtensions.test(cleanedUrl);
|
||||
};
|
||||
|
||||
const isValidAudioUrl = (url: string): boolean => {
|
||||
if (url.startsWith("data:audio")) {
|
||||
return true;
|
||||
}
|
||||
const validUrl = /^(https?:\/\/)(www\.)?/i;
|
||||
const audioExtensions = /\.(mp3|wav)$/i;
|
||||
const cleanedUrl = url.split("?")[0];
|
||||
return validUrl.test(cleanedUrl) && audioExtensions.test(cleanedUrl);
|
||||
return isValidMediaUri(url) && audioExtensions.test(cleanedUrl);
|
||||
};
|
||||
|
||||
const VideoRenderer: React.FC<{ videoUrl: string }> = ({ videoUrl }) => {
|
||||
|
||||
@@ -646,6 +646,24 @@ export default class BackendAPI {
|
||||
return this._request("POST", `/library/agents/${libraryAgentId}/fork`);
|
||||
}
|
||||
|
||||
async setupAgentTrigger(
|
||||
libraryAgentID: LibraryAgentID,
|
||||
params: {
|
||||
name: string;
|
||||
description?: string;
|
||||
trigger_config: Record<string, any>;
|
||||
agent_credentials: Record<string, CredentialsMetaInput>;
|
||||
},
|
||||
): Promise<LibraryAgentPreset> {
|
||||
return parseLibraryAgentPresetTimestamp(
|
||||
await this._request(
|
||||
"POST",
|
||||
`/library/agents/${libraryAgentID}/setup-trigger`,
|
||||
params,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
async listLibraryAgentPresets(params?: {
|
||||
graph_id?: GraphID;
|
||||
page?: number;
|
||||
@@ -697,14 +715,10 @@ export default class BackendAPI {
|
||||
|
||||
executeLibraryAgentPreset(
|
||||
presetID: LibraryAgentPresetID,
|
||||
graphID: GraphID,
|
||||
graphVersion: number,
|
||||
nodeInput: { [key: string]: any },
|
||||
inputs?: { [key: string]: any },
|
||||
): Promise<{ id: GraphExecutionID }> {
|
||||
return this._request("POST", `/library/presets/${presetID}/execute`, {
|
||||
graph_id: graphID,
|
||||
graph_version: graphVersion,
|
||||
node_input: nodeInput,
|
||||
inputs,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -401,11 +401,29 @@ export type LibraryAgent = {
|
||||
updated_at: Date;
|
||||
name: string;
|
||||
description: string;
|
||||
input_schema: BlockIOObjectSubSchema;
|
||||
input_schema: GraphIOSchema;
|
||||
credentials_input_schema: {
|
||||
type: "object";
|
||||
properties: { [key: string]: BlockIOCredentialsSubSchema };
|
||||
required: (keyof LibraryAgent["credentials_input_schema"]["properties"])[];
|
||||
};
|
||||
new_output: boolean;
|
||||
can_access_graph: boolean;
|
||||
is_latest_version: boolean;
|
||||
};
|
||||
} & (
|
||||
| {
|
||||
has_external_trigger: true;
|
||||
trigger_setup_info: {
|
||||
provider: CredentialsProviderName;
|
||||
config_schema: BlockIORootSchema;
|
||||
credentials_input_name?: string;
|
||||
};
|
||||
}
|
||||
| {
|
||||
has_external_trigger: false;
|
||||
trigger_setup_info?: null;
|
||||
}
|
||||
);
|
||||
|
||||
export type LibraryAgentID = Brand<string, "LibraryAgentID">;
|
||||
|
||||
@@ -432,9 +450,11 @@ export type LibraryAgentPreset = {
|
||||
graph_id: GraphID;
|
||||
graph_version: number;
|
||||
inputs: { [key: string]: any };
|
||||
credentials: Record<string, CredentialsMetaInput>;
|
||||
name: string;
|
||||
description: string;
|
||||
is_active: boolean;
|
||||
webhook_id?: string;
|
||||
};
|
||||
|
||||
export type LibraryAgentPresetID = Brand<string, "LibraryAgentPresetID">;
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||
import { createBrowserClient } from "@supabase/ssr";
|
||||
|
||||
const isClient = typeof window !== "undefined";
|
||||
|
||||
export const getSupabaseClient = async () => {
|
||||
return isClient
|
||||
? createBrowserClient(
|
||||
process.env.NEXT_PUBLIC_SUPABASE_URL!,
|
||||
process.env.NEXT_PUBLIC_SUPABASE_ANON_KEY!,
|
||||
{ isSingleton: true },
|
||||
)
|
||||
: await getServerSupabase();
|
||||
};
|
||||
Reference in New Issue
Block a user