diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index a1d8bbed0a..04f9949560 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -5,10 +5,10 @@ from __future__ import annotations import pickle from contextlib import contextmanager -from typing import Any, Iterator, List, Optional, Tuple, Type, Union +from typing import Any, Generator, Iterator, List, Optional, Tuple, Type, Union import torch -from diffusers import UNet2DConditionModel +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from invokeai.app.shared.models import FreeUConfig @@ -146,7 +146,7 @@ class ModelPatcher: cls, text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection], clip_skip: int, - ) -> None: + ) -> Generator[None, Any, Any]: skipped_layers = [] try: for _i in range(clip_skip): @@ -164,7 +164,7 @@ class ModelPatcher: cls, unet: UNet2DConditionModel, freeu_config: Optional[FreeUConfig] = None, - ) -> None: + ) -> Generator[None, Any, Any]: did_apply_freeu = False try: assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute?