mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-17 08:07:59 -05:00
Compare commits
2 Commits
test/node-
...
experiment
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2634f0e43a | ||
|
|
704151e8e3 |
@@ -108,14 +108,15 @@ class CompelInvocation(BaseInvocation):
|
|||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append((
|
||||||
|
name,
|
||||||
context.services.model_manager.get_model(
|
context.services.model_manager.get_model(
|
||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=self.clip.text_encoder.base_model,
|
base_model=self.clip.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
context=context,
|
context=context,
|
||||||
).context.model
|
).context.model
|
||||||
)
|
))
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
# import traceback
|
# import traceback
|
||||||
@@ -196,14 +197,15 @@ class SDXLPromptInvocationBase:
|
|||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append((
|
||||||
|
name,
|
||||||
context.services.model_manager.get_model(
|
context.services.model_manager.get_model(
|
||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=clip_field.text_encoder.base_model,
|
base_model=clip_field.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
context=context,
|
context=context,
|
||||||
).context.model
|
).context.model
|
||||||
)
|
))
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
# import traceback
|
# import traceback
|
||||||
@@ -270,14 +272,15 @@ class SDXLPromptInvocationBase:
|
|||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append((
|
||||||
|
name,
|
||||||
context.services.model_manager.get_model(
|
context.services.model_manager.get_model(
|
||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=clip_field.text_encoder.base_model,
|
base_model=clip_field.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
context=context,
|
context=context,
|
||||||
).context.model
|
).context.model
|
||||||
)
|
))
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
# import traceback
|
# import traceback
|
||||||
|
|||||||
@@ -65,7 +65,6 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
**self.clip.text_encoder.dict(),
|
**self.clip.text_encoder.dict(),
|
||||||
)
|
)
|
||||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
||||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
|
||||||
loras = [
|
loras = [
|
||||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||||
for lora in self.clip.loras
|
for lora in self.clip.loras
|
||||||
@@ -75,20 +74,14 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append((
|
||||||
# stack.enter_context(
|
name,
|
||||||
# context.services.model_manager.get_model(
|
|
||||||
# model_name=name,
|
|
||||||
# base_model=self.clip.text_encoder.base_model,
|
|
||||||
# model_type=ModelType.TextualInversion,
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
context.services.model_manager.get_model(
|
context.services.model_manager.get_model(
|
||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=self.clip.text_encoder.base_model,
|
base_model=self.clip.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
).context.model
|
).context.model
|
||||||
)
|
))
|
||||||
except Exception:
|
except Exception:
|
||||||
# print(e)
|
# print(e)
|
||||||
# import traceback
|
# import traceback
|
||||||
|
|||||||
@@ -562,7 +562,7 @@ class ModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
ti_list: List[Any],
|
ti_list: List[Tuple[str, Any]],
|
||||||
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
||||||
init_tokens_count = None
|
init_tokens_count = None
|
||||||
new_tokens_added = None
|
new_tokens_added = None
|
||||||
@@ -572,27 +572,27 @@ class ModelPatcher:
|
|||||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||||
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
||||||
|
|
||||||
def _get_trigger(ti, index):
|
def _get_trigger(ti_name, index):
|
||||||
trigger = ti.name
|
trigger = ti_name
|
||||||
if index > 0:
|
if index > 0:
|
||||||
trigger += f"-!pad-{i}"
|
trigger += f"-!pad-{i}"
|
||||||
return f"<{trigger}>"
|
return f"<{trigger}>"
|
||||||
|
|
||||||
# modify tokenizer
|
# modify tokenizer
|
||||||
new_tokens_added = 0
|
new_tokens_added = 0
|
||||||
for ti in ti_list:
|
for ti_name, ti in ti_list:
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti.embedding.shape[0]):
|
||||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
||||||
|
|
||||||
# modify text_encoder
|
# modify text_encoder
|
||||||
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
|
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
|
||||||
model_embeddings = text_encoder.get_input_embeddings()
|
model_embeddings = text_encoder.get_input_embeddings()
|
||||||
|
|
||||||
for ti in ti_list:
|
for ti_name, ti in ti_list:
|
||||||
ti_tokens = []
|
ti_tokens = []
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti.embedding.shape[0]):
|
||||||
embedding = ti.embedding[i]
|
embedding = ti.embedding[i]
|
||||||
trigger = _get_trigger(ti, i)
|
trigger = _get_trigger(ti_name, i)
|
||||||
|
|
||||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||||
if token_id == ti_tokenizer.unk_token_id:
|
if token_id == ti_tokenizer.unk_token_id:
|
||||||
@@ -637,7 +637,6 @@ class ModelPatcher:
|
|||||||
|
|
||||||
|
|
||||||
class TextualInversionModel:
|
class TextualInversionModel:
|
||||||
name: str
|
|
||||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -651,7 +650,6 @@ class TextualInversionModel:
|
|||||||
file_path = Path(file_path)
|
file_path = Path(file_path)
|
||||||
|
|
||||||
result = cls() # TODO:
|
result = cls() # TODO:
|
||||||
result.name = file_path.stem # TODO:
|
|
||||||
|
|
||||||
if file_path.suffix == ".safetensors":
|
if file_path.suffix == ".safetensors":
|
||||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||||
@@ -828,7 +826,7 @@ class ONNXModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
text_encoder: IAIOnnxRuntimeModel,
|
text_encoder: IAIOnnxRuntimeModel,
|
||||||
ti_list: List[Any],
|
ti_list: List[Tuple[str, Any]],
|
||||||
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
||||||
from .models.base import IAIOnnxRuntimeModel
|
from .models.base import IAIOnnxRuntimeModel
|
||||||
|
|
||||||
@@ -841,17 +839,17 @@ class ONNXModelPatcher:
|
|||||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||||
|
|
||||||
def _get_trigger(ti, index):
|
def _get_trigger(ti_name, index):
|
||||||
trigger = ti.name
|
trigger = ti_name
|
||||||
if index > 0:
|
if index > 0:
|
||||||
trigger += f"-!pad-{i}"
|
trigger += f"-!pad-{i}"
|
||||||
return f"<{trigger}>"
|
return f"<{trigger}>"
|
||||||
|
|
||||||
# modify tokenizer
|
# modify tokenizer
|
||||||
new_tokens_added = 0
|
new_tokens_added = 0
|
||||||
for ti in ti_list:
|
for ti_name, ti in ti_list:
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti.embedding.shape[0]):
|
||||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
||||||
|
|
||||||
# modify text_encoder
|
# modify text_encoder
|
||||||
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
||||||
@@ -861,11 +859,11 @@ class ONNXModelPatcher:
|
|||||||
axis=0,
|
axis=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
for ti in ti_list:
|
for ti_name, ti in ti_list:
|
||||||
ti_tokens = []
|
ti_tokens = []
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti.embedding.shape[0]):
|
||||||
embedding = ti.embedding[i].detach().numpy()
|
embedding = ti.embedding[i].detach().numpy()
|
||||||
trigger = _get_trigger(ti, i)
|
trigger = _get_trigger(ti_name, i)
|
||||||
|
|
||||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||||
if token_id == ti_tokenizer.unk_token_id:
|
if token_id == ti_tokenizer.unk_token_id:
|
||||||
|
|||||||
@@ -210,6 +210,31 @@ class ModelCache(object):
|
|||||||
|
|
||||||
return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
|
return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
|
||||||
|
|
||||||
|
def clear_one_model(self) -> bool:
|
||||||
|
reserved = self.max_vram_cache_size * GIG
|
||||||
|
vram_in_use = torch.cuda.memory_allocated()
|
||||||
|
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
||||||
|
smallest_key = None
|
||||||
|
smallest_size = float("inf")
|
||||||
|
for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||||
|
if not cache_entry.locked and cache_entry.loaded:
|
||||||
|
if cache_entry.size > 0 and cache_entry.size < smallest_size:
|
||||||
|
smallest_key = model_key
|
||||||
|
smallest_size = cache_entry.size
|
||||||
|
|
||||||
|
if smallest_key is not None:
|
||||||
|
cache_entry = self._cached_models[smallest_key]
|
||||||
|
self.logger.debug(f"!!!!!!!!!!!Offloading {smallest_key} from {self.execution_device} into {self.storage_device}")
|
||||||
|
with VRAMUsage() as mem:
|
||||||
|
cache_entry.model.to(self.storage_device)
|
||||||
|
self.logger.debug(f"GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB")
|
||||||
|
vram_in_use += mem.vram_used # note vram_used is negative
|
||||||
|
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
return smallest_key is not None
|
||||||
|
|
||||||
class ModelLocker(object):
|
class ModelLocker(object):
|
||||||
def __init__(self, cache, key, model, gpu_load, size_needed):
|
def __init__(self, cache, key, model, gpu_load, size_needed):
|
||||||
"""
|
"""
|
||||||
@@ -236,17 +261,48 @@ class ModelCache(object):
|
|||||||
self.cache_entry.lock()
|
self.cache_entry.lock()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.cache.lazy_offloading:
|
self.cache.logger.debug(f"Moving {self.key} into {self.cache.execution_device}")
|
||||||
self.cache._offload_unlocked_models(self.size_needed)
|
while True:
|
||||||
|
try:
|
||||||
|
with VRAMUsage() as mem:
|
||||||
|
self.model.to(self.cache.execution_device) # move into GPU
|
||||||
|
self.cache.logger.debug(f"GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB")
|
||||||
|
|
||||||
if self.model.device != self.cache.execution_device:
|
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
|
||||||
self.cache.logger.debug(f"Moving {self.key} into {self.cache.execution_device}")
|
self.cache._print_cuda_stats()
|
||||||
with VRAMUsage() as mem:
|
|
||||||
self.model.to(self.cache.execution_device) # move into GPU
|
|
||||||
self.cache.logger.debug(f"GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB")
|
|
||||||
|
|
||||||
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
|
def my_forward(module, cache, *args, **kwargs):
|
||||||
self.cache._print_cuda_stats()
|
while True:
|
||||||
|
try:
|
||||||
|
return module._orig_forward(*args, **kwargs)
|
||||||
|
except:
|
||||||
|
if not cache.clear_one_model():
|
||||||
|
raise
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from diffusers.models.unet_2d_blocks import DownBlock2D, CrossAttnDownBlock2D, UpBlock2D, CrossAttnUpBlock2D
|
||||||
|
from transformers.models.clip.modeling_clip import CLIPEncoderLayer
|
||||||
|
from diffusers.models.unet_2d_blocks import DownEncoderBlock2D, UpDecoderBlock2D
|
||||||
|
|
||||||
|
for module_name, module in self.model.named_modules():
|
||||||
|
if type(module) not in [
|
||||||
|
DownBlock2D, CrossAttnDownBlock2D, UpBlock2D, CrossAttnUpBlock2D, # unet blocks
|
||||||
|
CLIPEncoderLayer, # CLIPTextTransformer clip
|
||||||
|
DownEncoderBlock2D, UpDecoderBlock2D, # vae
|
||||||
|
]:
|
||||||
|
continue
|
||||||
|
# better here filter to only specific model modules
|
||||||
|
module._orig_forward = module.forward
|
||||||
|
module.forward = functools.partial(my_forward, module, self.cache)
|
||||||
|
|
||||||
|
self.model._orig_forward = self.model.forward
|
||||||
|
self.model.forward = functools.partial(my_forward, self.model, self.cache)
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
except:
|
||||||
|
if not self.cache.clear_one_model():
|
||||||
|
raise
|
||||||
|
|
||||||
except:
|
except:
|
||||||
self.cache_entry.unlock()
|
self.cache_entry.unlock()
|
||||||
@@ -264,10 +320,19 @@ class ModelCache(object):
|
|||||||
if not hasattr(self.model, "to"):
|
if not hasattr(self.model, "to"):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if hasattr(self.model, "_orig_forward"):
|
||||||
|
self.model.forward = self.model._orig_forward
|
||||||
|
delattr(self.model, "_orig_forward")
|
||||||
|
|
||||||
|
for module_name, module in self.model.named_modules():
|
||||||
|
if hasattr(module, "_orig_forward"):
|
||||||
|
module.forward = module._orig_forward
|
||||||
|
delattr(module, "_orig_forward")
|
||||||
|
|
||||||
self.cache_entry.unlock()
|
self.cache_entry.unlock()
|
||||||
if not self.cache.lazy_offloading:
|
#if not self.cache.lazy_offloading:
|
||||||
self.cache._offload_unlocked_models()
|
# self.cache._offload_unlocked_models()
|
||||||
self.cache._print_cuda_stats()
|
# self.cache._print_cuda_stats()
|
||||||
|
|
||||||
# TODO: should it be called untrack_model?
|
# TODO: should it be called untrack_model?
|
||||||
def uncache_model(self, cache_id: str):
|
def uncache_model(self, cache_id: str):
|
||||||
|
|||||||
Reference in New Issue
Block a user