Fix type errors in sdxl_lora_conversion_utils.py

This commit is contained in:
Ryan Dick
2024-09-04 14:35:38 +00:00
committed by Kent Keirsey
parent fc380f077f
commit d0d91eaeec

View File

@@ -1,10 +1,10 @@
import bisect
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, TypeVar
import torch
T = TypeVar("T")
def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, T]) -> dict[str, T]:
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
@@ -31,7 +31,7 @@ def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, torch.Tensor]) -
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict = {}
new_state_dict: dict[str, T] = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
@@ -66,7 +66,7 @@ def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, torch.Tensor]) -
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def _make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
unet_conversion_map_layer = []
unet_conversion_map_layer: list[tuple[str, str]] = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
@@ -124,7 +124,7 @@ def _make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map = []
unet_conversion_map: list[tuple[str, str]] = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet: