diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index ee6cd2e4b0..e52d04b4fe 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -366,39 +366,8 @@ class Graph(BaseGraph): @computed_field @property def credentials_input_schema(self) -> dict[str, Any]: - schema = self._credentials_input_schema.jsonschema() - - # Determine which credential fields are required based on credentials_optional metadata graph_credentials_inputs = self.aggregate_credentials_inputs() - required_fields = [] - # Build a map of node_id -> node for quick lookup - all_nodes = {node.id: node for node in self.nodes} - for sub_graph in self.sub_graphs: - for node in sub_graph.nodes: - all_nodes[node.id] = node - - for field_key, ( - _field_info, - node_field_pairs, - ) in graph_credentials_inputs.items(): - # A field is required if ANY node using it has credentials_optional=False - is_required = False - for node_id, _field_name in node_field_pairs: - node = all_nodes.get(node_id) - if node and not node.credentials_optional: - is_required = True - break - - if is_required: - required_fields.append(field_key) - - schema["required"] = required_fields - return schema - - @property - def _credentials_input_schema(self) -> type[BlockSchema]: - graph_credentials_inputs = self.aggregate_credentials_inputs() logger.debug( f"Combined credentials input fields for graph #{self.id} ({self.name}): " f"{graph_credentials_inputs}" @@ -406,8 +375,8 @@ class Graph(BaseGraph): # Warn if same-provider credentials inputs can't be combined (= bad UX) graph_cred_fields = list(graph_credentials_inputs.values()) - for i, (field, keys) in enumerate(graph_cred_fields): - for other_field, other_keys in list(graph_cred_fields)[i + 1 :]: + for i, (field, keys, _) in enumerate(graph_cred_fields): + for other_field, other_keys, _ in list(graph_cred_fields)[i + 1 :]: if field.provider != other_field.provider: continue if ProviderName.HTTP in field.provider: @@ -423,12 +392,23 @@ class Graph(BaseGraph): f"keys: {keys} <> {other_keys}." ) + # Build the Pydantic model for the credentials input schema fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = { agg_field_key: ( - CredentialsMetaInput[ - Literal[tuple(field_info.provider)], # type: ignore - Literal[tuple(field_info.supported_types)], # type: ignore - ], + ( + CMI + if ( + ( + CMI := CredentialsMetaInput[ + Literal[tuple(field_info.provider)], # type: ignore + Literal[tuple(field_info.supported_types)], # type: ignore + ] + ) + or True + ) + and is_required + else CMI | None + ), CredentialsField( required_scopes=set(field_info.required_scopes or []), discriminator=field_info.discriminator, @@ -436,18 +416,22 @@ class Graph(BaseGraph): discriminator_values=field_info.discriminator_values, ), ) - for agg_field_key, (field_info, _) in graph_credentials_inputs.items() + for agg_field_key, ( + field_info, + _, + is_required, + ) in graph_credentials_inputs.items() } return create_model( self.name.replace(" ", "") + "CredentialsInputSchema", __base__=BlockSchema, **fields, # type: ignore - ) + ).jsonschema() def aggregate_credentials_inputs( self, - ) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]: + ) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]: """ Returns: dict[aggregated_field_key, tuple( @@ -455,13 +439,19 @@ class Graph(BaseGraph): (now includes discriminator_values from matching nodes) set[(node_id, field_name)]: Node credentials fields that are compatible with this aggregated field spec + bool: True if the field is required (any node has credentials_optional=False) )] """ # First collect all credential field data with input defaults - node_credential_data = [] + # Track (field_info, (node_id, field_name), is_required) for each credential field + node_credential_data: list[tuple[CredentialsFieldInfo, tuple[str, str]]] = [] + node_required_map: dict[str, bool] = {} # node_id -> is_required for graph in [self] + self.sub_graphs: for node in graph.nodes: + # Track if this node requires credentials (credentials_optional=False means required) + node_required_map[node.id] = not node.credentials_optional + for ( field_name, field_info, @@ -485,7 +475,21 @@ class Graph(BaseGraph): ) # Combine credential field info (this will merge discriminator_values automatically) - return CredentialsFieldInfo.combine(*node_credential_data) + combined = CredentialsFieldInfo.combine(*node_credential_data) + + # Add is_required flag to each aggregated field + # A field is required if ANY node using it has credentials_optional=False + return { + key: ( + field_info, + node_field_pairs, + any( + node_required_map.get(node_id, True) + for node_id, _ in node_field_pairs + ), + ) + for key, (field_info, node_field_pairs) in combined.items() + } class GraphModel(Graph): @@ -832,16 +836,55 @@ class GraphModel(Graph): ) -class GraphMeta(Graph): - user_id: str +class GraphMeta(BaseModel): + """ + Graph metadata without nodes/links — used for list endpoints. - # Easy work-around to prevent exposing nodes and links in the API response - nodes: list[NodeModel] = Field(default=[], exclude=True) # type: ignore - links: list[Link] = Field(default=[], exclude=True) + This is a flat, lightweight model (not inheriting from Graph) to avoid recomputing + expensive computed fields. Values are copied from GraphModel. + """ + + id: str + version: int = 1 + user_id: str + is_active: bool = True + name: str + description: str + instructions: str | None = None + recommended_schedule_cron: str | None = None + forked_from_id: str | None = None + forked_from_version: int | None = None + + input_schema: dict[str, Any] = {} + output_schema: dict[str, Any] = {} + credentials_input_schema: dict[str, Any] = {} + has_external_trigger: bool = False + has_human_in_the_loop: bool = False + has_sensitive_action: bool = False + trigger_setup_info: Optional["GraphTriggerInfo"] = None @staticmethod - def from_graph(graph: GraphModel) -> "GraphMeta": - return GraphMeta(**graph.model_dump()) + def from_graph(graph: "GraphModel") -> "GraphMeta": + return GraphMeta( + id=graph.id, + version=graph.version, + user_id=graph.user_id, + is_active=graph.is_active, + name=graph.name, + description=graph.description, + instructions=graph.instructions, + recommended_schedule_cron=graph.recommended_schedule_cron, + forked_from_id=graph.forked_from_id, + forked_from_version=graph.forked_from_version, + # Pass pre-computed values for expensive fields + input_schema=graph.input_schema, + output_schema=graph.output_schema, + has_external_trigger=graph.has_external_trigger, + has_human_in_the_loop=graph.has_human_in_the_loop, + has_sensitive_action=graph.has_sensitive_action, + trigger_setup_info=graph.trigger_setup_info, + credentials_input_schema=graph.credentials_input_schema, + ) class GraphsPaginated(BaseModel): @@ -920,9 +963,9 @@ async def list_graphs_paginated( graph_models: list[GraphMeta] = [] for graph in graphs: try: + # GraphMeta.from_graph() accesses all computed fields on the GraphModel, + # which validates that the graph is well formed (e.g. no unknown block_ids). graph_meta = GraphModel.from_db(graph).meta() - # Trigger serialization to validate that the graph is well formed - graph_meta.model_dump() graph_models.append(graph_meta) except Exception as e: logger.error(f"Error processing graph {graph.id}: {e}")