[SD] Add support for a compiled version of the discrete Euler scheduler (#657)

* Add Shark version of euler scheduler

* Add Shark version of euler scheduler to web ui
This commit is contained in:
Quinn Dawkins
2022-12-17 22:25:43 -05:00
committed by GitHub
parent ffef1681e3
commit 2bc6de650d
10 changed files with 345 additions and 28 deletions

View File

@@ -17,6 +17,9 @@ import numpy as np
from stable_args import args
from utils import get_shark_model, set_iree_runtime_flags
from opt_params import get_unet, get_vae, get_clip
from schedulers import (
SharkEulerDiscreteScheduler,
)
import time
import sys
from shark.iree_utils.compile_utils import dump_isas
@@ -78,6 +81,7 @@ if __name__ == "__main__":
"CompVis/stable-diffusion-v1-4",
subfolder="scheduler",
)
cpu_scheduling = True
if args.version == "v2.1":
tokenizer = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-2-1", subfolder="tokenizer"
@@ -93,10 +97,19 @@ if __name__ == "__main__":
"stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer"
)
scheduler = EulerDiscreteScheduler.from_pretrained(
"stabilityai/stable-diffusion-2-1-base",
subfolder="scheduler",
)
if args.use_compiled_scheduler:
scheduler = SharkEulerDiscreteScheduler.from_pretrained(
"stabilityai/stable-diffusion-2-1-base",
subfolder="scheduler",
)
scheduler.compile()
cpu_scheduling = False
else:
scheduler = EulerDiscreteScheduler.from_pretrained(
"stabilityai/stable-diffusion-2-1-base",
subfolder="scheduler",
)
start = time.time()
text_input = tokenizer(
@@ -144,36 +157,42 @@ if __name__ == "__main__":
print(f"i = {i} t = {t}", end="")
timestep = torch.tensor([t]).to(dtype).detach().numpy()
latent_model_input = scheduler.scale_model_input(latents, t)
latents_numpy = latent_model_input.detach().numpy()
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = unet.forward(
(
latents_numpy,
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
)
),
send_to_host=False,
)
end_profiling(profile_device)
noise_pred = torch.from_numpy(noise_pred)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = scheduler.step(noise_pred, t, latents).prev_sample
else:
latents = scheduler.step(noise_pred, t, latents)
step_time = time.time() - step_start
avg_ms += step_time
step_ms = int((step_time) * 1000)
print(f" ({step_ms}ms)")
latents = scheduler.step(noise_pred, t, latents).prev_sample
avg_ms = 1000 * avg_ms / args.steps
print(f"Average step time: {avg_ms}ms/it")
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
# latents = latents.
latents_numpy = latents.detach().numpy()
latents_numpy = latents
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
image = vae.forward((latents_numpy,))

View File

@@ -0,0 +1,131 @@
import sys
import numpy as np
from typing import List, Optional, Tuple, Union
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
)
from diffusers.configuration_utils import register_to_config
from utils import compile_through_fx, get_shark_model
from stable_args import args
import torch
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
model_input = {
"euler": {
"latent": torch.randn(1, 4, 64, 64),
"output": torch.randn(1, 4, 64, 64),
"sigma": torch.tensor(1).to(torch.float32),
"dt": torch.tensor(1).to(torch.float32),
},
}
class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
):
super().__init__(
num_train_timesteps,
beta_start,
beta_end,
beta_schedule,
trained_betas,
prediction_type,
)
def compile(self):
example_latent = model_input["euler"]["latent"]
example_output = model_input["euler"]["output"]
if args.precision == "fp16":
example_latent = example_latent.half()
example_output = example_output.half()
example_sigma = model_input["euler"]["sigma"]
example_dt = model_input["euler"]["dt"]
class ScalingModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, latent, sigma):
return latent / ((sigma**2 + 1) ** 0.5)
class SchedulerStepModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, noise_pred, sigma, latent, dt):
pred_original_sample = latent - sigma * noise_pred
derivative = (latent - pred_original_sample) / sigma
return latent + derivative * dt
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if args.import_mlir:
scaling_model = ScalingModel()
self.scaling_model = compile_through_fx(
scaling_model,
(example_latent, example_sigma),
model_name="euler_scale_model_input_" + args.precision,
extra_args=iree_flags,
)
step_model = SchedulerStepModel()
self.step_model = compile_through_fx(
step_model,
(example_output, example_sigma, example_latent, example_dt),
model_name="euler_step_" + args.precision,
extra_args=iree_flags,
)
else:
self.scaling_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_scale_model_input_" + args.precision,
iree_flags,
)
self.step_model = get_shark_model(
SCHEDULER_BUCKET, "euler_step_" + args.precision, iree_flags
)
def scale_model_input(self, sample, timestep):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
return self.scaling_model.forward(
(
sample,
sigma,
),
send_to_host=False,
)
def step(self, noise_pred, timestep, latent):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
dt = self.sigmas[step_index + 1] - sigma
return self.step_model.forward(
(
noise_pred,
sigma,
latent,
dt,
),
send_to_host=False,
)

View File

@@ -132,6 +132,13 @@ p.add_argument(
### Misc. Debug and Optimization flags
##############################################################################
p.add_argument(
"--use_compiled_scheduler",
default=False,
action=argparse.BooleanOptionalAction,
help="use the default scheduler precompiled into the model if available",
)
p.add_argument(
"--local_tank_cache",
default="",

View File

@@ -348,21 +348,31 @@ def export_module_to_mlir_file(module, frontend, directory: str):
return filename
def get_results(compiled_vm, input, config, frontend="torch"):
def get_results(
compiled_vm, input, config, frontend="torch", send_to_host=True
):
"""Runs a .vmfb file given inputs and config and returns output."""
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
result = compiled_vm(*device_inputs)
result_tensors = []
if isinstance(result, tuple):
for val in result:
result_tensors.append(np.copy(np.asarray(val, val.dtype)))
if send_to_host:
for val in result:
result_tensors.append(np.asarray(val, val.dtype))
else:
for val in result:
result_tensors.append(val)
return result_tensors
elif isinstance(result, dict):
data = list(result.items())
res = np.array(data, dtype=object)
return np.copy(res)
if send_to_host:
res = np.array(data, dtype=object)
return np.copy(res)
return data
else:
return result.to_host()
if send_to_host:
return result.to_host()
return result
def get_iree_runtime_config(device):

View File

@@ -138,8 +138,8 @@ class SharkInference:
os.system(f"rm -rf {self.temp_dispatch_benchmarks_dir}")
# inputs are considered to be tuple of np.array.
def forward(self, inputs: tuple):
return self.shark_runner.run(inputs)
def forward(self, inputs: tuple, send_to_host=True):
return self.shark_runner.run(inputs, send_to_host)
# Captures the static input information from the mlir_module.
# TODO(pashu123): Generate the input information for dynamic shapes.

View File

@@ -91,10 +91,11 @@ class SharkRunner:
extra_args=self.extra_args,
)
def run(self, inputs: tuple):
def run(self, inputs: tuple, send_to_host=False):
return get_results(
self.iree_compilation_module,
inputs,
self.iree_config,
self.mlir_dialect,
send_to_host,
)

View File

@@ -114,13 +114,14 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
with gr.Row():
scheduler_key = gr.Dropdown(
label="Scheduler",
value="EulerDiscrete",
value="SharkEulerDiscrete",
choices=[
"DDIM",
"PNDM",
"LMSDiscrete",
"DPMSolverMultistep",
"EulerDiscrete",
"SharkEulerDiscrete",
],
)
with gr.Group():

View File

@@ -9,6 +9,9 @@ from diffusers import (
from models.stable_diffusion.opt_params import get_unet, get_vae, get_clip
from models.stable_diffusion.utils import set_iree_runtime_flags
from models.stable_diffusion.stable_args import args
from models.stable_diffusion.schedulers import (
SharkEulerDiscreteScheduler,
)
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
@@ -39,6 +42,11 @@ schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_config[args.version],
subfolder="scheduler",
)
schedulers["SharkEulerDiscrete"] = SharkEulerDiscreteScheduler.from_pretrained(
model_config[args.version],
subfolder="scheduler",
)
schedulers["SharkEulerDiscrete"].compile()
# use tuned unet model in case of rdna3 cards.
if "rdna3" in get_vulkan_triple_flag():

View File

@@ -56,6 +56,7 @@ def stable_diff_inf(
cache_obj["tokenizer"],
)
scheduler = schedulers[scheduler_key]
cpu_scheduling = not scheduler_key.startswith("Shark")
start = time.time()
text_input = tokenizer(
@@ -104,27 +105,35 @@ def stable_diff_inf(
step_start = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
latents_model_input = scheduler.scale_model_input(latents, t)
latents_numpy = latents_model_input.detach().numpy()
latent_model_input = scheduler.scale_model_input(latents, t)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
noise_pred = unet.forward(
(
latents_numpy,
latent_model_input,
timestep,
text_embeddings_numpy,
args.guidance_scale,
)
),
send_to_host=False,
)
noise_pred = torch.from_numpy(noise_pred)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = scheduler.step(noise_pred, t, latents).prev_sample
else:
latents = scheduler.step(noise_pred, t, latents)
step_time = time.time() - step_start
avg_ms += step_time
step_ms = int((step_time) * 1000)
print(f" \nIteration = {i}, Time = {step_ms}ms")
latents = scheduler.step(noise_pred, t, latents)["prev_sample"]
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
latents_numpy = latents.detach().numpy()
latents_numpy = latents
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
vae_start = time.time()
image = vae.forward((latents_numpy,))
vae_end = time.time()

View File

@@ -0,0 +1,131 @@
import sys
import numpy as np
from typing import List, Optional, Tuple, Union
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
)
from diffusers.configuration_utils import register_to_config
from models.stable_diffusion.utils import compile_through_fx, get_shark_model
from models.stable_diffusion.stable_args import args
import torch
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
model_input = {
"euler": {
"latent": torch.randn(1, 4, 64, 64),
"output": torch.randn(1, 4, 64, 64),
"sigma": torch.tensor(1).to(torch.float32),
"dt": torch.tensor(1).to(torch.float32),
},
}
class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
):
super().__init__(
num_train_timesteps,
beta_start,
beta_end,
beta_schedule,
trained_betas,
prediction_type,
)
def compile(self):
example_latent = model_input["euler"]["latent"]
example_output = model_input["euler"]["output"]
if args.precision == "fp16":
example_latent = example_latent.half()
example_output = example_output.half()
example_sigma = model_input["euler"]["sigma"]
example_dt = model_input["euler"]["dt"]
class ScalingModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, latent, sigma):
return latent / ((sigma**2 + 1) ** 0.5)
class SchedulerStepModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, noise_pred, sigma, latent, dt):
pred_original_sample = latent - sigma * noise_pred
derivative = (latent - pred_original_sample) / sigma
return latent + derivative * dt
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if args.import_mlir:
scaling_model = ScalingModel()
self.scaling_model = compile_through_fx(
scaling_model,
(example_latent, example_sigma),
model_name="euler_scale_model_input_" + args.precision,
extra_args=iree_flags,
)
step_model = SchedulerStepModel()
self.step_model = compile_through_fx(
step_model,
(example_output, example_sigma, example_latent, example_dt),
model_name="euler_step_" + args.precision,
extra_args=iree_flags,
)
else:
self.scaling_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_scale_model_input_" + args.precision,
iree_flags,
)
self.step_model = get_shark_model(
SCHEDULER_BUCKET, "euler_step_" + args.precision, iree_flags
)
def scale_model_input(self, sample, timestep):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
return self.scaling_model.forward(
(
sample,
sigma,
),
send_to_host=False,
)
def step(self, noise_pred, timestep, latent):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
dt = self.sigmas[step_index + 1] - sigma
return self.step_model.forward(
(
noise_pred,
sigma,
latent,
dt,
),
send_to_host=False,
)