mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(platform/library): Add credentials UX on /library/agents/[id] (#9789)
- Resolves #9771 - ... in a non-persistent way, so it won't work for webhook-triggered agents For webhooks: #9541 ### Changes 🏗️ Frontend: - Add credentials inputs in Library "New run" screen (based on `graph.credentials_input_schema`) - Refactor `CredentialsInput` and `useCredentials` to not rely on XYFlow context - Unsplit lists of saved credentials in `CredentialsProvider` state - Move logic that was being executed at component render to `useEffect` hooks in `CredentialsInput` Backend: - Implement logic to aggregate credentials input requirements to one per provider per graph - Add `BaseGraph.credentials_input_schema` (JSON schema) computed field Underlying added logic: - `BaseGraph._credentials_input_schema` - makes a `BlockSchema` from a graph's aggregated credentials inputs - `BaseGraph.aggregate_credentials_inputs()` - aggregates a graph's nodes' credentials inputs using `CredentialsFieldInfo.combine(..)` - `BlockSchema.get_credentials_fields_info() -> dict[str, CredentialsFieldInfo]` - `CredentialsFieldInfo` model (created from `_CredentialsFieldSchemaExtra`) - Implement logic to inject explicitly passed credentials into graph execution - Add `credentials_inputs` parameter to `execute_graph` endpoint - Add `graph_credentials_input` parameter to `.executor.utils.add_graph_execution(..)` - Implement `.executor.utils.make_node_credentials_input_map(..)` - Amend `.executor.utils.construct_node_execution_input` - Add `GraphExecutionEntry.node_credentials_input_map` attribute - Amend validation to allow injecting credentials - Amend `GraphModel._validate_graph(..)` - Amend `.executor.utils._validate_node_input_credentials` - Add `node_credentials_map` parameter to `ExecutionManager.add_execution(..)` - Amend execution validation to handle side-loaded credentials - Add `GraphExecutionEntry.node_execution_map` attribute - Add mechanism to inject passed credentials into node execution data - Add credentials injection mechanism to node execution queueing logic in `Executor._on_graph_execution(..)` - Replace boilerplate logic in `v1.execute_graph` endpoint with call to existing `.executor.utils.add_graph_execution(..)` - Replace calls to `.server.routers.v1.execute_graph` with `add_graph_execution` Also: - Address tech debt in `GraphModel._validate_gaph(..)` - Fix type checking in `BaseGraph._generate_schema(..)` #### TODO - [ ] ~~Make "Run again" work with credentials in `AgentRunDetailsView`~~ - [ ] Prohibit saving a graph if it has nodes with missing discriminator value for discriminated credentials inputs ### Checklist 📋 #### For code changes: - [ ] I have clearly listed my changes in the PR description - [ ] I have made a test plan - [ ] I have tested my changes according to the test plan: <!-- Put your test plan here: --> - [ ] ...
This commit is contained in:
committed by
GitHub
parent
f16a398a8e
commit
417d7732af
@@ -67,15 +67,15 @@ class AgentExecutorBlock(Block):
|
||||
graph_id=input_data.graph_id,
|
||||
graph_version=input_data.graph_version,
|
||||
user_id=input_data.user_id,
|
||||
data=input_data.data,
|
||||
inputs=input_data.data,
|
||||
)
|
||||
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.graph_exec_id}"
|
||||
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.id}"
|
||||
logger.info(f"Starting execution of {log_id}")
|
||||
|
||||
for event in event_bus.listen(
|
||||
user_id=graph_exec.user_id,
|
||||
graph_id=graph_exec.graph_id,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
graph_exec_id=graph_exec.id,
|
||||
):
|
||||
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
|
||||
if event.status in [
|
||||
|
||||
@@ -28,6 +28,7 @@ from backend.util.settings import Config
|
||||
from .model import (
|
||||
ContributorDetails,
|
||||
Credentials,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsMetaInput,
|
||||
is_credentials_field_name,
|
||||
)
|
||||
@@ -203,6 +204,15 @@ class BlockSchema(BaseModel):
|
||||
)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_credentials_fields_info(cls) -> dict[str, CredentialsFieldInfo]:
|
||||
return {
|
||||
field_name: CredentialsFieldInfo.model_validate(
|
||||
cls.get_field_schema(field_name), by_alias=True
|
||||
)
|
||||
for field_name in cls.get_credentials_fields().keys()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||
return data # Return as is, by default.
|
||||
@@ -509,6 +519,7 @@ async def initialize_blocks() -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_block(block_id: str) -> Block | None:
|
||||
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
||||
def get_block(block_id: str) -> Block[BlockSchema, BlockSchema] | None:
|
||||
cls = get_blocks().get(block_id)
|
||||
return cls() if cls else None
|
||||
|
||||
@@ -44,7 +44,7 @@ from .includes import (
|
||||
GRAPH_EXECUTION_INCLUDE,
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
)
|
||||
from .model import GraphExecutionStats, NodeExecutionStats
|
||||
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
|
||||
from .queue import AsyncRedisEventBus, RedisEventBus
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -220,6 +220,7 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
)
|
||||
for node_exec in self.node_executions
|
||||
],
|
||||
node_credentials_input_map={}, # FIXME
|
||||
)
|
||||
|
||||
|
||||
@@ -361,7 +362,7 @@ async def get_graph_execution(
|
||||
async def create_graph_execution(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
nodes_input: list[tuple[str, BlockInput]],
|
||||
starting_nodes_input: list[tuple[str, BlockInput]],
|
||||
user_id: str,
|
||||
preset_id: str | None = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
@@ -388,7 +389,7 @@ async def create_graph_execution(
|
||||
]
|
||||
},
|
||||
)
|
||||
for node_id, node_input in nodes_input
|
||||
for node_id, node_input in starting_nodes_input
|
||||
]
|
||||
},
|
||||
userId=user_id,
|
||||
@@ -712,6 +713,7 @@ class GraphExecutionEntry(BaseModel):
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
start_node_execs: list["NodeExecutionEntry"]
|
||||
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]]
|
||||
|
||||
|
||||
class NodeExecutionEntry(BaseModel):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any, Literal, Optional, Type, cast
|
||||
from typing import Any, Literal, Optional, cast
|
||||
|
||||
import prisma
|
||||
from prisma import Json
|
||||
@@ -13,12 +13,19 @@ from prisma.types import (
|
||||
AgentNodeCreateInput,
|
||||
AgentNodeLinkCreateInput,
|
||||
)
|
||||
from pydantic import create_model
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.db import prisma as db
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsMetaInput,
|
||||
is_credentials_field_name,
|
||||
)
|
||||
from backend.util import type as type_utils
|
||||
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
@@ -190,14 +197,19 @@ class BaseGraph(BaseDbModel):
|
||||
)
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def credentials_input_schema(self) -> dict[str, Any]:
|
||||
return self._credentials_input_schema.jsonschema()
|
||||
|
||||
@staticmethod
|
||||
def _generate_schema(
|
||||
*props: tuple[Type[AgentInputBlock.Input] | Type[AgentOutputBlock.Input], dict],
|
||||
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
|
||||
) -> dict[str, Any]:
|
||||
schema = []
|
||||
schema_fields: list[AgentInputBlock.Input | AgentOutputBlock.Input] = []
|
||||
for type_class, input_default in props:
|
||||
try:
|
||||
schema.append(type_class(**input_default))
|
||||
schema_fields.append(type_class(**input_default))
|
||||
except Exception as e:
|
||||
logger.warning(f"Invalid {type_class}: {input_default}, {e}")
|
||||
|
||||
@@ -217,9 +229,93 @@ class BaseGraph(BaseDbModel):
|
||||
**({"description": p.description} if p.description else {}),
|
||||
**({"default": p.value} if p.value is not None else {}),
|
||||
}
|
||||
for p in schema
|
||||
for p in schema_fields
|
||||
},
|
||||
"required": [p.name for p in schema if p.value is None],
|
||||
"required": [p.name for p in schema_fields if p.value is None],
|
||||
}
|
||||
|
||||
@property
|
||||
def _credentials_input_schema(self) -> type[BlockSchema]:
|
||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||
logger.debug(
|
||||
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
||||
f"{graph_credentials_inputs}"
|
||||
)
|
||||
|
||||
# Warn if same-provider credentials inputs can't be combined (= bad UX)
|
||||
graph_cred_fields = list(graph_credentials_inputs.values())
|
||||
for i, (field, keys) in enumerate(graph_cred_fields):
|
||||
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
|
||||
if field.provider != other_field.provider:
|
||||
continue
|
||||
|
||||
# If this happens, that means a block implementation probably needs
|
||||
# to be updated.
|
||||
logger.warning(
|
||||
"Multiple combined credentials fields "
|
||||
f"for provider {field.provider} "
|
||||
f"on graph #{self.id} ({self.name}); "
|
||||
f"fields: {field} <> {other_field};"
|
||||
f"keys: {keys} <> {other_keys}."
|
||||
)
|
||||
|
||||
fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = {
|
||||
agg_field_key: (
|
||||
CredentialsMetaInput[
|
||||
Literal[tuple(field_info.provider)], # type: ignore
|
||||
Literal[tuple(field_info.supported_types)], # type: ignore
|
||||
],
|
||||
CredentialsField(
|
||||
required_scopes=set(field_info.required_scopes or []),
|
||||
discriminator=field_info.discriminator,
|
||||
discriminator_mapping=field_info.discriminator_mapping,
|
||||
),
|
||||
)
|
||||
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
|
||||
}
|
||||
|
||||
return create_model(
|
||||
self.name.replace(" ", "") + "CredentialsInputSchema",
|
||||
__base__=BlockSchema,
|
||||
**fields, # type: ignore
|
||||
)
|
||||
|
||||
def aggregate_credentials_inputs(
|
||||
self,
|
||||
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]:
|
||||
"""
|
||||
Returns:
|
||||
dict[aggregated_field_key, tuple(
|
||||
CredentialsFieldInfo: A spec for one aggregated credentials field
|
||||
set[(node_id, field_name)]: Node credentials fields that are
|
||||
compatible with this aggregated field spec
|
||||
)]
|
||||
"""
|
||||
return {
|
||||
"_".join(sorted(agg_field_info.provider))
|
||||
+ "_"
|
||||
+ "_".join(sorted(agg_field_info.supported_types))
|
||||
+ "_credentials": (agg_field_info, node_fields)
|
||||
for agg_field_info, node_fields in CredentialsFieldInfo.combine(
|
||||
*(
|
||||
(
|
||||
# Apply discrimination before aggregating credentials inputs
|
||||
(
|
||||
field_info.discriminate(
|
||||
node.input_default[field_info.discriminator]
|
||||
)
|
||||
if (
|
||||
field_info.discriminator
|
||||
and node.input_default.get(field_info.discriminator)
|
||||
)
|
||||
else field_info
|
||||
),
|
||||
(node.id, field_name),
|
||||
)
|
||||
for node in self.nodes
|
||||
for field_name, field_info in node.block.input_schema.get_credentials_fields_info().items()
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -320,8 +416,6 @@ class GraphModel(Graph):
|
||||
return sanitized_name
|
||||
|
||||
# Validate smart decision maker nodes
|
||||
smart_decision_maker_nodes = set()
|
||||
agent_nodes = set()
|
||||
nodes_block = {
|
||||
node.id: block
|
||||
for node in graph.nodes
|
||||
@@ -332,13 +426,6 @@ class GraphModel(Graph):
|
||||
if (block := nodes_block.get(node.id)) is None:
|
||||
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
||||
|
||||
# Smart decision maker nodes
|
||||
if block.block_type == BlockType.AI:
|
||||
smart_decision_maker_nodes.add(node.id)
|
||||
# Agent nodes
|
||||
elif block.block_type == BlockType.AGENT:
|
||||
agent_nodes.add(node.id)
|
||||
|
||||
input_links = defaultdict(list)
|
||||
|
||||
for link in graph.links:
|
||||
@@ -353,16 +440,21 @@ class GraphModel(Graph):
|
||||
[sanitize(name) for name in node.input_default]
|
||||
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
|
||||
)
|
||||
for name in block.input_schema.get_required_fields():
|
||||
input_schema = block.input_schema
|
||||
for name in (required_fields := input_schema.get_required_fields()):
|
||||
if (
|
||||
name not in provided_inputs
|
||||
# Webhook payload is passed in by ExecutionManager
|
||||
and not (
|
||||
name == "payload"
|
||||
and block.block_type
|
||||
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
)
|
||||
# Checking availability of credentials is done by ExecutionManager
|
||||
and name not in input_schema.get_credentials_fields()
|
||||
# Validate only I/O nodes, or validate everything when executing
|
||||
and (
|
||||
for_run # Skip input completion validation, unless when executing.
|
||||
for_run
|
||||
or block.block_type
|
||||
in [
|
||||
BlockType.INPUT,
|
||||
@@ -375,9 +467,18 @@ class GraphModel(Graph):
|
||||
f"Node {block.name} #{node.id} required input missing: `{name}`"
|
||||
)
|
||||
|
||||
if (
|
||||
block.block_type == BlockType.INPUT
|
||||
and (input_key := node.input_default.get("name"))
|
||||
and is_credentials_field_name(input_key)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Agent input node uses reserved name '{input_key}'; "
|
||||
"'credentials' and `*_credentials` are reserved input names"
|
||||
)
|
||||
|
||||
# Get input schema properties and check dependencies
|
||||
input_schema = block.input_schema.model_fields
|
||||
required_fields = block.input_schema.get_required_fields()
|
||||
input_fields = input_schema.model_fields
|
||||
|
||||
def has_value(name):
|
||||
return (
|
||||
@@ -385,14 +486,21 @@ class GraphModel(Graph):
|
||||
and name in node.input_default
|
||||
and node.input_default[name] is not None
|
||||
and str(node.input_default[name]).strip() != ""
|
||||
) or (name in input_schema and input_schema[name].default is not None)
|
||||
) or (name in input_fields and input_fields[name].default is not None)
|
||||
|
||||
# Validate dependencies between fields
|
||||
for field_name, field_info in input_schema.items():
|
||||
for field_name, field_info in input_fields.items():
|
||||
# Apply input dependency validation only on run & field with depends_on
|
||||
json_schema_extra = field_info.json_schema_extra or {}
|
||||
dependencies = json_schema_extra.get("depends_on", [])
|
||||
if not for_run or not dependencies:
|
||||
if not (
|
||||
for_run
|
||||
and isinstance(json_schema_extra, dict)
|
||||
and (
|
||||
dependencies := cast(
|
||||
list[str], json_schema_extra.get("depends_on", [])
|
||||
)
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
# Check if dependent field has value in input_default
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -12,6 +13,7 @@ from typing import (
|
||||
Generic,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
get_args,
|
||||
@@ -300,9 +302,7 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
)
|
||||
field_schema = model.jsonschema()["properties"][field_name]
|
||||
try:
|
||||
schema_extra = _CredentialsFieldSchemaExtra[CP, CT].model_validate(
|
||||
field_schema
|
||||
)
|
||||
schema_extra = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
|
||||
except ValidationError as e:
|
||||
if "Field required [type=missing" not in str(e):
|
||||
raise
|
||||
@@ -328,14 +328,90 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
)
|
||||
|
||||
|
||||
class _CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
|
||||
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
|
||||
credentials_provider: list[CP]
|
||||
credentials_scopes: Optional[list[str]] = None
|
||||
credentials_types: list[CT]
|
||||
provider: frozenset[CP] = Field(..., alias="credentials_provider")
|
||||
supported_types: frozenset[CT] = Field(..., alias="credentials_types")
|
||||
required_scopes: Optional[frozenset[str]] = Field(None, alias="credentials_scopes")
|
||||
discriminator: Optional[str] = None
|
||||
discriminator_mapping: Optional[dict[str, CP]] = None
|
||||
|
||||
@classmethod
|
||||
def combine(
|
||||
cls, *fields: tuple[CredentialsFieldInfo[CP, CT], T]
|
||||
) -> Sequence[tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
|
||||
"""
|
||||
Combines multiple CredentialsFieldInfo objects into as few as possible.
|
||||
|
||||
Rules:
|
||||
- Items can only be combined if they have the same supported credentials types
|
||||
and the same supported providers.
|
||||
- When combining items, the `required_scopes` of the result is a join
|
||||
of the `required_scopes` of the original items.
|
||||
|
||||
Params:
|
||||
*fields: (CredentialsFieldInfo, key) objects to group and combine
|
||||
|
||||
Returns:
|
||||
A sequence of tuples containing combined CredentialsFieldInfo objects and
|
||||
the set of keys of the respective original items that were grouped together.
|
||||
"""
|
||||
if not fields:
|
||||
return []
|
||||
|
||||
# Group fields by their provider and supported_types
|
||||
grouped_fields: defaultdict[
|
||||
tuple[frozenset[CP], frozenset[CT]],
|
||||
list[tuple[T, CredentialsFieldInfo[CP, CT]]],
|
||||
] = defaultdict(list)
|
||||
|
||||
for field, key in fields:
|
||||
group_key = (frozenset(field.provider), frozenset(field.supported_types))
|
||||
grouped_fields[group_key].append((key, field))
|
||||
|
||||
# Combine fields within each group
|
||||
result: list[tuple[CredentialsFieldInfo[CP, CT], set[T]]] = []
|
||||
|
||||
for group in grouped_fields.values():
|
||||
# Start with the first field in the group
|
||||
_, combined = group[0]
|
||||
|
||||
# Track the keys that were combined
|
||||
combined_keys = {key for key, _ in group}
|
||||
|
||||
# Combine required_scopes from all fields in the group
|
||||
all_scopes = set()
|
||||
for _, field in group:
|
||||
if field.required_scopes:
|
||||
all_scopes.update(field.required_scopes)
|
||||
|
||||
# Create a new combined field
|
||||
result.append(
|
||||
(
|
||||
CredentialsFieldInfo[CP, CT](
|
||||
credentials_provider=combined.provider,
|
||||
credentials_types=combined.supported_types,
|
||||
credentials_scopes=frozenset(all_scopes) or None,
|
||||
discriminator=combined.discriminator,
|
||||
discriminator_mapping=combined.discriminator_mapping,
|
||||
),
|
||||
combined_keys,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def discriminate(self, discriminator_value: Any) -> CredentialsFieldInfo:
|
||||
if not (self.discriminator and self.discriminator_mapping):
|
||||
return self
|
||||
|
||||
discriminator_value = self.discriminator_mapping[discriminator_value]
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([discriminator_value]),
|
||||
credentials_types=self.supported_types,
|
||||
credentials_scopes=self.required_scopes,
|
||||
)
|
||||
|
||||
|
||||
def CredentialsField(
|
||||
required_scopes: set[str] = set(),
|
||||
|
||||
@@ -724,10 +724,10 @@ class Executor:
|
||||
execution_status = ExecutionStatus.TERMINATED
|
||||
return execution_stats, execution_status, error
|
||||
|
||||
exec_data = queue.get()
|
||||
queued_node_exec = queue.get()
|
||||
|
||||
# Avoid parallel execution of the same node.
|
||||
execution = running_executions.get(exec_data.node_id)
|
||||
execution = running_executions.get(queued_node_exec.node_id)
|
||||
if execution and not execution.ready():
|
||||
# TODO (performance improvement):
|
||||
# Wait for the completion of the same node execution is blocking.
|
||||
@@ -736,18 +736,18 @@ class Executor:
|
||||
execution.wait()
|
||||
|
||||
log_metadata.debug(
|
||||
f"Dispatching node execution {exec_data.node_exec_id} "
|
||||
f"for node {exec_data.node_id}",
|
||||
f"Dispatching node execution {queued_node_exec.node_exec_id} "
|
||||
f"for node {queued_node_exec.node_id}",
|
||||
)
|
||||
|
||||
try:
|
||||
exec_cost_counter = cls._charge_usage(
|
||||
node_exec=exec_data,
|
||||
node_exec=queued_node_exec,
|
||||
execution_count=exec_cost_counter + 1,
|
||||
execution_stats=execution_stats,
|
||||
)
|
||||
except InsufficientBalanceError as error:
|
||||
node_exec_id = exec_data.node_exec_id
|
||||
node_exec_id = queued_node_exec.node_exec_id
|
||||
cls.db_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="error",
|
||||
@@ -768,10 +768,23 @@ class Executor:
|
||||
)
|
||||
raise
|
||||
|
||||
running_executions[exec_data.node_id] = cls.executor.apply_async(
|
||||
# Add credentials input overrides
|
||||
node_id = queued_node_exec.node_id
|
||||
if (node_creds_map := graph_exec.node_credentials_input_map) and (
|
||||
node_field_creds_map := node_creds_map.get(node_id)
|
||||
):
|
||||
queued_node_exec.data.update(
|
||||
{
|
||||
field_name: creds_meta.model_dump()
|
||||
for field_name, creds_meta in node_field_creds_map.items()
|
||||
}
|
||||
)
|
||||
|
||||
# Initiate node execution
|
||||
running_executions[queued_node_exec.node_id] = cls.executor.apply_async(
|
||||
cls.on_node_execution,
|
||||
(queue, exec_data),
|
||||
callback=make_exec_callback(exec_data),
|
||||
(queue, queued_node_exec),
|
||||
callback=make_exec_callback(queued_node_exec),
|
||||
)
|
||||
|
||||
# Avoid terminating graph execution when some nodes are still running.
|
||||
|
||||
@@ -70,7 +70,7 @@ def execute_graph(**kwargs):
|
||||
log(f"Executing recurring job for graph #{args.graph_id}")
|
||||
execution_utils.add_graph_execution(
|
||||
graph_id=args.graph_id,
|
||||
data=args.input_data,
|
||||
inputs=args.input_data,
|
||||
user_id=args.user_id,
|
||||
graph_version=args.graph_version,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from pydantic import BaseModel
|
||||
@@ -14,15 +14,23 @@ from backend.data.block import (
|
||||
)
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCostType
|
||||
from backend.data.execution import GraphExecutionEntry, RedisExecutionEventBus
|
||||
from backend.data.graph import GraphModel, Node
|
||||
from backend.data.execution import (
|
||||
AsyncRedisExecutionEventBus,
|
||||
GraphExecutionWithNodes,
|
||||
RedisExecutionEventBus,
|
||||
create_graph_execution,
|
||||
)
|
||||
from backend.data.graph import GraphModel, Node, get_graph
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.rabbitmq import (
|
||||
AsyncRabbitMQ,
|
||||
Exchange,
|
||||
ExchangeType,
|
||||
Queue,
|
||||
RabbitMQConfig,
|
||||
SyncRabbitMQ,
|
||||
)
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Config
|
||||
@@ -43,6 +51,11 @@ def get_execution_event_bus() -> RedisExecutionEventBus:
|
||||
return RedisExecutionEventBus()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_async_execution_event_bus() -> AsyncRedisExecutionEventBus:
|
||||
return AsyncRedisExecutionEventBus()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_queue() -> SyncRabbitMQ:
|
||||
client = SyncRabbitMQ(create_execution_queue_config())
|
||||
@@ -50,6 +63,13 @@ def get_execution_queue() -> SyncRabbitMQ:
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
async def get_async_execution_queue() -> AsyncRabbitMQ:
|
||||
client = AsyncRabbitMQ(create_execution_queue_config())
|
||||
await client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_integration_credentials_store() -> "IntegrationCredentialsStore":
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
@@ -347,7 +367,13 @@ def merge_execution_input(data: BlockInput) -> BlockInput:
|
||||
return data
|
||||
|
||||
|
||||
def _validate_node_input_credentials(graph: GraphModel, user_id: str):
|
||||
def _validate_node_input_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
):
|
||||
"""Checks all credentials for all nodes of the graph"""
|
||||
|
||||
for node in graph.nodes:
|
||||
@@ -361,9 +387,22 @@ def _validate_node_input_credentials(graph: GraphModel, user_id: str):
|
||||
continue
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
)
|
||||
if (
|
||||
node_credentials_input_map
|
||||
and (node_credentials_inputs := node_credentials_input_map.get(node.id))
|
||||
and field_name in node_credentials_inputs
|
||||
):
|
||||
credentials_meta = node_credentials_input_map[node.id][field_name]
|
||||
elif field_name in node.input_default:
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Credentials absent for {block.name} node #{node.id} "
|
||||
f"input '{field_name}'"
|
||||
)
|
||||
|
||||
# Fetch the corresponding Credentials and perform sanity checks
|
||||
credentials = get_integration_credentials_store().get_creds_by_id(
|
||||
user_id, credentials_meta.id
|
||||
@@ -389,10 +428,46 @@ def _validate_node_input_credentials(graph: GraphModel, user_id: str):
|
||||
)
|
||||
|
||||
|
||||
def make_node_credentials_input_map(
|
||||
graph: GraphModel,
|
||||
graph_credentials_input: dict[str, CredentialsMetaInput],
|
||||
) -> dict[str, dict[str, CredentialsMetaInput]]:
|
||||
"""
|
||||
Maps credentials for an execution to the correct nodes.
|
||||
|
||||
Params:
|
||||
graph: The graph to be executed.
|
||||
graph_credentials_input: A (graph_input_name, credentials_meta) map.
|
||||
|
||||
Returns:
|
||||
dict[node_id, dict[field_name, CredentialsMetaInput]]: Node credentials input map.
|
||||
"""
|
||||
result: dict[str, dict[str, CredentialsMetaInput]] = {}
|
||||
|
||||
# Get aggregated credentials fields for the graph
|
||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
||||
|
||||
for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items():
|
||||
# Best-effort map: skip missing items
|
||||
if graph_input_name not in graph_credentials_input:
|
||||
continue
|
||||
|
||||
# Use passed-in credentials for all compatible node input fields
|
||||
for node_id, node_field_name in compatible_node_fields:
|
||||
if node_id not in result:
|
||||
result[node_id] = {}
|
||||
result[node_id][node_field_name] = graph_credentials_input[graph_input_name]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def construct_node_execution_input(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
data: BlockInput,
|
||||
graph_inputs: BlockInput,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
) -> list[tuple[str, BlockInput]]:
|
||||
"""
|
||||
Validates and prepares the input data for executing a graph.
|
||||
@@ -404,13 +479,14 @@ def construct_node_execution_input(
|
||||
graph (GraphModel): The graph model to execute.
|
||||
user_id (str): The ID of the user executing the graph.
|
||||
data (BlockInput): The input data for the graph execution.
|
||||
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
|
||||
|
||||
Returns:
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
|
||||
the corresponding input data for that node.
|
||||
"""
|
||||
graph.validate_graph(for_run=True)
|
||||
_validate_node_input_credentials(graph, user_id)
|
||||
_validate_node_input_credentials(graph, user_id, node_credentials_input_map)
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
@@ -424,8 +500,8 @@ def construct_node_execution_input(
|
||||
# Extract request input data, and assign it to the input pin.
|
||||
if block.block_type == BlockType.INPUT:
|
||||
input_name = node.input_default.get("name")
|
||||
if input_name and input_name in data:
|
||||
input_data = {"value": data[input_name]}
|
||||
if input_name and input_name in graph_inputs:
|
||||
input_data = {"value": graph_inputs[input_name]}
|
||||
|
||||
# Extract webhook payload, and assign it to the input pin
|
||||
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
|
||||
@@ -433,11 +509,17 @@ def construct_node_execution_input(
|
||||
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
and node.webhook_id
|
||||
):
|
||||
if webhook_payload_key not in data:
|
||||
if webhook_payload_key not in graph_inputs:
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} webhook payload is missing"
|
||||
)
|
||||
input_data = {"payload": data[webhook_payload_key]}
|
||||
input_data = {"payload": graph_inputs[webhook_payload_key]}
|
||||
|
||||
# Apply node credentials overrides
|
||||
if node_credentials_input_map and (
|
||||
node_credentials := node_credentials_input_map.get(node.id)
|
||||
):
|
||||
input_data.update({k: v.model_dump() for k, v in node_credentials.items()})
|
||||
|
||||
input_data, error = validate_exec(node, input_data)
|
||||
if input_data is None:
|
||||
@@ -505,47 +587,128 @@ def create_execution_queue_config() -> RabbitMQConfig:
|
||||
)
|
||||
|
||||
|
||||
def add_graph_execution(
|
||||
async def add_graph_execution_async(
|
||||
graph_id: str,
|
||||
data: BlockInput,
|
||||
user_id: str,
|
||||
graph_version: int | None = None,
|
||||
preset_id: str | None = None,
|
||||
) -> GraphExecutionEntry:
|
||||
inputs: BlockInput,
|
||||
preset_id: Optional[str] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Adds a graph execution to the queue and returns the execution entry.
|
||||
|
||||
Args:
|
||||
graph_id (str): The ID of the graph to execute.
|
||||
data (BlockInput): The input data for the graph execution.
|
||||
user_id (str): The ID of the user executing the graph.
|
||||
graph_version (int | None): The version of the graph to execute. Defaults to None.
|
||||
preset_id (str | None): The ID of the preset to use. Defaults to None.
|
||||
graph_id: The ID of the graph to execute.
|
||||
user_id: The ID of the user executing the graph.
|
||||
inputs: The input data for the graph execution.
|
||||
preset_id: The ID of the preset to use.
|
||||
graph_version: The version of the graph to execute.
|
||||
graph_credentials_inputs: Credentials inputs to use in the execution.
|
||||
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
|
||||
Returns:
|
||||
GraphExecutionEntry: The entry for the graph execution.
|
||||
Raises:
|
||||
ValueError: If the graph is not found or if there are validation errors.
|
||||
""" # noqa
|
||||
execution_event_bus = get_async_execution_event_bus()
|
||||
graph_execution_queue = await get_async_execution_queue()
|
||||
|
||||
graph: GraphModel | None = await get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Graph #{graph_id} not found.")
|
||||
|
||||
node_credentials_input_map = (
|
||||
make_node_credentials_input_map(graph, graph_credentials_inputs)
|
||||
if graph_credentials_inputs
|
||||
else None
|
||||
)
|
||||
|
||||
graph_exec = await create_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
starting_nodes_input=construct_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=inputs,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
),
|
||||
preset_id=preset_id,
|
||||
)
|
||||
await execution_event_bus.publish(graph_exec)
|
||||
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry()
|
||||
if node_credentials_input_map:
|
||||
graph_exec_entry.node_credentials_input_map = node_credentials_input_map
|
||||
await graph_execution_queue.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec_entry.model_dump_json(),
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
|
||||
return graph_exec
|
||||
|
||||
|
||||
def add_graph_execution(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
inputs: BlockInput,
|
||||
preset_id: Optional[str] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Adds a graph execution to the queue and returns the execution entry.
|
||||
|
||||
Args:
|
||||
graph_id: The ID of the graph to execute.
|
||||
user_id: The ID of the user executing the graph.
|
||||
inputs: The input data for the graph execution.
|
||||
preset_id: The ID of the preset to use.
|
||||
graph_version: The version of the graph to execute.
|
||||
graph_credentials_inputs: Credentials inputs to use in the execution.
|
||||
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
|
||||
Returns:
|
||||
GraphExecutionEntry: The entry for the graph execution.
|
||||
Raises:
|
||||
ValueError: If the graph is not found or if there are validation errors.
|
||||
""" # noqa
|
||||
graph: GraphModel | None = get_db_client().get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
if not graph:
|
||||
raise ValueError(f"Graph #{graph_id} not found.")
|
||||
raise NotFoundError(f"Graph #{graph_id} not found.")
|
||||
|
||||
node_credentials_input_map = (
|
||||
make_node_credentials_input_map(graph, graph_credentials_inputs)
|
||||
if graph_credentials_inputs
|
||||
else None
|
||||
)
|
||||
|
||||
graph_exec = get_db_client().create_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
nodes_input=construct_node_execution_input(graph, user_id, data),
|
||||
user_id=user_id,
|
||||
starting_nodes_input=construct_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=inputs,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
),
|
||||
preset_id=preset_id,
|
||||
)
|
||||
get_execution_event_bus().publish(graph_exec)
|
||||
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry()
|
||||
if node_credentials_input_map:
|
||||
graph_exec_entry.node_credentials_input_map = node_credentials_input_map
|
||||
get_execution_queue().publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec_entry.model_dump_json(),
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
|
||||
return graph_exec_entry
|
||||
return graph_exec
|
||||
|
||||
@@ -12,8 +12,8 @@ from backend.data import graph as graph_db
|
||||
from backend.data.api_key import APIKey
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.execution import NodeExecutionResult
|
||||
from backend.executor.utils import add_graph_execution_async
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.server.routers import v1 as internal_api_routes
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
@@ -97,13 +97,13 @@ async def execute_graph(
|
||||
api_key: APIKey = Depends(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
graph_exec = await internal_api_routes.execute_graph(
|
||||
graph_exec = await add_graph_execution_async(
|
||||
graph_id=graph_id,
|
||||
node_input=node_input,
|
||||
user_id=api_key.user_id,
|
||||
inputs=node_input,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
return {"id": graph_exec.graph_exec_id}
|
||||
return {"id": graph_exec.id}
|
||||
except Exception as e:
|
||||
msg = str(e).encode().decode("unicode_escape")
|
||||
raise HTTPException(status_code=400, detail=msg)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Annotated, Literal
|
||||
from typing import TYPE_CHECKING, Annotated, Awaitable, Literal
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -15,11 +15,11 @@ from backend.data.integrations import (
|
||||
wait_for_webhook_event,
|
||||
)
|
||||
from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
|
||||
from backend.executor.utils import add_graph_execution_async
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.server.routers import v1 as internal_api_routes
|
||||
from backend.util.exceptions import NeedConfirmation, NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -309,7 +309,7 @@ async def webhook_ingress_generic(
|
||||
if not webhook.attached_nodes:
|
||||
return
|
||||
|
||||
executions = []
|
||||
executions: list[Awaitable] = []
|
||||
for node in webhook.attached_nodes:
|
||||
logger.debug(f"Webhook-attached node: {node}")
|
||||
if not node.is_triggered_by_event_type(event_type):
|
||||
@@ -317,11 +317,11 @@ async def webhook_ingress_generic(
|
||||
continue
|
||||
logger.debug(f"Executing graph #{node.graph_id} node #{node.id}")
|
||||
executions.append(
|
||||
internal_api_routes.execute_graph(
|
||||
add_graph_execution_async(
|
||||
user_id=webhook.user_id,
|
||||
graph_id=node.graph_id,
|
||||
graph_version=node.graph_version,
|
||||
node_input={f"webhook_{webhook_id}_payload": payload},
|
||||
user_id=webhook.user_id,
|
||||
inputs={f"webhook_{webhook_id}_payload": payload},
|
||||
)
|
||||
)
|
||||
asyncio.gather(*executions)
|
||||
|
||||
@@ -159,7 +159,8 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
node_input=node_input or {},
|
||||
inputs=node_input or {},
|
||||
credentials_inputs={},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -41,6 +41,7 @@ from backend.data.credit import (
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.data.onboarding import (
|
||||
UserOnboardingUpdate,
|
||||
@@ -592,35 +593,21 @@ async def set_graph_active_version(
|
||||
)
|
||||
async def execute_graph(
|
||||
graph_id: str,
|
||||
node_input: Annotated[dict[str, Any], Body(..., default_factory=dict)],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
inputs: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
credentials_inputs: Annotated[
|
||||
dict[str, CredentialsMetaInput], Body(..., embed=True, default_factory=dict)
|
||||
],
|
||||
graph_version: Optional[int] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
) -> ExecuteGraphResponse:
|
||||
graph: graph_db.GraphModel | None = await graph_db.get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
if not graph:
|
||||
raise ValueError(f"Graph #{graph_id} not found.")
|
||||
|
||||
graph_exec = await execution_db.create_graph_execution(
|
||||
graph_exec = await execution_utils.add_graph_execution_async(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
nodes_input=execution_utils.construct_node_execution_input(
|
||||
graph, user_id, node_input
|
||||
),
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
|
||||
bus = execution_event_bus()
|
||||
await bus.publish(graph_exec)
|
||||
|
||||
queue = await execution_queue_client()
|
||||
await queue.publish_message(
|
||||
routing_key=execution_utils.GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec.to_graph_execution_entry().model_dump_json(),
|
||||
exchange=execution_utils.GRAPH_EXECUTION_EXCHANGE,
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=credentials_inputs,
|
||||
)
|
||||
return ExecuteGraphResponse(graph_exec_id=graph_exec.id)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from fastapi import APIRouter, Body, Depends, HTTPException, status
|
||||
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.library.model as models
|
||||
from backend.executor.utils import add_graph_execution_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -207,8 +208,6 @@ async def execute_preset(
|
||||
HTTPException: If the preset is not found or an error occurs while executing the preset.
|
||||
"""
|
||||
try:
|
||||
from backend.server.routers import v1 as internal_api_routes
|
||||
|
||||
preset = await db.get_preset(user_id, preset_id)
|
||||
if not preset:
|
||||
raise HTTPException(
|
||||
@@ -219,17 +218,17 @@ async def execute_preset(
|
||||
# Merge input overrides with preset inputs
|
||||
merged_node_input = preset.inputs | node_input
|
||||
|
||||
execution = await internal_api_routes.execute_graph(
|
||||
execution = await add_graph_execution_async(
|
||||
graph_id=graph_id,
|
||||
node_input=merged_node_input,
|
||||
graph_version=graph_version,
|
||||
user_id=user_id,
|
||||
inputs=merged_node_input,
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
|
||||
logger.debug(f"Execution added: {execution} with input: {merged_node_input}")
|
||||
|
||||
return {"id": execution.graph_exec_id}
|
||||
return {"id": execution.id}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -131,11 +131,7 @@ export default function PrivatePage() {
|
||||
|
||||
const allCredentials = providers
|
||||
? Object.values(providers).flatMap((provider) =>
|
||||
[
|
||||
...provider.savedOAuthCredentials,
|
||||
...provider.savedApiKeys,
|
||||
...provider.savedUserPasswordCredentials,
|
||||
]
|
||||
provider.savedCredentials
|
||||
.filter((cred) => !hiddenCredentials.includes(cred.id))
|
||||
.map((credentials) => ({
|
||||
...credentials,
|
||||
|
||||
@@ -178,18 +178,24 @@ export const CustomNode = React.memo(
|
||||
return obj;
|
||||
}, []);
|
||||
|
||||
const setHardcodedValues = (values: any) => {
|
||||
updateNodeData(id, { hardcodedValues: values });
|
||||
};
|
||||
const setHardcodedValues = useCallback(
|
||||
(values: any) => {
|
||||
updateNodeData(id, { hardcodedValues: values });
|
||||
},
|
||||
[id, updateNodeData],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
isInitialSetup.current = false;
|
||||
setHardcodedValues(fillDefaults(data.hardcodedValues, data.inputSchema));
|
||||
}, []);
|
||||
|
||||
const setErrors = (errors: { [key: string]: string }) => {
|
||||
updateNodeData(id, { errors });
|
||||
};
|
||||
const setErrors = useCallback(
|
||||
(errors: { [key: string]: string }) => {
|
||||
updateNodeData(id, { errors });
|
||||
},
|
||||
[id, updateNodeData],
|
||||
);
|
||||
|
||||
const toggleOutput = (checked: boolean) => {
|
||||
setIsOutputOpen(checked);
|
||||
@@ -340,46 +346,49 @@ export const CustomNode = React.memo(
|
||||
});
|
||||
}
|
||||
};
|
||||
const handleInputChange = (path: string, value: any) => {
|
||||
const keys = parseKeys(path);
|
||||
const newValues = JSON.parse(JSON.stringify(data.hardcodedValues));
|
||||
let current = newValues;
|
||||
const handleInputChange = useCallback(
|
||||
(path: string, value: any) => {
|
||||
const keys = parseKeys(path);
|
||||
const newValues = JSON.parse(JSON.stringify(data.hardcodedValues));
|
||||
let current = newValues;
|
||||
|
||||
for (let i = 0; i < keys.length - 1; i++) {
|
||||
const { key: currentKey, index } = keys[i];
|
||||
if (index !== undefined) {
|
||||
if (!current[currentKey]) current[currentKey] = [];
|
||||
if (!current[currentKey][index]) current[currentKey][index] = {};
|
||||
current = current[currentKey][index];
|
||||
} else {
|
||||
if (!current[currentKey]) current[currentKey] = {};
|
||||
current = current[currentKey];
|
||||
for (let i = 0; i < keys.length - 1; i++) {
|
||||
const { key: currentKey, index } = keys[i];
|
||||
if (index !== undefined) {
|
||||
if (!current[currentKey]) current[currentKey] = [];
|
||||
if (!current[currentKey][index]) current[currentKey][index] = {};
|
||||
current = current[currentKey][index];
|
||||
} else {
|
||||
if (!current[currentKey]) current[currentKey] = {};
|
||||
current = current[currentKey];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const lastKey = keys[keys.length - 1];
|
||||
if (lastKey.index !== undefined) {
|
||||
if (!current[lastKey.key]) current[lastKey.key] = [];
|
||||
current[lastKey.key][lastKey.index] = value;
|
||||
} else {
|
||||
current[lastKey.key] = value;
|
||||
}
|
||||
const lastKey = keys[keys.length - 1];
|
||||
if (lastKey.index !== undefined) {
|
||||
if (!current[lastKey.key]) current[lastKey.key] = [];
|
||||
current[lastKey.key][lastKey.index] = value;
|
||||
} else {
|
||||
current[lastKey.key] = value;
|
||||
}
|
||||
|
||||
if (!isInitialSetup.current) {
|
||||
history.push({
|
||||
type: "UPDATE_INPUT",
|
||||
payload: { nodeId: id, oldValues: data.hardcodedValues, newValues },
|
||||
undo: () => setHardcodedValues(data.hardcodedValues),
|
||||
redo: () => setHardcodedValues(newValues),
|
||||
});
|
||||
}
|
||||
if (!isInitialSetup.current) {
|
||||
history.push({
|
||||
type: "UPDATE_INPUT",
|
||||
payload: { nodeId: id, oldValues: data.hardcodedValues, newValues },
|
||||
undo: () => setHardcodedValues(data.hardcodedValues),
|
||||
redo: () => setHardcodedValues(newValues),
|
||||
});
|
||||
}
|
||||
|
||||
setHardcodedValues(newValues);
|
||||
const errors = data.errors || {};
|
||||
// Remove error with the same key
|
||||
setNestedProperty(errors, path, null);
|
||||
setErrors({ ...errors });
|
||||
};
|
||||
setHardcodedValues(newValues);
|
||||
const errors = data.errors || {};
|
||||
// Remove error with the same key
|
||||
setNestedProperty(errors, path, null);
|
||||
setErrors({ ...errors });
|
||||
},
|
||||
[data.hardcodedValues, id, setHardcodedValues, data.errors, setErrors],
|
||||
);
|
||||
|
||||
const isInputHandleConnected = (key: string) => {
|
||||
return (
|
||||
@@ -407,28 +416,34 @@ export const CustomNode = React.memo(
|
||||
);
|
||||
};
|
||||
|
||||
const handleInputClick = (key: string) => {
|
||||
console.debug(`Opening modal for key: ${key}`);
|
||||
setActiveKey(key);
|
||||
const value = getValue(key, data.hardcodedValues);
|
||||
setInputModalValue(
|
||||
typeof value === "object" ? JSON.stringify(value, null, 2) : value,
|
||||
);
|
||||
setIsModalOpen(true);
|
||||
};
|
||||
const handleInputClick = useCallback(
|
||||
(key: string) => {
|
||||
console.debug(`Opening modal for key: ${key}`);
|
||||
setActiveKey(key);
|
||||
const value = getValue(key, data.hardcodedValues);
|
||||
setInputModalValue(
|
||||
typeof value === "object" ? JSON.stringify(value, null, 2) : value,
|
||||
);
|
||||
setIsModalOpen(true);
|
||||
},
|
||||
[data.hardcodedValues],
|
||||
);
|
||||
|
||||
const handleModalSave = (value: string) => {
|
||||
if (activeKey) {
|
||||
try {
|
||||
const parsedValue = JSON.parse(value);
|
||||
handleInputChange(activeKey, parsedValue);
|
||||
} catch (error) {
|
||||
handleInputChange(activeKey, value);
|
||||
const handleModalSave = useCallback(
|
||||
(value: string) => {
|
||||
if (activeKey) {
|
||||
try {
|
||||
const parsedValue = JSON.parse(value);
|
||||
handleInputChange(activeKey, parsedValue);
|
||||
} catch (error) {
|
||||
handleInputChange(activeKey, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
setIsModalOpen(false);
|
||||
setActiveKey(null);
|
||||
};
|
||||
setIsModalOpen(false);
|
||||
setActiveKey(null);
|
||||
},
|
||||
[activeKey, handleInputChange],
|
||||
);
|
||||
|
||||
const handleOutputClick = () => {
|
||||
setIsOutputModalOpen(true);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"use client";
|
||||
import React, { useCallback, useMemo } from "react";
|
||||
import { isEmpty } from "lodash";
|
||||
import moment from "moment";
|
||||
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
@@ -164,7 +165,8 @@ export default function AgentRunDetailsView({
|
||||
] satisfies ButtonAction[])
|
||||
: []),
|
||||
...(["success", "failed", "stopped"].includes(runStatus) &&
|
||||
!graph.has_webhook_trigger
|
||||
!graph.has_webhook_trigger &&
|
||||
isEmpty(graph.credentials_input_schema.required) // TODO: enable re-run with credentials - https://linear.app/autogpt/issue/SECRT-1243
|
||||
? [
|
||||
{
|
||||
label: (
|
||||
@@ -193,6 +195,7 @@ export default function AgentRunDetailsView({
|
||||
stopRun,
|
||||
deleteRun,
|
||||
graph.has_webhook_trigger,
|
||||
graph.credentials_input_schema.properties,
|
||||
agent.can_access_graph,
|
||||
run.graph_id,
|
||||
run.graph_version,
|
||||
|
||||
@@ -6,6 +6,7 @@ import { GraphExecutionID, GraphMeta } from "@/lib/autogpt-server-api";
|
||||
|
||||
import type { ButtonAction } from "@/components/agptui/types";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { CredentialsInput } from "@/components/integrations/credentials-input";
|
||||
import { TypeBasedInput } from "@/components/type-based-input";
|
||||
import { useToastOnFail } from "@/components/ui/use-toast";
|
||||
import ActionButtonGroup from "@/components/agptui/action-button-group";
|
||||
@@ -26,19 +27,32 @@ export default function AgentRunDraftView({
|
||||
const toastOnFail = useToastOnFail();
|
||||
|
||||
const agentInputs = graph.input_schema.properties;
|
||||
const agentCredentialsInputs = graph.credentials_input_schema.properties;
|
||||
const [inputValues, setInputValues] = useState<Record<string, any>>({});
|
||||
const [inputCredentials, setInputCredentials] = useState<Record<string, any>>(
|
||||
{},
|
||||
);
|
||||
const { state, completeStep } = useOnboarding();
|
||||
|
||||
const doRun = useCallback(() => {
|
||||
api
|
||||
.executeGraph(graph.id, graph.version, inputValues)
|
||||
.executeGraph(graph.id, graph.version, inputValues, inputCredentials)
|
||||
.then((newRun) => onRun(newRun.graph_exec_id))
|
||||
.catch(toastOnFail("execute agent"));
|
||||
// Mark run agent onboarding step as completed
|
||||
if (state?.completedSteps.includes("MARKETPLACE_ADD_AGENT")) {
|
||||
completeStep("MARKETPLACE_RUN_AGENT");
|
||||
}
|
||||
}, [api, graph, inputValues, onRun, state]);
|
||||
}, [
|
||||
api,
|
||||
graph,
|
||||
inputValues,
|
||||
inputCredentials,
|
||||
onRun,
|
||||
toastOnFail,
|
||||
state,
|
||||
completeStep,
|
||||
]);
|
||||
|
||||
const runActions: ButtonAction[] = useMemo(
|
||||
() => [
|
||||
@@ -64,6 +78,26 @@ export default function AgentRunDraftView({
|
||||
<CardTitle className="font-poppins text-lg">Input</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent className="flex flex-col gap-4">
|
||||
{/* Credentials inputs */}
|
||||
{Object.entries(agentCredentialsInputs).map(
|
||||
([key, inputSubSchema]) => (
|
||||
<CredentialsInput
|
||||
key={key}
|
||||
schema={{ ...inputSubSchema, discriminator: undefined }}
|
||||
selectedCredentials={
|
||||
inputCredentials[key] ?? inputSubSchema.default
|
||||
}
|
||||
onSelectCredentials={(value) =>
|
||||
setInputCredentials((obj) => ({
|
||||
...obj,
|
||||
[key]: value,
|
||||
}))
|
||||
}
|
||||
/>
|
||||
),
|
||||
)}
|
||||
|
||||
{/* Regular inputs */}
|
||||
{Object.entries(agentInputs).map(([key, inputSubSchema]) => (
|
||||
<div key={key} className="flex flex-col space-y-2">
|
||||
<label className="flex items-center gap-1 text-sm font-medium">
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { FC, useEffect, useMemo, useState } from "react";
|
||||
import { z } from "zod";
|
||||
import { beautifyString, cn } from "@/lib/utils";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Button } from "@/components/ui/button";
|
||||
@@ -16,8 +17,8 @@ import {
|
||||
FaKey,
|
||||
FaHubspot,
|
||||
} from "react-icons/fa";
|
||||
import { FC, useMemo, useState } from "react";
|
||||
import {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsMetaInput,
|
||||
CredentialsProviderName,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
@@ -106,13 +107,18 @@ export type OAuthPopupResultMessage = { message_type: "oauth_popup_result" } & (
|
||||
);
|
||||
|
||||
export const CredentialsInput: FC<{
|
||||
selfKey: string;
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
className?: string;
|
||||
selectedCredentials?: CredentialsMetaInput;
|
||||
onSelectCredentials: (newValue?: CredentialsMetaInput) => void;
|
||||
}> = ({ selfKey, className, selectedCredentials, onSelectCredentials }) => {
|
||||
const api = useBackendAPI();
|
||||
const credentials = useCredentials(selfKey);
|
||||
siblingInputs?: Record<string, any>;
|
||||
}> = ({
|
||||
schema,
|
||||
className,
|
||||
selectedCredentials,
|
||||
onSelectCredentials,
|
||||
siblingInputs,
|
||||
}) => {
|
||||
const [isAPICredentialsModalOpen, setAPICredentialsModalOpen] =
|
||||
useState(false);
|
||||
const [
|
||||
@@ -124,20 +130,47 @@ export const CredentialsInput: FC<{
|
||||
useState<AbortController | null>(null);
|
||||
const [oAuthError, setOAuthError] = useState<string | null>(null);
|
||||
|
||||
if (!credentials || credentials.isLoading) {
|
||||
const api = useBackendAPI();
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
|
||||
// Deselect credentials if they do not exist (e.g. provider was changed)
|
||||
useEffect(() => {
|
||||
if (!credentials || !("savedCredentials" in credentials)) return;
|
||||
if (
|
||||
selectedCredentials &&
|
||||
!credentials.savedCredentials.some((c) => c.id === selectedCredentials.id)
|
||||
) {
|
||||
onSelectCredentials(undefined);
|
||||
}
|
||||
}, [credentials, selectedCredentials, onSelectCredentials]);
|
||||
|
||||
const singleCredential = useMemo(() => {
|
||||
if (!credentials || !("savedCredentials" in credentials)) return null;
|
||||
|
||||
if (credentials.savedCredentials.length === 1)
|
||||
return credentials.savedCredentials[0];
|
||||
|
||||
return null;
|
||||
}, [credentials]);
|
||||
|
||||
// If only 1 credential is available, auto-select it and hide this input
|
||||
useEffect(() => {
|
||||
if (singleCredential && !selectedCredentials) {
|
||||
onSelectCredentials(singleCredential);
|
||||
}
|
||||
}, [singleCredential, selectedCredentials, onSelectCredentials]);
|
||||
|
||||
if (!credentials || credentials.isLoading || singleCredential) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const {
|
||||
schema,
|
||||
provider,
|
||||
providerName,
|
||||
supportsApiKey,
|
||||
supportsOAuth2,
|
||||
supportsUserPassword,
|
||||
savedApiKeys,
|
||||
savedOAuthCredentials,
|
||||
savedUserPasswordCredentials,
|
||||
savedCredentials,
|
||||
oAuthCallback,
|
||||
} = credentials;
|
||||
|
||||
@@ -235,13 +268,14 @@ export const CredentialsInput: FC<{
|
||||
<>
|
||||
{supportsApiKey && (
|
||||
<APIKeyCredentialsModal
|
||||
credentialsFieldName={selfKey}
|
||||
schema={schema}
|
||||
open={isAPICredentialsModalOpen}
|
||||
onClose={() => setAPICredentialsModalOpen(false)}
|
||||
onCredentialsCreate={(credsMeta) => {
|
||||
onSelectCredentials(credsMeta);
|
||||
setAPICredentialsModalOpen(false);
|
||||
}}
|
||||
siblingInputs={siblingInputs}
|
||||
/>
|
||||
)}
|
||||
{supportsOAuth2 && (
|
||||
@@ -253,43 +287,34 @@ export const CredentialsInput: FC<{
|
||||
)}
|
||||
{supportsUserPassword && (
|
||||
<UserPasswordCredentialsModal
|
||||
credentialsFieldName={selfKey}
|
||||
schema={schema}
|
||||
open={isUserPasswordCredentialsModalOpen}
|
||||
onClose={() => setUserPasswordCredentialsModalOpen(false)}
|
||||
onCredentialsCreate={(creds) => {
|
||||
onSelectCredentials(creds);
|
||||
setUserPasswordCredentialsModalOpen(false);
|
||||
}}
|
||||
siblingInputs={siblingInputs}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
|
||||
// Deselect credentials if they do not exist (e.g. provider was changed)
|
||||
if (
|
||||
selectedCredentials &&
|
||||
!savedApiKeys
|
||||
.concat(savedOAuthCredentials)
|
||||
.concat(savedUserPasswordCredentials)
|
||||
.some((c) => c.id === selectedCredentials.id)
|
||||
) {
|
||||
onSelectCredentials(undefined);
|
||||
}
|
||||
const fieldHeader = (
|
||||
<div className="mb-2 flex gap-1">
|
||||
<span className="text-m green text-gray-900">
|
||||
{providerName} Credentials
|
||||
</span>
|
||||
<SchemaTooltip description={schema.description} />
|
||||
</div>
|
||||
);
|
||||
|
||||
// No saved credentials yet
|
||||
if (
|
||||
savedApiKeys.length === 0 &&
|
||||
savedOAuthCredentials.length === 0 &&
|
||||
savedUserPasswordCredentials.length === 0
|
||||
) {
|
||||
if (savedCredentials.length === 0) {
|
||||
return (
|
||||
<>
|
||||
<div className="mb-2 flex gap-1">
|
||||
<span className="text-m green text-gray-900">
|
||||
{providerName} Credentials
|
||||
</span>
|
||||
<SchemaTooltip description={schema.description} />
|
||||
</div>
|
||||
<div>
|
||||
{fieldHeader}
|
||||
|
||||
<div className={cn("flex flex-row space-x-2", className)}>
|
||||
{supportsOAuth2 && (
|
||||
<Button onClick={handleOAuthLogin}>
|
||||
@@ -314,46 +339,10 @@ export const CredentialsInput: FC<{
|
||||
{oAuthError && (
|
||||
<div className="mt-2 text-red-500">Error: {oAuthError}</div>
|
||||
)}
|
||||
</>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const getCredentialCounts = () => ({
|
||||
apiKeys: savedApiKeys.length,
|
||||
oauth: savedOAuthCredentials.length,
|
||||
userPass: savedUserPasswordCredentials.length,
|
||||
});
|
||||
|
||||
const getSingleCredential = () => {
|
||||
const counts = getCredentialCounts();
|
||||
const totalCredentials = Object.values(counts).reduce(
|
||||
(sum, count) => sum + count,
|
||||
0,
|
||||
);
|
||||
|
||||
if (totalCredentials !== 1) return null;
|
||||
|
||||
if (counts.apiKeys === 1) return savedApiKeys[0];
|
||||
if (counts.oauth === 1) return savedOAuthCredentials[0];
|
||||
if (counts.userPass === 1) return savedUserPasswordCredentials[0];
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
const singleCredential = getSingleCredential();
|
||||
|
||||
if (singleCredential) {
|
||||
if (!selectedCredentials) {
|
||||
onSelectCredentials({
|
||||
id: singleCredential.id,
|
||||
type: singleCredential.type,
|
||||
provider,
|
||||
title: singleCredential.title,
|
||||
});
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function handleValueChange(newValue: string) {
|
||||
if (newValue === "sign-in") {
|
||||
// Trigger OAuth2 sign in flow
|
||||
@@ -362,10 +351,7 @@ export const CredentialsInput: FC<{
|
||||
// Open API key dialog
|
||||
setAPICredentialsModalOpen(true);
|
||||
} else {
|
||||
const selectedCreds = savedApiKeys
|
||||
.concat(savedOAuthCredentials)
|
||||
.concat(savedUserPasswordCredentials)
|
||||
.find((c) => c.id == newValue)!;
|
||||
const selectedCreds = savedCredentials.find((c) => c.id == newValue)!;
|
||||
|
||||
onSelectCredentials({
|
||||
id: selectedCreds.id,
|
||||
@@ -378,38 +364,40 @@ export const CredentialsInput: FC<{
|
||||
|
||||
// Saved credentials exist
|
||||
return (
|
||||
<>
|
||||
<div className="flex gap-1">
|
||||
<span className="text-m green mb-0 text-gray-900">
|
||||
{providerName} Credentials
|
||||
</span>
|
||||
<SchemaTooltip description={schema.description} />
|
||||
</div>
|
||||
<div>
|
||||
{fieldHeader}
|
||||
|
||||
<Select value={selectedCredentials?.id} onValueChange={handleValueChange}>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder={schema.placeholder} />
|
||||
</SelectTrigger>
|
||||
<SelectContent className="nodrag">
|
||||
{savedOAuthCredentials.map((credentials, index) => (
|
||||
<SelectItem key={index} value={credentials.id}>
|
||||
<ProviderIcon className="mr-2 inline h-4 w-4" />
|
||||
{credentials.username}
|
||||
</SelectItem>
|
||||
))}
|
||||
{savedApiKeys.map((credentials, index) => (
|
||||
<SelectItem key={index} value={credentials.id}>
|
||||
<ProviderIcon className="mr-2 inline h-4 w-4" />
|
||||
<IconKey className="mr-1.5 inline" />
|
||||
{credentials.title}
|
||||
</SelectItem>
|
||||
))}
|
||||
{savedUserPasswordCredentials.map((credentials, index) => (
|
||||
<SelectItem key={index} value={credentials.id}>
|
||||
<ProviderIcon className="mr-2 inline h-4 w-4" />
|
||||
<IconUserPlus className="mr-1.5 inline" />
|
||||
{credentials.title}
|
||||
</SelectItem>
|
||||
))}
|
||||
{savedCredentials
|
||||
.filter((c) => c.type == "oauth2")
|
||||
.map((credentials, index) => (
|
||||
<SelectItem key={index} value={credentials.id}>
|
||||
<ProviderIcon className="mr-2 inline h-4 w-4" />
|
||||
{credentials.username}
|
||||
</SelectItem>
|
||||
))}
|
||||
{savedCredentials
|
||||
.filter((c) => c.type == "api_key")
|
||||
.map((credentials, index) => (
|
||||
<SelectItem key={index} value={credentials.id}>
|
||||
<ProviderIcon className="mr-2 inline h-4 w-4" />
|
||||
<IconKey className="mr-1.5 inline" />
|
||||
{credentials.title}
|
||||
</SelectItem>
|
||||
))}
|
||||
{savedCredentials
|
||||
.filter((c) => c.type == "user_password")
|
||||
.map((credentials, index) => (
|
||||
<SelectItem key={index} value={credentials.id}>
|
||||
<ProviderIcon className="mr-2 inline h-4 w-4" />
|
||||
<IconUserPlus className="mr-1.5 inline" />
|
||||
{credentials.title}
|
||||
</SelectItem>
|
||||
))}
|
||||
<SelectSeparator />
|
||||
{supportsOAuth2 && (
|
||||
<SelectItem value="sign-in">
|
||||
@@ -435,17 +423,18 @@ export const CredentialsInput: FC<{
|
||||
{oAuthError && (
|
||||
<div className="mt-2 text-red-500">Error: {oAuthError}</div>
|
||||
)}
|
||||
</>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export const APIKeyCredentialsModal: FC<{
|
||||
credentialsFieldName: string;
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||
}> = ({ credentialsFieldName, open, onClose, onCredentialsCreate }) => {
|
||||
const credentials = useCredentials(credentialsFieldName);
|
||||
siblingInputs?: Record<string, any>;
|
||||
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
|
||||
const formSchema = z.object({
|
||||
apiKey: z.string().min(1, "API Key is required"),
|
||||
@@ -466,8 +455,7 @@ export const APIKeyCredentialsModal: FC<{
|
||||
return null;
|
||||
}
|
||||
|
||||
const { schema, provider, providerName, createAPIKeyCredentials } =
|
||||
credentials;
|
||||
const { provider, providerName, createAPIKeyCredentials } = credentials;
|
||||
|
||||
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||
const expiresAt = values.expiresAt
|
||||
@@ -576,12 +564,13 @@ export const APIKeyCredentialsModal: FC<{
|
||||
};
|
||||
|
||||
export const UserPasswordCredentialsModal: FC<{
|
||||
credentialsFieldName: string;
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||
}> = ({ credentialsFieldName, open, onClose, onCredentialsCreate }) => {
|
||||
const credentials = useCredentials(credentialsFieldName);
|
||||
siblingInputs?: Record<string, any>;
|
||||
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
|
||||
const formSchema = z.object({
|
||||
username: z.string().min(1, "Username is required"),
|
||||
@@ -606,8 +595,7 @@ export const UserPasswordCredentialsModal: FC<{
|
||||
return null;
|
||||
}
|
||||
|
||||
const { schema, provider, providerName, createUserPasswordCredentials } =
|
||||
credentials;
|
||||
const { provider, providerName, createUserPasswordCredentials } = credentials;
|
||||
|
||||
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||
const newCredentials = await createUserPasswordCredentials({
|
||||
|
||||
@@ -68,9 +68,7 @@ type UserPasswordCredentialsCreatable = Omit<
|
||||
export type CredentialsProviderData = {
|
||||
provider: CredentialsProviderName;
|
||||
providerName: string;
|
||||
savedApiKeys: CredentialsMetaResponse[];
|
||||
savedOAuthCredentials: CredentialsMetaResponse[];
|
||||
savedUserPasswordCredentials: CredentialsMetaResponse[];
|
||||
savedCredentials: CredentialsMetaResponse[];
|
||||
oAuthCallback: (
|
||||
code: string,
|
||||
state_token: string,
|
||||
@@ -113,28 +111,12 @@ export default function CredentialsProvider({
|
||||
setProviders((prev) => {
|
||||
if (!prev || !prev[provider]) return prev;
|
||||
|
||||
const updatedProvider = { ...prev[provider] };
|
||||
|
||||
if (credentials.type === "api_key") {
|
||||
updatedProvider.savedApiKeys = [
|
||||
...updatedProvider.savedApiKeys,
|
||||
credentials,
|
||||
];
|
||||
} else if (credentials.type === "oauth2") {
|
||||
updatedProvider.savedOAuthCredentials = [
|
||||
...updatedProvider.savedOAuthCredentials,
|
||||
credentials,
|
||||
];
|
||||
} else if (credentials.type === "user_password") {
|
||||
updatedProvider.savedUserPasswordCredentials = [
|
||||
...updatedProvider.savedUserPasswordCredentials,
|
||||
credentials,
|
||||
];
|
||||
}
|
||||
|
||||
return {
|
||||
...prev,
|
||||
[provider]: updatedProvider,
|
||||
[provider]: {
|
||||
...prev[provider],
|
||||
savedCredentials: [...prev[provider].savedCredentials, credentials],
|
||||
},
|
||||
};
|
||||
});
|
||||
},
|
||||
@@ -203,21 +185,14 @@ export default function CredentialsProvider({
|
||||
setProviders((prev) => {
|
||||
if (!prev || !prev[provider]) return prev;
|
||||
|
||||
const updatedProvider = { ...prev[provider] };
|
||||
updatedProvider.savedApiKeys = updatedProvider.savedApiKeys.filter(
|
||||
(cred) => cred.id !== id,
|
||||
);
|
||||
updatedProvider.savedOAuthCredentials =
|
||||
updatedProvider.savedOAuthCredentials.filter(
|
||||
(cred) => cred.id !== id,
|
||||
);
|
||||
updatedProvider.savedUserPasswordCredentials =
|
||||
updatedProvider.savedUserPasswordCredentials.filter(
|
||||
(cred) => cred.id !== id,
|
||||
);
|
||||
return {
|
||||
...prev,
|
||||
[provider]: updatedProvider,
|
||||
[provider]: {
|
||||
...prev[provider],
|
||||
savedCredentials: prev[provider].savedCredentials.filter(
|
||||
(cred) => cred.id !== id,
|
||||
),
|
||||
},
|
||||
};
|
||||
});
|
||||
return result;
|
||||
@@ -233,29 +208,12 @@ export default function CredentialsProvider({
|
||||
const credentialsByProvider = response.reduce(
|
||||
(acc, cred) => {
|
||||
if (!acc[cred.provider]) {
|
||||
acc[cred.provider] = {
|
||||
oauthCreds: [],
|
||||
apiKeys: [],
|
||||
userPasswordCreds: [],
|
||||
};
|
||||
}
|
||||
if (cred.type === "oauth2") {
|
||||
acc[cred.provider].oauthCreds.push(cred);
|
||||
} else if (cred.type === "api_key") {
|
||||
acc[cred.provider].apiKeys.push(cred);
|
||||
} else if (cred.type === "user_password") {
|
||||
acc[cred.provider].userPasswordCreds.push(cred);
|
||||
acc[cred.provider] = [];
|
||||
}
|
||||
acc[cred.provider].push(cred);
|
||||
return acc;
|
||||
},
|
||||
{} as Record<
|
||||
CredentialsProviderName,
|
||||
{
|
||||
oauthCreds: CredentialsMetaResponse[];
|
||||
apiKeys: CredentialsMetaResponse[];
|
||||
userPasswordCreds: CredentialsMetaResponse[];
|
||||
}
|
||||
>,
|
||||
{} as Record<CredentialsProviderName, CredentialsMetaResponse[]>,
|
||||
);
|
||||
|
||||
setProviders((prev) => ({
|
||||
@@ -265,40 +223,19 @@ export default function CredentialsProvider({
|
||||
provider,
|
||||
{
|
||||
provider,
|
||||
providerName:
|
||||
providerDisplayNames[provider as CredentialsProviderName],
|
||||
savedApiKeys: credentialsByProvider[provider]?.apiKeys ?? [],
|
||||
savedOAuthCredentials:
|
||||
credentialsByProvider[provider]?.oauthCreds ?? [],
|
||||
savedUserPasswordCredentials:
|
||||
credentialsByProvider[provider]?.userPasswordCreds ?? [],
|
||||
providerName: providerDisplayNames[provider],
|
||||
savedCredentials: credentialsByProvider[provider] ?? [],
|
||||
oAuthCallback: (code: string, state_token: string) =>
|
||||
oAuthCallback(
|
||||
provider as CredentialsProviderName,
|
||||
code,
|
||||
state_token,
|
||||
),
|
||||
oAuthCallback(provider, code, state_token),
|
||||
createAPIKeyCredentials: (
|
||||
credentials: APIKeyCredentialsCreatable,
|
||||
) =>
|
||||
createAPIKeyCredentials(
|
||||
provider as CredentialsProviderName,
|
||||
credentials,
|
||||
),
|
||||
) => createAPIKeyCredentials(provider, credentials),
|
||||
createUserPasswordCredentials: (
|
||||
credentials: UserPasswordCredentialsCreatable,
|
||||
) =>
|
||||
createUserPasswordCredentials(
|
||||
provider as CredentialsProviderName,
|
||||
credentials,
|
||||
),
|
||||
) => createUserPasswordCredentials(provider, credentials),
|
||||
deleteCredentials: (id: string, force: boolean = false) =>
|
||||
deleteCredentials(
|
||||
provider as CredentialsProviderName,
|
||||
id,
|
||||
force,
|
||||
),
|
||||
},
|
||||
deleteCredentials(provider, id, force),
|
||||
} satisfies CredentialsProviderData,
|
||||
]),
|
||||
),
|
||||
}));
|
||||
|
||||
@@ -6,18 +6,21 @@ import {
|
||||
} from "@/components/ui/popover";
|
||||
import { format } from "date-fns";
|
||||
import { CalendarIcon } from "lucide-react";
|
||||
import { Cross2Icon, Pencil2Icon, PlusIcon } from "@radix-ui/react-icons";
|
||||
import { beautifyString, cn } from "@/lib/utils";
|
||||
import { Node, useNodeId, useNodesData } from "@xyflow/react";
|
||||
import { ConnectionData, CustomNodeData } from "@/components/CustomNode";
|
||||
import { Cross2Icon, Pencil2Icon, PlusIcon } from "@radix-ui/react-icons";
|
||||
import {
|
||||
BlockIORootSchema,
|
||||
BlockIOSubSchema,
|
||||
BlockIOObjectSubSchema,
|
||||
BlockIOKVSubSchema,
|
||||
BlockIOArraySubSchema,
|
||||
BlockIOStringSubSchema,
|
||||
BlockIONumberSubSchema,
|
||||
BlockIOBooleanSubSchema,
|
||||
BlockIOCredentialsSubSchema,
|
||||
BlockIOKVSubSchema,
|
||||
BlockIONumberSubSchema,
|
||||
BlockIOObjectSubSchema,
|
||||
BlockIORootSchema,
|
||||
BlockIOSimpleTypeSubSchema,
|
||||
BlockIOStringSubSchema,
|
||||
BlockIOSubSchema,
|
||||
DataType,
|
||||
determineDataType,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
@@ -48,8 +51,7 @@ import {
|
||||
} from "./ui/multiselect";
|
||||
import { LocalValuedInput } from "./ui/input";
|
||||
import NodeHandle from "./NodeHandle";
|
||||
import { ConnectionData } from "./CustomNode";
|
||||
import { CredentialsInput } from "./integrations/credentials-input";
|
||||
import { CredentialsInput } from "@/components/integrations/credentials-input";
|
||||
|
||||
type NodeObjectInputTreeProps = {
|
||||
nodeId: string;
|
||||
@@ -357,6 +359,7 @@ export const NodeGenericInputField: FC<{
|
||||
return (
|
||||
<NodeCredentialsInput
|
||||
selfKey={propKey}
|
||||
schema={propSchema as BlockIOCredentialsSubSchema}
|
||||
value={currentValue}
|
||||
errors={errors}
|
||||
className={className}
|
||||
@@ -697,15 +700,19 @@ const NodeOneOfDiscriminatorField: FC<{
|
||||
|
||||
const NodeCredentialsInput: FC<{
|
||||
selfKey: string;
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
value: any;
|
||||
errors: { [key: string]: string | undefined };
|
||||
handleInputChange: NodeObjectInputTreeProps["handleInputChange"];
|
||||
className?: string;
|
||||
}> = ({ selfKey, value, errors, handleInputChange, className }) => {
|
||||
}> = ({ selfKey, schema, value, errors, handleInputChange, className }) => {
|
||||
const nodeInputValues = useNodesData<Node<CustomNodeData>>(useNodeId()!)?.data
|
||||
.hardcodedValues;
|
||||
return (
|
||||
<div className={cn("flex flex-col", className)}>
|
||||
<CredentialsInput
|
||||
selfKey={selfKey}
|
||||
schema={schema}
|
||||
siblingInputs={nodeInputValues}
|
||||
onSelectCredentials={(credsMeta) =>
|
||||
handleInputChange(selfKey, credsMeta)
|
||||
}
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
import { useContext } from "react";
|
||||
import { CustomNodeData } from "@/components/CustomNode";
|
||||
import {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsProviderName,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
import { Node, useNodeId, useNodesData } from "@xyflow/react";
|
||||
import { getValue } from "@/lib/utils";
|
||||
|
||||
import {
|
||||
CredentialsProviderData,
|
||||
CredentialsProvidersContext,
|
||||
} from "@/components/integrations/credentials-provider";
|
||||
import { getValue } from "@/lib/utils";
|
||||
import {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsProviderName,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
|
||||
export type CredentialsData =
|
||||
| {
|
||||
@@ -29,59 +28,53 @@ export type CredentialsData =
|
||||
});
|
||||
|
||||
export default function useCredentials(
|
||||
inputFieldName: string,
|
||||
credsInputSchema: BlockIOCredentialsSubSchema,
|
||||
nodeInputValues?: Record<string, any>,
|
||||
): CredentialsData | null {
|
||||
const nodeId = useNodeId();
|
||||
const allProviders = useContext(CredentialsProvidersContext);
|
||||
|
||||
if (!nodeId) {
|
||||
throw new Error("useCredentials must be within a CustomNode");
|
||||
}
|
||||
|
||||
const data = useNodesData<Node<CustomNodeData>>(nodeId)!.data;
|
||||
const credentialsSchema = data.inputSchema.properties[
|
||||
inputFieldName
|
||||
] as BlockIOCredentialsSubSchema;
|
||||
|
||||
const discriminatorValue: CredentialsProviderName | null =
|
||||
(credentialsSchema.discriminator &&
|
||||
credentialsSchema.discriminator_mapping![
|
||||
getValue(credentialsSchema.discriminator, data.hardcodedValues)
|
||||
(credsInputSchema.discriminator &&
|
||||
credsInputSchema.discriminator_mapping![
|
||||
getValue(credsInputSchema.discriminator, nodeInputValues)
|
||||
]) ||
|
||||
null;
|
||||
|
||||
let providerName: CredentialsProviderName;
|
||||
if (credentialsSchema.credentials_provider.length > 1) {
|
||||
if (!credentialsSchema.discriminator) {
|
||||
if (credsInputSchema.credentials_provider.length > 1) {
|
||||
if (!credsInputSchema.discriminator) {
|
||||
throw new Error(
|
||||
"Multi-provider credential input requires discriminator!",
|
||||
);
|
||||
}
|
||||
if (!discriminatorValue) {
|
||||
console.log(
|
||||
`Missing discriminator value from '${credsInputSchema.discriminator}': ` +
|
||||
"hiding credentials input until it is set.",
|
||||
);
|
||||
return null;
|
||||
}
|
||||
providerName = discriminatorValue;
|
||||
} else {
|
||||
providerName = credentialsSchema.credentials_provider[0];
|
||||
providerName = credsInputSchema.credentials_provider[0];
|
||||
}
|
||||
const provider = allProviders ? allProviders[providerName] : null;
|
||||
|
||||
// If block input schema doesn't have credentials, return null
|
||||
if (!credentialsSchema) {
|
||||
if (!credsInputSchema) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const supportsApiKey =
|
||||
credentialsSchema.credentials_types.includes("api_key");
|
||||
const supportsOAuth2 = credentialsSchema.credentials_types.includes("oauth2");
|
||||
const supportsApiKey = credsInputSchema.credentials_types.includes("api_key");
|
||||
const supportsOAuth2 = credsInputSchema.credentials_types.includes("oauth2");
|
||||
const supportsUserPassword =
|
||||
credentialsSchema.credentials_types.includes("user_password");
|
||||
credsInputSchema.credentials_types.includes("user_password");
|
||||
|
||||
// No provider means maybe it's still loading
|
||||
if (!provider) {
|
||||
// return {
|
||||
// provider: credentialsSchema.credentials_provider,
|
||||
// schema: credentialsSchema,
|
||||
// provider: credsInputSchema.credentials_provider,
|
||||
// schema: credsInputSchema,
|
||||
// supportsApiKey,
|
||||
// supportsOAuth2,
|
||||
// isLoading: true,
|
||||
@@ -90,24 +83,23 @@ export default function useCredentials(
|
||||
}
|
||||
|
||||
// Filter by OAuth credentials that have sufficient scopes for this block
|
||||
const requiredScopes = credentialsSchema.credentials_scopes;
|
||||
const savedOAuthCredentials = requiredScopes
|
||||
? provider.savedOAuthCredentials.filter((c) =>
|
||||
new Set(c.scopes).isSupersetOf(new Set(requiredScopes)),
|
||||
const requiredScopes = credsInputSchema.credentials_scopes;
|
||||
const savedCredentials = requiredScopes
|
||||
? provider.savedCredentials.filter(
|
||||
(c) =>
|
||||
c.type != "oauth2" ||
|
||||
new Set(c.scopes).isSupersetOf(new Set(requiredScopes)),
|
||||
)
|
||||
: provider.savedOAuthCredentials;
|
||||
|
||||
const savedUserPasswordCredentials = provider.savedUserPasswordCredentials;
|
||||
: provider.savedCredentials;
|
||||
|
||||
return {
|
||||
...provider,
|
||||
provider: providerName,
|
||||
schema: credentialsSchema,
|
||||
schema: credsInputSchema,
|
||||
supportsApiKey,
|
||||
supportsOAuth2,
|
||||
supportsUserPassword,
|
||||
savedOAuthCredentials,
|
||||
savedUserPasswordCredentials,
|
||||
savedCredentials,
|
||||
isLoading: false,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -53,6 +53,7 @@ import {
|
||||
UserOnboarding,
|
||||
ReviewSubmissionRequest,
|
||||
SubmissionStatus,
|
||||
CredentialsMetaInput,
|
||||
} from "./types";
|
||||
import { createBrowserClient } from "@supabase/ssr";
|
||||
import getServerSupabase from "../supabase/getServerSupabase";
|
||||
@@ -251,9 +252,13 @@ export default class BackendAPI {
|
||||
executeGraph(
|
||||
id: GraphID,
|
||||
version: number,
|
||||
inputData: { [key: string]: any } = {},
|
||||
inputs: { [key: string]: any } = {},
|
||||
credentials_inputs: { [key: string]: CredentialsMetaInput } = {},
|
||||
): Promise<{ graph_exec_id: GraphExecutionID }> {
|
||||
return this._request("POST", `/graphs/${id}/execute/${version}`, inputData);
|
||||
return this._request("POST", `/graphs/${id}/execute/${version}`, {
|
||||
inputs,
|
||||
credentials_inputs,
|
||||
});
|
||||
}
|
||||
|
||||
getExecutions(): Promise<GraphExecutionMeta[]> {
|
||||
|
||||
@@ -282,6 +282,11 @@ export type GraphMeta = {
|
||||
description: string;
|
||||
input_schema: GraphIOSchema;
|
||||
output_schema: GraphIOSchema;
|
||||
credentials_input_schema: {
|
||||
type: "object";
|
||||
properties: { [key: string]: BlockIOCredentialsSubSchema };
|
||||
required: (keyof GraphMeta["credentials_input_schema"]["properties"])[];
|
||||
};
|
||||
};
|
||||
|
||||
export type GraphID = Brand<string, "GraphID">;
|
||||
@@ -317,6 +322,7 @@ export type GraphUpdateable = Omit<
|
||||
| "links"
|
||||
| "input_schema"
|
||||
| "output_schema"
|
||||
| "credentials_input_schema"
|
||||
| "has_webhook_trigger"
|
||||
> & {
|
||||
version?: number;
|
||||
|
||||
Reference in New Issue
Block a user