From 2f68a1a76cdd8367f75b94749ba4dcafe4c6360d Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 9 Aug 2023 09:21:29 -0400 Subject: [PATCH] use Stalker's simplified LoRA vector-length detection code --- invokeai/backend/model_management/util.py | 73 ++++------------------- 1 file changed, 12 insertions(+), 61 deletions(-) diff --git a/invokeai/backend/model_management/util.py b/invokeai/backend/model_management/util.py index ece9c96d4c..f435ab79b6 100644 --- a/invokeai/backend/model_management/util.py +++ b/invokeai/backend/model_management/util.py @@ -9,15 +9,11 @@ def lora_token_vector_length(checkpoint: dict) -> int: :param checkpoint: The checkpoint """ - def _handle_unet_key(key, tensor, checkpoint): + def _get_shape_1(key, tensor, checkpoint): lora_token_vector_length = None - if "_attn2_to_k." not in key and "_attn2_to_v." not in key: - return lora_token_vector_length # check lora/locon - if ".lora_up.weight" in key: - lora_token_vector_length = tensor.shape[0] - elif ".lora_down.weight" in key: + if ".lora_down.weight" in key: lora_token_vector_length = tensor.shape[1] # check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes) @@ -49,65 +45,20 @@ def lora_token_vector_length(checkpoint: dict) -> int: return lora_token_vector_length - def _handle_te_key(key, tensor, checkpoint): - lora_token_vector_length = None - if "text_model_encoder_layers_" not in key: - return lora_token_vector_length - - # skip detect by mlp - if "_self_attn_" not in key: - return lora_token_vector_length - - # check lora/locon - if ".lora_up.weight" in key: - lora_token_vector_length = tensor.shape[0] - elif ".lora_down.weight" in key: - lora_token_vector_length = tensor.shape[1] - - # check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes) - elif ".hada_w1_a" in key or ".hada_w2_a" in key: - lora_token_vector_length = tensor.shape[0] - elif ".hada_w1_b" in key or ".hada_w2_b" in key: - lora_token_vector_length = tensor.shape[1] - - # check lokr (don't worry about lokr_t2 as it used only in 4d shapes) - elif ".lokr_" in key: - _lokr_key = key.split(".")[0] - - if _lokr_key + ".lokr_w1" in checkpoint: - _lokr_w1 = checkpoint[_lokr_key + ".lokr_w1"] - elif _lokr_key + "lokr_w1_b" in checkpoint: - _lokr_w1 = checkpoint[_lokr_key + ".lokr_w1_b"] - else: - return lora_token_vector_length # unknown format - - if _lokr_key + ".lokr_w2" in checkpoint: - _lokr_w2 = checkpoint[_lokr_key + ".lokr_w2"] - elif _lokr_key + "lokr_w2_b" in checkpoint: - _lokr_w2 = checkpoint[_lokr_key + ".lokr_w2_b"] - else: - return lora_token_vector_length # unknown format - - lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1] - - elif ".diff" in key: - lora_token_vector_length = tensor.shape[1] - - return lora_token_vector_length - lora_token_vector_length = None lora_te1_length = None lora_te2_length = None for key, tensor in checkpoint.items(): - if key.startswith("lora_unet_"): - lora_token_vector_length = _handle_unet_key(key, tensor, checkpoint) - elif key.startswith("lora_te_"): - lora_token_vector_length = _handle_te_key(key, tensor, checkpoint) - - elif key.startswith("lora_te1_"): - lora_te1_length = _handle_te_key(key, tensor, checkpoint) - elif key.startswith("lora_te2_"): - lora_te2_length = _handle_te_key(key, tensor, checkpoint) + if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key): + lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) + elif key.startswith("lora_te") and "_self_attn_" in key: + tmp_length = _get_shape_1(key, tensor, checkpoint) + if key.startswith("lora_te_"): + lora_token_vector_length = tmp_length + elif key.startswith("lora_te1_"): + lora_te1_length = tmp_length + elif key.startswith("lora_te2_"): + lora_te2_length = tmp_length if lora_te1_length is not None and lora_te2_length is not None: lora_token_vector_length = lora_te1_length + lora_te2_length