feat(platform/library): Add credentials UX on /library/agents/[id] (#9789)

- Resolves #9771
- ... in a non-persistent way, so it won't work for webhook-triggered
agents
    For webhooks: #9541

### Changes 🏗️

Frontend:
- Add credentials inputs in Library "New run" screen (based on
`graph.credentials_input_schema`)
- Refactor `CredentialsInput` and `useCredentials` to not rely on XYFlow
context

- Unsplit lists of saved credentials in `CredentialsProvider` state

- Move logic that was being executed at component render to `useEffect`
hooks in `CredentialsInput`

Backend:
- Implement logic to aggregate credentials input requirements to one per
provider per graph
- Add `BaseGraph.credentials_input_schema` (JSON schema) computed field
    Underlying added logic:
- `BaseGraph._credentials_input_schema` - makes a `BlockSchema` from a
graph's aggregated credentials inputs
- `BaseGraph.aggregate_credentials_inputs()` - aggregates a graph's
nodes' credentials inputs using `CredentialsFieldInfo.combine(..)`
- `BlockSchema.get_credentials_fields_info() -> dict[str,
CredentialsFieldInfo]`
- `CredentialsFieldInfo` model (created from
`_CredentialsFieldSchemaExtra`)

- Implement logic to inject explicitly passed credentials into graph
execution
  - Add `credentials_inputs` parameter to `execute_graph` endpoint
- Add `graph_credentials_input` parameter to
`.executor.utils.add_graph_execution(..)`
  - Implement `.executor.utils.make_node_credentials_input_map(..)`
  - Amend `.executor.utils.construct_node_execution_input`
  - Add `GraphExecutionEntry.node_credentials_input_map` attribute
  - Amend validation to allow injecting credentials
    - Amend `GraphModel._validate_graph(..)`
    - Amend `.executor.utils._validate_node_input_credentials`
- Add `node_credentials_map` parameter to
`ExecutionManager.add_execution(..)`
    - Amend execution validation to handle side-loaded credentials
    - Add `GraphExecutionEntry.node_execution_map` attribute
- Add mechanism to inject passed credentials into node execution data
- Add credentials injection mechanism to node execution queueing logic
in `Executor._on_graph_execution(..)`

- Replace boilerplate logic in `v1.execute_graph` endpoint with call to
existing `.executor.utils.add_graph_execution(..)`
- Replace calls to `.server.routers.v1.execute_graph` with
`add_graph_execution`

Also:
- Address tech debt in `GraphModel._validate_gaph(..)`
- Fix type checking in `BaseGraph._generate_schema(..)`

#### TODO
- [ ] ~~Make "Run again" work with credentials in
`AgentRunDetailsView`~~
- [ ] Prohibit saving a graph if it has nodes with missing discriminator
value for discriminated credentials inputs

### Checklist 📋

#### For code changes:
- [ ] I have clearly listed my changes in the PR description
- [ ] I have made a test plan
- [ ] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [ ] ...
This commit is contained in:
Reinier van der Leer
2025-04-18 16:27:13 +02:00
committed by GitHub
parent f16a398a8e
commit 417d7732af
23 changed files with 777 additions and 434 deletions

View File

@@ -67,15 +67,15 @@ class AgentExecutorBlock(Block):
graph_id=input_data.graph_id,
graph_version=input_data.graph_version,
user_id=input_data.user_id,
data=input_data.data,
inputs=input_data.data,
)
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.graph_exec_id}"
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.id}"
logger.info(f"Starting execution of {log_id}")
for event in event_bus.listen(
user_id=graph_exec.user_id,
graph_id=graph_exec.graph_id,
graph_exec_id=graph_exec.graph_exec_id,
graph_exec_id=graph_exec.id,
):
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
if event.status in [

View File

@@ -28,6 +28,7 @@ from backend.util.settings import Config
from .model import (
ContributorDetails,
Credentials,
CredentialsFieldInfo,
CredentialsMetaInput,
is_credentials_field_name,
)
@@ -203,6 +204,15 @@ class BlockSchema(BaseModel):
)
}
@classmethod
def get_credentials_fields_info(cls) -> dict[str, CredentialsFieldInfo]:
return {
field_name: CredentialsFieldInfo.model_validate(
cls.get_field_schema(field_name), by_alias=True
)
for field_name in cls.get_credentials_fields().keys()
}
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
return data # Return as is, by default.
@@ -509,6 +519,7 @@ async def initialize_blocks() -> None:
)
def get_block(block_id: str) -> Block | None:
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
def get_block(block_id: str) -> Block[BlockSchema, BlockSchema] | None:
cls = get_blocks().get(block_id)
return cls() if cls else None

View File

@@ -44,7 +44,7 @@ from .includes import (
GRAPH_EXECUTION_INCLUDE,
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
)
from .model import GraphExecutionStats, NodeExecutionStats
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
from .queue import AsyncRedisEventBus, RedisEventBus
T = TypeVar("T")
@@ -220,6 +220,7 @@ class GraphExecutionWithNodes(GraphExecution):
)
for node_exec in self.node_executions
],
node_credentials_input_map={}, # FIXME
)
@@ -361,7 +362,7 @@ async def get_graph_execution(
async def create_graph_execution(
graph_id: str,
graph_version: int,
nodes_input: list[tuple[str, BlockInput]],
starting_nodes_input: list[tuple[str, BlockInput]],
user_id: str,
preset_id: str | None = None,
) -> GraphExecutionWithNodes:
@@ -388,7 +389,7 @@ async def create_graph_execution(
]
},
)
for node_id, node_input in nodes_input
for node_id, node_input in starting_nodes_input
]
},
userId=user_id,
@@ -712,6 +713,7 @@ class GraphExecutionEntry(BaseModel):
graph_id: str
graph_version: int
start_node_execs: list["NodeExecutionEntry"]
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]]
class NodeExecutionEntry(BaseModel):

View File

@@ -1,7 +1,7 @@
import logging
import uuid
from collections import defaultdict
from typing import Any, Literal, Optional, Type, cast
from typing import Any, Literal, Optional, cast
import prisma
from prisma import Json
@@ -13,12 +13,19 @@ from prisma.types import (
AgentNodeCreateInput,
AgentNodeLinkCreateInput,
)
from pydantic import create_model
from pydantic.fields import computed_field
from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.blocks.llm import LlmModel
from backend.data.db import prisma as db
from backend.data.model import (
CredentialsField,
CredentialsFieldInfo,
CredentialsMetaInput,
is_credentials_field_name,
)
from backend.util import type as type_utils
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
@@ -190,14 +197,19 @@ class BaseGraph(BaseDbModel):
)
)
@computed_field
@property
def credentials_input_schema(self) -> dict[str, Any]:
return self._credentials_input_schema.jsonschema()
@staticmethod
def _generate_schema(
*props: tuple[Type[AgentInputBlock.Input] | Type[AgentOutputBlock.Input], dict],
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
) -> dict[str, Any]:
schema = []
schema_fields: list[AgentInputBlock.Input | AgentOutputBlock.Input] = []
for type_class, input_default in props:
try:
schema.append(type_class(**input_default))
schema_fields.append(type_class(**input_default))
except Exception as e:
logger.warning(f"Invalid {type_class}: {input_default}, {e}")
@@ -217,9 +229,93 @@ class BaseGraph(BaseDbModel):
**({"description": p.description} if p.description else {}),
**({"default": p.value} if p.value is not None else {}),
}
for p in schema
for p in schema_fields
},
"required": [p.name for p in schema if p.value is None],
"required": [p.name for p in schema_fields if p.value is None],
}
@property
def _credentials_input_schema(self) -> type[BlockSchema]:
graph_credentials_inputs = self.aggregate_credentials_inputs()
logger.debug(
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
f"{graph_credentials_inputs}"
)
# Warn if same-provider credentials inputs can't be combined (= bad UX)
graph_cred_fields = list(graph_credentials_inputs.values())
for i, (field, keys) in enumerate(graph_cred_fields):
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
if field.provider != other_field.provider:
continue
# If this happens, that means a block implementation probably needs
# to be updated.
logger.warning(
"Multiple combined credentials fields "
f"for provider {field.provider} "
f"on graph #{self.id} ({self.name}); "
f"fields: {field} <> {other_field};"
f"keys: {keys} <> {other_keys}."
)
fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = {
agg_field_key: (
CredentialsMetaInput[
Literal[tuple(field_info.provider)], # type: ignore
Literal[tuple(field_info.supported_types)], # type: ignore
],
CredentialsField(
required_scopes=set(field_info.required_scopes or []),
discriminator=field_info.discriminator,
discriminator_mapping=field_info.discriminator_mapping,
),
)
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
}
return create_model(
self.name.replace(" ", "") + "CredentialsInputSchema",
__base__=BlockSchema,
**fields, # type: ignore
)
def aggregate_credentials_inputs(
self,
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]:
"""
Returns:
dict[aggregated_field_key, tuple(
CredentialsFieldInfo: A spec for one aggregated credentials field
set[(node_id, field_name)]: Node credentials fields that are
compatible with this aggregated field spec
)]
"""
return {
"_".join(sorted(agg_field_info.provider))
+ "_"
+ "_".join(sorted(agg_field_info.supported_types))
+ "_credentials": (agg_field_info, node_fields)
for agg_field_info, node_fields in CredentialsFieldInfo.combine(
*(
(
# Apply discrimination before aggregating credentials inputs
(
field_info.discriminate(
node.input_default[field_info.discriminator]
)
if (
field_info.discriminator
and node.input_default.get(field_info.discriminator)
)
else field_info
),
(node.id, field_name),
)
for node in self.nodes
for field_name, field_info in node.block.input_schema.get_credentials_fields_info().items()
)
)
}
@@ -320,8 +416,6 @@ class GraphModel(Graph):
return sanitized_name
# Validate smart decision maker nodes
smart_decision_maker_nodes = set()
agent_nodes = set()
nodes_block = {
node.id: block
for node in graph.nodes
@@ -332,13 +426,6 @@ class GraphModel(Graph):
if (block := nodes_block.get(node.id)) is None:
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
# Smart decision maker nodes
if block.block_type == BlockType.AI:
smart_decision_maker_nodes.add(node.id)
# Agent nodes
elif block.block_type == BlockType.AGENT:
agent_nodes.add(node.id)
input_links = defaultdict(list)
for link in graph.links:
@@ -353,16 +440,21 @@ class GraphModel(Graph):
[sanitize(name) for name in node.input_default]
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
)
for name in block.input_schema.get_required_fields():
input_schema = block.input_schema
for name in (required_fields := input_schema.get_required_fields()):
if (
name not in provided_inputs
# Webhook payload is passed in by ExecutionManager
and not (
name == "payload"
and block.block_type
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
)
# Checking availability of credentials is done by ExecutionManager
and name not in input_schema.get_credentials_fields()
# Validate only I/O nodes, or validate everything when executing
and (
for_run # Skip input completion validation, unless when executing.
for_run
or block.block_type
in [
BlockType.INPUT,
@@ -375,9 +467,18 @@ class GraphModel(Graph):
f"Node {block.name} #{node.id} required input missing: `{name}`"
)
if (
block.block_type == BlockType.INPUT
and (input_key := node.input_default.get("name"))
and is_credentials_field_name(input_key)
):
raise ValueError(
f"Agent input node uses reserved name '{input_key}'; "
"'credentials' and `*_credentials` are reserved input names"
)
# Get input schema properties and check dependencies
input_schema = block.input_schema.model_fields
required_fields = block.input_schema.get_required_fields()
input_fields = input_schema.model_fields
def has_value(name):
return (
@@ -385,14 +486,21 @@ class GraphModel(Graph):
and name in node.input_default
and node.input_default[name] is not None
and str(node.input_default[name]).strip() != ""
) or (name in input_schema and input_schema[name].default is not None)
) or (name in input_fields and input_fields[name].default is not None)
# Validate dependencies between fields
for field_name, field_info in input_schema.items():
for field_name, field_info in input_fields.items():
# Apply input dependency validation only on run & field with depends_on
json_schema_extra = field_info.json_schema_extra or {}
dependencies = json_schema_extra.get("depends_on", [])
if not for_run or not dependencies:
if not (
for_run
and isinstance(json_schema_extra, dict)
and (
dependencies := cast(
list[str], json_schema_extra.get("depends_on", [])
)
)
):
continue
# Check if dependent field has value in input_default

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import base64
import logging
from collections import defaultdict
from datetime import datetime, timezone
from typing import (
TYPE_CHECKING,
@@ -12,6 +13,7 @@ from typing import (
Generic,
Literal,
Optional,
Sequence,
TypedDict,
TypeVar,
get_args,
@@ -300,9 +302,7 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
)
field_schema = model.jsonschema()["properties"][field_name]
try:
schema_extra = _CredentialsFieldSchemaExtra[CP, CT].model_validate(
field_schema
)
schema_extra = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
except ValidationError as e:
if "Field required [type=missing" not in str(e):
raise
@@ -328,14 +328,90 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
)
class _CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
credentials_provider: list[CP]
credentials_scopes: Optional[list[str]] = None
credentials_types: list[CT]
provider: frozenset[CP] = Field(..., alias="credentials_provider")
supported_types: frozenset[CT] = Field(..., alias="credentials_types")
required_scopes: Optional[frozenset[str]] = Field(None, alias="credentials_scopes")
discriminator: Optional[str] = None
discriminator_mapping: Optional[dict[str, CP]] = None
@classmethod
def combine(
cls, *fields: tuple[CredentialsFieldInfo[CP, CT], T]
) -> Sequence[tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
"""
Combines multiple CredentialsFieldInfo objects into as few as possible.
Rules:
- Items can only be combined if they have the same supported credentials types
and the same supported providers.
- When combining items, the `required_scopes` of the result is a join
of the `required_scopes` of the original items.
Params:
*fields: (CredentialsFieldInfo, key) objects to group and combine
Returns:
A sequence of tuples containing combined CredentialsFieldInfo objects and
the set of keys of the respective original items that were grouped together.
"""
if not fields:
return []
# Group fields by their provider and supported_types
grouped_fields: defaultdict[
tuple[frozenset[CP], frozenset[CT]],
list[tuple[T, CredentialsFieldInfo[CP, CT]]],
] = defaultdict(list)
for field, key in fields:
group_key = (frozenset(field.provider), frozenset(field.supported_types))
grouped_fields[group_key].append((key, field))
# Combine fields within each group
result: list[tuple[CredentialsFieldInfo[CP, CT], set[T]]] = []
for group in grouped_fields.values():
# Start with the first field in the group
_, combined = group[0]
# Track the keys that were combined
combined_keys = {key for key, _ in group}
# Combine required_scopes from all fields in the group
all_scopes = set()
for _, field in group:
if field.required_scopes:
all_scopes.update(field.required_scopes)
# Create a new combined field
result.append(
(
CredentialsFieldInfo[CP, CT](
credentials_provider=combined.provider,
credentials_types=combined.supported_types,
credentials_scopes=frozenset(all_scopes) or None,
discriminator=combined.discriminator,
discriminator_mapping=combined.discriminator_mapping,
),
combined_keys,
)
)
return result
def discriminate(self, discriminator_value: Any) -> CredentialsFieldInfo:
if not (self.discriminator and self.discriminator_mapping):
return self
discriminator_value = self.discriminator_mapping[discriminator_value]
return CredentialsFieldInfo(
credentials_provider=frozenset([discriminator_value]),
credentials_types=self.supported_types,
credentials_scopes=self.required_scopes,
)
def CredentialsField(
required_scopes: set[str] = set(),

View File

@@ -724,10 +724,10 @@ class Executor:
execution_status = ExecutionStatus.TERMINATED
return execution_stats, execution_status, error
exec_data = queue.get()
queued_node_exec = queue.get()
# Avoid parallel execution of the same node.
execution = running_executions.get(exec_data.node_id)
execution = running_executions.get(queued_node_exec.node_id)
if execution and not execution.ready():
# TODO (performance improvement):
# Wait for the completion of the same node execution is blocking.
@@ -736,18 +736,18 @@ class Executor:
execution.wait()
log_metadata.debug(
f"Dispatching node execution {exec_data.node_exec_id} "
f"for node {exec_data.node_id}",
f"Dispatching node execution {queued_node_exec.node_exec_id} "
f"for node {queued_node_exec.node_id}",
)
try:
exec_cost_counter = cls._charge_usage(
node_exec=exec_data,
node_exec=queued_node_exec,
execution_count=exec_cost_counter + 1,
execution_stats=execution_stats,
)
except InsufficientBalanceError as error:
node_exec_id = exec_data.node_exec_id
node_exec_id = queued_node_exec.node_exec_id
cls.db_client.upsert_execution_output(
node_exec_id=node_exec_id,
output_name="error",
@@ -768,10 +768,23 @@ class Executor:
)
raise
running_executions[exec_data.node_id] = cls.executor.apply_async(
# Add credentials input overrides
node_id = queued_node_exec.node_id
if (node_creds_map := graph_exec.node_credentials_input_map) and (
node_field_creds_map := node_creds_map.get(node_id)
):
queued_node_exec.data.update(
{
field_name: creds_meta.model_dump()
for field_name, creds_meta in node_field_creds_map.items()
}
)
# Initiate node execution
running_executions[queued_node_exec.node_id] = cls.executor.apply_async(
cls.on_node_execution,
(queue, exec_data),
callback=make_exec_callback(exec_data),
(queue, queued_node_exec),
callback=make_exec_callback(queued_node_exec),
)
# Avoid terminating graph execution when some nodes are still running.

View File

@@ -70,7 +70,7 @@ def execute_graph(**kwargs):
log(f"Executing recurring job for graph #{args.graph_id}")
execution_utils.add_graph_execution(
graph_id=args.graph_id,
data=args.input_data,
inputs=args.input_data,
user_id=args.user_id,
graph_version=args.graph_version,
)

View File

@@ -1,5 +1,5 @@
import logging
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Optional, cast
from autogpt_libs.utils.cache import thread_cached
from pydantic import BaseModel
@@ -14,15 +14,23 @@ from backend.data.block import (
)
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.cost import BlockCostType
from backend.data.execution import GraphExecutionEntry, RedisExecutionEventBus
from backend.data.graph import GraphModel, Node
from backend.data.execution import (
AsyncRedisExecutionEventBus,
GraphExecutionWithNodes,
RedisExecutionEventBus,
create_graph_execution,
)
from backend.data.graph import GraphModel, Node, get_graph
from backend.data.model import CredentialsMetaInput
from backend.data.rabbitmq import (
AsyncRabbitMQ,
Exchange,
ExchangeType,
Queue,
RabbitMQConfig,
SyncRabbitMQ,
)
from backend.util.exceptions import NotFoundError
from backend.util.mock import MockObject
from backend.util.service import get_service_client
from backend.util.settings import Config
@@ -43,6 +51,11 @@ def get_execution_event_bus() -> RedisExecutionEventBus:
return RedisExecutionEventBus()
@thread_cached
def get_async_execution_event_bus() -> AsyncRedisExecutionEventBus:
return AsyncRedisExecutionEventBus()
@thread_cached
def get_execution_queue() -> SyncRabbitMQ:
client = SyncRabbitMQ(create_execution_queue_config())
@@ -50,6 +63,13 @@ def get_execution_queue() -> SyncRabbitMQ:
return client
@thread_cached
async def get_async_execution_queue() -> AsyncRabbitMQ:
client = AsyncRabbitMQ(create_execution_queue_config())
await client.connect()
return client
@thread_cached
def get_integration_credentials_store() -> "IntegrationCredentialsStore":
from backend.integrations.credentials_store import IntegrationCredentialsStore
@@ -347,7 +367,13 @@ def merge_execution_input(data: BlockInput) -> BlockInput:
return data
def _validate_node_input_credentials(graph: GraphModel, user_id: str):
def _validate_node_input_credentials(
graph: GraphModel,
user_id: str,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
):
"""Checks all credentials for all nodes of the graph"""
for node in graph.nodes:
@@ -361,9 +387,22 @@ def _validate_node_input_credentials(graph: GraphModel, user_id: str):
continue
for field_name, credentials_meta_type in credentials_fields.items():
credentials_meta = credentials_meta_type.model_validate(
node.input_default[field_name]
)
if (
node_credentials_input_map
and (node_credentials_inputs := node_credentials_input_map.get(node.id))
and field_name in node_credentials_inputs
):
credentials_meta = node_credentials_input_map[node.id][field_name]
elif field_name in node.input_default:
credentials_meta = credentials_meta_type.model_validate(
node.input_default[field_name]
)
else:
raise ValueError(
f"Credentials absent for {block.name} node #{node.id} "
f"input '{field_name}'"
)
# Fetch the corresponding Credentials and perform sanity checks
credentials = get_integration_credentials_store().get_creds_by_id(
user_id, credentials_meta.id
@@ -389,10 +428,46 @@ def _validate_node_input_credentials(graph: GraphModel, user_id: str):
)
def make_node_credentials_input_map(
graph: GraphModel,
graph_credentials_input: dict[str, CredentialsMetaInput],
) -> dict[str, dict[str, CredentialsMetaInput]]:
"""
Maps credentials for an execution to the correct nodes.
Params:
graph: The graph to be executed.
graph_credentials_input: A (graph_input_name, credentials_meta) map.
Returns:
dict[node_id, dict[field_name, CredentialsMetaInput]]: Node credentials input map.
"""
result: dict[str, dict[str, CredentialsMetaInput]] = {}
# Get aggregated credentials fields for the graph
graph_cred_inputs = graph.aggregate_credentials_inputs()
for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items():
# Best-effort map: skip missing items
if graph_input_name not in graph_credentials_input:
continue
# Use passed-in credentials for all compatible node input fields
for node_id, node_field_name in compatible_node_fields:
if node_id not in result:
result[node_id] = {}
result[node_id][node_field_name] = graph_credentials_input[graph_input_name]
return result
def construct_node_execution_input(
graph: GraphModel,
user_id: str,
data: BlockInput,
graph_inputs: BlockInput,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
) -> list[tuple[str, BlockInput]]:
"""
Validates and prepares the input data for executing a graph.
@@ -404,13 +479,14 @@ def construct_node_execution_input(
graph (GraphModel): The graph model to execute.
user_id (str): The ID of the user executing the graph.
data (BlockInput): The input data for the graph execution.
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
Returns:
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
the corresponding input data for that node.
"""
graph.validate_graph(for_run=True)
_validate_node_input_credentials(graph, user_id)
_validate_node_input_credentials(graph, user_id, node_credentials_input_map)
nodes_input = []
for node in graph.starting_nodes:
@@ -424,8 +500,8 @@ def construct_node_execution_input(
# Extract request input data, and assign it to the input pin.
if block.block_type == BlockType.INPUT:
input_name = node.input_default.get("name")
if input_name and input_name in data:
input_data = {"value": data[input_name]}
if input_name and input_name in graph_inputs:
input_data = {"value": graph_inputs[input_name]}
# Extract webhook payload, and assign it to the input pin
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
@@ -433,11 +509,17 @@ def construct_node_execution_input(
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
and node.webhook_id
):
if webhook_payload_key not in data:
if webhook_payload_key not in graph_inputs:
raise ValueError(
f"Node {block.name} #{node.id} webhook payload is missing"
)
input_data = {"payload": data[webhook_payload_key]}
input_data = {"payload": graph_inputs[webhook_payload_key]}
# Apply node credentials overrides
if node_credentials_input_map and (
node_credentials := node_credentials_input_map.get(node.id)
):
input_data.update({k: v.model_dump() for k, v in node_credentials.items()})
input_data, error = validate_exec(node, input_data)
if input_data is None:
@@ -505,47 +587,128 @@ def create_execution_queue_config() -> RabbitMQConfig:
)
def add_graph_execution(
async def add_graph_execution_async(
graph_id: str,
data: BlockInput,
user_id: str,
graph_version: int | None = None,
preset_id: str | None = None,
) -> GraphExecutionEntry:
inputs: BlockInput,
preset_id: Optional[str] = None,
graph_version: Optional[int] = None,
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
) -> GraphExecutionWithNodes:
"""
Adds a graph execution to the queue and returns the execution entry.
Args:
graph_id (str): The ID of the graph to execute.
data (BlockInput): The input data for the graph execution.
user_id (str): The ID of the user executing the graph.
graph_version (int | None): The version of the graph to execute. Defaults to None.
preset_id (str | None): The ID of the preset to use. Defaults to None.
graph_id: The ID of the graph to execute.
user_id: The ID of the user executing the graph.
inputs: The input data for the graph execution.
preset_id: The ID of the preset to use.
graph_version: The version of the graph to execute.
graph_credentials_inputs: Credentials inputs to use in the execution.
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
Returns:
GraphExecutionEntry: The entry for the graph execution.
Raises:
ValueError: If the graph is not found or if there are validation errors.
""" # noqa
execution_event_bus = get_async_execution_event_bus()
graph_execution_queue = await get_async_execution_queue()
graph: GraphModel | None = await get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
)
if not graph:
raise NotFoundError(f"Graph #{graph_id} not found.")
node_credentials_input_map = (
make_node_credentials_input_map(graph, graph_credentials_inputs)
if graph_credentials_inputs
else None
)
graph_exec = await create_graph_execution(
user_id=user_id,
graph_id=graph_id,
graph_version=graph.version,
starting_nodes_input=construct_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=inputs,
node_credentials_input_map=node_credentials_input_map,
),
preset_id=preset_id,
)
await execution_event_bus.publish(graph_exec)
graph_exec_entry = graph_exec.to_graph_execution_entry()
if node_credentials_input_map:
graph_exec_entry.node_credentials_input_map = node_credentials_input_map
await graph_execution_queue.publish_message(
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
message=graph_exec_entry.model_dump_json(),
exchange=GRAPH_EXECUTION_EXCHANGE,
)
return graph_exec
def add_graph_execution(
graph_id: str,
user_id: str,
inputs: BlockInput,
preset_id: Optional[str] = None,
graph_version: Optional[int] = None,
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
) -> GraphExecutionWithNodes:
"""
Adds a graph execution to the queue and returns the execution entry.
Args:
graph_id: The ID of the graph to execute.
user_id: The ID of the user executing the graph.
inputs: The input data for the graph execution.
preset_id: The ID of the preset to use.
graph_version: The version of the graph to execute.
graph_credentials_inputs: Credentials inputs to use in the execution.
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
Returns:
GraphExecutionEntry: The entry for the graph execution.
Raises:
ValueError: If the graph is not found or if there are validation errors.
""" # noqa
graph: GraphModel | None = get_db_client().get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
)
if not graph:
raise ValueError(f"Graph #{graph_id} not found.")
raise NotFoundError(f"Graph #{graph_id} not found.")
node_credentials_input_map = (
make_node_credentials_input_map(graph, graph_credentials_inputs)
if graph_credentials_inputs
else None
)
graph_exec = get_db_client().create_graph_execution(
user_id=user_id,
graph_id=graph_id,
graph_version=graph.version,
nodes_input=construct_node_execution_input(graph, user_id, data),
user_id=user_id,
starting_nodes_input=construct_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=inputs,
node_credentials_input_map=node_credentials_input_map,
),
preset_id=preset_id,
)
get_execution_event_bus().publish(graph_exec)
graph_exec_entry = graph_exec.to_graph_execution_entry()
if node_credentials_input_map:
graph_exec_entry.node_credentials_input_map = node_credentials_input_map
get_execution_queue().publish_message(
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
message=graph_exec_entry.model_dump_json(),
exchange=GRAPH_EXECUTION_EXCHANGE,
)
return graph_exec_entry
return graph_exec

View File

@@ -12,8 +12,8 @@ from backend.data import graph as graph_db
from backend.data.api_key import APIKey
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.execution import NodeExecutionResult
from backend.executor.utils import add_graph_execution_async
from backend.server.external.middleware import require_permission
from backend.server.routers import v1 as internal_api_routes
from backend.util.settings import Settings
settings = Settings()
@@ -97,13 +97,13 @@ async def execute_graph(
api_key: APIKey = Depends(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
) -> dict[str, Any]:
try:
graph_exec = await internal_api_routes.execute_graph(
graph_exec = await add_graph_execution_async(
graph_id=graph_id,
node_input=node_input,
user_id=api_key.user_id,
inputs=node_input,
graph_version=graph_version,
)
return {"id": graph_exec.graph_exec_id}
return {"id": graph_exec.id}
except Exception as e:
msg = str(e).encode().decode("unicode_escape")
raise HTTPException(status_code=400, detail=msg)

View File

@@ -1,6 +1,6 @@
import asyncio
import logging
from typing import TYPE_CHECKING, Annotated, Literal
from typing import TYPE_CHECKING, Annotated, Awaitable, Literal
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
from pydantic import BaseModel, Field
@@ -15,11 +15,11 @@ from backend.data.integrations import (
wait_for_webhook_event,
)
from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
from backend.executor.utils import add_graph_execution_async
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
from backend.server.routers import v1 as internal_api_routes
from backend.util.exceptions import NeedConfirmation, NotFoundError
from backend.util.settings import Settings
@@ -309,7 +309,7 @@ async def webhook_ingress_generic(
if not webhook.attached_nodes:
return
executions = []
executions: list[Awaitable] = []
for node in webhook.attached_nodes:
logger.debug(f"Webhook-attached node: {node}")
if not node.is_triggered_by_event_type(event_type):
@@ -317,11 +317,11 @@ async def webhook_ingress_generic(
continue
logger.debug(f"Executing graph #{node.graph_id} node #{node.id}")
executions.append(
internal_api_routes.execute_graph(
add_graph_execution_async(
user_id=webhook.user_id,
graph_id=node.graph_id,
graph_version=node.graph_version,
node_input={f"webhook_{webhook_id}_payload": payload},
user_id=webhook.user_id,
inputs={f"webhook_{webhook_id}_payload": payload},
)
)
asyncio.gather(*executions)

View File

@@ -159,7 +159,8 @@ class AgentServer(backend.util.service.AppProcess):
user_id=user_id,
graph_id=graph_id,
graph_version=graph_version,
node_input=node_input or {},
inputs=node_input or {},
credentials_inputs={},
)
@staticmethod

View File

@@ -41,6 +41,7 @@ from backend.data.credit import (
set_auto_top_up,
)
from backend.data.execution import AsyncRedisExecutionEventBus
from backend.data.model import CredentialsMetaInput
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.data.onboarding import (
UserOnboardingUpdate,
@@ -592,35 +593,21 @@ async def set_graph_active_version(
)
async def execute_graph(
graph_id: str,
node_input: Annotated[dict[str, Any], Body(..., default_factory=dict)],
user_id: Annotated[str, Depends(get_user_id)],
inputs: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
credentials_inputs: Annotated[
dict[str, CredentialsMetaInput], Body(..., embed=True, default_factory=dict)
],
graph_version: Optional[int] = None,
preset_id: Optional[str] = None,
) -> ExecuteGraphResponse:
graph: graph_db.GraphModel | None = await graph_db.get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
)
if not graph:
raise ValueError(f"Graph #{graph_id} not found.")
graph_exec = await execution_db.create_graph_execution(
graph_exec = await execution_utils.add_graph_execution_async(
graph_id=graph_id,
graph_version=graph.version,
nodes_input=execution_utils.construct_node_execution_input(
graph, user_id, node_input
),
user_id=user_id,
inputs=inputs,
preset_id=preset_id,
)
bus = execution_event_bus()
await bus.publish(graph_exec)
queue = await execution_queue_client()
await queue.publish_message(
routing_key=execution_utils.GRAPH_EXECUTION_ROUTING_KEY,
message=graph_exec.to_graph_execution_entry().model_dump_json(),
exchange=execution_utils.GRAPH_EXECUTION_EXCHANGE,
graph_version=graph_version,
graph_credentials_inputs=credentials_inputs,
)
return ExecuteGraphResponse(graph_exec_id=graph_exec.id)

View File

@@ -6,6 +6,7 @@ from fastapi import APIRouter, Body, Depends, HTTPException, status
import backend.server.v2.library.db as db
import backend.server.v2.library.model as models
from backend.executor.utils import add_graph_execution_async
logger = logging.getLogger(__name__)
@@ -207,8 +208,6 @@ async def execute_preset(
HTTPException: If the preset is not found or an error occurs while executing the preset.
"""
try:
from backend.server.routers import v1 as internal_api_routes
preset = await db.get_preset(user_id, preset_id)
if not preset:
raise HTTPException(
@@ -219,17 +218,17 @@ async def execute_preset(
# Merge input overrides with preset inputs
merged_node_input = preset.inputs | node_input
execution = await internal_api_routes.execute_graph(
execution = await add_graph_execution_async(
graph_id=graph_id,
node_input=merged_node_input,
graph_version=graph_version,
user_id=user_id,
inputs=merged_node_input,
preset_id=preset_id,
graph_version=graph_version,
)
logger.debug(f"Execution added: {execution} with input: {merged_node_input}")
return {"id": execution.graph_exec_id}
return {"id": execution.id}
except HTTPException:
raise
except Exception as e:

View File

@@ -131,11 +131,7 @@ export default function PrivatePage() {
const allCredentials = providers
? Object.values(providers).flatMap((provider) =>
[
...provider.savedOAuthCredentials,
...provider.savedApiKeys,
...provider.savedUserPasswordCredentials,
]
provider.savedCredentials
.filter((cred) => !hiddenCredentials.includes(cred.id))
.map((credentials) => ({
...credentials,

View File

@@ -178,18 +178,24 @@ export const CustomNode = React.memo(
return obj;
}, []);
const setHardcodedValues = (values: any) => {
updateNodeData(id, { hardcodedValues: values });
};
const setHardcodedValues = useCallback(
(values: any) => {
updateNodeData(id, { hardcodedValues: values });
},
[id, updateNodeData],
);
useEffect(() => {
isInitialSetup.current = false;
setHardcodedValues(fillDefaults(data.hardcodedValues, data.inputSchema));
}, []);
const setErrors = (errors: { [key: string]: string }) => {
updateNodeData(id, { errors });
};
const setErrors = useCallback(
(errors: { [key: string]: string }) => {
updateNodeData(id, { errors });
},
[id, updateNodeData],
);
const toggleOutput = (checked: boolean) => {
setIsOutputOpen(checked);
@@ -340,46 +346,49 @@ export const CustomNode = React.memo(
});
}
};
const handleInputChange = (path: string, value: any) => {
const keys = parseKeys(path);
const newValues = JSON.parse(JSON.stringify(data.hardcodedValues));
let current = newValues;
const handleInputChange = useCallback(
(path: string, value: any) => {
const keys = parseKeys(path);
const newValues = JSON.parse(JSON.stringify(data.hardcodedValues));
let current = newValues;
for (let i = 0; i < keys.length - 1; i++) {
const { key: currentKey, index } = keys[i];
if (index !== undefined) {
if (!current[currentKey]) current[currentKey] = [];
if (!current[currentKey][index]) current[currentKey][index] = {};
current = current[currentKey][index];
} else {
if (!current[currentKey]) current[currentKey] = {};
current = current[currentKey];
for (let i = 0; i < keys.length - 1; i++) {
const { key: currentKey, index } = keys[i];
if (index !== undefined) {
if (!current[currentKey]) current[currentKey] = [];
if (!current[currentKey][index]) current[currentKey][index] = {};
current = current[currentKey][index];
} else {
if (!current[currentKey]) current[currentKey] = {};
current = current[currentKey];
}
}
}
const lastKey = keys[keys.length - 1];
if (lastKey.index !== undefined) {
if (!current[lastKey.key]) current[lastKey.key] = [];
current[lastKey.key][lastKey.index] = value;
} else {
current[lastKey.key] = value;
}
const lastKey = keys[keys.length - 1];
if (lastKey.index !== undefined) {
if (!current[lastKey.key]) current[lastKey.key] = [];
current[lastKey.key][lastKey.index] = value;
} else {
current[lastKey.key] = value;
}
if (!isInitialSetup.current) {
history.push({
type: "UPDATE_INPUT",
payload: { nodeId: id, oldValues: data.hardcodedValues, newValues },
undo: () => setHardcodedValues(data.hardcodedValues),
redo: () => setHardcodedValues(newValues),
});
}
if (!isInitialSetup.current) {
history.push({
type: "UPDATE_INPUT",
payload: { nodeId: id, oldValues: data.hardcodedValues, newValues },
undo: () => setHardcodedValues(data.hardcodedValues),
redo: () => setHardcodedValues(newValues),
});
}
setHardcodedValues(newValues);
const errors = data.errors || {};
// Remove error with the same key
setNestedProperty(errors, path, null);
setErrors({ ...errors });
};
setHardcodedValues(newValues);
const errors = data.errors || {};
// Remove error with the same key
setNestedProperty(errors, path, null);
setErrors({ ...errors });
},
[data.hardcodedValues, id, setHardcodedValues, data.errors, setErrors],
);
const isInputHandleConnected = (key: string) => {
return (
@@ -407,28 +416,34 @@ export const CustomNode = React.memo(
);
};
const handleInputClick = (key: string) => {
console.debug(`Opening modal for key: ${key}`);
setActiveKey(key);
const value = getValue(key, data.hardcodedValues);
setInputModalValue(
typeof value === "object" ? JSON.stringify(value, null, 2) : value,
);
setIsModalOpen(true);
};
const handleInputClick = useCallback(
(key: string) => {
console.debug(`Opening modal for key: ${key}`);
setActiveKey(key);
const value = getValue(key, data.hardcodedValues);
setInputModalValue(
typeof value === "object" ? JSON.stringify(value, null, 2) : value,
);
setIsModalOpen(true);
},
[data.hardcodedValues],
);
const handleModalSave = (value: string) => {
if (activeKey) {
try {
const parsedValue = JSON.parse(value);
handleInputChange(activeKey, parsedValue);
} catch (error) {
handleInputChange(activeKey, value);
const handleModalSave = useCallback(
(value: string) => {
if (activeKey) {
try {
const parsedValue = JSON.parse(value);
handleInputChange(activeKey, parsedValue);
} catch (error) {
handleInputChange(activeKey, value);
}
}
}
setIsModalOpen(false);
setActiveKey(null);
};
setIsModalOpen(false);
setActiveKey(null);
},
[activeKey, handleInputChange],
);
const handleOutputClick = () => {
setIsOutputModalOpen(true);

View File

@@ -1,5 +1,6 @@
"use client";
import React, { useCallback, useMemo } from "react";
import { isEmpty } from "lodash";
import moment from "moment";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
@@ -164,7 +165,8 @@ export default function AgentRunDetailsView({
] satisfies ButtonAction[])
: []),
...(["success", "failed", "stopped"].includes(runStatus) &&
!graph.has_webhook_trigger
!graph.has_webhook_trigger &&
isEmpty(graph.credentials_input_schema.required) // TODO: enable re-run with credentials - https://linear.app/autogpt/issue/SECRT-1243
? [
{
label: (
@@ -193,6 +195,7 @@ export default function AgentRunDetailsView({
stopRun,
deleteRun,
graph.has_webhook_trigger,
graph.credentials_input_schema.properties,
agent.can_access_graph,
run.graph_id,
run.graph_version,

View File

@@ -6,6 +6,7 @@ import { GraphExecutionID, GraphMeta } from "@/lib/autogpt-server-api";
import type { ButtonAction } from "@/components/agptui/types";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import { CredentialsInput } from "@/components/integrations/credentials-input";
import { TypeBasedInput } from "@/components/type-based-input";
import { useToastOnFail } from "@/components/ui/use-toast";
import ActionButtonGroup from "@/components/agptui/action-button-group";
@@ -26,19 +27,32 @@ export default function AgentRunDraftView({
const toastOnFail = useToastOnFail();
const agentInputs = graph.input_schema.properties;
const agentCredentialsInputs = graph.credentials_input_schema.properties;
const [inputValues, setInputValues] = useState<Record<string, any>>({});
const [inputCredentials, setInputCredentials] = useState<Record<string, any>>(
{},
);
const { state, completeStep } = useOnboarding();
const doRun = useCallback(() => {
api
.executeGraph(graph.id, graph.version, inputValues)
.executeGraph(graph.id, graph.version, inputValues, inputCredentials)
.then((newRun) => onRun(newRun.graph_exec_id))
.catch(toastOnFail("execute agent"));
// Mark run agent onboarding step as completed
if (state?.completedSteps.includes("MARKETPLACE_ADD_AGENT")) {
completeStep("MARKETPLACE_RUN_AGENT");
}
}, [api, graph, inputValues, onRun, state]);
}, [
api,
graph,
inputValues,
inputCredentials,
onRun,
toastOnFail,
state,
completeStep,
]);
const runActions: ButtonAction[] = useMemo(
() => [
@@ -64,6 +78,26 @@ export default function AgentRunDraftView({
<CardTitle className="font-poppins text-lg">Input</CardTitle>
</CardHeader>
<CardContent className="flex flex-col gap-4">
{/* Credentials inputs */}
{Object.entries(agentCredentialsInputs).map(
([key, inputSubSchema]) => (
<CredentialsInput
key={key}
schema={{ ...inputSubSchema, discriminator: undefined }}
selectedCredentials={
inputCredentials[key] ?? inputSubSchema.default
}
onSelectCredentials={(value) =>
setInputCredentials((obj) => ({
...obj,
[key]: value,
}))
}
/>
),
)}
{/* Regular inputs */}
{Object.entries(agentInputs).map(([key, inputSubSchema]) => (
<div key={key} className="flex flex-col space-y-2">
<label className="flex items-center gap-1 text-sm font-medium">

View File

@@ -1,5 +1,6 @@
import { FC, useEffect, useMemo, useState } from "react";
import { z } from "zod";
import { beautifyString, cn } from "@/lib/utils";
import { cn } from "@/lib/utils";
import { useForm } from "react-hook-form";
import { Input } from "@/components/ui/input";
import { Button } from "@/components/ui/button";
@@ -16,8 +17,8 @@ import {
FaKey,
FaHubspot,
} from "react-icons/fa";
import { FC, useMemo, useState } from "react";
import {
BlockIOCredentialsSubSchema,
CredentialsMetaInput,
CredentialsProviderName,
} from "@/lib/autogpt-server-api/types";
@@ -106,13 +107,18 @@ export type OAuthPopupResultMessage = { message_type: "oauth_popup_result" } & (
);
export const CredentialsInput: FC<{
selfKey: string;
schema: BlockIOCredentialsSubSchema;
className?: string;
selectedCredentials?: CredentialsMetaInput;
onSelectCredentials: (newValue?: CredentialsMetaInput) => void;
}> = ({ selfKey, className, selectedCredentials, onSelectCredentials }) => {
const api = useBackendAPI();
const credentials = useCredentials(selfKey);
siblingInputs?: Record<string, any>;
}> = ({
schema,
className,
selectedCredentials,
onSelectCredentials,
siblingInputs,
}) => {
const [isAPICredentialsModalOpen, setAPICredentialsModalOpen] =
useState(false);
const [
@@ -124,20 +130,47 @@ export const CredentialsInput: FC<{
useState<AbortController | null>(null);
const [oAuthError, setOAuthError] = useState<string | null>(null);
if (!credentials || credentials.isLoading) {
const api = useBackendAPI();
const credentials = useCredentials(schema, siblingInputs);
// Deselect credentials if they do not exist (e.g. provider was changed)
useEffect(() => {
if (!credentials || !("savedCredentials" in credentials)) return;
if (
selectedCredentials &&
!credentials.savedCredentials.some((c) => c.id === selectedCredentials.id)
) {
onSelectCredentials(undefined);
}
}, [credentials, selectedCredentials, onSelectCredentials]);
const singleCredential = useMemo(() => {
if (!credentials || !("savedCredentials" in credentials)) return null;
if (credentials.savedCredentials.length === 1)
return credentials.savedCredentials[0];
return null;
}, [credentials]);
// If only 1 credential is available, auto-select it and hide this input
useEffect(() => {
if (singleCredential && !selectedCredentials) {
onSelectCredentials(singleCredential);
}
}, [singleCredential, selectedCredentials, onSelectCredentials]);
if (!credentials || credentials.isLoading || singleCredential) {
return null;
}
const {
schema,
provider,
providerName,
supportsApiKey,
supportsOAuth2,
supportsUserPassword,
savedApiKeys,
savedOAuthCredentials,
savedUserPasswordCredentials,
savedCredentials,
oAuthCallback,
} = credentials;
@@ -235,13 +268,14 @@ export const CredentialsInput: FC<{
<>
{supportsApiKey && (
<APIKeyCredentialsModal
credentialsFieldName={selfKey}
schema={schema}
open={isAPICredentialsModalOpen}
onClose={() => setAPICredentialsModalOpen(false)}
onCredentialsCreate={(credsMeta) => {
onSelectCredentials(credsMeta);
setAPICredentialsModalOpen(false);
}}
siblingInputs={siblingInputs}
/>
)}
{supportsOAuth2 && (
@@ -253,43 +287,34 @@ export const CredentialsInput: FC<{
)}
{supportsUserPassword && (
<UserPasswordCredentialsModal
credentialsFieldName={selfKey}
schema={schema}
open={isUserPasswordCredentialsModalOpen}
onClose={() => setUserPasswordCredentialsModalOpen(false)}
onCredentialsCreate={(creds) => {
onSelectCredentials(creds);
setUserPasswordCredentialsModalOpen(false);
}}
siblingInputs={siblingInputs}
/>
)}
</>
);
// Deselect credentials if they do not exist (e.g. provider was changed)
if (
selectedCredentials &&
!savedApiKeys
.concat(savedOAuthCredentials)
.concat(savedUserPasswordCredentials)
.some((c) => c.id === selectedCredentials.id)
) {
onSelectCredentials(undefined);
}
const fieldHeader = (
<div className="mb-2 flex gap-1">
<span className="text-m green text-gray-900">
{providerName} Credentials
</span>
<SchemaTooltip description={schema.description} />
</div>
);
// No saved credentials yet
if (
savedApiKeys.length === 0 &&
savedOAuthCredentials.length === 0 &&
savedUserPasswordCredentials.length === 0
) {
if (savedCredentials.length === 0) {
return (
<>
<div className="mb-2 flex gap-1">
<span className="text-m green text-gray-900">
{providerName} Credentials
</span>
<SchemaTooltip description={schema.description} />
</div>
<div>
{fieldHeader}
<div className={cn("flex flex-row space-x-2", className)}>
{supportsOAuth2 && (
<Button onClick={handleOAuthLogin}>
@@ -314,46 +339,10 @@ export const CredentialsInput: FC<{
{oAuthError && (
<div className="mt-2 text-red-500">Error: {oAuthError}</div>
)}
</>
</div>
);
}
const getCredentialCounts = () => ({
apiKeys: savedApiKeys.length,
oauth: savedOAuthCredentials.length,
userPass: savedUserPasswordCredentials.length,
});
const getSingleCredential = () => {
const counts = getCredentialCounts();
const totalCredentials = Object.values(counts).reduce(
(sum, count) => sum + count,
0,
);
if (totalCredentials !== 1) return null;
if (counts.apiKeys === 1) return savedApiKeys[0];
if (counts.oauth === 1) return savedOAuthCredentials[0];
if (counts.userPass === 1) return savedUserPasswordCredentials[0];
return null;
};
const singleCredential = getSingleCredential();
if (singleCredential) {
if (!selectedCredentials) {
onSelectCredentials({
id: singleCredential.id,
type: singleCredential.type,
provider,
title: singleCredential.title,
});
}
return null;
}
function handleValueChange(newValue: string) {
if (newValue === "sign-in") {
// Trigger OAuth2 sign in flow
@@ -362,10 +351,7 @@ export const CredentialsInput: FC<{
// Open API key dialog
setAPICredentialsModalOpen(true);
} else {
const selectedCreds = savedApiKeys
.concat(savedOAuthCredentials)
.concat(savedUserPasswordCredentials)
.find((c) => c.id == newValue)!;
const selectedCreds = savedCredentials.find((c) => c.id == newValue)!;
onSelectCredentials({
id: selectedCreds.id,
@@ -378,38 +364,40 @@ export const CredentialsInput: FC<{
// Saved credentials exist
return (
<>
<div className="flex gap-1">
<span className="text-m green mb-0 text-gray-900">
{providerName} Credentials
</span>
<SchemaTooltip description={schema.description} />
</div>
<div>
{fieldHeader}
<Select value={selectedCredentials?.id} onValueChange={handleValueChange}>
<SelectTrigger>
<SelectValue placeholder={schema.placeholder} />
</SelectTrigger>
<SelectContent className="nodrag">
{savedOAuthCredentials.map((credentials, index) => (
<SelectItem key={index} value={credentials.id}>
<ProviderIcon className="mr-2 inline h-4 w-4" />
{credentials.username}
</SelectItem>
))}
{savedApiKeys.map((credentials, index) => (
<SelectItem key={index} value={credentials.id}>
<ProviderIcon className="mr-2 inline h-4 w-4" />
<IconKey className="mr-1.5 inline" />
{credentials.title}
</SelectItem>
))}
{savedUserPasswordCredentials.map((credentials, index) => (
<SelectItem key={index} value={credentials.id}>
<ProviderIcon className="mr-2 inline h-4 w-4" />
<IconUserPlus className="mr-1.5 inline" />
{credentials.title}
</SelectItem>
))}
{savedCredentials
.filter((c) => c.type == "oauth2")
.map((credentials, index) => (
<SelectItem key={index} value={credentials.id}>
<ProviderIcon className="mr-2 inline h-4 w-4" />
{credentials.username}
</SelectItem>
))}
{savedCredentials
.filter((c) => c.type == "api_key")
.map((credentials, index) => (
<SelectItem key={index} value={credentials.id}>
<ProviderIcon className="mr-2 inline h-4 w-4" />
<IconKey className="mr-1.5 inline" />
{credentials.title}
</SelectItem>
))}
{savedCredentials
.filter((c) => c.type == "user_password")
.map((credentials, index) => (
<SelectItem key={index} value={credentials.id}>
<ProviderIcon className="mr-2 inline h-4 w-4" />
<IconUserPlus className="mr-1.5 inline" />
{credentials.title}
</SelectItem>
))}
<SelectSeparator />
{supportsOAuth2 && (
<SelectItem value="sign-in">
@@ -435,17 +423,18 @@ export const CredentialsInput: FC<{
{oAuthError && (
<div className="mt-2 text-red-500">Error: {oAuthError}</div>
)}
</>
</div>
);
};
export const APIKeyCredentialsModal: FC<{
credentialsFieldName: string;
schema: BlockIOCredentialsSubSchema;
open: boolean;
onClose: () => void;
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
}> = ({ credentialsFieldName, open, onClose, onCredentialsCreate }) => {
const credentials = useCredentials(credentialsFieldName);
siblingInputs?: Record<string, any>;
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
const credentials = useCredentials(schema, siblingInputs);
const formSchema = z.object({
apiKey: z.string().min(1, "API Key is required"),
@@ -466,8 +455,7 @@ export const APIKeyCredentialsModal: FC<{
return null;
}
const { schema, provider, providerName, createAPIKeyCredentials } =
credentials;
const { provider, providerName, createAPIKeyCredentials } = credentials;
async function onSubmit(values: z.infer<typeof formSchema>) {
const expiresAt = values.expiresAt
@@ -576,12 +564,13 @@ export const APIKeyCredentialsModal: FC<{
};
export const UserPasswordCredentialsModal: FC<{
credentialsFieldName: string;
schema: BlockIOCredentialsSubSchema;
open: boolean;
onClose: () => void;
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
}> = ({ credentialsFieldName, open, onClose, onCredentialsCreate }) => {
const credentials = useCredentials(credentialsFieldName);
siblingInputs?: Record<string, any>;
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
const credentials = useCredentials(schema, siblingInputs);
const formSchema = z.object({
username: z.string().min(1, "Username is required"),
@@ -606,8 +595,7 @@ export const UserPasswordCredentialsModal: FC<{
return null;
}
const { schema, provider, providerName, createUserPasswordCredentials } =
credentials;
const { provider, providerName, createUserPasswordCredentials } = credentials;
async function onSubmit(values: z.infer<typeof formSchema>) {
const newCredentials = await createUserPasswordCredentials({

View File

@@ -68,9 +68,7 @@ type UserPasswordCredentialsCreatable = Omit<
export type CredentialsProviderData = {
provider: CredentialsProviderName;
providerName: string;
savedApiKeys: CredentialsMetaResponse[];
savedOAuthCredentials: CredentialsMetaResponse[];
savedUserPasswordCredentials: CredentialsMetaResponse[];
savedCredentials: CredentialsMetaResponse[];
oAuthCallback: (
code: string,
state_token: string,
@@ -113,28 +111,12 @@ export default function CredentialsProvider({
setProviders((prev) => {
if (!prev || !prev[provider]) return prev;
const updatedProvider = { ...prev[provider] };
if (credentials.type === "api_key") {
updatedProvider.savedApiKeys = [
...updatedProvider.savedApiKeys,
credentials,
];
} else if (credentials.type === "oauth2") {
updatedProvider.savedOAuthCredentials = [
...updatedProvider.savedOAuthCredentials,
credentials,
];
} else if (credentials.type === "user_password") {
updatedProvider.savedUserPasswordCredentials = [
...updatedProvider.savedUserPasswordCredentials,
credentials,
];
}
return {
...prev,
[provider]: updatedProvider,
[provider]: {
...prev[provider],
savedCredentials: [...prev[provider].savedCredentials, credentials],
},
};
});
},
@@ -203,21 +185,14 @@ export default function CredentialsProvider({
setProviders((prev) => {
if (!prev || !prev[provider]) return prev;
const updatedProvider = { ...prev[provider] };
updatedProvider.savedApiKeys = updatedProvider.savedApiKeys.filter(
(cred) => cred.id !== id,
);
updatedProvider.savedOAuthCredentials =
updatedProvider.savedOAuthCredentials.filter(
(cred) => cred.id !== id,
);
updatedProvider.savedUserPasswordCredentials =
updatedProvider.savedUserPasswordCredentials.filter(
(cred) => cred.id !== id,
);
return {
...prev,
[provider]: updatedProvider,
[provider]: {
...prev[provider],
savedCredentials: prev[provider].savedCredentials.filter(
(cred) => cred.id !== id,
),
},
};
});
return result;
@@ -233,29 +208,12 @@ export default function CredentialsProvider({
const credentialsByProvider = response.reduce(
(acc, cred) => {
if (!acc[cred.provider]) {
acc[cred.provider] = {
oauthCreds: [],
apiKeys: [],
userPasswordCreds: [],
};
}
if (cred.type === "oauth2") {
acc[cred.provider].oauthCreds.push(cred);
} else if (cred.type === "api_key") {
acc[cred.provider].apiKeys.push(cred);
} else if (cred.type === "user_password") {
acc[cred.provider].userPasswordCreds.push(cred);
acc[cred.provider] = [];
}
acc[cred.provider].push(cred);
return acc;
},
{} as Record<
CredentialsProviderName,
{
oauthCreds: CredentialsMetaResponse[];
apiKeys: CredentialsMetaResponse[];
userPasswordCreds: CredentialsMetaResponse[];
}
>,
{} as Record<CredentialsProviderName, CredentialsMetaResponse[]>,
);
setProviders((prev) => ({
@@ -265,40 +223,19 @@ export default function CredentialsProvider({
provider,
{
provider,
providerName:
providerDisplayNames[provider as CredentialsProviderName],
savedApiKeys: credentialsByProvider[provider]?.apiKeys ?? [],
savedOAuthCredentials:
credentialsByProvider[provider]?.oauthCreds ?? [],
savedUserPasswordCredentials:
credentialsByProvider[provider]?.userPasswordCreds ?? [],
providerName: providerDisplayNames[provider],
savedCredentials: credentialsByProvider[provider] ?? [],
oAuthCallback: (code: string, state_token: string) =>
oAuthCallback(
provider as CredentialsProviderName,
code,
state_token,
),
oAuthCallback(provider, code, state_token),
createAPIKeyCredentials: (
credentials: APIKeyCredentialsCreatable,
) =>
createAPIKeyCredentials(
provider as CredentialsProviderName,
credentials,
),
) => createAPIKeyCredentials(provider, credentials),
createUserPasswordCredentials: (
credentials: UserPasswordCredentialsCreatable,
) =>
createUserPasswordCredentials(
provider as CredentialsProviderName,
credentials,
),
) => createUserPasswordCredentials(provider, credentials),
deleteCredentials: (id: string, force: boolean = false) =>
deleteCredentials(
provider as CredentialsProviderName,
id,
force,
),
},
deleteCredentials(provider, id, force),
} satisfies CredentialsProviderData,
]),
),
}));

View File

@@ -6,18 +6,21 @@ import {
} from "@/components/ui/popover";
import { format } from "date-fns";
import { CalendarIcon } from "lucide-react";
import { Cross2Icon, Pencil2Icon, PlusIcon } from "@radix-ui/react-icons";
import { beautifyString, cn } from "@/lib/utils";
import { Node, useNodeId, useNodesData } from "@xyflow/react";
import { ConnectionData, CustomNodeData } from "@/components/CustomNode";
import { Cross2Icon, Pencil2Icon, PlusIcon } from "@radix-ui/react-icons";
import {
BlockIORootSchema,
BlockIOSubSchema,
BlockIOObjectSubSchema,
BlockIOKVSubSchema,
BlockIOArraySubSchema,
BlockIOStringSubSchema,
BlockIONumberSubSchema,
BlockIOBooleanSubSchema,
BlockIOCredentialsSubSchema,
BlockIOKVSubSchema,
BlockIONumberSubSchema,
BlockIOObjectSubSchema,
BlockIORootSchema,
BlockIOSimpleTypeSubSchema,
BlockIOStringSubSchema,
BlockIOSubSchema,
DataType,
determineDataType,
} from "@/lib/autogpt-server-api/types";
@@ -48,8 +51,7 @@ import {
} from "./ui/multiselect";
import { LocalValuedInput } from "./ui/input";
import NodeHandle from "./NodeHandle";
import { ConnectionData } from "./CustomNode";
import { CredentialsInput } from "./integrations/credentials-input";
import { CredentialsInput } from "@/components/integrations/credentials-input";
type NodeObjectInputTreeProps = {
nodeId: string;
@@ -357,6 +359,7 @@ export const NodeGenericInputField: FC<{
return (
<NodeCredentialsInput
selfKey={propKey}
schema={propSchema as BlockIOCredentialsSubSchema}
value={currentValue}
errors={errors}
className={className}
@@ -697,15 +700,19 @@ const NodeOneOfDiscriminatorField: FC<{
const NodeCredentialsInput: FC<{
selfKey: string;
schema: BlockIOCredentialsSubSchema;
value: any;
errors: { [key: string]: string | undefined };
handleInputChange: NodeObjectInputTreeProps["handleInputChange"];
className?: string;
}> = ({ selfKey, value, errors, handleInputChange, className }) => {
}> = ({ selfKey, schema, value, errors, handleInputChange, className }) => {
const nodeInputValues = useNodesData<Node<CustomNodeData>>(useNodeId()!)?.data
.hardcodedValues;
return (
<div className={cn("flex flex-col", className)}>
<CredentialsInput
selfKey={selfKey}
schema={schema}
siblingInputs={nodeInputValues}
onSelectCredentials={(credsMeta) =>
handleInputChange(selfKey, credsMeta)
}

View File

@@ -1,15 +1,14 @@
import { useContext } from "react";
import { CustomNodeData } from "@/components/CustomNode";
import {
BlockIOCredentialsSubSchema,
CredentialsProviderName,
} from "@/lib/autogpt-server-api";
import { Node, useNodeId, useNodesData } from "@xyflow/react";
import { getValue } from "@/lib/utils";
import {
CredentialsProviderData,
CredentialsProvidersContext,
} from "@/components/integrations/credentials-provider";
import { getValue } from "@/lib/utils";
import {
BlockIOCredentialsSubSchema,
CredentialsProviderName,
} from "@/lib/autogpt-server-api";
export type CredentialsData =
| {
@@ -29,59 +28,53 @@ export type CredentialsData =
});
export default function useCredentials(
inputFieldName: string,
credsInputSchema: BlockIOCredentialsSubSchema,
nodeInputValues?: Record<string, any>,
): CredentialsData | null {
const nodeId = useNodeId();
const allProviders = useContext(CredentialsProvidersContext);
if (!nodeId) {
throw new Error("useCredentials must be within a CustomNode");
}
const data = useNodesData<Node<CustomNodeData>>(nodeId)!.data;
const credentialsSchema = data.inputSchema.properties[
inputFieldName
] as BlockIOCredentialsSubSchema;
const discriminatorValue: CredentialsProviderName | null =
(credentialsSchema.discriminator &&
credentialsSchema.discriminator_mapping![
getValue(credentialsSchema.discriminator, data.hardcodedValues)
(credsInputSchema.discriminator &&
credsInputSchema.discriminator_mapping![
getValue(credsInputSchema.discriminator, nodeInputValues)
]) ||
null;
let providerName: CredentialsProviderName;
if (credentialsSchema.credentials_provider.length > 1) {
if (!credentialsSchema.discriminator) {
if (credsInputSchema.credentials_provider.length > 1) {
if (!credsInputSchema.discriminator) {
throw new Error(
"Multi-provider credential input requires discriminator!",
);
}
if (!discriminatorValue) {
console.log(
`Missing discriminator value from '${credsInputSchema.discriminator}': ` +
"hiding credentials input until it is set.",
);
return null;
}
providerName = discriminatorValue;
} else {
providerName = credentialsSchema.credentials_provider[0];
providerName = credsInputSchema.credentials_provider[0];
}
const provider = allProviders ? allProviders[providerName] : null;
// If block input schema doesn't have credentials, return null
if (!credentialsSchema) {
if (!credsInputSchema) {
return null;
}
const supportsApiKey =
credentialsSchema.credentials_types.includes("api_key");
const supportsOAuth2 = credentialsSchema.credentials_types.includes("oauth2");
const supportsApiKey = credsInputSchema.credentials_types.includes("api_key");
const supportsOAuth2 = credsInputSchema.credentials_types.includes("oauth2");
const supportsUserPassword =
credentialsSchema.credentials_types.includes("user_password");
credsInputSchema.credentials_types.includes("user_password");
// No provider means maybe it's still loading
if (!provider) {
// return {
// provider: credentialsSchema.credentials_provider,
// schema: credentialsSchema,
// provider: credsInputSchema.credentials_provider,
// schema: credsInputSchema,
// supportsApiKey,
// supportsOAuth2,
// isLoading: true,
@@ -90,24 +83,23 @@ export default function useCredentials(
}
// Filter by OAuth credentials that have sufficient scopes for this block
const requiredScopes = credentialsSchema.credentials_scopes;
const savedOAuthCredentials = requiredScopes
? provider.savedOAuthCredentials.filter((c) =>
new Set(c.scopes).isSupersetOf(new Set(requiredScopes)),
const requiredScopes = credsInputSchema.credentials_scopes;
const savedCredentials = requiredScopes
? provider.savedCredentials.filter(
(c) =>
c.type != "oauth2" ||
new Set(c.scopes).isSupersetOf(new Set(requiredScopes)),
)
: provider.savedOAuthCredentials;
const savedUserPasswordCredentials = provider.savedUserPasswordCredentials;
: provider.savedCredentials;
return {
...provider,
provider: providerName,
schema: credentialsSchema,
schema: credsInputSchema,
supportsApiKey,
supportsOAuth2,
supportsUserPassword,
savedOAuthCredentials,
savedUserPasswordCredentials,
savedCredentials,
isLoading: false,
};
}

View File

@@ -53,6 +53,7 @@ import {
UserOnboarding,
ReviewSubmissionRequest,
SubmissionStatus,
CredentialsMetaInput,
} from "./types";
import { createBrowserClient } from "@supabase/ssr";
import getServerSupabase from "../supabase/getServerSupabase";
@@ -251,9 +252,13 @@ export default class BackendAPI {
executeGraph(
id: GraphID,
version: number,
inputData: { [key: string]: any } = {},
inputs: { [key: string]: any } = {},
credentials_inputs: { [key: string]: CredentialsMetaInput } = {},
): Promise<{ graph_exec_id: GraphExecutionID }> {
return this._request("POST", `/graphs/${id}/execute/${version}`, inputData);
return this._request("POST", `/graphs/${id}/execute/${version}`, {
inputs,
credentials_inputs,
});
}
getExecutions(): Promise<GraphExecutionMeta[]> {

View File

@@ -282,6 +282,11 @@ export type GraphMeta = {
description: string;
input_schema: GraphIOSchema;
output_schema: GraphIOSchema;
credentials_input_schema: {
type: "object";
properties: { [key: string]: BlockIOCredentialsSubSchema };
required: (keyof GraphMeta["credentials_input_schema"]["properties"])[];
};
};
export type GraphID = Brand<string, "GraphID">;
@@ -317,6 +322,7 @@ export type GraphUpdateable = Omit<
| "links"
| "input_schema"
| "output_schema"
| "credentials_input_schema"
| "has_webhook_trigger"
> & {
version?: number;