diff --git a/docs/train.md b/docs/train.md index 0fa155b..e1c16f3 100644 --- a/docs/train.md +++ b/docs/train.md @@ -119,6 +119,8 @@ Do not ask us why we use these three names - this is related to the dark history Then you need to decide which Stable Diffusion Model you want to control. In this example, we will just use standard SD1.5. You can download it from the [official page of Stability](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main). You want the file "v1-5-pruned.ckpt". +(Or ["v2-1_512-ema-pruned.ckpt"](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/tree/main) if you are using SD2) + Then you need to attach a control net to the SD model. The architecture is ![img](../github_page/sd.png) @@ -129,6 +131,10 @@ We provide a simple script for you to achieve this easily. If your SD filename i python tool_add_control.py ./models/v1-5-pruned.ckpt ./models/control_sd15_ini.ckpt +Or if you are using SD2: + + python tool_add_control_sd21.py ./models/v2-1_512-ema-pruned.ckpt ./models/control_sd21_ini.ckpt + You may also use other filenames as long as the command is "python tool_add_control.py input_path output_path". This is the correct output from my machine: @@ -177,6 +183,7 @@ trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger]) trainer.fit(model, dataloader) ``` +(or "tutorial_train_sd21.py" if you are using SD2) Thanks to our organized dataset pytorch object and the power of pytorch_lightning, the entire code is just super short. diff --git a/models/cldm_v21.yaml b/models/cldm_v21.yaml new file mode 100644 index 0000000..fc65193 --- /dev/null +++ b/models/cldm_v21.yaml @@ -0,0 +1,85 @@ +model: + target: cldm.cldm.ControlLDM + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + control_key: "hint" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + only_mid_control: False + + control_stage_config: + target: cldm.cldm.ControlNet + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 4 + hint_channels: 3 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + unet_config: + target: cldm.cldm.ControlledUnetModel + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/tool_add_control_sd21.py b/tool_add_control_sd21.py new file mode 100644 index 0000000..7c3ac5f --- /dev/null +++ b/tool_add_control_sd21.py @@ -0,0 +1,50 @@ +import sys +import os + +assert len(sys.argv) == 3, 'Args are wrong.' + +input_path = sys.argv[1] +output_path = sys.argv[2] + +assert os.path.exists(input_path), 'Input model does not exist.' +assert not os.path.exists(output_path), 'Output filename already exists.' +assert os.path.exists(os.path.dirname(output_path)), 'Output path is not valid.' + +import torch +from share import * +from cldm.model import create_model + + +def get_node_name(name, parent_name): + if len(name) <= len(parent_name): + return False, '' + p = name[:len(parent_name)] + if p != parent_name: + return False, '' + return True, name[len(parent_name):] + + +model = create_model(config_path='./models/cldm_v21.yaml') + +pretrained_weights = torch.load(input_path) +if 'state_dict' in pretrained_weights: + pretrained_weights = pretrained_weights['state_dict'] + +scratch_dict = model.state_dict() + +target_dict = {} +for k in scratch_dict.keys(): + is_control, name = get_node_name(k, 'control_') + if is_control: + copy_k = 'model.diffusion_' + name + else: + copy_k = k + if copy_k in pretrained_weights: + target_dict[k] = pretrained_weights[copy_k].clone() + else: + target_dict[k] = scratch_dict[k].clone() + print(f'These weights are newly added: {k}') + +model.load_state_dict(target_dict, strict=True) +torch.save(model.state_dict(), output_path) +print('Done.') diff --git a/tutorial_train_sd21.py b/tutorial_train_sd21.py new file mode 100644 index 0000000..8bbc148 --- /dev/null +++ b/tutorial_train_sd21.py @@ -0,0 +1,35 @@ +from share import * + +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from tutorial_dataset import MyDataset +from cldm.logger import ImageLogger +from cldm.model import create_model, load_state_dict + + +# Configs +resume_path = './models/control_sd21_ini.ckpt' +batch_size = 4 +logger_freq = 300 +learning_rate = 1e-5 +sd_locked = True +only_mid_control = False + + +# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs. +model = create_model('./models/cldm_v21.yaml').cpu() +model.load_state_dict(load_state_dict(resume_path, location='cpu')) +model.learning_rate = learning_rate +model.sd_locked = sd_locked +model.only_mid_control = only_mid_control + + +# Misc +dataset = MyDataset() +dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True) +logger = ImageLogger(batch_frequency=logger_freq) +trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger]) + + +# Train! +trainer.fit(model, dataloader)