From d0d91eaeec1ea7480d5c39cd5ae0bff44607eb0d Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 4 Sep 2024 14:35:38 +0000 Subject: [PATCH] Fix type errors in sdxl_lora_conversion_utils.py --- .../lora/conversions/sdxl_lora_conversion_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/invokeai/backend/lora/conversions/sdxl_lora_conversion_utils.py b/invokeai/backend/lora/conversions/sdxl_lora_conversion_utils.py index eb31b34777..e3780a7e8a 100644 --- a/invokeai/backend/lora/conversions/sdxl_lora_conversion_utils.py +++ b/invokeai/backend/lora/conversions/sdxl_lora_conversion_utils.py @@ -1,10 +1,10 @@ import bisect -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, TypeVar -import torch +T = TypeVar("T") -def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: +def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, T]) -> dict[str, T]: """Convert the keys of an SDXL LoRA state_dict to diffusers format. The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in @@ -31,7 +31,7 @@ def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, torch.Tensor]) - stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP) stability_unet_keys.sort() - new_state_dict = {} + new_state_dict: dict[str, T] = {} for full_key, value in state_dict.items(): if full_key.startswith("lora_unet_"): search_key = full_key.replace("lora_unet_", "") @@ -66,7 +66,7 @@ def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, torch.Tensor]) - # https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32 def _make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]: """Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format.""" - unet_conversion_map_layer = [] + unet_conversion_map_layer: list[tuple[str, str]] = [] for i in range(3): # num_blocks is 3 in sdxl # loop over downblocks/upblocks @@ -124,7 +124,7 @@ def _make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]: ("skip_connection.", "conv_shortcut."), ] - unet_conversion_map = [] + unet_conversion_map: list[tuple[str, str]] = [] for sd, hf in unet_conversion_map_layer: if "resnets" in hf: for sd_res, hf_res in unet_conversion_map_resnet: