mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
694 lines
22 KiB
Python
694 lines
22 KiB
Python
# Install the required libs
|
|
# pip install -U git+https://github.com/huggingface/diffusers.git
|
|
# pip install accelerate transformers ftfy
|
|
|
|
# HuggingFace Token
|
|
# YOUR_TOKEN = "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"
|
|
|
|
|
|
# Import required libraries
|
|
import itertools
|
|
import math
|
|
import os
|
|
from typing import List
|
|
import random
|
|
import torch_mlir
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.utils.checkpoint
|
|
from torch.utils.data import Dataset
|
|
|
|
import PIL
|
|
import logging
|
|
|
|
from diffusers import (
|
|
AutoencoderKL,
|
|
DDPMScheduler,
|
|
PNDMScheduler,
|
|
StableDiffusionPipeline,
|
|
UNet2DConditionModel,
|
|
)
|
|
from PIL import Image
|
|
from tqdm.auto import tqdm
|
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|
from diffusers.loaders import AttnProcsLayers
|
|
from diffusers.models.attention_processor import LoRAXFormersAttnProcessor
|
|
|
|
import torch_mlir
|
|
from torch_mlir.dynamo import make_simple_dynamo_backend
|
|
import torch._dynamo as dynamo
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
|
from shark.shark_inference import SharkInference
|
|
|
|
torch._dynamo.config.verbose = True
|
|
|
|
from diffusers import (
|
|
AutoencoderKL,
|
|
DDPMScheduler,
|
|
PNDMScheduler,
|
|
StableDiffusionPipeline,
|
|
UNet2DConditionModel,
|
|
)
|
|
from diffusers.optimization import get_scheduler
|
|
from diffusers.pipelines.stable_diffusion import (
|
|
StableDiffusionSafetyChecker,
|
|
)
|
|
from PIL import Image
|
|
from tqdm.auto import tqdm
|
|
from transformers import (
|
|
CLIPFeatureExtractor,
|
|
CLIPTextModel,
|
|
CLIPTokenizer,
|
|
)
|
|
|
|
from io import BytesIO
|
|
|
|
from dataclasses import dataclass
|
|
from apps.stable_diffusion.src import (
|
|
args,
|
|
get_schedulers,
|
|
set_init_device_flags,
|
|
clear_all,
|
|
)
|
|
from apps.stable_diffusion.src.utils import update_lora_weight
|
|
|
|
|
|
# Setup the dataset
|
|
class LoraDataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
data_root,
|
|
tokenizer,
|
|
size=512,
|
|
repeats=100,
|
|
interpolation="bicubic",
|
|
set="train",
|
|
prompt="myloraprompt",
|
|
center_crop=False,
|
|
):
|
|
self.data_root = data_root
|
|
self.tokenizer = tokenizer
|
|
self.size = size
|
|
self.center_crop = center_crop
|
|
self.prompt = prompt
|
|
|
|
self.image_paths = [
|
|
os.path.join(self.data_root, file_path)
|
|
for file_path in os.listdir(self.data_root)
|
|
]
|
|
|
|
self.num_images = len(self.image_paths)
|
|
self._length = self.num_images
|
|
|
|
if set == "train":
|
|
self._length = self.num_images * repeats
|
|
|
|
self.interpolation = {
|
|
"linear": PIL.Image.LINEAR,
|
|
"bilinear": PIL.Image.BILINEAR,
|
|
"bicubic": PIL.Image.BICUBIC,
|
|
"lanczos": PIL.Image.LANCZOS,
|
|
}[interpolation]
|
|
|
|
def __len__(self):
|
|
return self._length
|
|
|
|
def __getitem__(self, i):
|
|
example = {}
|
|
image = Image.open(self.image_paths[i % self.num_images])
|
|
|
|
if not image.mode == "RGB":
|
|
image = image.convert("RGB")
|
|
|
|
example["input_ids"] = self.tokenizer(
|
|
self.prompt,
|
|
padding="max_length",
|
|
truncation=True,
|
|
max_length=self.tokenizer.model_max_length,
|
|
return_tensors="pt",
|
|
).input_ids[0]
|
|
|
|
# default to score-sde preprocessing
|
|
img = np.array(image).astype(np.uint8)
|
|
|
|
if self.center_crop:
|
|
crop = min(img.shape[0], img.shape[1])
|
|
(
|
|
h,
|
|
w,
|
|
) = (
|
|
img.shape[0],
|
|
img.shape[1],
|
|
)
|
|
img = img[
|
|
(h - crop) // 2 : (h + crop) // 2,
|
|
(w - crop) // 2 : (w + crop) // 2,
|
|
]
|
|
|
|
image = Image.fromarray(img)
|
|
image = image.resize(
|
|
(self.size, self.size), resample=self.interpolation
|
|
)
|
|
|
|
image = np.array(image).astype(np.uint8)
|
|
image = (image / 127.5 - 1.0).astype(np.float32)
|
|
|
|
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
|
return example
|
|
|
|
|
|
def torch_device(device):
|
|
device_tokens = device.split("=>")
|
|
if len(device_tokens) == 1:
|
|
device_str = device_tokens[0].strip()
|
|
else:
|
|
device_str = device_tokens[1].strip()
|
|
device_type_tokens = device_str.split("://")
|
|
if device_type_tokens[0] == "metal":
|
|
device_type_tokens[0] = "vulkan"
|
|
if len(device_type_tokens) > 1:
|
|
return device_type_tokens[0] + ":" + device_type_tokens[1]
|
|
else:
|
|
return device_type_tokens[0]
|
|
|
|
|
|
########## Setting up the model ##########
|
|
def lora_train(
|
|
prompt: str,
|
|
height: int,
|
|
width: int,
|
|
steps: int,
|
|
guidance_scale: float,
|
|
seed: int,
|
|
batch_count: int,
|
|
batch_size: int,
|
|
scheduler: str,
|
|
custom_model: str,
|
|
hf_model_id: str,
|
|
precision: str,
|
|
device: str,
|
|
max_length: int,
|
|
training_images_dir: str,
|
|
lora_save_dir: str,
|
|
use_lora: str,
|
|
):
|
|
from apps.stable_diffusion.web.ui.utils import (
|
|
get_custom_model_pathfile,
|
|
Config,
|
|
)
|
|
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
|
|
|
print(
|
|
"Note LoRA training is not compatible with the latest torch-mlir branch"
|
|
)
|
|
print(
|
|
"To run LoRA training you'll need this to follow this guide for the torch-mlir branch: https://github.com/nod-ai/SHARK/tree/main/shark/examples/shark_training/stable_diffusion"
|
|
)
|
|
torch.manual_seed(seed)
|
|
|
|
args.prompts = [prompt]
|
|
args.steps = steps
|
|
|
|
# set ckpt_loc and hf_model_id.
|
|
types = (
|
|
".ckpt",
|
|
".safetensors",
|
|
) # the tuple of file types
|
|
args.ckpt_loc = ""
|
|
args.hf_model_id = ""
|
|
if custom_model == "None":
|
|
if not hf_model_id:
|
|
return (
|
|
None,
|
|
"Please provide either custom model or huggingface model ID, both must not be "
|
|
"empty.",
|
|
)
|
|
args.hf_model_id = hf_model_id
|
|
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
|
args.ckpt_loc = custom_model
|
|
else:
|
|
args.hf_model_id = custom_model
|
|
|
|
args.training_images_dir = training_images_dir
|
|
args.lora_save_dir = lora_save_dir
|
|
|
|
args.precision = precision
|
|
args.batch_size = batch_size
|
|
args.max_length = max_length
|
|
args.height = height
|
|
args.width = width
|
|
args.device = torch_device(device)
|
|
args.use_lora = use_lora
|
|
|
|
# Load the Stable Diffusion model
|
|
text_encoder = CLIPTextModel.from_pretrained(
|
|
args.hf_model_id, subfolder="text_encoder"
|
|
)
|
|
vae = AutoencoderKL.from_pretrained(args.hf_model_id, subfolder="vae")
|
|
unet = UNet2DConditionModel.from_pretrained(
|
|
args.hf_model_id, subfolder="unet"
|
|
)
|
|
|
|
def freeze_params(params):
|
|
for param in params:
|
|
param.requires_grad = False
|
|
|
|
# Freeze everything but LoRA
|
|
freeze_params(vae.parameters())
|
|
freeze_params(unet.parameters())
|
|
freeze_params(text_encoder.parameters())
|
|
|
|
# Move vae and unet to device
|
|
vae.to(args.device)
|
|
unet.to(args.device)
|
|
text_encoder.to(args.device)
|
|
|
|
if use_lora != "":
|
|
update_lora_weight(unet, args.use_lora, "unet")
|
|
else:
|
|
lora_attn_procs = {}
|
|
for name in unet.attn_processors.keys():
|
|
cross_attention_dim = (
|
|
None
|
|
if name.endswith("attn1.processor")
|
|
else unet.config.cross_attention_dim
|
|
)
|
|
if name.startswith("mid_block"):
|
|
hidden_size = unet.config.block_out_channels[-1]
|
|
elif name.startswith("up_blocks"):
|
|
block_id = int(name[len("up_blocks.")])
|
|
hidden_size = list(reversed(unet.config.block_out_channels))[
|
|
block_id
|
|
]
|
|
elif name.startswith("down_blocks"):
|
|
block_id = int(name[len("down_blocks.")])
|
|
hidden_size = unet.config.block_out_channels[block_id]
|
|
|
|
lora_attn_procs[name] = LoRAXFormersAttnProcessor(
|
|
hidden_size=hidden_size,
|
|
cross_attention_dim=cross_attention_dim,
|
|
)
|
|
|
|
unet.set_attn_processor(lora_attn_procs)
|
|
lora_layers = AttnProcsLayers(unet.attn_processors)
|
|
|
|
class VaeModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.vae = vae
|
|
|
|
def forward(self, input):
|
|
x = self.vae.encode(input, return_dict=False)[0]
|
|
return x
|
|
|
|
class UnetModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.unet = unet
|
|
|
|
def forward(self, x, y, z):
|
|
return self.unet.forward(x, y, z, return_dict=False)[0]
|
|
|
|
shark_vae = VaeModel()
|
|
shark_unet = UnetModel()
|
|
|
|
####### Creating our training data ########
|
|
|
|
tokenizer = CLIPTokenizer.from_pretrained(
|
|
args.hf_model_id,
|
|
subfolder="tokenizer",
|
|
)
|
|
|
|
# Let's create the Dataset and Dataloader
|
|
train_dataset = LoraDataset(
|
|
data_root=args.training_images_dir,
|
|
tokenizer=tokenizer,
|
|
size=vae.sample_size,
|
|
prompt=args.prompts[0],
|
|
repeats=100,
|
|
center_crop=False,
|
|
set="train",
|
|
)
|
|
|
|
def create_dataloader(train_batch_size=1):
|
|
return torch.utils.data.DataLoader(
|
|
train_dataset, batch_size=train_batch_size, shuffle=True
|
|
)
|
|
|
|
# Create noise_scheduler for training
|
|
noise_scheduler = DDPMScheduler.from_config(
|
|
args.hf_model_id, subfolder="scheduler"
|
|
)
|
|
|
|
######## Training ###########
|
|
|
|
# Define hyperparameters for our training. If you are not happy with your results,
|
|
# you can tune the `learning_rate` and the `max_train_steps`
|
|
|
|
# Setting up all training args
|
|
hyperparameters = {
|
|
"learning_rate": 5e-04,
|
|
"scale_lr": True,
|
|
"max_train_steps": steps,
|
|
"train_batch_size": batch_size,
|
|
"gradient_accumulation_steps": 1,
|
|
"gradient_checkpointing": True,
|
|
"mixed_precision": "fp16",
|
|
"seed": 42,
|
|
"output_dir": "sd-concept-output",
|
|
}
|
|
# creating output directory
|
|
cwd = os.getcwd()
|
|
out_dir = os.path.join(cwd, hyperparameters["output_dir"])
|
|
while not os.path.exists(str(out_dir)):
|
|
try:
|
|
os.mkdir(out_dir)
|
|
except OSError as error:
|
|
print("Output directory not created")
|
|
|
|
###### Torch-MLIR Compilation ######
|
|
|
|
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
|
removed_indexes = []
|
|
for node in fx_g.graph.nodes:
|
|
if node.op == "output":
|
|
assert (
|
|
len(node.args) == 1
|
|
), "Output node must have a single argument"
|
|
node_arg = node.args[0]
|
|
if isinstance(node_arg, (list, tuple)):
|
|
node_arg = list(node_arg)
|
|
node_args_len = len(node_arg)
|
|
for i in range(node_args_len):
|
|
curr_index = node_args_len - (i + 1)
|
|
if node_arg[curr_index] is None:
|
|
removed_indexes.append(curr_index)
|
|
node_arg.pop(curr_index)
|
|
node.args = (tuple(node_arg),)
|
|
break
|
|
|
|
if len(removed_indexes) > 0:
|
|
fx_g.graph.lint()
|
|
fx_g.graph.eliminate_dead_code()
|
|
fx_g.recompile()
|
|
removed_indexes.sort()
|
|
return removed_indexes
|
|
|
|
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
|
"""
|
|
Replace tuple with tuple element in functions that return one-element tuples.
|
|
Returns true if an unwrapping took place, and false otherwise.
|
|
"""
|
|
unwrapped_tuple = False
|
|
for node in fx_g.graph.nodes:
|
|
if node.op == "output":
|
|
assert (
|
|
len(node.args) == 1
|
|
), "Output node must have a single argument"
|
|
node_arg = node.args[0]
|
|
if isinstance(node_arg, tuple):
|
|
if len(node_arg) == 1:
|
|
node.args = (node_arg[0],)
|
|
unwrapped_tuple = True
|
|
break
|
|
|
|
if unwrapped_tuple:
|
|
fx_g.graph.lint()
|
|
fx_g.recompile()
|
|
return unwrapped_tuple
|
|
|
|
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
|
|
for node in fx_g.graph.nodes:
|
|
if node.op == "output":
|
|
assert (
|
|
len(node.args) == 1
|
|
), "Output node must have a single argument"
|
|
node_arg = node.args[0]
|
|
if isinstance(node_arg, tuple):
|
|
return len(node_arg) == 0
|
|
return False
|
|
|
|
def transform_fx(fx_g):
|
|
for node in fx_g.graph.nodes:
|
|
if node.op == "call_function":
|
|
if node.target in [
|
|
torch.ops.aten.empty,
|
|
]:
|
|
# aten.empty should be filled with zeros.
|
|
if node.target in [torch.ops.aten.empty]:
|
|
with fx_g.graph.inserting_after(node):
|
|
new_node = fx_g.graph.call_function(
|
|
torch.ops.aten.zero_,
|
|
args=(node,),
|
|
)
|
|
node.append(new_node)
|
|
node.replace_all_uses_with(new_node)
|
|
new_node.args = (node,)
|
|
|
|
fx_g.graph.lint()
|
|
|
|
@make_simple_dynamo_backend
|
|
def refbackend_torchdynamo_backend(
|
|
fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
|
):
|
|
# handling usage of empty tensor without initializing
|
|
transform_fx(fx_graph)
|
|
fx_graph.recompile()
|
|
if _returns_nothing(fx_graph):
|
|
return fx_graph
|
|
removed_none_indexes = _remove_nones(fx_graph)
|
|
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
|
|
|
|
mlir_module = torch_mlir.compile(
|
|
fx_graph, example_inputs, output_type="linalg-on-tensors"
|
|
)
|
|
|
|
bytecode_stream = BytesIO()
|
|
mlir_module.operation.write_bytecode(bytecode_stream)
|
|
bytecode = bytecode_stream.getvalue()
|
|
|
|
shark_module = SharkInference(
|
|
mlir_module=bytecode, device=args.device, mlir_dialect="tm_tensor"
|
|
)
|
|
shark_module.compile()
|
|
|
|
def compiled_callable(*inputs):
|
|
inputs = [x.numpy() for x in inputs]
|
|
result = shark_module("forward", inputs)
|
|
if was_unwrapped:
|
|
result = [
|
|
result,
|
|
]
|
|
if not isinstance(result, list):
|
|
result = torch.from_numpy(result)
|
|
else:
|
|
result = tuple(torch.from_numpy(x) for x in result)
|
|
result = list(result)
|
|
for removed_index in removed_none_indexes:
|
|
result.insert(removed_index, None)
|
|
result = tuple(result)
|
|
return result
|
|
|
|
return compiled_callable
|
|
|
|
def predictions(torch_func, jit_func, batchA, batchB):
|
|
res = jit_func(batchA.numpy(), batchB.numpy())
|
|
if res is not None:
|
|
# prediction = torch.from_numpy(res)
|
|
prediction = res
|
|
else:
|
|
prediction = None
|
|
return prediction
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
train_batch_size = hyperparameters["train_batch_size"]
|
|
gradient_accumulation_steps = hyperparameters[
|
|
"gradient_accumulation_steps"
|
|
]
|
|
learning_rate = hyperparameters["learning_rate"]
|
|
if hyperparameters["scale_lr"]:
|
|
learning_rate = (
|
|
learning_rate
|
|
* gradient_accumulation_steps
|
|
* train_batch_size
|
|
# * accelerator.num_processes
|
|
)
|
|
|
|
# Initialize the optimizer
|
|
optimizer = torch.optim.AdamW(
|
|
lora_layers.parameters(), # only optimize the embeddings
|
|
lr=learning_rate,
|
|
)
|
|
|
|
# Training function
|
|
def train_func(batch_pixel_values, batch_input_ids):
|
|
# Convert images to latent space
|
|
latents = shark_vae(batch_pixel_values).sample().detach()
|
|
latents = latents * 0.18215
|
|
|
|
# Sample noise that we'll add to the latents
|
|
noise = torch.randn_like(latents)
|
|
bsz = latents.shape[0]
|
|
# Sample a random timestep for each image
|
|
timesteps = torch.randint(
|
|
0,
|
|
noise_scheduler.num_train_timesteps,
|
|
(bsz,),
|
|
device=latents.device,
|
|
).long()
|
|
|
|
# Add noise to the latents according to the noise magnitude at each timestep
|
|
# (this is the forward diffusion process)
|
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
|
|
|
# Get the text embedding for conditioning
|
|
encoder_hidden_states = text_encoder(batch_input_ids)[0]
|
|
|
|
# Predict the noise residual
|
|
noise_pred = shark_unet(
|
|
noisy_latents,
|
|
timesteps,
|
|
encoder_hidden_states,
|
|
)
|
|
|
|
# Get the target for loss depending on the prediction type
|
|
if noise_scheduler.config.prediction_type == "epsilon":
|
|
target = noise
|
|
elif noise_scheduler.config.prediction_type == "v_prediction":
|
|
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
|
|
)
|
|
|
|
loss = (
|
|
F.mse_loss(noise_pred, target, reduction="none")
|
|
.mean([1, 2, 3])
|
|
.mean()
|
|
)
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
return loss
|
|
|
|
def training_function():
|
|
max_train_steps = hyperparameters["max_train_steps"]
|
|
output_dir = hyperparameters["output_dir"]
|
|
gradient_checkpointing = hyperparameters["gradient_checkpointing"]
|
|
|
|
train_dataloader = create_dataloader(train_batch_size)
|
|
|
|
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
|
num_update_steps_per_epoch = math.ceil(
|
|
len(train_dataloader) / gradient_accumulation_steps
|
|
)
|
|
num_train_epochs = math.ceil(
|
|
max_train_steps / num_update_steps_per_epoch
|
|
)
|
|
|
|
# Train!
|
|
total_batch_size = (
|
|
train_batch_size
|
|
* gradient_accumulation_steps
|
|
# train_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
|
)
|
|
|
|
logger.info("***** Running training *****")
|
|
logger.info(f" Num examples = {len(train_dataset)}")
|
|
logger.info(
|
|
f" Instantaneous batch size per device = {train_batch_size}"
|
|
)
|
|
logger.info(
|
|
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
|
)
|
|
logger.info(
|
|
f" Gradient Accumulation steps = {gradient_accumulation_steps}"
|
|
)
|
|
logger.info(f" Total optimization steps = {max_train_steps}")
|
|
# Only show the progress bar once on each machine.
|
|
progress_bar = tqdm(
|
|
# range(max_train_steps), disable=not accelerator.is_local_main_process
|
|
range(max_train_steps)
|
|
)
|
|
progress_bar.set_description("Steps")
|
|
global_step = 0
|
|
|
|
params__ = [
|
|
i for i in text_encoder.get_input_embeddings().parameters()
|
|
]
|
|
|
|
for epoch in range(num_train_epochs):
|
|
unet.train()
|
|
for step, batch in enumerate(train_dataloader):
|
|
dynamo_callable = dynamo.optimize(
|
|
refbackend_torchdynamo_backend
|
|
)(train_func)
|
|
lam_func = lambda x, y: dynamo_callable(
|
|
torch.from_numpy(x), torch.from_numpy(y)
|
|
)
|
|
loss = predictions(
|
|
train_func,
|
|
lam_func,
|
|
batch["pixel_values"],
|
|
batch["input_ids"],
|
|
)
|
|
|
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
|
progress_bar.update(1)
|
|
global_step += 1
|
|
|
|
logs = {"loss": loss.detach().item()}
|
|
progress_bar.set_postfix(**logs)
|
|
|
|
if global_step >= max_train_steps:
|
|
break
|
|
|
|
training_function()
|
|
|
|
# Save the lora weights
|
|
unet.save_attn_procs(args.lora_save_dir)
|
|
|
|
for param in itertools.chain(unet.parameters(), text_encoder.parameters()):
|
|
if param.grad is not None:
|
|
del param.grad # free some memory
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if args.clear_all:
|
|
clear_all()
|
|
|
|
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
|
cpu_scheduling = not args.scheduler.startswith("Shark")
|
|
set_init_device_flags()
|
|
schedulers = get_schedulers(args.hf_model_id)
|
|
scheduler_obj = schedulers[args.scheduler]
|
|
seed = args.seed
|
|
if len(args.prompts) != 1:
|
|
print("Need exactly one prompt for the LoRA word")
|
|
lora_train(
|
|
args.prompts[0],
|
|
args.height,
|
|
args.width,
|
|
args.training_steps,
|
|
args.guidance_scale,
|
|
args.seed,
|
|
args.batch_count,
|
|
args.batch_size,
|
|
args.scheduler,
|
|
"None",
|
|
args.hf_model_id,
|
|
args.precision,
|
|
args.device,
|
|
args.max_length,
|
|
args.training_images_dir,
|
|
args.lora_save_dir,
|
|
args.use_lora,
|
|
)
|