mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-04 05:15:00 -05:00
327 lines
14 KiB
Python
327 lines
14 KiB
Python
import pyparsing
|
|
import pyparsing as pp
|
|
from pyparsing import original_text_for
|
|
|
|
|
|
class Prompt():
|
|
|
|
def __init__(self, parts: list):
|
|
for c in parts:
|
|
if not issubclass(type(c), BaseFragment):
|
|
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c)}, only {BaseFragment.__subclasses__()} are allowed")
|
|
self.children = parts
|
|
def __repr__(self):
|
|
return f"Prompt:{self.children}"
|
|
def __eq__(self, other):
|
|
return type(other) is Prompt and other.children == self.children
|
|
|
|
class FlattenedPrompt():
|
|
def __init__(self, parts: list):
|
|
# verify type correctness
|
|
parts_converted = []
|
|
for part in parts:
|
|
if issubclass(type(part), BaseFragment):
|
|
parts_converted.append(part)
|
|
elif type(part) is tuple:
|
|
# upgrade tuples to Fragments
|
|
if type(part[0]) is not str or (type(part[1]) is not float and type(part[1]) is not int):
|
|
raise PromptParser.ParsingException(
|
|
f"FlattenedPrompt cannot contain {part}, only Fragments or (str, float) tuples are allowed")
|
|
parts_converted.append(Fragment(part[0], part[1]))
|
|
else:
|
|
raise PromptParser.ParsingException(
|
|
f"FlattenedPrompt cannot contain {part}, only Fragments or (str, float) tuples are allowed")
|
|
# all looks good
|
|
self.children = parts_converted
|
|
|
|
def __repr__(self):
|
|
return f"FlattenedPrompt:{self.children}"
|
|
def __eq__(self, other):
|
|
return type(other) is FlattenedPrompt and other.children == self.children
|
|
|
|
# abstract base class for Fragments
|
|
class BaseFragment:
|
|
pass
|
|
|
|
class Fragment(BaseFragment):
|
|
def __init__(self, text: str, weight: float=1):
|
|
assert(type(text) is str)
|
|
self.text = text
|
|
self.weight = float(weight)
|
|
|
|
def __repr__(self):
|
|
return "Fragment:'"+self.text+"'@"+str(self.weight)
|
|
def __eq__(self, other):
|
|
return type(other) is Fragment \
|
|
and other.text == self.text \
|
|
and other.weight == self.weight
|
|
|
|
class CrossAttentionControlledFragment(BaseFragment):
|
|
pass
|
|
|
|
class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
|
|
def __init__(self, original: Fragment, edited: Fragment):
|
|
self.original = original
|
|
self.edited = edited
|
|
|
|
def __repr__(self):
|
|
return f"CrossAttentionControlSubstitute:('{self.original}'->'{self.edited}')"
|
|
def __eq__(self, other):
|
|
return type(other) is CrossAttentionControlSubstitute \
|
|
and other.original == self.original \
|
|
and other.edited == self.edited
|
|
|
|
class CrossAttentionControlAppend(CrossAttentionControlledFragment):
|
|
def __init__(self, fragment: Fragment):
|
|
self.fragment = fragment
|
|
def __repr__(self):
|
|
return "CrossAttentionControlAppend:",self.fragment
|
|
def __eq__(self, other):
|
|
return type(other) is CrossAttentionControlAppend \
|
|
and other.fragment == self.fragment
|
|
|
|
|
|
|
|
class Conjunction():
|
|
def __init__(self, prompts: list, weights: list = None):
|
|
# force everything to be a Prompt
|
|
#print("making conjunction with", parts)
|
|
self.prompts = [x if (type(x) is Prompt
|
|
or type(x) is Blend
|
|
or type(x) is FlattenedPrompt)
|
|
else Prompt(x) for x in prompts]
|
|
self.weights = [1.0]*len(self.prompts) if weights is None else list(weights)
|
|
if len(self.weights) != len(self.prompts):
|
|
raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}")
|
|
self.type = 'AND'
|
|
|
|
def __repr__(self):
|
|
return f"Conjunction:{self.prompts} | weights {self.weights}"
|
|
def __eq__(self, other):
|
|
return type(other) is Conjunction \
|
|
and other.prompts == self.prompts \
|
|
and other.weights == self.weights
|
|
|
|
|
|
class Blend():
|
|
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
|
|
#print("making Blend with prompts", prompts, "and weights", weights)
|
|
if len(prompts) != len(weights):
|
|
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
|
|
for c in prompts:
|
|
if type(c) is not Prompt and type(c) is not FlattenedPrompt:
|
|
raise(PromptParser.ParsingException(f"{type(c)} cannot be added to a Blend, only Prompts or FlattenedPrompts"))
|
|
# upcast all lists to Prompt objects
|
|
self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt)
|
|
else Prompt(x) for x in prompts]
|
|
self.prompts = prompts
|
|
self.weights = weights
|
|
self.normalize_weights = normalize_weights
|
|
|
|
def __repr__(self):
|
|
return f"Blend:{self.prompts} | weights {self.weights}"
|
|
def __eq__(self, other):
|
|
return other.__repr__() == self.__repr__()
|
|
|
|
|
|
class PromptParser():
|
|
|
|
class ParsingException(Exception):
|
|
pass
|
|
|
|
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
|
|
|
|
self.attention_plus_base = attention_plus_base
|
|
self.attention_minus_base = attention_minus_base
|
|
|
|
self.root = self.build_parser_logic()
|
|
|
|
|
|
def parse(self, prompt: str) -> [list]:
|
|
'''
|
|
:param prompt: The prompt string to parse
|
|
:return: a tuple
|
|
'''
|
|
#print(f"!!parsing '{prompt}'")
|
|
|
|
if len(prompt.strip()) == 0:
|
|
return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0])
|
|
|
|
root = self.root.parse_string(prompt)
|
|
#print(f"'{prompt}' parsed to root", root)
|
|
#fused = fuse_fragments(parts)
|
|
#print("fused to", fused)
|
|
|
|
return self.flatten(root[0])
|
|
|
|
def flatten(self, root: Conjunction):
|
|
|
|
def fuse_fragments(items):
|
|
# print("fusing fragments in ", items)
|
|
result = []
|
|
for x in items:
|
|
if issubclass(type(x), CrossAttentionControlledFragment):
|
|
result.append(x)
|
|
else:
|
|
last_weight = result[-1].weight \
|
|
if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \
|
|
else None
|
|
this_text = x.text
|
|
this_weight = x.weight
|
|
if last_weight is not None and last_weight == this_weight:
|
|
last_text = result[-1].text
|
|
result[-1] = Fragment(last_text + ' ' + this_text, last_weight)
|
|
else:
|
|
result.append(x)
|
|
return result
|
|
|
|
def flatten_internal(node, weight_scale, results, prefix):
|
|
#print(prefix + "flattening", node, "...")
|
|
if type(node) is pp.ParseResults:
|
|
for x in node:
|
|
results = flatten_internal(x, weight_scale, results, prefix+'pr')
|
|
#print(prefix, " ParseResults expanded, results is now", results)
|
|
elif issubclass(type(node), BaseFragment):
|
|
results.append(node)
|
|
#elif type(node) is Attention:
|
|
# #if node.weight < 1:
|
|
# # todo: inject a blend when flattening attention with weight <1"
|
|
# for c in node.children:
|
|
# results = flatten_internal(c, weight_scale*node.weight, results, prefix+' ')
|
|
elif type(node) is Blend:
|
|
flattened_subprompts = []
|
|
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
|
|
for prompt in node.prompts:
|
|
# prompt is a list
|
|
flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ')
|
|
results += [Blend(prompts=flattened_subprompts, weights=node.weights)]
|
|
elif type(node) is Prompt:
|
|
#print(prefix + "about to flatten Prompt with children", node.children)
|
|
flattened_prompt = []
|
|
for child in node.children:
|
|
flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ')
|
|
results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))]
|
|
#print(prefix + "after flattening Prompt, results is", results)
|
|
else:
|
|
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
|
|
#print(prefix + "-> after flattening", type(node), "results is", results)
|
|
return results
|
|
|
|
#print("flattening", root)
|
|
|
|
flattened_parts = []
|
|
for part in root.prompts:
|
|
flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
|
|
weights = root.weights
|
|
return Conjunction(flattened_parts, weights)
|
|
|
|
|
|
|
|
def build_parser_logic(self):
|
|
|
|
lparen = pp.Literal("(").suppress()
|
|
rparen = pp.Literal(")").suppress()
|
|
# accepts int or float notation, always maps to float
|
|
number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float))
|
|
SPACE_CHARS = ' \t\n'
|
|
|
|
prompt_part = pp.Forward()
|
|
word = pp.Word(pp.printables).set_parse_action(lambda x: Fragment(' '.join([s for s in x])))
|
|
word.set_name("word")
|
|
word.set_debug(False)
|
|
|
|
def make_fragment(x):
|
|
#print("### making fragment for", x)
|
|
if type(x) is str:
|
|
return Fragment(x)
|
|
elif type(x) is pp.ParseResults or type(x) is list:
|
|
return Fragment(' '.join([s for s in x]))
|
|
else:
|
|
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
|
|
|
|
|
|
original_words = (
|
|
(lparen + pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) |
|
|
(pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('term2').set_debug(False) |
|
|
(lparen + pp.CharsNotIn(')') + rparen).set_name('term3').set_debug(False)
|
|
).set_name('original_words')
|
|
edited_words = (
|
|
(pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('termA').set_debug(False) |
|
|
pp.CharsNotIn(')').set_name('termB').set_debug(False)
|
|
).set_name('edited_words')
|
|
cross_attention_substitute = original_words + \
|
|
pp.Literal(".swap").suppress() + \
|
|
lparen + edited_words + rparen
|
|
cross_attention_substitute.set_name('cross_attention_substitute')
|
|
|
|
def make_cross_attention_substitute(x):
|
|
#print("making cacs for", x)
|
|
return CrossAttentionControlSubstitute(x[0], x[1])
|
|
#print("made", cacs)
|
|
#return cacs
|
|
|
|
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
|
|
|
|
# simple fragments of text
|
|
prompt_part << (cross_attention_substitute
|
|
#| attention
|
|
| word
|
|
)
|
|
prompt_part.set_debug(False)
|
|
prompt_part.set_name("prompt_part")
|
|
|
|
# root prompt definition
|
|
prompt = pp.Group(pp.OneOrMore(prompt_part))\
|
|
.set_parse_action(lambda x: Prompt(x[0]))
|
|
|
|
# weighted blend of prompts
|
|
# ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or
|
|
# int weights.
|
|
# can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c)
|
|
|
|
def make_prompt_from_quoted_string(x):
|
|
#print(' got quoted prompt', x)
|
|
|
|
x_unquoted = x[0][1:-1]
|
|
if len(x_unquoted.strip()) == 0:
|
|
# print(' b : just an empty string')
|
|
return Prompt([Fragment('')])
|
|
# print(' b parsing ', c_unquoted)
|
|
x_parsed = prompt.parse_string(x_unquoted)
|
|
#print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed)
|
|
return x_parsed[0]
|
|
|
|
quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string)
|
|
quoted_prompt.set_name('quoted_prompt')
|
|
|
|
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms')
|
|
blend_weights = pp.delimited_list(number).set_name('blend_weights')
|
|
blend = pp.Group(lparen + pp.Group(blend_terms) + rparen
|
|
+ pp.Literal(".blend").suppress()
|
|
+ lparen + pp.Group(blend_weights) + rparen).set_name('blend')
|
|
blend.set_debug(False)
|
|
|
|
|
|
blend.set_parse_action(lambda x: Blend(prompts=x[0][0], weights=x[0][1]))
|
|
|
|
conjunction_terms = blend_terms.copy().set_name('conjunction_terms')
|
|
conjunction_weights = blend_weights.copy().set_name('conjunction_weights')
|
|
conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen
|
|
+ pp.Literal(".and").suppress()
|
|
+ lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction')
|
|
def make_conjunction(x):
|
|
parts_raw = x[0][0]
|
|
weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw)
|
|
parts = [part for part in parts_raw]
|
|
return Conjunction(parts, weights)
|
|
conjunction_with_parens_and_quotes.set_parse_action(make_conjunction)
|
|
|
|
implicit_conjunction = pp.OneOrMore(blend | prompt)
|
|
implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
|
|
|
|
conjunction = conjunction_with_parens_and_quotes | implicit_conjunction
|
|
conjunction.set_debug(False)
|
|
|
|
# top-level is a conjunction of one or more blends or prompts
|
|
return conjunction
|