mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-06 21:05:13 -05:00
## Problem When marketplace agents are included in the `library_agents` payload sent to the Agent Generator service, they were missing required fields (`graph_id`, `graph_version`, `input_schema`, `output_schema`). This caused Pydantic validation to fail with HTTP 422 Unprocessable Entity. **Root cause:** The `MarketplaceAgentSummary` TypedDict had a different shape than `LibraryAgentInfo` expected by the Agent Generator: - Agent Generator expects: `graph_id`, `graph_version`, `name`, `description`, `input_schema`, `output_schema` - MarketplaceAgentSummary had: `name`, `description`, `sub_heading`, `creator`, `is_marketplace_agent` ## Solution 1. **Add `agent_graph_id` to `StoreAgent` model** - The field was already in the database view but not exposed 2. **Include `agentGraphId` in hybrid search SQL query** - Carry the field through the search CTEs 3. **Update `search_marketplace_agents_for_generation()`** - Now fetches full graph schemas using `get_graph()` and returns `LibraryAgentSummary` (same type as library agents) 4. **Update deduplication logic** - Use `graph_id` instead of name for more accurate deduplication ## Changes - `backend/api/features/store/model.py`: Add optional `agent_graph_id` field to `StoreAgent` - `backend/api/features/store/hybrid_search.py`: Include `agentGraphId` in SQL query columns - `backend/api/features/store/db.py`: Map `agentGraphId` when creating `StoreAgent` objects - `backend/api/features/chat/tools/agent_generator/core.py`: Update `search_marketplace_agents_for_generation()` to fetch and include full graph schemas ## Testing - [ ] Agent creation on dev with marketplace agents in context - [ ] Verify no 422 errors from Agent Generator - [ ] Verify marketplace agents can be used as sub-agents Fixes: SECRT-1817 --------- Co-authored-by: majdyz <majdyz@users.noreply.github.com> Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
1599 lines
55 KiB
Python
1599 lines
55 KiB
Python
import asyncio
|
|
import logging
|
|
import uuid
|
|
from collections import defaultdict
|
|
from datetime import datetime, timezone
|
|
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, cast
|
|
|
|
from prisma.enums import SubmissionStatus
|
|
from prisma.models import (
|
|
AgentGraph,
|
|
AgentNode,
|
|
AgentNodeLink,
|
|
LibraryAgent,
|
|
StoreListingVersion,
|
|
)
|
|
from prisma.types import (
|
|
AgentGraphCreateInput,
|
|
AgentGraphWhereInput,
|
|
AgentNodeCreateInput,
|
|
AgentNodeLinkCreateInput,
|
|
StoreListingVersionWhereInput,
|
|
)
|
|
from pydantic import BaseModel, BeforeValidator, Field, create_model
|
|
from pydantic.fields import computed_field
|
|
|
|
from backend.blocks.agent import AgentExecutorBlock
|
|
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
|
from backend.blocks.llm import LlmModel
|
|
from backend.data.db import prisma as db
|
|
from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name
|
|
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
|
from backend.data.model import (
|
|
CredentialsField,
|
|
CredentialsFieldInfo,
|
|
CredentialsMetaInput,
|
|
is_credentials_field_name,
|
|
)
|
|
from backend.integrations.providers import ProviderName
|
|
from backend.util import type as type_utils
|
|
from backend.util.exceptions import GraphNotAccessibleError, GraphNotInLibraryError
|
|
from backend.util.json import SafeJson
|
|
from backend.util.models import Pagination
|
|
|
|
from .block import (
|
|
AnyBlockSchema,
|
|
Block,
|
|
BlockInput,
|
|
BlockSchema,
|
|
BlockType,
|
|
EmptySchema,
|
|
get_block,
|
|
get_blocks,
|
|
)
|
|
from .db import BaseDbModel, query_raw_with_schema, transaction
|
|
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
|
|
|
if TYPE_CHECKING:
|
|
from .execution import NodesInputMasks
|
|
from .integrations import Webhook
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class GraphSettings(BaseModel):
|
|
# Use Annotated with BeforeValidator to coerce None to default values.
|
|
# This handles cases where the database has null values for these fields.
|
|
model_config = {"extra": "ignore"}
|
|
|
|
human_in_the_loop_safe_mode: Annotated[
|
|
bool, BeforeValidator(lambda v: v if v is not None else True)
|
|
] = True
|
|
sensitive_action_safe_mode: Annotated[
|
|
bool, BeforeValidator(lambda v: v if v is not None else False)
|
|
] = False
|
|
|
|
@classmethod
|
|
def from_graph(
|
|
cls,
|
|
graph: "GraphModel",
|
|
hitl_safe_mode: bool | None = None,
|
|
sensitive_action_safe_mode: bool = False,
|
|
) -> "GraphSettings":
|
|
# Default to True if not explicitly set
|
|
if hitl_safe_mode is None:
|
|
hitl_safe_mode = True
|
|
return cls(
|
|
human_in_the_loop_safe_mode=hitl_safe_mode,
|
|
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
|
)
|
|
|
|
|
|
class Link(BaseDbModel):
|
|
source_id: str
|
|
sink_id: str
|
|
source_name: str
|
|
sink_name: str
|
|
is_static: bool = False
|
|
|
|
@staticmethod
|
|
def from_db(link: AgentNodeLink):
|
|
return Link(
|
|
id=link.id,
|
|
source_name=link.sourceName,
|
|
source_id=link.agentNodeSourceId,
|
|
sink_name=link.sinkName,
|
|
sink_id=link.agentNodeSinkId,
|
|
is_static=link.isStatic,
|
|
)
|
|
|
|
def __hash__(self):
|
|
return hash((self.source_id, self.sink_id, self.source_name, self.sink_name))
|
|
|
|
|
|
class Node(BaseDbModel):
|
|
block_id: str
|
|
input_default: BlockInput = {} # dict[input_name, default_value]
|
|
metadata: dict[str, Any] = {}
|
|
input_links: list[Link] = []
|
|
output_links: list[Link] = []
|
|
|
|
@property
|
|
def credentials_optional(self) -> bool:
|
|
"""
|
|
Whether credentials are optional for this node.
|
|
When True and credentials are not configured, the node will be skipped
|
|
during execution rather than causing a validation error.
|
|
"""
|
|
return self.metadata.get("credentials_optional", False)
|
|
|
|
@property
|
|
def block(self) -> AnyBlockSchema | "_UnknownBlockBase":
|
|
"""Get the block for this node. Returns UnknownBlock if block is deleted/missing."""
|
|
block = get_block(self.block_id)
|
|
if not block:
|
|
# Log warning but don't raise exception - return a placeholder block for deleted blocks
|
|
logger.warning(
|
|
f"Block #{self.block_id} does not exist for Node #{self.id} (deleted/missing block), using UnknownBlock"
|
|
)
|
|
return _UnknownBlockBase(self.block_id)
|
|
return block
|
|
|
|
|
|
class NodeModel(Node):
|
|
graph_id: str
|
|
graph_version: int
|
|
|
|
webhook_id: Optional[str] = None
|
|
webhook: Optional["Webhook"] = None
|
|
|
|
@staticmethod
|
|
def from_db(node: AgentNode, for_export: bool = False) -> "NodeModel":
|
|
from .integrations import Webhook
|
|
|
|
obj = NodeModel(
|
|
id=node.id,
|
|
block_id=node.agentBlockId,
|
|
input_default=type_utils.convert(node.constantInput, dict[str, Any]),
|
|
metadata=type_utils.convert(node.metadata, dict[str, Any]),
|
|
graph_id=node.agentGraphId,
|
|
graph_version=node.agentGraphVersion,
|
|
webhook_id=node.webhookId,
|
|
webhook=Webhook.from_db(node.Webhook) if node.Webhook else None,
|
|
)
|
|
obj.input_links = [Link.from_db(link) for link in node.Input or []]
|
|
obj.output_links = [Link.from_db(link) for link in node.Output or []]
|
|
if for_export:
|
|
return obj.stripped_for_export()
|
|
return obj
|
|
|
|
def is_triggered_by_event_type(self, event_type: str) -> bool:
|
|
return self.block.is_triggered_by_event_type(self.input_default, event_type)
|
|
|
|
def stripped_for_export(self) -> "NodeModel":
|
|
"""
|
|
Returns a copy of the node model, stripped of any non-transferable properties
|
|
"""
|
|
stripped_node = self.model_copy(deep=True)
|
|
|
|
# Remove credentials and other (possible) secrets from node input
|
|
if stripped_node.input_default:
|
|
stripped_node.input_default = NodeModel._filter_secrets_from_node_input(
|
|
stripped_node.input_default, self.block.input_schema.jsonschema()
|
|
)
|
|
|
|
# Remove default secret value from secret input nodes
|
|
if (
|
|
stripped_node.block.block_type == BlockType.INPUT
|
|
and stripped_node.input_default.get("secret", False) is True
|
|
and "value" in stripped_node.input_default
|
|
):
|
|
del stripped_node.input_default["value"]
|
|
|
|
# Remove webhook info
|
|
stripped_node.webhook_id = None
|
|
stripped_node.webhook = None
|
|
|
|
return stripped_node
|
|
|
|
@staticmethod
|
|
def _filter_secrets_from_node_input(
|
|
input_data: dict[str, Any], schema: dict[str, Any] | None
|
|
) -> dict[str, Any]:
|
|
sensitive_keys = ["credentials", "api_key", "password", "token", "secret"]
|
|
field_schemas = schema.get("properties", {}) if schema else {}
|
|
result = {}
|
|
for key, value in input_data.items():
|
|
field_schema: dict | None = field_schemas.get(key)
|
|
if (field_schema and field_schema.get("secret", False)) or (
|
|
any(sensitive_key in key.lower() for sensitive_key in sensitive_keys)
|
|
# Prevent removing `secret` flag on input nodes
|
|
and type(value) is not bool
|
|
):
|
|
# This is a secret value -> filter this key-value pair out
|
|
continue
|
|
elif isinstance(value, dict):
|
|
result[key] = NodeModel._filter_secrets_from_node_input(
|
|
value, field_schema
|
|
)
|
|
else:
|
|
result[key] = value
|
|
return result
|
|
|
|
|
|
class BaseGraph(BaseDbModel):
|
|
version: int = 1
|
|
is_active: bool = True
|
|
name: str
|
|
description: str
|
|
instructions: str | None = None
|
|
recommended_schedule_cron: str | None = None
|
|
nodes: list[Node] = []
|
|
links: list[Link] = []
|
|
forked_from_id: str | None = None
|
|
forked_from_version: int | None = None
|
|
|
|
@computed_field
|
|
@property
|
|
def input_schema(self) -> dict[str, Any]:
|
|
return self._generate_schema(
|
|
*(
|
|
(block.input_schema, node.input_default)
|
|
for node in self.nodes
|
|
if (block := node.block).block_type == BlockType.INPUT
|
|
and issubclass(block.input_schema, AgentInputBlock.Input)
|
|
)
|
|
)
|
|
|
|
@computed_field
|
|
@property
|
|
def output_schema(self) -> dict[str, Any]:
|
|
return self._generate_schema(
|
|
*(
|
|
(block.input_schema, node.input_default)
|
|
for node in self.nodes
|
|
if (block := node.block).block_type == BlockType.OUTPUT
|
|
and issubclass(block.input_schema, AgentOutputBlock.Input)
|
|
)
|
|
)
|
|
|
|
@computed_field
|
|
@property
|
|
def has_external_trigger(self) -> bool:
|
|
return self.webhook_input_node is not None
|
|
|
|
@computed_field
|
|
@property
|
|
def has_human_in_the_loop(self) -> bool:
|
|
return any(
|
|
node.block_id
|
|
for node in self.nodes
|
|
if node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
|
|
)
|
|
|
|
@computed_field
|
|
@property
|
|
def has_sensitive_action(self) -> bool:
|
|
return any(
|
|
node.block_id for node in self.nodes if node.block.is_sensitive_action
|
|
)
|
|
|
|
@property
|
|
def webhook_input_node(self) -> Node | None:
|
|
return next(
|
|
(
|
|
node
|
|
for node in self.nodes
|
|
if node.block.block_type
|
|
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
|
),
|
|
None,
|
|
)
|
|
|
|
@computed_field
|
|
@property
|
|
def trigger_setup_info(self) -> "GraphTriggerInfo | None":
|
|
if not (
|
|
self.webhook_input_node
|
|
and (trigger_block := self.webhook_input_node.block).webhook_config
|
|
):
|
|
return None
|
|
|
|
return GraphTriggerInfo(
|
|
provider=trigger_block.webhook_config.provider,
|
|
config_schema={
|
|
**(json_schema := trigger_block.input_schema.jsonschema()),
|
|
"properties": {
|
|
pn: sub_schema
|
|
for pn, sub_schema in json_schema["properties"].items()
|
|
if not is_credentials_field_name(pn)
|
|
},
|
|
"required": [
|
|
pn
|
|
for pn in json_schema.get("required", [])
|
|
if not is_credentials_field_name(pn)
|
|
],
|
|
},
|
|
credentials_input_name=next(
|
|
iter(trigger_block.input_schema.get_credentials_fields()), None
|
|
),
|
|
)
|
|
|
|
@staticmethod
|
|
def _generate_schema(
|
|
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
|
|
) -> dict[str, Any]:
|
|
schema_fields: list[AgentInputBlock.Input | AgentOutputBlock.Input] = []
|
|
for type_class, input_default in props:
|
|
try:
|
|
schema_fields.append(type_class.model_construct(**input_default))
|
|
except Exception as e:
|
|
logger.error(f"Invalid {type_class}: {input_default}, {e}")
|
|
|
|
return {
|
|
"type": "object",
|
|
"properties": {
|
|
p.name: {
|
|
**{
|
|
k: v
|
|
for k, v in p.generate_schema().items()
|
|
if k not in ["description", "default"]
|
|
},
|
|
"secret": p.secret,
|
|
# Default value has to be set for advanced fields.
|
|
"advanced": p.advanced and p.value is not None,
|
|
"title": p.title or p.name,
|
|
**({"description": p.description} if p.description else {}),
|
|
**({"default": p.value} if p.value is not None else {}),
|
|
}
|
|
for p in schema_fields
|
|
},
|
|
"required": [p.name for p in schema_fields if p.value is None],
|
|
}
|
|
|
|
|
|
class GraphTriggerInfo(BaseModel):
|
|
provider: ProviderName
|
|
config_schema: dict[str, Any] = Field(
|
|
description="Input schema for the trigger block"
|
|
)
|
|
credentials_input_name: Optional[str]
|
|
|
|
|
|
class Graph(BaseGraph):
|
|
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs
|
|
|
|
@computed_field
|
|
@property
|
|
def credentials_input_schema(self) -> dict[str, Any]:
|
|
schema = self._credentials_input_schema.jsonschema()
|
|
|
|
# Determine which credential fields are required based on credentials_optional metadata
|
|
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
|
required_fields = []
|
|
|
|
# Build a map of node_id -> node for quick lookup
|
|
all_nodes = {node.id: node for node in self.nodes}
|
|
for sub_graph in self.sub_graphs:
|
|
for node in sub_graph.nodes:
|
|
all_nodes[node.id] = node
|
|
|
|
for field_key, (
|
|
_field_info,
|
|
node_field_pairs,
|
|
) in graph_credentials_inputs.items():
|
|
# A field is required if ANY node using it has credentials_optional=False
|
|
is_required = False
|
|
for node_id, _field_name in node_field_pairs:
|
|
node = all_nodes.get(node_id)
|
|
if node and not node.credentials_optional:
|
|
is_required = True
|
|
break
|
|
|
|
if is_required:
|
|
required_fields.append(field_key)
|
|
|
|
schema["required"] = required_fields
|
|
return schema
|
|
|
|
@property
|
|
def _credentials_input_schema(self) -> type[BlockSchema]:
|
|
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
|
logger.debug(
|
|
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
|
f"{graph_credentials_inputs}"
|
|
)
|
|
|
|
# Warn if same-provider credentials inputs can't be combined (= bad UX)
|
|
graph_cred_fields = list(graph_credentials_inputs.values())
|
|
for i, (field, keys) in enumerate(graph_cred_fields):
|
|
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
|
|
if field.provider != other_field.provider:
|
|
continue
|
|
if ProviderName.HTTP in field.provider:
|
|
continue
|
|
|
|
# If this happens, that means a block implementation probably needs
|
|
# to be updated.
|
|
logger.warning(
|
|
"Multiple combined credentials fields "
|
|
f"for provider {field.provider} "
|
|
f"on graph #{self.id} ({self.name}); "
|
|
f"fields: {field} <> {other_field};"
|
|
f"keys: {keys} <> {other_keys}."
|
|
)
|
|
|
|
fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = {
|
|
agg_field_key: (
|
|
CredentialsMetaInput[
|
|
Literal[tuple(field_info.provider)], # type: ignore
|
|
Literal[tuple(field_info.supported_types)], # type: ignore
|
|
],
|
|
CredentialsField(
|
|
required_scopes=set(field_info.required_scopes or []),
|
|
discriminator=field_info.discriminator,
|
|
discriminator_mapping=field_info.discriminator_mapping,
|
|
discriminator_values=field_info.discriminator_values,
|
|
),
|
|
)
|
|
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
|
|
}
|
|
|
|
return create_model(
|
|
self.name.replace(" ", "") + "CredentialsInputSchema",
|
|
__base__=BlockSchema,
|
|
**fields, # type: ignore
|
|
)
|
|
|
|
def aggregate_credentials_inputs(
|
|
self,
|
|
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]:
|
|
"""
|
|
Returns:
|
|
dict[aggregated_field_key, tuple(
|
|
CredentialsFieldInfo: A spec for one aggregated credentials field
|
|
(now includes discriminator_values from matching nodes)
|
|
set[(node_id, field_name)]: Node credentials fields that are
|
|
compatible with this aggregated field spec
|
|
)]
|
|
"""
|
|
# First collect all credential field data with input defaults
|
|
node_credential_data = []
|
|
|
|
for graph in [self] + self.sub_graphs:
|
|
for node in graph.nodes:
|
|
for (
|
|
field_name,
|
|
field_info,
|
|
) in node.block.input_schema.get_credentials_fields_info().items():
|
|
|
|
discriminator = field_info.discriminator
|
|
if not discriminator:
|
|
node_credential_data.append((field_info, (node.id, field_name)))
|
|
continue
|
|
|
|
discriminator_value = node.input_default.get(discriminator)
|
|
if discriminator_value is None:
|
|
node_credential_data.append((field_info, (node.id, field_name)))
|
|
continue
|
|
|
|
discriminated_info = field_info.discriminate(discriminator_value)
|
|
discriminated_info.discriminator_values.add(discriminator_value)
|
|
|
|
node_credential_data.append(
|
|
(discriminated_info, (node.id, field_name))
|
|
)
|
|
|
|
# Combine credential field info (this will merge discriminator_values automatically)
|
|
return CredentialsFieldInfo.combine(*node_credential_data)
|
|
|
|
|
|
class GraphModel(Graph):
|
|
user_id: str
|
|
nodes: list[NodeModel] = [] # type: ignore
|
|
|
|
created_at: datetime
|
|
|
|
@property
|
|
def starting_nodes(self) -> list[NodeModel]:
|
|
outbound_nodes = {link.sink_id for link in self.links}
|
|
input_nodes = {
|
|
node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
|
|
}
|
|
return [
|
|
node
|
|
for node in self.nodes
|
|
if node.id not in outbound_nodes or node.id in input_nodes
|
|
]
|
|
|
|
@property
|
|
def webhook_input_node(self) -> NodeModel | None: # type: ignore
|
|
return cast(NodeModel, super().webhook_input_node)
|
|
|
|
def meta(self) -> "GraphMeta":
|
|
"""
|
|
Returns a GraphMeta object with metadata about the graph.
|
|
This is used to return metadata about the graph without exposing nodes and links.
|
|
"""
|
|
return GraphMeta.from_graph(self)
|
|
|
|
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
|
"""
|
|
Reassigns all IDs in the graph to new UUIDs.
|
|
This method can be used before storing a new graph to the database.
|
|
"""
|
|
if reassign_graph_id:
|
|
graph_id_map = {
|
|
self.id: str(uuid.uuid4()),
|
|
**{sub_graph.id: str(uuid.uuid4()) for sub_graph in self.sub_graphs},
|
|
}
|
|
else:
|
|
graph_id_map = {}
|
|
|
|
self._reassign_ids(self, user_id, graph_id_map)
|
|
for sub_graph in self.sub_graphs:
|
|
self._reassign_ids(sub_graph, user_id, graph_id_map)
|
|
|
|
@staticmethod
|
|
def _reassign_ids(
|
|
graph: BaseGraph,
|
|
user_id: str,
|
|
graph_id_map: dict[str, str],
|
|
):
|
|
# Reassign Graph ID
|
|
if graph.id in graph_id_map:
|
|
graph.id = graph_id_map[graph.id]
|
|
|
|
# Reassign Node IDs
|
|
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
|
for node in graph.nodes:
|
|
node.id = id_map[node.id]
|
|
|
|
# Reassign Link IDs
|
|
for link in graph.links:
|
|
if link.source_id in id_map:
|
|
link.source_id = id_map[link.source_id]
|
|
if link.sink_id in id_map:
|
|
link.sink_id = id_map[link.sink_id]
|
|
|
|
# Reassign User IDs for agent blocks
|
|
for node in graph.nodes:
|
|
if node.block_id != AgentExecutorBlock().id:
|
|
continue
|
|
node.input_default["user_id"] = user_id
|
|
node.input_default.setdefault("inputs", {})
|
|
if (
|
|
graph_id := node.input_default.get("graph_id")
|
|
) and graph_id in graph_id_map:
|
|
node.input_default["graph_id"] = graph_id_map[graph_id]
|
|
|
|
def validate_graph(
|
|
self,
|
|
for_run: bool = False,
|
|
nodes_input_masks: Optional["NodesInputMasks"] = None,
|
|
):
|
|
"""
|
|
Validate graph structure and raise `ValueError` on issues.
|
|
For structured error reporting, use `validate_graph_get_errors`.
|
|
"""
|
|
self._validate_graph(self, for_run, nodes_input_masks)
|
|
for sub_graph in self.sub_graphs:
|
|
self._validate_graph(sub_graph, for_run, nodes_input_masks)
|
|
|
|
@staticmethod
|
|
def _validate_graph(
|
|
graph: BaseGraph,
|
|
for_run: bool = False,
|
|
nodes_input_masks: Optional["NodesInputMasks"] = None,
|
|
) -> None:
|
|
errors = GraphModel._validate_graph_get_errors(
|
|
graph, for_run, nodes_input_masks
|
|
)
|
|
if errors:
|
|
# Just raise the first error for backward compatibility
|
|
first_error = next(iter(errors.values()))
|
|
first_field_error = next(iter(first_error.values()))
|
|
raise ValueError(first_field_error)
|
|
|
|
def validate_graph_get_errors(
|
|
self,
|
|
for_run: bool = False,
|
|
nodes_input_masks: Optional["NodesInputMasks"] = None,
|
|
) -> dict[str, dict[str, str]]:
|
|
"""
|
|
Validate graph and return structured errors per node.
|
|
|
|
Returns: dict[node_id, dict[field_name, error_message]]
|
|
"""
|
|
return {
|
|
**self._validate_graph_get_errors(self, for_run, nodes_input_masks),
|
|
**{
|
|
node_id: error
|
|
for sub_graph in self.sub_graphs
|
|
for node_id, error in self._validate_graph_get_errors(
|
|
sub_graph, for_run, nodes_input_masks
|
|
).items()
|
|
},
|
|
}
|
|
|
|
@staticmethod
|
|
def _validate_graph_get_errors(
|
|
graph: BaseGraph,
|
|
for_run: bool = False,
|
|
nodes_input_masks: Optional["NodesInputMasks"] = None,
|
|
) -> dict[str, dict[str, str]]:
|
|
"""
|
|
Validate graph and return structured errors per node.
|
|
|
|
Returns: dict[node_id, dict[field_name, error_message]]
|
|
"""
|
|
# First, check for structural issues with the graph
|
|
try:
|
|
GraphModel._validate_graph_structure(graph)
|
|
except ValueError:
|
|
# If structural validation fails, we can't provide per-node errors
|
|
# so we re-raise as is
|
|
raise
|
|
|
|
# Collect errors per node
|
|
node_errors: dict[str, dict[str, str]] = defaultdict(dict)
|
|
|
|
# Validate smart decision maker nodes
|
|
nodes_block = {
|
|
node.id: block
|
|
for node in graph.nodes
|
|
if (block := get_block(node.block_id)) is not None
|
|
}
|
|
|
|
input_links: dict[str, list[Link]] = defaultdict(list)
|
|
|
|
for link in graph.links:
|
|
input_links[link.sink_id].append(link)
|
|
|
|
# Nodes: required fields are filled or connected and dependencies are satisfied
|
|
for node in graph.nodes:
|
|
if (block := nodes_block.get(node.id)) is None:
|
|
# For invalid blocks, we still raise immediately as this is a structural issue
|
|
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
|
|
|
node_input_mask = (
|
|
nodes_input_masks.get(node.id, {}) if nodes_input_masks else {}
|
|
)
|
|
provided_inputs = set(
|
|
[sanitize_pin_name(name) for name in node.input_default]
|
|
+ [
|
|
sanitize_pin_name(link.sink_name)
|
|
for link in input_links.get(node.id, [])
|
|
]
|
|
+ ([name for name in node_input_mask] if node_input_mask else [])
|
|
)
|
|
InputSchema = block.input_schema
|
|
|
|
for name in (required_fields := InputSchema.get_required_fields()):
|
|
if (
|
|
name not in provided_inputs
|
|
# Checking availability of credentials is done by ExecutionManager
|
|
and name not in InputSchema.get_credentials_fields()
|
|
# Validate only I/O nodes, or validate everything when executing
|
|
and (
|
|
for_run
|
|
or block.block_type
|
|
in [
|
|
BlockType.INPUT,
|
|
BlockType.OUTPUT,
|
|
BlockType.AGENT,
|
|
]
|
|
)
|
|
):
|
|
node_errors[node.id][name] = "This field is required"
|
|
|
|
if (
|
|
block.block_type == BlockType.INPUT
|
|
and (input_key := node.input_default.get("name"))
|
|
and is_credentials_field_name(input_key)
|
|
):
|
|
node_errors[node.id]["name"] = (
|
|
f"'{input_key}' is a reserved input name: "
|
|
"'credentials' and `*_credentials` are reserved"
|
|
)
|
|
|
|
# Get input schema properties and check dependencies
|
|
input_fields = InputSchema.model_fields
|
|
|
|
def has_value(node: Node, name: str):
|
|
return (
|
|
(
|
|
name in node.input_default
|
|
and node.input_default[name] is not None
|
|
and str(node.input_default[name]).strip() != ""
|
|
)
|
|
or (name in input_fields and input_fields[name].default is not None)
|
|
or (
|
|
name in node_input_mask
|
|
and node_input_mask[name] is not None
|
|
and str(node_input_mask[name]).strip() != ""
|
|
)
|
|
)
|
|
|
|
# Validate dependencies between fields
|
|
for field_name in input_fields.keys():
|
|
field_json_schema = InputSchema.get_field_schema(field_name)
|
|
|
|
dependencies: list[str] = []
|
|
|
|
# Check regular field dependencies (only pre graph execution)
|
|
if for_run:
|
|
dependencies.extend(field_json_schema.get("depends_on", []))
|
|
|
|
# Require presence of credentials discriminator (always).
|
|
# The `discriminator` is either the name of a sibling field (str),
|
|
# or an object that discriminates between possible types for this field:
|
|
# {"propertyName": prop_name, "mapping": {prop_value: sub_schema}}
|
|
if (
|
|
discriminator := field_json_schema.get("discriminator")
|
|
) and isinstance(discriminator, str):
|
|
dependencies.append(discriminator)
|
|
|
|
if not dependencies:
|
|
continue
|
|
|
|
# Check if dependent field has value in input_default
|
|
field_has_value = has_value(node, field_name)
|
|
field_is_required = field_name in required_fields
|
|
|
|
# Check for missing dependencies when dependent field is present
|
|
missing_deps = [dep for dep in dependencies if not has_value(node, dep)]
|
|
if missing_deps and (field_has_value or field_is_required):
|
|
node_errors[node.id][
|
|
field_name
|
|
] = f"Requires {', '.join(missing_deps)} to be set"
|
|
|
|
return node_errors
|
|
|
|
@staticmethod
|
|
def _validate_graph_structure(graph: BaseGraph):
|
|
"""Validate graph structure (links, connections, etc.)"""
|
|
node_map = {v.id: v for v in graph.nodes}
|
|
|
|
def is_static_output_block(nid: str) -> bool:
|
|
return node_map[nid].block.static_output
|
|
|
|
# Links: links are connected and the connected pin data type are compatible.
|
|
for link in graph.links:
|
|
source = (link.source_id, link.source_name)
|
|
sink = (link.sink_id, link.sink_name)
|
|
prefix = f"Link {source} <-> {sink}"
|
|
|
|
for i, (node_id, name) in enumerate([source, sink]):
|
|
node = node_map.get(node_id)
|
|
if not node:
|
|
raise ValueError(
|
|
f"{prefix}, {node_id} is invalid node id, available nodes: {node_map.keys()}"
|
|
)
|
|
|
|
block = get_block(node.block_id)
|
|
if not block:
|
|
blocks = {v().id: v().name for v in get_blocks().values()}
|
|
raise ValueError(
|
|
f"{prefix}, {node.block_id} is invalid block id, available blocks: {blocks}"
|
|
)
|
|
|
|
sanitized_name = sanitize_pin_name(name)
|
|
vals = node.input_default
|
|
if i == 0:
|
|
fields = (
|
|
block.output_schema.get_fields()
|
|
if block.block_type not in [BlockType.AGENT]
|
|
else vals.get("output_schema", {}).get("properties", {}).keys()
|
|
)
|
|
else:
|
|
fields = (
|
|
block.input_schema.get_fields()
|
|
if block.block_type not in [BlockType.AGENT]
|
|
else vals.get("input_schema", {}).get("properties", {}).keys()
|
|
)
|
|
if sanitized_name not in fields and not is_tool_pin(name):
|
|
fields_msg = f"Allowed fields: {fields}"
|
|
raise ValueError(f"{prefix}, `{name}` invalid, {fields_msg}")
|
|
|
|
if is_static_output_block(link.source_id):
|
|
link.is_static = True # Each value block output should be static.
|
|
|
|
@staticmethod
|
|
def from_db(
|
|
graph: AgentGraph,
|
|
for_export: bool = False,
|
|
sub_graphs: list[AgentGraph] | None = None,
|
|
) -> "GraphModel":
|
|
return GraphModel(
|
|
id=graph.id,
|
|
user_id=graph.userId if not for_export else "",
|
|
version=graph.version,
|
|
forked_from_id=graph.forkedFromId,
|
|
forked_from_version=graph.forkedFromVersion,
|
|
created_at=graph.createdAt,
|
|
is_active=graph.isActive,
|
|
name=graph.name or "",
|
|
description=graph.description or "",
|
|
instructions=graph.instructions,
|
|
recommended_schedule_cron=graph.recommendedScheduleCron,
|
|
nodes=[NodeModel.from_db(node, for_export) for node in graph.Nodes or []],
|
|
links=list(
|
|
{
|
|
Link.from_db(link)
|
|
for node in graph.Nodes or []
|
|
for link in (node.Input or []) + (node.Output or [])
|
|
}
|
|
),
|
|
sub_graphs=[
|
|
GraphModel.from_db(sub_graph, for_export)
|
|
for sub_graph in sub_graphs or []
|
|
],
|
|
)
|
|
|
|
|
|
class GraphMeta(Graph):
|
|
user_id: str
|
|
|
|
# Easy work-around to prevent exposing nodes and links in the API response
|
|
nodes: list[NodeModel] = Field(default=[], exclude=True) # type: ignore
|
|
links: list[Link] = Field(default=[], exclude=True)
|
|
|
|
@staticmethod
|
|
def from_graph(graph: GraphModel) -> "GraphMeta":
|
|
return GraphMeta(**graph.model_dump())
|
|
|
|
|
|
class GraphsPaginated(BaseModel):
|
|
"""Response schema for paginated graphs."""
|
|
|
|
graphs: list[GraphMeta]
|
|
pagination: Pagination
|
|
|
|
|
|
# --------------------- CRUD functions --------------------- #
|
|
|
|
|
|
async def get_node(node_id: str) -> NodeModel:
|
|
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
|
node = await AgentNode.prisma().find_unique_or_raise(
|
|
where={"id": node_id},
|
|
include=AGENT_NODE_INCLUDE,
|
|
)
|
|
return NodeModel.from_db(node)
|
|
|
|
|
|
async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
|
|
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
|
node = await AgentNode.prisma().update(
|
|
where={"id": node_id},
|
|
data=(
|
|
{"Webhook": {"connect": {"id": webhook_id}}}
|
|
if webhook_id
|
|
else {"Webhook": {"disconnect": True}}
|
|
),
|
|
include=AGENT_NODE_INCLUDE,
|
|
)
|
|
if not node:
|
|
raise ValueError(f"Node #{node_id} not found")
|
|
return NodeModel.from_db(node)
|
|
|
|
|
|
async def list_graphs_paginated(
|
|
user_id: str,
|
|
page: int = 1,
|
|
page_size: int = 25,
|
|
filter_by: Literal["active"] | None = "active",
|
|
) -> GraphsPaginated:
|
|
"""
|
|
Retrieves paginated graph metadata objects.
|
|
|
|
Args:
|
|
user_id: The ID of the user that owns the graphs.
|
|
page: Page number (1-based).
|
|
page_size: Number of graphs per page.
|
|
filter_by: An optional filter to either select graphs.
|
|
|
|
Returns:
|
|
GraphsPaginated: Paginated list of graph metadata.
|
|
"""
|
|
where_clause: AgentGraphWhereInput = {"userId": user_id}
|
|
|
|
if filter_by == "active":
|
|
where_clause["isActive"] = True
|
|
|
|
# Get total count
|
|
total_count = await AgentGraph.prisma().count(where=where_clause)
|
|
total_pages = (total_count + page_size - 1) // page_size
|
|
|
|
# Get paginated results
|
|
offset = (page - 1) * page_size
|
|
graphs = await AgentGraph.prisma().find_many(
|
|
where=where_clause,
|
|
distinct=["id"],
|
|
order={"version": "desc"},
|
|
include=AGENT_GRAPH_INCLUDE,
|
|
skip=offset,
|
|
take=page_size,
|
|
)
|
|
|
|
graph_models: list[GraphMeta] = []
|
|
for graph in graphs:
|
|
try:
|
|
graph_meta = GraphModel.from_db(graph).meta()
|
|
# Trigger serialization to validate that the graph is well formed
|
|
graph_meta.model_dump()
|
|
graph_models.append(graph_meta)
|
|
except Exception as e:
|
|
logger.error(f"Error processing graph {graph.id}: {e}")
|
|
continue
|
|
|
|
return GraphsPaginated(
|
|
graphs=graph_models,
|
|
pagination=Pagination(
|
|
total_items=total_count,
|
|
total_pages=total_pages,
|
|
current_page=page,
|
|
page_size=page_size,
|
|
),
|
|
)
|
|
|
|
|
|
async def get_graph_metadata(graph_id: str, version: int | None = None) -> Graph | None:
|
|
where_clause: AgentGraphWhereInput = {
|
|
"id": graph_id,
|
|
}
|
|
|
|
if version is not None:
|
|
where_clause["version"] = version
|
|
|
|
graph = await AgentGraph.prisma().find_first(
|
|
where=where_clause,
|
|
order={"version": "desc"},
|
|
)
|
|
|
|
if not graph:
|
|
return None
|
|
|
|
return Graph(
|
|
id=graph.id,
|
|
name=graph.name or "",
|
|
description=graph.description or "",
|
|
version=graph.version,
|
|
is_active=graph.isActive,
|
|
)
|
|
|
|
|
|
async def get_graph(
|
|
graph_id: str,
|
|
version: int | None,
|
|
user_id: str | None,
|
|
*,
|
|
for_export: bool = False,
|
|
include_subgraphs: bool = False,
|
|
skip_access_check: bool = False,
|
|
) -> GraphModel | None:
|
|
"""
|
|
Retrieves a graph from the DB.
|
|
Defaults to the version with `is_active` if `version` is not passed.
|
|
|
|
Returns `None` if the record is not found.
|
|
"""
|
|
graph = None
|
|
|
|
# Only search graph directly on owned graph (or access check is skipped)
|
|
if skip_access_check or user_id is not None:
|
|
graph_where_clause: AgentGraphWhereInput = {
|
|
"id": graph_id,
|
|
}
|
|
if version is not None:
|
|
graph_where_clause["version"] = version
|
|
if not skip_access_check and user_id is not None:
|
|
graph_where_clause["userId"] = user_id
|
|
|
|
graph = await AgentGraph.prisma().find_first(
|
|
where=graph_where_clause,
|
|
include=AGENT_GRAPH_INCLUDE,
|
|
order={"version": "desc"},
|
|
)
|
|
|
|
# Use store listed graph to find not owned graph
|
|
if graph is None:
|
|
store_where_clause: StoreListingVersionWhereInput = {
|
|
"agentGraphId": graph_id,
|
|
"submissionStatus": SubmissionStatus.APPROVED,
|
|
"isDeleted": False,
|
|
}
|
|
if version is not None:
|
|
store_where_clause["agentGraphVersion"] = version
|
|
|
|
if store_listing := await StoreListingVersion.prisma().find_first(
|
|
where=store_where_clause,
|
|
order={"agentGraphVersion": "desc"},
|
|
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
|
|
):
|
|
graph = store_listing.AgentGraph
|
|
|
|
if graph is None:
|
|
return None
|
|
|
|
if include_subgraphs or for_export:
|
|
sub_graphs = await get_sub_graphs(graph)
|
|
return GraphModel.from_db(
|
|
graph=graph,
|
|
sub_graphs=sub_graphs,
|
|
for_export=for_export,
|
|
)
|
|
|
|
return GraphModel.from_db(graph, for_export)
|
|
|
|
|
|
async def get_store_listed_graphs(*graph_ids: str) -> dict[str, GraphModel]:
|
|
"""Batch-fetch multiple store-listed graphs by their IDs.
|
|
|
|
Only returns graphs that have approved store listings (publicly available).
|
|
Does not require permission checks since store-listed graphs are public.
|
|
|
|
Args:
|
|
*graph_ids: Variable number of graph IDs to fetch
|
|
|
|
Returns:
|
|
Dict mapping graph_id to GraphModel for graphs with approved store listings
|
|
"""
|
|
if not graph_ids:
|
|
return {}
|
|
|
|
store_listings = await StoreListingVersion.prisma().find_many(
|
|
where={
|
|
"agentGraphId": {"in": list(graph_ids)},
|
|
"submissionStatus": SubmissionStatus.APPROVED,
|
|
"isDeleted": False,
|
|
},
|
|
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
|
|
distinct=["agentGraphId"],
|
|
order={"agentGraphVersion": "desc"},
|
|
)
|
|
|
|
return {
|
|
listing.agentGraphId: GraphModel.from_db(listing.AgentGraph)
|
|
for listing in store_listings
|
|
if listing.AgentGraph
|
|
}
|
|
|
|
|
|
async def get_graph_as_admin(
|
|
graph_id: str,
|
|
version: int | None = None,
|
|
user_id: str | None = None,
|
|
for_export: bool = False,
|
|
) -> GraphModel | None:
|
|
"""
|
|
Intentionally parallels the get_graph but should only be used for admin tasks, because can return any graph that's been submitted
|
|
Retrieves a graph from the DB.
|
|
Defaults to the version with `is_active` if `version` is not passed.
|
|
|
|
Returns `None` if the record is not found.
|
|
"""
|
|
logger.warning(f"Getting {graph_id=} {version=} as ADMIN {user_id=} {for_export=}")
|
|
where_clause: AgentGraphWhereInput = {
|
|
"id": graph_id,
|
|
}
|
|
|
|
if version is not None:
|
|
where_clause["version"] = version
|
|
|
|
graph = await AgentGraph.prisma().find_first(
|
|
where=where_clause,
|
|
include=AGENT_GRAPH_INCLUDE,
|
|
order={"version": "desc"},
|
|
)
|
|
|
|
# For access, the graph must be owned by the user or listed in the store
|
|
if graph is None or (
|
|
graph.userId != user_id
|
|
and not await is_graph_published_in_marketplace(
|
|
graph_id, version or graph.version
|
|
)
|
|
):
|
|
return None
|
|
|
|
if for_export:
|
|
sub_graphs = await get_sub_graphs(graph)
|
|
return GraphModel.from_db(
|
|
graph=graph,
|
|
sub_graphs=sub_graphs,
|
|
for_export=for_export,
|
|
)
|
|
|
|
return GraphModel.from_db(graph, for_export)
|
|
|
|
|
|
async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
|
|
"""
|
|
Iteratively fetches all sub-graphs of a given graph, and flattens them into a list.
|
|
This call involves a DB fetch in batch, breadth-first, per-level of graph depth.
|
|
On each DB fetch we will only fetch the sub-graphs that are not already in the list.
|
|
"""
|
|
sub_graphs = {graph.id: graph}
|
|
search_graphs = [graph]
|
|
agent_block_id = AgentExecutorBlock().id
|
|
|
|
while search_graphs:
|
|
sub_graph_ids = [
|
|
(graph_id, graph_version)
|
|
for graph in search_graphs
|
|
for node in graph.Nodes or []
|
|
if (
|
|
node.AgentBlock
|
|
and node.AgentBlock.id == agent_block_id
|
|
and (graph_id := cast(str, dict(node.constantInput).get("graph_id")))
|
|
and (
|
|
graph_version := cast(
|
|
int, dict(node.constantInput).get("graph_version")
|
|
)
|
|
)
|
|
)
|
|
]
|
|
if not sub_graph_ids:
|
|
break
|
|
|
|
graphs = await AgentGraph.prisma().find_many(
|
|
where={
|
|
"OR": [
|
|
{
|
|
"id": graph_id,
|
|
"version": graph_version,
|
|
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
|
|
}
|
|
for graph_id, graph_version in sub_graph_ids
|
|
]
|
|
},
|
|
include=AGENT_GRAPH_INCLUDE,
|
|
)
|
|
|
|
search_graphs = [graph for graph in graphs if graph.id not in sub_graphs]
|
|
sub_graphs.update({graph.id: graph for graph in search_graphs})
|
|
|
|
return [g for g in sub_graphs.values() if g.id != graph.id]
|
|
|
|
|
|
async def get_connected_output_nodes(node_id: str) -> list[tuple[Link, Node]]:
|
|
links = await AgentNodeLink.prisma().find_many(
|
|
where={"agentNodeSourceId": node_id},
|
|
include={"AgentNodeSink": {"include": AGENT_NODE_INCLUDE}},
|
|
)
|
|
return [
|
|
(Link.from_db(link), NodeModel.from_db(link.AgentNodeSink))
|
|
for link in links
|
|
if link.AgentNodeSink
|
|
]
|
|
|
|
|
|
async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None:
|
|
# Activate the requested version if it exists and is owned by the user.
|
|
updated_count = await AgentGraph.prisma().update_many(
|
|
data={"isActive": True},
|
|
where={
|
|
"id": graph_id,
|
|
"version": version,
|
|
"userId": user_id,
|
|
},
|
|
)
|
|
if updated_count == 0:
|
|
raise Exception(f"Graph #{graph_id} v{version} not found or not owned by user")
|
|
|
|
# Deactivate all other versions.
|
|
await AgentGraph.prisma().update_many(
|
|
data={"isActive": False},
|
|
where={
|
|
"id": graph_id,
|
|
"version": {"not": version},
|
|
"userId": user_id,
|
|
"isActive": True,
|
|
},
|
|
)
|
|
|
|
|
|
async def get_graph_all_versions(
|
|
graph_id: str, user_id: str, limit: int = MAX_GRAPH_VERSIONS_FETCH
|
|
) -> list[GraphModel]:
|
|
graph_versions = await AgentGraph.prisma().find_many(
|
|
where={"id": graph_id, "userId": user_id},
|
|
order={"version": "desc"},
|
|
include=AGENT_GRAPH_INCLUDE,
|
|
take=limit,
|
|
)
|
|
|
|
if not graph_versions:
|
|
return []
|
|
|
|
return [GraphModel.from_db(graph) for graph in graph_versions]
|
|
|
|
|
|
async def delete_graph(graph_id: str, user_id: str) -> int:
|
|
entries_count = await AgentGraph.prisma().delete_many(
|
|
where={"id": graph_id, "userId": user_id}
|
|
)
|
|
if entries_count:
|
|
logger.info(f"Deleted {entries_count} graph entries for Graph #{graph_id}")
|
|
return entries_count
|
|
|
|
|
|
async def get_graph_settings(user_id: str, graph_id: str) -> GraphSettings:
|
|
lib = await LibraryAgent.prisma().find_first(
|
|
where={
|
|
"userId": user_id,
|
|
"agentGraphId": graph_id,
|
|
"isDeleted": False,
|
|
"isArchived": False,
|
|
},
|
|
order={"agentGraphVersion": "desc"},
|
|
)
|
|
if not lib or not lib.settings:
|
|
return GraphSettings()
|
|
|
|
try:
|
|
return GraphSettings.model_validate(lib.settings)
|
|
except Exception:
|
|
logger.warning(
|
|
f"Malformed settings for LibraryAgent user={user_id} graph={graph_id}"
|
|
)
|
|
return GraphSettings()
|
|
|
|
|
|
async def validate_graph_execution_permissions(
|
|
user_id: str, graph_id: str, graph_version: int, is_sub_graph: bool = False
|
|
) -> None:
|
|
"""
|
|
Validate that a user has permission to execute a specific graph.
|
|
|
|
This function performs comprehensive authorization checks and raises specific
|
|
exceptions for different types of failures to enable appropriate error handling.
|
|
|
|
## Logic
|
|
A user can execute a graph if any of these is true:
|
|
1. They own the graph and some version of it is still listed in their library
|
|
2. The graph is published in the marketplace and listed in their library
|
|
3. The graph is published in the marketplace and is being executed as a sub-agent
|
|
|
|
Args:
|
|
graph_id: The ID of the graph to check
|
|
user_id: The ID of the user
|
|
graph_version: The version of the graph to check
|
|
is_sub_graph: Whether this is being executed as a sub-graph.
|
|
If `True`, the graph isn't required to be in the user's Library.
|
|
|
|
Raises:
|
|
GraphNotAccessibleError: If the graph is not accessible to the user.
|
|
GraphNotInLibraryError: If the graph is not in the user's library (deleted/archived).
|
|
NotAuthorizedError: If the user lacks execution permissions for other reasons
|
|
"""
|
|
graph, library_agent = await asyncio.gather(
|
|
AgentGraph.prisma().find_unique(
|
|
where={"graphVersionId": {"id": graph_id, "version": graph_version}}
|
|
),
|
|
LibraryAgent.prisma().find_first(
|
|
where={
|
|
"userId": user_id,
|
|
"agentGraphId": graph_id,
|
|
"isDeleted": False,
|
|
"isArchived": False,
|
|
}
|
|
),
|
|
)
|
|
|
|
# Step 1: Check if user owns this graph
|
|
user_owns_graph = graph and graph.userId == user_id
|
|
|
|
# Step 2: Check if agent is in the library *and not deleted*
|
|
user_has_in_library = library_agent is not None
|
|
|
|
# Step 3: Apply permission logic
|
|
if not (
|
|
user_owns_graph
|
|
or await is_graph_published_in_marketplace(graph_id, graph_version)
|
|
):
|
|
raise GraphNotAccessibleError(
|
|
f"You do not have access to graph #{graph_id} v{graph_version}: "
|
|
"it is not owned by you and not available in the Marketplace"
|
|
)
|
|
elif not (user_has_in_library or is_sub_graph):
|
|
raise GraphNotInLibraryError(f"Graph #{graph_id} is not in your library")
|
|
|
|
# Step 6: Check execution-specific permissions (raises generic NotAuthorizedError)
|
|
# Additional authorization checks beyond the above:
|
|
# 1. Check if user has execution credits (future)
|
|
# 2. Check if graph is suspended/disabled (future)
|
|
# 3. Check rate limiting rules (future)
|
|
# 4. Check organization-level permissions (future)
|
|
|
|
# For now, the above check logic is sufficient for execution permission.
|
|
# Future enhancements can add more granular permission checks here.
|
|
# When adding new checks, raise NotAuthorizedError for non-library issues.
|
|
|
|
|
|
async def is_graph_published_in_marketplace(graph_id: str, graph_version: int) -> bool:
|
|
"""
|
|
Check if a graph is published in the marketplace.
|
|
|
|
Params:
|
|
graph_id: The ID of the graph to check
|
|
graph_version: The version of the graph to check
|
|
|
|
Returns:
|
|
True if the graph is published and approved in the marketplace, False otherwise
|
|
"""
|
|
marketplace_listing = await StoreListingVersion.prisma().find_first(
|
|
where={
|
|
"agentGraphId": graph_id,
|
|
"agentGraphVersion": graph_version,
|
|
"submissionStatus": SubmissionStatus.APPROVED,
|
|
"isDeleted": False,
|
|
}
|
|
)
|
|
return marketplace_listing is not None
|
|
|
|
|
|
async def create_graph(graph: Graph, user_id: str) -> GraphModel:
|
|
async with transaction() as tx:
|
|
await __create_graph(tx, graph, user_id)
|
|
|
|
if created_graph := await get_graph(graph.id, graph.version, user_id=user_id):
|
|
return created_graph
|
|
|
|
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
|
|
|
|
|
|
async def fork_graph(graph_id: str, graph_version: int, user_id: str) -> GraphModel:
|
|
"""
|
|
Forks a graph by copying it and all its nodes and links to a new graph.
|
|
"""
|
|
graph = await get_graph(graph_id, graph_version, user_id=user_id, for_export=True)
|
|
if not graph:
|
|
raise ValueError(f"Graph {graph_id} v{graph_version} not found")
|
|
|
|
# Set forked from ID and version as itself as it's about ot be copied
|
|
graph.forked_from_id = graph.id
|
|
graph.forked_from_version = graph.version
|
|
graph.name = f"{graph.name} (copy)"
|
|
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
|
graph.validate_graph(for_run=False)
|
|
|
|
async with transaction() as tx:
|
|
await __create_graph(tx, graph, user_id)
|
|
|
|
return graph
|
|
|
|
|
|
async def __create_graph(tx, graph: Graph, user_id: str):
|
|
graphs = [graph] + graph.sub_graphs
|
|
|
|
await AgentGraph.prisma(tx).create_many(
|
|
data=[
|
|
AgentGraphCreateInput(
|
|
id=graph.id,
|
|
version=graph.version,
|
|
name=graph.name,
|
|
description=graph.description,
|
|
recommendedScheduleCron=graph.recommended_schedule_cron,
|
|
isActive=graph.is_active,
|
|
userId=user_id,
|
|
forkedFromId=graph.forked_from_id,
|
|
forkedFromVersion=graph.forked_from_version,
|
|
)
|
|
for graph in graphs
|
|
]
|
|
)
|
|
|
|
await AgentNode.prisma(tx).create_many(
|
|
data=[
|
|
AgentNodeCreateInput(
|
|
id=node.id,
|
|
agentGraphId=graph.id,
|
|
agentGraphVersion=graph.version,
|
|
agentBlockId=node.block_id,
|
|
constantInput=SafeJson(node.input_default),
|
|
metadata=SafeJson(node.metadata),
|
|
)
|
|
for graph in graphs
|
|
for node in graph.nodes
|
|
]
|
|
)
|
|
|
|
await AgentNodeLink.prisma(tx).create_many(
|
|
data=[
|
|
AgentNodeLinkCreateInput(
|
|
id=str(uuid.uuid4()),
|
|
sourceName=link.source_name,
|
|
sinkName=link.sink_name,
|
|
agentNodeSourceId=link.source_id,
|
|
agentNodeSinkId=link.sink_id,
|
|
isStatic=link.is_static,
|
|
)
|
|
for graph in graphs
|
|
for link in graph.links
|
|
]
|
|
)
|
|
|
|
|
|
# ------------------------ UTILITIES ------------------------ #
|
|
|
|
|
|
def make_graph_model(creatable_graph: Graph, user_id: str) -> GraphModel:
|
|
"""
|
|
Convert a Graph to a GraphModel, setting graph_id and graph_version on all nodes.
|
|
|
|
Args:
|
|
creatable_graph (Graph): The creatable graph to convert.
|
|
user_id (str): The ID of the user creating the graph.
|
|
|
|
Returns:
|
|
GraphModel: The converted Graph object.
|
|
"""
|
|
# Create a new Graph object, inheriting properties from CreatableGraph
|
|
return GraphModel(
|
|
**creatable_graph.model_dump(exclude={"nodes"}),
|
|
user_id=user_id,
|
|
created_at=datetime.now(tz=timezone.utc),
|
|
nodes=[
|
|
NodeModel(
|
|
**creatable_node.model_dump(),
|
|
graph_id=creatable_graph.id,
|
|
graph_version=creatable_graph.version,
|
|
)
|
|
for creatable_node in creatable_graph.nodes
|
|
],
|
|
)
|
|
|
|
|
|
async def fix_llm_provider_credentials():
|
|
"""Fix node credentials with provider `llm`"""
|
|
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
|
|
|
from .user import get_user_integrations
|
|
|
|
store = IntegrationCredentialsStore()
|
|
|
|
broken_nodes = []
|
|
try:
|
|
broken_nodes = await query_raw_with_schema(
|
|
"""
|
|
SELECT graph."userId" user_id,
|
|
node.id node_id,
|
|
node."constantInput" node_preset_input
|
|
FROM {schema_prefix}"AgentNode" node
|
|
LEFT JOIN {schema_prefix}"AgentGraph" graph
|
|
ON node."agentGraphId" = graph.id
|
|
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
|
|
ORDER BY graph."userId";
|
|
"""
|
|
)
|
|
logger.info(f"Fixing LLM credential inputs on {len(broken_nodes)} nodes")
|
|
except Exception as e:
|
|
logger.error(f"Error fixing LLM credential inputs: {e}")
|
|
|
|
user_id: str = ""
|
|
user_integrations = None
|
|
for node in broken_nodes:
|
|
if node["user_id"] != user_id:
|
|
# Save queries by only fetching once per user
|
|
user_id = node["user_id"]
|
|
user_integrations = await get_user_integrations(user_id)
|
|
elif not user_integrations:
|
|
raise RuntimeError(f"Impossible state while processing node {node}")
|
|
|
|
node_id: str = node["node_id"]
|
|
node_preset_input: dict = node["node_preset_input"]
|
|
credentials_meta: dict = node_preset_input["credentials"]
|
|
|
|
credentials = next(
|
|
(
|
|
c
|
|
for c in user_integrations.credentials
|
|
if c.id == credentials_meta["id"]
|
|
),
|
|
None,
|
|
)
|
|
if not credentials:
|
|
continue
|
|
if credentials.type != "api_key":
|
|
logger.warning(
|
|
f"User {user_id} credentials {credentials.id} with provider 'llm' "
|
|
f"has invalid type '{credentials.type}'"
|
|
)
|
|
continue
|
|
|
|
api_key = credentials.api_key.get_secret_value()
|
|
if api_key.startswith("sk-ant-api03-"):
|
|
credentials.provider = credentials_meta["provider"] = "anthropic"
|
|
elif api_key.startswith("sk-"):
|
|
credentials.provider = credentials_meta["provider"] = "openai"
|
|
elif api_key.startswith("gsk_"):
|
|
credentials.provider = credentials_meta["provider"] = "groq"
|
|
else:
|
|
logger.warning(
|
|
f"Could not identify provider from key prefix {api_key[:13]}*****"
|
|
)
|
|
continue
|
|
|
|
await store.update_creds(user_id, credentials)
|
|
await AgentNode.prisma().update(
|
|
where={"id": node_id},
|
|
data={"constantInput": SafeJson(node_preset_input)},
|
|
)
|
|
|
|
|
|
async def migrate_llm_models(migrate_to: LlmModel):
|
|
"""
|
|
Update all LLM models in all AI blocks that don't exist in the enum.
|
|
Note: Only updates top level LlmModel SchemaFields of blocks (won't update nested fields).
|
|
"""
|
|
logger.info("Migrating LLM models")
|
|
# Scan all blocks and search for LlmModel fields
|
|
llm_model_fields: dict[str, str] = {} # {block_id: field_name}
|
|
|
|
# Search for all LlmModel fields
|
|
for block_type in get_blocks().values():
|
|
block = block_type()
|
|
from pydantic.fields import FieldInfo
|
|
|
|
fields: dict[str, FieldInfo] = block.input_schema.model_fields
|
|
|
|
# Collect top-level LlmModel fields
|
|
for field_name, field in fields.items():
|
|
if field.annotation == LlmModel:
|
|
llm_model_fields[block.id] = field_name
|
|
|
|
# Convert enum values to a list of strings for the SQL query
|
|
enum_values = [v.value for v in LlmModel]
|
|
escaped_enum_values = repr(tuple(enum_values)) # hack but works
|
|
|
|
# Update each block
|
|
for id, path in llm_model_fields.items():
|
|
query = f"""
|
|
UPDATE platform."AgentNode"
|
|
SET "constantInput" = jsonb_set("constantInput", $1, to_jsonb($2), true)
|
|
WHERE "agentBlockId" = $3
|
|
AND "constantInput" ? ($4)::text
|
|
AND "constantInput"->>($4)::text NOT IN {escaped_enum_values}
|
|
"""
|
|
|
|
await db.execute_raw(
|
|
query, # type: ignore - is supposed to be LiteralString
|
|
[path],
|
|
migrate_to.value,
|
|
id,
|
|
path,
|
|
)
|
|
|
|
|
|
# Simple placeholder class for deleted/missing blocks
|
|
class _UnknownBlockBase(Block):
|
|
"""
|
|
Placeholder for deleted/missing blocks that inherits from Block
|
|
but uses a name that doesn't end with 'Block' to avoid auto-discovery.
|
|
"""
|
|
|
|
def __init__(self, block_id: str = "00000000-0000-0000-0000-000000000000"):
|
|
# Initialize with minimal valid Block parameters
|
|
super().__init__(
|
|
id=block_id,
|
|
description=f"Unknown or deleted block (original ID: {block_id})",
|
|
disabled=True,
|
|
input_schema=EmptySchema,
|
|
output_schema=EmptySchema,
|
|
categories=set(),
|
|
contributors=[],
|
|
static_output=False,
|
|
block_type=BlockType.STANDARD,
|
|
webhook_config=None,
|
|
)
|
|
|
|
@property
|
|
def name(self):
|
|
return "UnknownBlock"
|
|
|
|
async def run(self, input_data, **kwargs):
|
|
"""Always yield an error for missing blocks."""
|
|
yield "error", f"Block {self.id} no longer exists"
|