mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-18 01:11:20 -05:00
Apply black
This commit is contained in:
@@ -7,21 +7,30 @@ from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
def get_placeholder_loop(placeholder_string, embedder, use_bert):
|
||||
|
||||
new_placeholder = None
|
||||
def get_placeholder_loop(placeholder_string, embedder, use_bert):
|
||||
new_placeholder = None
|
||||
|
||||
while True:
|
||||
if new_placeholder is None:
|
||||
new_placeholder = input(f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: ")
|
||||
new_placeholder = input(
|
||||
f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: "
|
||||
)
|
||||
else:
|
||||
new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ")
|
||||
new_placeholder = input(
|
||||
f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: "
|
||||
)
|
||||
|
||||
token = get_bert_token_for_string(embedder.tknz_fn, new_placeholder) if use_bert else get_clip_token_for_string(embedder.tokenizer, new_placeholder)
|
||||
token = (
|
||||
get_bert_token_for_string(embedder.tknz_fn, new_placeholder)
|
||||
if use_bert
|
||||
else get_clip_token_for_string(embedder.tokenizer, new_placeholder)
|
||||
)
|
||||
|
||||
if token is not None:
|
||||
return new_placeholder, token
|
||||
|
||||
|
||||
def get_clip_token_for_string(tokenizer, string):
|
||||
batch_encoding = tokenizer(
|
||||
string,
|
||||
@@ -30,7 +39,7 @@ def get_clip_token_for_string(tokenizer, string):
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt"
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
tokens = batch_encoding["input_ids"]
|
||||
@@ -40,6 +49,7 @@ def get_clip_token_for_string(tokenizer, string):
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_bert_token_for_string(tokenizer, string):
|
||||
token = tokenizer(string)
|
||||
if torch.count_nonzero(token) == 3:
|
||||
@@ -49,22 +59,17 @@ def get_bert_token_for_string(tokenizer, string):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
type=str,
|
||||
default='.',
|
||||
help="Path to the InvokeAI install directory containing 'models', 'outputs' and 'configs'."
|
||||
default=".",
|
||||
help="Path to the InvokeAI install directory containing 'models', 'outputs' and 'configs'.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--manager_ckpts",
|
||||
type=str,
|
||||
nargs="+",
|
||||
required=True,
|
||||
help="Paths to a set of embedding managers to be merged."
|
||||
"--manager_ckpts", type=str, nargs="+", required=True, help="Paths to a set of embedding managers to be merged."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -75,13 +80,14 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-sd", "--use_bert",
|
||||
"-sd",
|
||||
"--use_bert",
|
||||
action="store_true",
|
||||
help="Flag to denote that we are not merging stable diffusion embeddings"
|
||||
help="Flag to denote that we are not merging stable diffusion embeddings",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
Globals.root=args.root_dir
|
||||
Globals.root = args.root_dir
|
||||
|
||||
if args.use_bert:
|
||||
embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda()
|
||||
|
||||
Reference in New Issue
Block a user