Compare commits

...

2 Commits

Author SHA1 Message Date
Sergey Borisov
2634f0e43a First draft 2023-08-01 20:18:22 +03:00
Sergey Borisov
704151e8e3 Provide ti name from model manager, not from ti itself 2023-08-01 18:04:10 +03:00
4 changed files with 103 additions and 44 deletions

View File

@@ -108,14 +108,15 @@ class CompelInvocation(BaseInvocation):
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1] name = trigger[1:-1]
try: try:
ti_list.append( ti_list.append((
name,
context.services.model_manager.get_model( context.services.model_manager.get_model(
model_name=name, model_name=name,
base_model=self.clip.text_encoder.base_model, base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
context=context, context=context,
).context.model ).context.model
) ))
except ModelNotFoundException: except ModelNotFoundException:
# print(e) # print(e)
# import traceback # import traceback
@@ -196,14 +197,15 @@ class SDXLPromptInvocationBase:
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
name = trigger[1:-1] name = trigger[1:-1]
try: try:
ti_list.append( ti_list.append((
name,
context.services.model_manager.get_model( context.services.model_manager.get_model(
model_name=name, model_name=name,
base_model=clip_field.text_encoder.base_model, base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
context=context, context=context,
).context.model ).context.model
) ))
except ModelNotFoundException: except ModelNotFoundException:
# print(e) # print(e)
# import traceback # import traceback
@@ -270,14 +272,15 @@ class SDXLPromptInvocationBase:
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
name = trigger[1:-1] name = trigger[1:-1]
try: try:
ti_list.append( ti_list.append((
name,
context.services.model_manager.get_model( context.services.model_manager.get_model(
model_name=name, model_name=name,
base_model=clip_field.text_encoder.base_model, base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
context=context, context=context,
).context.model ).context.model
) ))
except ModelNotFoundException: except ModelNotFoundException:
# print(e) # print(e)
# import traceback # import traceback

View File

@@ -65,7 +65,6 @@ class ONNXPromptInvocation(BaseInvocation):
**self.clip.text_encoder.dict(), **self.clip.text_encoder.dict(),
) )
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack: with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
loras = [ loras = [
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) (context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
for lora in self.clip.loras for lora in self.clip.loras
@@ -75,20 +74,14 @@ class ONNXPromptInvocation(BaseInvocation):
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1] name = trigger[1:-1]
try: try:
ti_list.append( ti_list.append((
# stack.enter_context( name,
# context.services.model_manager.get_model(
# model_name=name,
# base_model=self.clip.text_encoder.base_model,
# model_type=ModelType.TextualInversion,
# )
# )
context.services.model_manager.get_model( context.services.model_manager.get_model(
model_name=name, model_name=name,
base_model=self.clip.text_encoder.base_model, base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
).context.model ).context.model
) ))
except Exception: except Exception:
# print(e) # print(e)
# import traceback # import traceback

View File

@@ -562,7 +562,7 @@ class ModelPatcher:
cls, cls,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
ti_list: List[Any], ti_list: List[Tuple[str, Any]],
) -> Tuple[CLIPTokenizer, TextualInversionManager]: ) -> Tuple[CLIPTokenizer, TextualInversionManager]:
init_tokens_count = None init_tokens_count = None
new_tokens_added = None new_tokens_added = None
@@ -572,27 +572,27 @@ class ModelPatcher:
ti_manager = TextualInversionManager(ti_tokenizer) ti_manager = TextualInversionManager(ti_tokenizer)
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
def _get_trigger(ti, index): def _get_trigger(ti_name, index):
trigger = ti.name trigger = ti_name
if index > 0: if index > 0:
trigger += f"-!pad-{i}" trigger += f"-!pad-{i}"
return f"<{trigger}>" return f"<{trigger}>"
# modify tokenizer # modify tokenizer
new_tokens_added = 0 new_tokens_added = 0
for ti in ti_list: for ti_name, ti in ti_list:
for i in range(ti.embedding.shape[0]): for i in range(ti.embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i)) new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
# modify text_encoder # modify text_encoder
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added) text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
model_embeddings = text_encoder.get_input_embeddings() model_embeddings = text_encoder.get_input_embeddings()
for ti in ti_list: for ti_name, ti in ti_list:
ti_tokens = [] ti_tokens = []
for i in range(ti.embedding.shape[0]): for i in range(ti.embedding.shape[0]):
embedding = ti.embedding[i] embedding = ti.embedding[i]
trigger = _get_trigger(ti, i) trigger = _get_trigger(ti_name, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger) token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
if token_id == ti_tokenizer.unk_token_id: if token_id == ti_tokenizer.unk_token_id:
@@ -637,7 +637,6 @@ class ModelPatcher:
class TextualInversionModel: class TextualInversionModel:
name: str
embedding: torch.Tensor # [n, 768]|[n, 1280] embedding: torch.Tensor # [n, 768]|[n, 1280]
@classmethod @classmethod
@@ -651,7 +650,6 @@ class TextualInversionModel:
file_path = Path(file_path) file_path = Path(file_path)
result = cls() # TODO: result = cls() # TODO:
result.name = file_path.stem # TODO:
if file_path.suffix == ".safetensors": if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu") state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
@@ -828,7 +826,7 @@ class ONNXModelPatcher:
cls, cls,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
text_encoder: IAIOnnxRuntimeModel, text_encoder: IAIOnnxRuntimeModel,
ti_list: List[Any], ti_list: List[Tuple[str, Any]],
) -> Tuple[CLIPTokenizer, TextualInversionManager]: ) -> Tuple[CLIPTokenizer, TextualInversionManager]:
from .models.base import IAIOnnxRuntimeModel from .models.base import IAIOnnxRuntimeModel
@@ -841,17 +839,17 @@ class ONNXModelPatcher:
ti_tokenizer = copy.deepcopy(tokenizer) ti_tokenizer = copy.deepcopy(tokenizer)
ti_manager = TextualInversionManager(ti_tokenizer) ti_manager = TextualInversionManager(ti_tokenizer)
def _get_trigger(ti, index): def _get_trigger(ti_name, index):
trigger = ti.name trigger = ti_name
if index > 0: if index > 0:
trigger += f"-!pad-{i}" trigger += f"-!pad-{i}"
return f"<{trigger}>" return f"<{trigger}>"
# modify tokenizer # modify tokenizer
new_tokens_added = 0 new_tokens_added = 0
for ti in ti_list: for ti_name, ti in ti_list:
for i in range(ti.embedding.shape[0]): for i in range(ti.embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i)) new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
# modify text_encoder # modify text_encoder
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"] orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
@@ -861,11 +859,11 @@ class ONNXModelPatcher:
axis=0, axis=0,
) )
for ti in ti_list: for ti_name, ti in ti_list:
ti_tokens = [] ti_tokens = []
for i in range(ti.embedding.shape[0]): for i in range(ti.embedding.shape[0]):
embedding = ti.embedding[i].detach().numpy() embedding = ti.embedding[i].detach().numpy()
trigger = _get_trigger(ti, i) trigger = _get_trigger(ti_name, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger) token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
if token_id == ti_tokenizer.unk_token_id: if token_id == ti_tokenizer.unk_token_id:

View File

@@ -210,6 +210,31 @@ class ModelCache(object):
return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size) return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
def clear_one_model(self) -> bool:
reserved = self.max_vram_cache_size * GIG
vram_in_use = torch.cuda.memory_allocated()
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
smallest_key = None
smallest_size = float("inf")
for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if not cache_entry.locked and cache_entry.loaded:
if cache_entry.size > 0 and cache_entry.size < smallest_size:
smallest_key = model_key
smallest_size = cache_entry.size
if smallest_key is not None:
cache_entry = self._cached_models[smallest_key]
self.logger.debug(f"!!!!!!!!!!!Offloading {smallest_key} from {self.execution_device} into {self.storage_device}")
with VRAMUsage() as mem:
cache_entry.model.to(self.storage_device)
self.logger.debug(f"GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB")
vram_in_use += mem.vram_used # note vram_used is negative
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
torch.cuda.empty_cache()
gc.collect()
return smallest_key is not None
class ModelLocker(object): class ModelLocker(object):
def __init__(self, cache, key, model, gpu_load, size_needed): def __init__(self, cache, key, model, gpu_load, size_needed):
""" """
@@ -236,17 +261,48 @@ class ModelCache(object):
self.cache_entry.lock() self.cache_entry.lock()
try: try:
if self.cache.lazy_offloading: self.cache.logger.debug(f"Moving {self.key} into {self.cache.execution_device}")
self.cache._offload_unlocked_models(self.size_needed) while True:
try:
with VRAMUsage() as mem:
self.model.to(self.cache.execution_device) # move into GPU
self.cache.logger.debug(f"GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB")
if self.model.device != self.cache.execution_device: self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
self.cache.logger.debug(f"Moving {self.key} into {self.cache.execution_device}") self.cache._print_cuda_stats()
with VRAMUsage() as mem:
self.model.to(self.cache.execution_device) # move into GPU
self.cache.logger.debug(f"GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB")
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}") def my_forward(module, cache, *args, **kwargs):
self.cache._print_cuda_stats() while True:
try:
return module._orig_forward(*args, **kwargs)
except:
if not cache.clear_one_model():
raise
import functools
from diffusers.models.unet_2d_blocks import DownBlock2D, CrossAttnDownBlock2D, UpBlock2D, CrossAttnUpBlock2D
from transformers.models.clip.modeling_clip import CLIPEncoderLayer
from diffusers.models.unet_2d_blocks import DownEncoderBlock2D, UpDecoderBlock2D
for module_name, module in self.model.named_modules():
if type(module) not in [
DownBlock2D, CrossAttnDownBlock2D, UpBlock2D, CrossAttnUpBlock2D, # unet blocks
CLIPEncoderLayer, # CLIPTextTransformer clip
DownEncoderBlock2D, UpDecoderBlock2D, # vae
]:
continue
# better here filter to only specific model modules
module._orig_forward = module.forward
module.forward = functools.partial(my_forward, module, self.cache)
self.model._orig_forward = self.model.forward
self.model.forward = functools.partial(my_forward, self.model, self.cache)
break
except:
if not self.cache.clear_one_model():
raise
except: except:
self.cache_entry.unlock() self.cache_entry.unlock()
@@ -264,10 +320,19 @@ class ModelCache(object):
if not hasattr(self.model, "to"): if not hasattr(self.model, "to"):
return return
if hasattr(self.model, "_orig_forward"):
self.model.forward = self.model._orig_forward
delattr(self.model, "_orig_forward")
for module_name, module in self.model.named_modules():
if hasattr(module, "_orig_forward"):
module.forward = module._orig_forward
delattr(module, "_orig_forward")
self.cache_entry.unlock() self.cache_entry.unlock()
if not self.cache.lazy_offloading: #if not self.cache.lazy_offloading:
self.cache._offload_unlocked_models() # self.cache._offload_unlocked_models()
self.cache._print_cuda_stats() # self.cache._print_cuda_stats()
# TODO: should it be called untrack_model? # TODO: should it be called untrack_model?
def uncache_model(self, cache_id: str): def uncache_model(self, cache_id: str):