fix extracting loras from legacy blends (#3161)

This commit is contained in:
Lincoln Stein
2023-04-09 16:43:35 -04:00
committed by GitHub

View File

@@ -12,7 +12,8 @@ from typing import Union, Optional, Any
from transformers import CLIPTokenizer
from compel import Compel
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser, \
Conjunction
from .devices import torch_dtype
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.invoke.globals import Globals
@@ -55,22 +56,25 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
# get rid of any newline characters
prompt_string = prompt_string.replace("\n", " ")
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
positive_prompt: FlattenedPrompt|Blend
lora_conditions = None
positive_conjunction: Conjunction
if legacy_blend is not None:
positive_prompt = legacy_blend
positive_conjunction = legacy_blend
else:
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
positive_prompt = positive_conjunction.prompts[0]
should_use_lora_manager = True
lora_weights = positive_conjunction.lora_weights
if model.peft_manager:
should_use_lora_manager = model.peft_manager.should_use(lora_weights)
if not should_use_lora_manager:
model.peft_manager.set_loras(lora_weights)
if model.lora_manager and should_use_lora_manager:
lora_conditions = model.lora_manager.set_loras_conditions(lora_weights)
positive_prompt = positive_conjunction.prompts[0]
should_use_lora_manager = True
lora_weights = positive_conjunction.lora_weights
lora_conditions = None
if model.peft_manager:
should_use_lora_manager = model.peft_manager.should_use(lora_weights)
if not should_use_lora_manager:
model.peft_manager.set_loras(lora_weights)
if model.lora_manager and should_use_lora_manager:
lora_conditions = model.lora_manager.set_loras_conditions(lora_weights)
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
@@ -93,12 +97,12 @@ def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = Fals
Union[FlattenedPrompt, Blend], FlattenedPrompt):
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
positive_prompt: FlattenedPrompt|Blend
positive_conjunction: Conjunction
if legacy_blend is not None:
positive_prompt = legacy_blend
positive_conjunction = legacy_blend
else:
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
positive_prompt = positive_conjunction.prompts[0]
positive_prompt = positive_conjunction.prompts[0]
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
@@ -217,18 +221,26 @@ def log_tokenization_for_text(text, tokenizer, display_label=None):
print(f'{discarded}\x1b[0m')
def try_parse_legacy_blend(text: str, skip_normalize: bool=False) -> Optional[Blend]:
def try_parse_legacy_blend(text: str, skip_normalize: bool=False) -> Optional[Conjunction]:
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
if len(weighted_subprompts) <= 1:
return None
strings = [x[0] for x in weighted_subprompts]
weights = [x[1] for x in weighted_subprompts]
pp = PromptParser()
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
flattened_prompts = []
weights = []
loras = []
for i, x in enumerate(parsed_conjunctions):
if len(x.prompts)>0:
flattened_prompts.append(x.prompts[0])
weights.append(weighted_subprompts[i][1])
if len(x.lora_weights)>0:
loras.extend(x.lora_weights)
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)
return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)],
lora_weights = loras)
def split_weighted_subprompts(text, skip_normalize=False)->list: