further cleanup and implement wrapper

This commit is contained in:
Jordan
2023-02-21 20:42:40 -07:00
parent 686f6ef8d6
commit af3543a8c7

View File

@@ -22,26 +22,112 @@ class LoRALayer:
self.scale = alpha / rank
class LoRAModuleWrapper:
unet: UNet2DConditionModel
text_encoder: CLIPTextModel
def __init__(self, unet, text_encoder):
self.unet = unet
self.text_encoder = text_encoder
self.hooks = []
self.text_modules = None
self.unet_modules = None
self.applied_loras = {}
self.loaded_loras = {}
self.UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
self.TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
self.LORA_PREFIX_UNET = 'lora_unet'
self.LORA_PREFIX_TEXT_ENCODER = 'lora_te'
def find_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> dict[str, torch.nn.Module]:
mapping = {}
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
layer_type = child_module.__class__.__name__
if layer_type == "Linear" or (layer_type == "Conv2d" and child_module.kernel_size == (1, 1)):
lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_')
mapping[lora_name] = child_module
self.apply_module_forward(child_module, lora_name)
return mapping
if self.text_modules is None:
self.text_modules = find_modules(
self.LORA_PREFIX_TEXT_ENCODER,
text_encoder,
self.TEXT_ENCODER_TARGET_REPLACE_MODULE
)
if self.unet_modules is None:
self.unet_modules = find_modules(
self.LORA_PREFIX_UNET,
unet,
self.UNET_TARGET_REPLACE_MODULE
)
def lora_forward_hook(self, name):
wrapper = self
def lora_forward(module, input_h, output):
if len(wrapper.loaded_loras) == 0:
return output
for lora in wrapper.applied_loras.values():
layer = lora.layers.get(name, None)
if layer is None:
continue
output = output + layer.up(layer.down(*input_h)) * lora.multiplier * layer.scale
return output
return lora_forward
def apply_module_forward(self, module, name):
handle = module.register_forward_hook(self.lora_forward_hook(name))
self.hooks.append(handle)
def clear_hooks(self):
for hook in self.hooks:
hook.remove()
self.hooks.clear()
def clear_applied_loras(self):
self.applied_loras.clear()
def clear_loaded_loras(self):
self.loaded_loras.clear()
def __del__(self):
self.clear_hooks()
self.clear_applied_loras()
self.clear_loaded_loras()
del self.text_modules
del self.unet_modules
del self.hooks
class LoRA:
name: str
layers: dict[str, LoRALayer]
device: torch.device
dtype: torch.dtype
wrapper: LoRAModuleWrapper
multiplier: float
def __init__(self, name: str, device, dtype, multiplier=1.0):
def __init__(self, name: str, device, dtype, wrapper, multiplier=1.0):
self.name = name
self.layers = {}
self.multiplier = multiplier
self.device = device
self.dtype = dtype
self.wrapper = wrapper
self.rank = None
self.alpha = None
def load_from_dict(self,
state_dict,
text_modules: dict[str, torch.nn.Module],
unet_modules: dict[str, torch.nn.Module]):
def load_from_dict(self, state_dict):
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
@@ -50,8 +136,8 @@ class LoRA:
self.alpha = value.item()
continue
if stem.startswith(LORA_PREFIX_TEXT_ENCODER):
wrapped = text_modules.get(stem, None)
if stem.startswith(self.wrapper.LORA_PREFIX_TEXT_ENCODER):
wrapped = self.wrapper.text_modules.get(stem, None)
if wrapped is None:
print(f">> Missing layer: {stem}")
continue
@@ -60,8 +146,8 @@ class LoRA:
self.rank = value.shape[0]
self.load_lora_layer(stem, leaf, value, wrapped)
continue
elif stem.startswith(LORA_PREFIX_UNET):
wrapped = unet_modules.get(stem, None)
elif stem.startswith(self.wrapper.LORA_PREFIX_UNET):
wrapped = self.wrapper.unet_modules.get(stem, None)
if wrapped is None:
print(f">> Missing layer: {stem}")
continue
@@ -101,11 +187,6 @@ class LoRA:
return
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = 'lora_unet'
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
re_digits = re.compile(r"\d+")
re_unet_transformer_attn_blocks = re.compile(
r"lora_unet_(.+)_blocks_(\d+)_attentions_(\d+)_transformer_blocks_(\d+)_attn(\d+)_(.+).(weight|alpha)"
@@ -196,8 +277,7 @@ def convert_key_to_diffusers(key):
def load_lora_attn(
name: str,
path_file: Path,
unet: UNet2DConditionModel,
text_encoder: CLIPTextModel,
wrapper: LoRAModuleWrapper,
multiplier=1.0
):
print(f">> Loading lora {name} from {path_file}")
@@ -207,10 +287,10 @@ def load_lora_attn(
checkpoint = torch.load(path_file, map_location='cpu')
for key in list(checkpoint.keys()):
if key.startswith(LORA_PREFIX_UNET):
if key.startswith(wrapper.LORA_PREFIX_UNET):
# convert unet keys
checkpoint[convert_key_to_diffusers(key)] = checkpoint.pop(key)
elif key.startswith(LORA_PREFIX_UNET):
elif key.startswith(wrapper.LORA_PREFIX_UNET):
# convert text encoder keys (not yet supported)
# state_dict[convert_key_to_diffusers(key)] = state_dict.pop(key)
checkpoint.pop(key)
@@ -218,44 +298,8 @@ def load_lora_attn(
# remove invalid key
checkpoint.pop(key)
unet.load_attn_procs(checkpoint)
# text_encoder.load_attn_procs(checkpoint)
def lora_forward_hook(name):
def lora_forward(module, input_h, output):
if len(loaded_loras) == 0:
return output
for lora in applied_loras.values():
layer = lora.layers.get(name, None)
if layer is None:
continue
output = output + layer.up(layer.down(*input_h)) * lora.multiplier * layer.scale
return output
return lora_forward
def load_lora(
name: str,
path_file: Path,
device: torch.device,
dtype: torch.dtype,
text_modules: dict[str, torch.nn.Module],
unet_modules: dict[str, torch.nn.Module],
multiplier=1.0
):
print(f">> Loading lora {name} from {path_file}")
if path_file.suffix == '.safetensors':
checkpoint = load_file(path_file.absolute().as_posix(), device='cpu')
else:
checkpoint = torch.load(path_file, map_location='cpu')
lora = LoRA(name, device, dtype, multiplier)
lora.load_from_dict(checkpoint, text_modules, unet_modules)
return lora
wrapper.unet.load_attn_procs(checkpoint)
# wrapper.text_encoder.load_attn_procs(checkpoint)
class LoraManager:
@@ -269,37 +313,23 @@ class LoraManager:
self.device = torch.device(choose_torch_device())
self.dtype = pipe.unet.dtype
self.loras_to_load = {}
self.hooks = []
self.wrapper = LoRAModuleWrapper(pipe.unet, pipe.text_encoder)
def find_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> dict[str, torch.nn.Module]:
mapping = {}
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
layer_type = child_module.__class__.__name__
if layer_type == "Linear" or (layer_type == "Conv2d" and child_module.kernel_size == (1, 1)):
lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_')
mapping[lora_name] = child_module
self.apply_module_forward(child_module, lora_name)
return mapping
self.text_modules = find_modules(
LORA_PREFIX_TEXT_ENCODER, self.text_encoder, TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.unet_modules = find_modules(
LORA_PREFIX_UNET, self.unet, UNET_TARGET_REPLACE_MODULE)
def _load_lora(self, name, path_file, multiplier: float = 1.0):
def load_lora_module(self, name, path_file, multiplier: float = 1.0):
# can be used instead to load through diffusers, once enough support is added
# lora = load_lora_attn(name, path_file, self.unet, self.text_encoder, multiplier)
lora = load_lora(name, path_file, self.device, self.dtype, self.text_modules, self.unet_modules, multiplier)
loaded_loras[name] = lora
return lora
# lora = load_lora_attn(name, path_file, self.wrapper, multiplier)
def apply_module_forward(self, module, lora_name):
handle = module.register_forward_hook(lora_forward_hook(lora_name))
self.hooks.append(handle)
print(f">> Loading lora {name} from {path_file}")
if path_file.suffix == '.safetensors':
checkpoint = load_file(path_file.absolute().as_posix(), device='cpu')
else:
checkpoint = torch.load(path_file, map_location='cpu')
lora = LoRA(name, self.device, self.dtype, self.wrapper, multiplier)
lora.load_from_dict(checkpoint)
self.wrapper.loaded_loras[name] = lora
return lora
def apply_lora_model(self, name, mult: float = 1.0):
path = Path(self.lora_path, name)
@@ -318,31 +348,30 @@ class LoraManager:
print(f">> Unable to find lora: {name}")
return
lora = loaded_loras.get(name, None)
lora = self.wrapper.loaded_loras.get(name, None)
if lora is None:
lora = self._load_lora(name, path_file, mult)
lora = self.load_lora_module(name, path_file, mult)
lora.multiplier = mult
applied_loras[name] = lora
self.wrapper.applied_loras[name] = lora
def load_lora(self):
print(self.loras_to_load)
for name, multiplier in self.loras_to_load.items():
self.apply_lora_model(name, multiplier)
# unload any lora's not defined by loras_to_load
for name in list(applied_loras.keys()):
for name in list(self.wrapper.applied_loras.keys()):
if name not in self.loras_to_load:
self.unload_applied_lora(name)
@staticmethod
def unload_applied_lora(lora_name: str):
if lora_name in applied_loras:
del applied_loras[lora_name]
def unload_applied_lora(self, lora_name: str):
if lora_name in self.wrapper.applied_loras:
del self.wrapper.applied_loras[lora_name]
@staticmethod
def unload_lora(lora_name: str):
if lora_name in loaded_loras:
del loaded_loras[lora_name]
def unload_lora(self, lora_name: str):
if lora_name in self.wrapper.loaded_loras:
del self.wrapper.loaded_loras[lora_name]
# Define a lora to be loaded
# Can be used to define a lora to be loaded outside of prompts
@@ -350,8 +379,8 @@ class LoraManager:
self.loras_to_load[name] = multiplier
# update the multiplier if the lora was already loaded
if name in loaded_loras:
loaded_loras[name].multiplier = multiplier
if name in self.wrapper.loaded_loras:
self.wrapper.loaded_loras[name].multiplier = multiplier
# Load the lora from a prompt, syntax is <lora:lora_name:multiplier>
# Multiplier should be a value between 0.0 and 1.0
@@ -374,31 +403,8 @@ class LoraManager:
return re.sub(lora_match, "", prompt)
def clear_loras(self):
clear_applied_loras()
self.wrapper.clear_applied_loras()
self.loras_to_load = {}
def clear_hooks(self):
for hook in self.hooks:
hook.remove()
self.hooks.clear()
def __del__(self):
self.clear_hooks()
clear_applied_loras()
clear_loaded_loras()
del self.text_modules
del self.unet_modules
del self.loras_to_load
applied_loras = {}
loaded_loras = {}
def clear_applied_loras():
applied_loras.clear()
def clear_loaded_loras():
loaded_loras.clear()