Support Stable Diffusion V2

This commit is contained in:
lvmin
2023-02-15 00:18:05 -08:00
parent a4eea6a59e
commit 2a4424c1c0
4 changed files with 177 additions and 0 deletions

View File

@@ -119,6 +119,8 @@ Do not ask us why we use these three names - this is related to the dark history
Then you need to decide which Stable Diffusion Model you want to control. In this example, we will just use standard SD1.5. You can download it from the [official page of Stability](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main). You want the file "v1-5-pruned.ckpt".
(Or ["v2-1_512-ema-pruned.ckpt"](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/tree/main) if you are using SD2)
Then you need to attach a control net to the SD model. The architecture is
![img](../github_page/sd.png)
@@ -129,6 +131,10 @@ We provide a simple script for you to achieve this easily. If your SD filename i
python tool_add_control.py ./models/v1-5-pruned.ckpt ./models/control_sd15_ini.ckpt
Or if you are using SD2:
python tool_add_control_sd21.py ./models/v2-1_512-ema-pruned.ckpt ./models/control_sd21_ini.ckpt
You may also use other filenames as long as the command is "python tool_add_control.py input_path output_path".
This is the correct output from my machine:
@@ -177,6 +183,7 @@ trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger])
trainer.fit(model, dataloader)
```
(or "tutorial_train_sd21.py" if you are using SD2)
Thanks to our organized dataset pytorch object and the power of pytorch_lightning, the entire code is just super short.

85
models/cldm_v21.yaml Normal file
View File

@@ -0,0 +1,85 @@
model:
target: cldm.cldm.ControlLDM
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
control_key: "hint"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
only_mid_control: False
control_stage_config:
target: cldm.cldm.ControlNet
params:
use_checkpoint: True
image_size: 32 # unused
in_channels: 4
hint_channels: 3
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
unet_config:
target: cldm.cldm.ControlledUnetModel
params:
use_checkpoint: True
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

50
tool_add_control_sd21.py Normal file
View File

@@ -0,0 +1,50 @@
import sys
import os
assert len(sys.argv) == 3, 'Args are wrong.'
input_path = sys.argv[1]
output_path = sys.argv[2]
assert os.path.exists(input_path), 'Input model does not exist.'
assert not os.path.exists(output_path), 'Output filename already exists.'
assert os.path.exists(os.path.dirname(output_path)), 'Output path is not valid.'
import torch
from share import *
from cldm.model import create_model
def get_node_name(name, parent_name):
if len(name) <= len(parent_name):
return False, ''
p = name[:len(parent_name)]
if p != parent_name:
return False, ''
return True, name[len(parent_name):]
model = create_model(config_path='./models/cldm_v21.yaml')
pretrained_weights = torch.load(input_path)
if 'state_dict' in pretrained_weights:
pretrained_weights = pretrained_weights['state_dict']
scratch_dict = model.state_dict()
target_dict = {}
for k in scratch_dict.keys():
is_control, name = get_node_name(k, 'control_')
if is_control:
copy_k = 'model.diffusion_' + name
else:
copy_k = k
if copy_k in pretrained_weights:
target_dict[k] = pretrained_weights[copy_k].clone()
else:
target_dict[k] = scratch_dict[k].clone()
print(f'These weights are newly added: {k}')
model.load_state_dict(target_dict, strict=True)
torch.save(model.state_dict(), output_path)
print('Done.')

35
tutorial_train_sd21.py Normal file
View File

@@ -0,0 +1,35 @@
from share import *
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from tutorial_dataset import MyDataset
from cldm.logger import ImageLogger
from cldm.model import create_model, load_state_dict
# Configs
resume_path = './models/control_sd21_ini.ckpt'
batch_size = 4
logger_freq = 300
learning_rate = 1e-5
sd_locked = True
only_mid_control = False
# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
model = create_model('./models/cldm_v21.yaml').cpu()
model.load_state_dict(load_state_dict(resume_path, location='cpu'))
model.learning_rate = learning_rate
model.sd_locked = sd_locked
model.only_mid_control = only_mid_control
# Misc
dataset = MyDataset()
dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
logger = ImageLogger(batch_frequency=logger_freq)
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger])
# Train!
trainer.fit(model, dataloader)