tidy(mm): patcher types and import paths

This commit is contained in:
psychedelicious
2025-09-23 18:26:38 +10:00
parent 8a6d5f4f6a
commit 4e2145c6c4

View File

@@ -5,10 +5,10 @@ from __future__ import annotations
import pickle
from contextlib import contextmanager
from typing import Any, Iterator, List, Optional, Tuple, Type, Union
from typing import Any, Generator, Iterator, List, Optional, Tuple, Type, Union
import torch
from diffusers import UNet2DConditionModel
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.shared.models import FreeUConfig
@@ -146,7 +146,7 @@ class ModelPatcher:
cls,
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
clip_skip: int,
) -> None:
) -> Generator[None, Any, Any]:
skipped_layers = []
try:
for _i in range(clip_skip):
@@ -164,7 +164,7 @@ class ModelPatcher:
cls,
unet: UNet2DConditionModel,
freeu_config: Optional[FreeUConfig] = None,
) -> None:
) -> Generator[None, Any, Any]:
did_apply_freeu = False
try:
assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute?