mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-02 20:55:09 -05:00
Fix type errors in sdxl_lora_conversion_utils.py
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
import bisect
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Tuple, TypeVar
|
||||
|
||||
import torch
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, T]) -> dict[str, T]:
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
@@ -31,7 +31,7 @@ def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, torch.Tensor]) -
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
new_state_dict = {}
|
||||
new_state_dict: dict[str, T] = {}
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
@@ -66,7 +66,7 @@ def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, torch.Tensor]) -
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def _make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
|
||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||
unet_conversion_map_layer = []
|
||||
unet_conversion_map_layer: list[tuple[str, str]] = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
@@ -124,7 +124,7 @@ def _make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map = []
|
||||
unet_conversion_map: list[tuple[str, str]] = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
|
||||
Reference in New Issue
Block a user