Files
InvokeAI/ldm/modules/peft_manager.py
Lincoln Stein 879c80022e preliminary LoRA support ready for testing
Instructions:

1. Download LoRA .safetensors files of your choice and place in
   `INVOKEAIROOT/loras`. Unlike the draft version of this, the file
   names can contain underscores and alphanumerics. Names with
   arbitrary unicode characters are not supported.

2. Add `withLora(lora-file-basename,weight)` to your prompt. The
   weight is optional and will default to 1.0. A few examples, assuming
   that a LoRA file named `loras/sushi.safetensors` is present:

```
family sitting at dinner table eating sushi withLora(sushi,0.9)
family sitting at dinner table eating sushi withLora(sushi, 0.75)
family sitting at dinner table eating sushi withLora(sushi)
```

Multiple `withLora()` prompt fragments are allowed. The weight can be
arbitrarily large, but the useful range is roughly 0.5 to 1.0. Higher
weights make the LoRA's influence stronger.

In my limited testing, I found it useful to reduce the CFG to avoid
oversharpening. Also I got better results when running the LoRA on top
of the model on which it was based during training.

Don't try to load a SD 1.x-trained LoRA into a SD 2.x model, and vice
versa. You will get a nasty stack trace. This needs to be cleaned up.

3. You can change the location of the `loras` directory by passing the
   `--lora_directory` option to `invokeai.

Documentation can be found in docs/features/LORAS.md.
2023-03-31 00:03:16 -04:00

96 lines
3.1 KiB
Python

from peft import LoraModel, LoraConfig, set_peft_model_state_dict
import torch
import json
from pathlib import Path
from ldm.invoke.globals import global_lora_models_dir
class LoraPeftModule:
def __init__(self, lora_dir, multiplier: float = 1.0):
self.lora_dir = lora_dir
self.multiplier = multiplier
self.config = self.load_config()
self.checkpoint = self.load_checkpoint()
def load_config(self):
lora_config_file = Path(self.lora_dir, f'lora_config.json')
with open(lora_config_file, "r") as f:
return json.load(f)
def load_checkpoint(self):
return torch.load(Path(self.lora_dir, f'lora.pt'))
def unet(self, text_encoder):
lora_ds = {
k.replace("text_encoder_", ""): v for k, v in self.checkpoint.items() if "text_encoder_" in k
}
config = LoraConfig(**self.config["peft_config"])
model = LoraModel(config, text_encoder)
set_peft_model_state_dict(model, lora_ds)
return model
def text_encoder(self, unet):
lora_ds = {
k: v for k, v in self.checkpoint.items() if "text_encoder_" not in k
}
config = LoraConfig(**self.config["text_encoder_peft_config"])
model = LoraModel(config, unet)
set_peft_model_state_dict(model, lora_ds)
return model
def apply(self, pipe, dtype):
pipe.unet = self.unet(pipe.unet)
if "text_encoder_peft_config" in self.config:
pipe.text_encoder = self.text_encoder(pipe.text_encoder)
if dtype in (torch.float16, torch.bfloat16):
pipe.unet.half()
pipe.text_encoder.half()
return pipe
class PeftManager:
modules: list[LoraPeftModule]
def __init__(self):
self.lora_path = global_lora_models_dir()
self.modules = []
def set_loras(self, lora_weights: list):
if len(lora_weights) > 0:
for lora in lora_weights:
self.add(lora.model, lora.weight)
def add(self, name, multiplier: float = 1.0):
lora_dir = Path(self.lora_path, name)
if lora_dir.exists():
lora_config_file = Path(lora_dir, f'lora_config.json')
lora_checkpoint = Path(lora_dir, f'lora.pt')
if lora_config_file.exists() and lora_checkpoint.exists():
self.modules.append(LoraPeftModule(lora_dir, multiplier))
return
print(f">> Failed to load lora {name}")
def load(self, pipe, dtype):
if len(self.modules) > 0:
for module in self.modules:
pipe = module.apply(pipe, dtype)
return pipe
# Simple check to allow previous functionality
def should_use(self, lora_weights: list):
if len(lora_weights) > 0:
for lora in lora_weights:
lora_dir = Path(self.lora_path, lora.model)
if lora_dir.exists():
lora_config_file = Path(lora_dir, f'lora_config.json')
lora_checkpoint = Path(lora_dir, f'lora.pt')
if lora_config_file.exists() and lora_checkpoint.exists():
return False
return True