mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
fix extracting loras from legacy blends (#3161)
This commit is contained in:
@@ -12,7 +12,8 @@ from typing import Union, Optional, Any
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser
|
||||
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser, \
|
||||
Conjunction
|
||||
from .devices import torch_dtype
|
||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ldm.invoke.globals import Globals
|
||||
@@ -55,22 +56,25 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
|
||||
# get rid of any newline characters
|
||||
prompt_string = prompt_string.replace("\n", " ")
|
||||
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||
|
||||
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
|
||||
positive_prompt: FlattenedPrompt|Blend
|
||||
lora_conditions = None
|
||||
positive_conjunction: Conjunction
|
||||
if legacy_blend is not None:
|
||||
positive_prompt = legacy_blend
|
||||
positive_conjunction = legacy_blend
|
||||
else:
|
||||
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
should_use_lora_manager = True
|
||||
lora_weights = positive_conjunction.lora_weights
|
||||
if model.peft_manager:
|
||||
should_use_lora_manager = model.peft_manager.should_use(lora_weights)
|
||||
if not should_use_lora_manager:
|
||||
model.peft_manager.set_loras(lora_weights)
|
||||
if model.lora_manager and should_use_lora_manager:
|
||||
lora_conditions = model.lora_manager.set_loras_conditions(lora_weights)
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
|
||||
should_use_lora_manager = True
|
||||
lora_weights = positive_conjunction.lora_weights
|
||||
lora_conditions = None
|
||||
if model.peft_manager:
|
||||
should_use_lora_manager = model.peft_manager.should_use(lora_weights)
|
||||
if not should_use_lora_manager:
|
||||
model.peft_manager.set_loras(lora_weights)
|
||||
if model.lora_manager and should_use_lora_manager:
|
||||
lora_conditions = model.lora_manager.set_loras_conditions(lora_weights)
|
||||
|
||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
|
||||
|
||||
@@ -93,12 +97,12 @@ def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = Fals
|
||||
Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
|
||||
positive_prompt: FlattenedPrompt|Blend
|
||||
positive_conjunction: Conjunction
|
||||
if legacy_blend is not None:
|
||||
positive_prompt = legacy_blend
|
||||
positive_conjunction = legacy_blend
|
||||
else:
|
||||
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||
negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
|
||||
|
||||
@@ -217,18 +221,26 @@ def log_tokenization_for_text(text, tokenizer, display_label=None):
|
||||
print(f'{discarded}\x1b[0m')
|
||||
|
||||
|
||||
def try_parse_legacy_blend(text: str, skip_normalize: bool=False) -> Optional[Blend]:
|
||||
def try_parse_legacy_blend(text: str, skip_normalize: bool=False) -> Optional[Conjunction]:
|
||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
||||
if len(weighted_subprompts) <= 1:
|
||||
return None
|
||||
strings = [x[0] for x in weighted_subprompts]
|
||||
weights = [x[1] for x in weighted_subprompts]
|
||||
|
||||
pp = PromptParser()
|
||||
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
|
||||
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
|
||||
flattened_prompts = []
|
||||
weights = []
|
||||
loras = []
|
||||
for i, x in enumerate(parsed_conjunctions):
|
||||
if len(x.prompts)>0:
|
||||
flattened_prompts.append(x.prompts[0])
|
||||
weights.append(weighted_subprompts[i][1])
|
||||
if len(x.lora_weights)>0:
|
||||
loras.extend(x.lora_weights)
|
||||
|
||||
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)
|
||||
return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)],
|
||||
lora_weights = loras)
|
||||
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||
|
||||
Reference in New Issue
Block a user