fix(backend): Prevent computing credentials_input_schema multiple times

This commit is contained in:
Reinier van der Leer
2026-02-05 22:47:02 +01:00
parent 3ca2387631
commit bb0bc45528

View File

@@ -366,39 +366,8 @@ class Graph(BaseGraph):
@computed_field
@property
def credentials_input_schema(self) -> dict[str, Any]:
schema = self._credentials_input_schema.jsonschema()
# Determine which credential fields are required based on credentials_optional metadata
graph_credentials_inputs = self.aggregate_credentials_inputs()
required_fields = []
# Build a map of node_id -> node for quick lookup
all_nodes = {node.id: node for node in self.nodes}
for sub_graph in self.sub_graphs:
for node in sub_graph.nodes:
all_nodes[node.id] = node
for field_key, (
_field_info,
node_field_pairs,
) in graph_credentials_inputs.items():
# A field is required if ANY node using it has credentials_optional=False
is_required = False
for node_id, _field_name in node_field_pairs:
node = all_nodes.get(node_id)
if node and not node.credentials_optional:
is_required = True
break
if is_required:
required_fields.append(field_key)
schema["required"] = required_fields
return schema
@property
def _credentials_input_schema(self) -> type[BlockSchema]:
graph_credentials_inputs = self.aggregate_credentials_inputs()
logger.debug(
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
f"{graph_credentials_inputs}"
@@ -406,8 +375,8 @@ class Graph(BaseGraph):
# Warn if same-provider credentials inputs can't be combined (= bad UX)
graph_cred_fields = list(graph_credentials_inputs.values())
for i, (field, keys) in enumerate(graph_cred_fields):
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
for i, (field, keys, _) in enumerate(graph_cred_fields):
for other_field, other_keys, _ in list(graph_cred_fields)[i + 1 :]:
if field.provider != other_field.provider:
continue
if ProviderName.HTTP in field.provider:
@@ -423,12 +392,23 @@ class Graph(BaseGraph):
f"keys: {keys} <> {other_keys}."
)
# Build the Pydantic model for the credentials input schema
fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = {
agg_field_key: (
CredentialsMetaInput[
Literal[tuple(field_info.provider)], # type: ignore
Literal[tuple(field_info.supported_types)], # type: ignore
],
(
CMI
if (
(
CMI := CredentialsMetaInput[
Literal[tuple(field_info.provider)], # type: ignore
Literal[tuple(field_info.supported_types)], # type: ignore
]
)
or True
)
and is_required
else CMI | None
),
CredentialsField(
required_scopes=set(field_info.required_scopes or []),
discriminator=field_info.discriminator,
@@ -436,18 +416,22 @@ class Graph(BaseGraph):
discriminator_values=field_info.discriminator_values,
),
)
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
for agg_field_key, (
field_info,
_,
is_required,
) in graph_credentials_inputs.items()
}
return create_model(
self.name.replace(" ", "") + "CredentialsInputSchema",
__base__=BlockSchema,
**fields, # type: ignore
)
).jsonschema()
def aggregate_credentials_inputs(
self,
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]:
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
"""
Returns:
dict[aggregated_field_key, tuple(
@@ -455,13 +439,19 @@ class Graph(BaseGraph):
(now includes discriminator_values from matching nodes)
set[(node_id, field_name)]: Node credentials fields that are
compatible with this aggregated field spec
bool: True if the field is required (any node has credentials_optional=False)
)]
"""
# First collect all credential field data with input defaults
node_credential_data = []
# Track (field_info, (node_id, field_name), is_required) for each credential field
node_credential_data: list[tuple[CredentialsFieldInfo, tuple[str, str]]] = []
node_required_map: dict[str, bool] = {} # node_id -> is_required
for graph in [self] + self.sub_graphs:
for node in graph.nodes:
# Track if this node requires credentials (credentials_optional=False means required)
node_required_map[node.id] = not node.credentials_optional
for (
field_name,
field_info,
@@ -485,7 +475,21 @@ class Graph(BaseGraph):
)
# Combine credential field info (this will merge discriminator_values automatically)
return CredentialsFieldInfo.combine(*node_credential_data)
combined = CredentialsFieldInfo.combine(*node_credential_data)
# Add is_required flag to each aggregated field
# A field is required if ANY node using it has credentials_optional=False
return {
key: (
field_info,
node_field_pairs,
any(
node_required_map.get(node_id, True)
for node_id, _ in node_field_pairs
),
)
for key, (field_info, node_field_pairs) in combined.items()
}
class GraphModel(Graph):
@@ -832,16 +836,55 @@ class GraphModel(Graph):
)
class GraphMeta(Graph):
user_id: str
class GraphMeta(BaseModel):
"""
Graph metadata without nodes/links — used for list endpoints.
# Easy work-around to prevent exposing nodes and links in the API response
nodes: list[NodeModel] = Field(default=[], exclude=True) # type: ignore
links: list[Link] = Field(default=[], exclude=True)
This is a flat, lightweight model (not inheriting from Graph) to avoid recomputing
expensive computed fields. Values are copied from GraphModel.
"""
id: str
version: int = 1
user_id: str
is_active: bool = True
name: str
description: str
instructions: str | None = None
recommended_schedule_cron: str | None = None
forked_from_id: str | None = None
forked_from_version: int | None = None
input_schema: dict[str, Any] = {}
output_schema: dict[str, Any] = {}
credentials_input_schema: dict[str, Any] = {}
has_external_trigger: bool = False
has_human_in_the_loop: bool = False
has_sensitive_action: bool = False
trigger_setup_info: Optional["GraphTriggerInfo"] = None
@staticmethod
def from_graph(graph: GraphModel) -> "GraphMeta":
return GraphMeta(**graph.model_dump())
def from_graph(graph: "GraphModel") -> "GraphMeta":
return GraphMeta(
id=graph.id,
version=graph.version,
user_id=graph.user_id,
is_active=graph.is_active,
name=graph.name,
description=graph.description,
instructions=graph.instructions,
recommended_schedule_cron=graph.recommended_schedule_cron,
forked_from_id=graph.forked_from_id,
forked_from_version=graph.forked_from_version,
# Pass pre-computed values for expensive fields
input_schema=graph.input_schema,
output_schema=graph.output_schema,
has_external_trigger=graph.has_external_trigger,
has_human_in_the_loop=graph.has_human_in_the_loop,
has_sensitive_action=graph.has_sensitive_action,
trigger_setup_info=graph.trigger_setup_info,
credentials_input_schema=graph.credentials_input_schema,
)
class GraphsPaginated(BaseModel):
@@ -920,9 +963,9 @@ async def list_graphs_paginated(
graph_models: list[GraphMeta] = []
for graph in graphs:
try:
# GraphMeta.from_graph() accesses all computed fields on the GraphModel,
# which validates that the graph is well formed (e.g. no unknown block_ids).
graph_meta = GraphModel.from_db(graph).meta()
# Trigger serialization to validate that the graph is well formed
graph_meta.model_dump()
graph_models.append(graph_meta)
except Exception as e:
logger.error(f"Error processing graph {graph.id}: {e}")