mirror of
https://github.com/lllyasviel/ControlNet.git
synced 2026-01-09 14:08:03 -05:00
Support Stable Diffusion V2
This commit is contained in:
@@ -119,6 +119,8 @@ Do not ask us why we use these three names - this is related to the dark history
|
||||
|
||||
Then you need to decide which Stable Diffusion Model you want to control. In this example, we will just use standard SD1.5. You can download it from the [official page of Stability](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main). You want the file "v1-5-pruned.ckpt".
|
||||
|
||||
(Or ["v2-1_512-ema-pruned.ckpt"](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/tree/main) if you are using SD2)
|
||||
|
||||
Then you need to attach a control net to the SD model. The architecture is
|
||||
|
||||

|
||||
@@ -129,6 +131,10 @@ We provide a simple script for you to achieve this easily. If your SD filename i
|
||||
|
||||
python tool_add_control.py ./models/v1-5-pruned.ckpt ./models/control_sd15_ini.ckpt
|
||||
|
||||
Or if you are using SD2:
|
||||
|
||||
python tool_add_control_sd21.py ./models/v2-1_512-ema-pruned.ckpt ./models/control_sd21_ini.ckpt
|
||||
|
||||
You may also use other filenames as long as the command is "python tool_add_control.py input_path output_path".
|
||||
|
||||
This is the correct output from my machine:
|
||||
@@ -177,6 +183,7 @@ trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger])
|
||||
trainer.fit(model, dataloader)
|
||||
|
||||
```
|
||||
(or "tutorial_train_sd21.py" if you are using SD2)
|
||||
|
||||
Thanks to our organized dataset pytorch object and the power of pytorch_lightning, the entire code is just super short.
|
||||
|
||||
|
||||
85
models/cldm_v21.yaml
Normal file
85
models/cldm_v21.yaml
Normal file
@@ -0,0 +1,85 @@
|
||||
model:
|
||||
target: cldm.cldm.ControlLDM
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
control_key: "hint"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
only_mid_control: False
|
||||
|
||||
control_stage_config:
|
||||
target: cldm.cldm.ControlNet
|
||||
params:
|
||||
use_checkpoint: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
hint_channels: 3
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
unet_config:
|
||||
target: cldm.cldm.ControlledUnetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
#attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
50
tool_add_control_sd21.py
Normal file
50
tool_add_control_sd21.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
assert len(sys.argv) == 3, 'Args are wrong.'
|
||||
|
||||
input_path = sys.argv[1]
|
||||
output_path = sys.argv[2]
|
||||
|
||||
assert os.path.exists(input_path), 'Input model does not exist.'
|
||||
assert not os.path.exists(output_path), 'Output filename already exists.'
|
||||
assert os.path.exists(os.path.dirname(output_path)), 'Output path is not valid.'
|
||||
|
||||
import torch
|
||||
from share import *
|
||||
from cldm.model import create_model
|
||||
|
||||
|
||||
def get_node_name(name, parent_name):
|
||||
if len(name) <= len(parent_name):
|
||||
return False, ''
|
||||
p = name[:len(parent_name)]
|
||||
if p != parent_name:
|
||||
return False, ''
|
||||
return True, name[len(parent_name):]
|
||||
|
||||
|
||||
model = create_model(config_path='./models/cldm_v21.yaml')
|
||||
|
||||
pretrained_weights = torch.load(input_path)
|
||||
if 'state_dict' in pretrained_weights:
|
||||
pretrained_weights = pretrained_weights['state_dict']
|
||||
|
||||
scratch_dict = model.state_dict()
|
||||
|
||||
target_dict = {}
|
||||
for k in scratch_dict.keys():
|
||||
is_control, name = get_node_name(k, 'control_')
|
||||
if is_control:
|
||||
copy_k = 'model.diffusion_' + name
|
||||
else:
|
||||
copy_k = k
|
||||
if copy_k in pretrained_weights:
|
||||
target_dict[k] = pretrained_weights[copy_k].clone()
|
||||
else:
|
||||
target_dict[k] = scratch_dict[k].clone()
|
||||
print(f'These weights are newly added: {k}')
|
||||
|
||||
model.load_state_dict(target_dict, strict=True)
|
||||
torch.save(model.state_dict(), output_path)
|
||||
print('Done.')
|
||||
35
tutorial_train_sd21.py
Normal file
35
tutorial_train_sd21.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from share import *
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import DataLoader
|
||||
from tutorial_dataset import MyDataset
|
||||
from cldm.logger import ImageLogger
|
||||
from cldm.model import create_model, load_state_dict
|
||||
|
||||
|
||||
# Configs
|
||||
resume_path = './models/control_sd21_ini.ckpt'
|
||||
batch_size = 4
|
||||
logger_freq = 300
|
||||
learning_rate = 1e-5
|
||||
sd_locked = True
|
||||
only_mid_control = False
|
||||
|
||||
|
||||
# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
|
||||
model = create_model('./models/cldm_v21.yaml').cpu()
|
||||
model.load_state_dict(load_state_dict(resume_path, location='cpu'))
|
||||
model.learning_rate = learning_rate
|
||||
model.sd_locked = sd_locked
|
||||
model.only_mid_control = only_mid_control
|
||||
|
||||
|
||||
# Misc
|
||||
dataset = MyDataset()
|
||||
dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
|
||||
logger = ImageLogger(batch_frequency=logger_freq)
|
||||
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger])
|
||||
|
||||
|
||||
# Train!
|
||||
trainer.fit(model, dataloader)
|
||||
Reference in New Issue
Block a user