From 25f8ab24aaa7aa59f74d340e9742f19e0aae7626 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 8 Oct 2025 17:19:03 +1100 Subject: [PATCH] tests: fix test for breaking pydantic v2.12 change Fixes a test failure introduced by https://github.com/pydantic/pydantic/pull/11957 TL;DR: "after" model validators should be instance methods, not class methods. Batch model updated to use an instance method, which fixes the failing test. --- .../session_queue/session_queue_common.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index ac6993ffb1..e912753f42 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -1,7 +1,7 @@ import datetime import json from itertools import chain, product -from typing import Generator, Literal, Optional, TypeAlias, Union, cast +from typing import Generator, Literal, Optional, TypeAlias, Union from pydantic import ( AliasChoices, @@ -15,7 +15,6 @@ from pydantic import ( ) from pydantic_core import to_jsonable_python -from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.fields import ImageField from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError from invokeai.app.services.workflow_records.workflow_records_common import ( @@ -137,20 +136,18 @@ class Batch(BaseModel): return v @model_validator(mode="after") - def validate_batch_nodes_and_edges(cls, values): - batch_data_collection = cast(Optional[BatchDataCollection], values.data) - if batch_data_collection is None: - return values - graph = cast(Graph, values.graph) - for batch_data_list in batch_data_collection: + def validate_batch_nodes_and_edges(self): + if self.data is None: + return self + for batch_data_list in self.data: for batch_data in batch_data_list: try: - node = cast(BaseInvocation, graph.get_node(batch_data.node_path)) + node = self.graph.get_node(batch_data.node_path) except NodeNotFoundError: raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph") if batch_data.field_name not in type(node).model_fields: raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}") - return values + return self @field_validator("graph") def validate_graph(cls, v: Graph):