adjusting back to hooks, forcing to be last in execution

This commit is contained in:
Jordan
2023-02-21 01:34:06 -07:00
parent 49c0516602
commit 5529309e73

View File

@@ -38,25 +38,19 @@ LORA_PREFIX_UNET = 'lora_unet'
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
def lora_forward(module, input_h, output):
if len(loaded_loras) == 0:
def lora_forward_hook(name):
def lora_forward(module, input_h, output):
if len(loaded_loras) == 0:
return output
for lora in applied_loras.values():
layer = lora.layers.get(name, None)
if layer is None:
continue
output = output + layer.up(layer.down(*input_h)) * lora.multiplier * layer.scale
return output
lora_name = getattr(module, 'lora_name', None)
for lora in applied_loras.values():
layer = lora.layers.get(lora_name, None)
if layer is None:
continue
output = output + layer.up(layer.down(*input_h)) * lora.multiplier * layer.scale
return output
def lora_linear_forward(self, input_h):
return lora_forward(self, input_h, torch.nn.Linear.forward_before_lora(self, input_h))
def lora_conv2d_forward(self, input_h):
return lora_forward(self, input_h, torch.nn.Conv2d.forward_before_lora(self, input_h))
return lora_forward
def load_lora(
@@ -141,6 +135,7 @@ def load_lora(
class LoraManager:
loras_to_load: dict[str, float]
hooks: list[RemovableHandle]
def __init__(self, pipe):
self.lora_path = Path(global_models_dir(), 'lora')
@@ -149,26 +144,19 @@ class LoraManager:
self.device = torch.device(choose_torch_device())
self.dtype = pipe.unet.dtype
self.loras_to_load = {}
if not hasattr(torch.nn.Linear, 'forward_before_lora'):
torch.nn.Linear.forward_before_lora = torch.nn.Linear.forward
if not hasattr(torch.nn.Conv2d, 'forward_before_lora'):
torch.nn.Conv2d.forward_before_lora = torch.nn.Conv2d.forward
torch.nn.Linear.forward = lora_linear_forward
torch.nn.Conv2d.forward = lora_conv2d_forward
self.hooks = []
def find_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> dict[str, torch.nn.Module]:
mapping = {}
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
layer_type = child_module.__class__.__name__
if layer_type == "Linear" or (layer_type == "Conv2d" and child_module.kernel_size == (1, 1)):
lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_')
mapping[lora_name] = child_module
module.lora_name = lora_name
self.apply_module_forward(child_module, lora_name)
return mapping
self.text_modules = find_modules(
@@ -182,6 +170,12 @@ class LoraManager:
loaded_loras[name] = lora
return lora
def apply_module_forward(self, module, lora_name):
handle = RemovableHandle(module._forward_hooks)
handle.id = 9000
module._forward_hooks[handle.id] = lora_forward_hook(lora_name)
self.hooks.append(handle)
def apply_lora_model(self, name, mult: float = 1.0):
path = Path(self.lora_path, name)
file = Path(path, "pytorch_lora_weights.bin")
@@ -258,16 +252,14 @@ class LoraManager:
clear_applied_loras()
self.loras_to_load = {}
def clear_hooks(self):
for hook in self.hooks:
hook.remove()
self.hooks.clear()
def __del__(self):
# cleanup overrides
if hasattr(torch.nn.Linear, 'forward_before_lora'):
torch.nn.Linear.forward = torch.nn.Linear.forward_before_lora
del torch.nn.Linear.forward_before_lora
if hasattr(torch.nn.Conv2d, 'forward_before_lora'):
torch.nn.Conv2d.forward = torch.nn.Conv2d.forward_before_lora
del torch.nn.Conv2d.forward_before_lora
self.clear_hooks()
clear_applied_loras()
clear_loaded_loras()
del self.text_modules