Genera cleanup/documentation.

This commit is contained in:
Ryan Dick
2024-09-09 21:52:05 +00:00
parent 940269e60a
commit 8f3c09348d
4 changed files with 8 additions and 4 deletions

View File

@@ -202,6 +202,7 @@ def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str,
"""Groups the keys in the state dict by layer."""
layer_dict: dict[str, dict[str, torch.Tensor]] = {}
for key in state_dict:
# Split the 'lora_A.weight' or 'lora_B.weight' suffix from the layer name.
parts = key.rsplit(".", maxsplit=2)
layer_name = parts[0]
key_name = ".".join(parts[1:])

View File

@@ -23,10 +23,7 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
for k in state_dict.keys():
if not re.match(FLUX_KOHYA_KEY_REGEX, k):
return False
return True
return all(re.match(FLUX_KOHYA_KEY_REGEX, k) for k in state_dict.keys())
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:

View File

@@ -1,3 +1,6 @@
# A sample state dict in the Diffusers FLUX LoRA format.
# These keys are based on the LoRA model here:
# https://civitai.com/models/200255/hands-xl-sd-15-flux1-dev?modelVersionId=781855
state_dict_keys = [
"transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight",

View File

@@ -1,3 +1,6 @@
# A sample state dict in the Kohya FLUX LoRA format.
# These keys are based on the LoRA model here:
# https://civitai.com/models/159333/pokemon-trainer-sprite-pixelart?modelVersionId=779247
state_dict_keys = [
"lora_unet_double_blocks_0_img_attn_proj.alpha",
"lora_unet_double_blocks_0_img_attn_proj.lora_down.weight",