mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[SD] Add support for a compiled version of the discrete Euler scheduler (#657)
* Add Shark version of euler scheduler * Add Shark version of euler scheduler to web ui
This commit is contained in:
@@ -17,6 +17,9 @@ import numpy as np
|
||||
from stable_args import args
|
||||
from utils import get_shark_model, set_iree_runtime_flags
|
||||
from opt_params import get_unet, get_vae, get_clip
|
||||
from schedulers import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
)
|
||||
import time
|
||||
import sys
|
||||
from shark.iree_utils.compile_utils import dump_isas
|
||||
@@ -78,6 +81,7 @@ if __name__ == "__main__":
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
cpu_scheduling = True
|
||||
if args.version == "v2.1":
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1", subfolder="tokenizer"
|
||||
@@ -93,10 +97,19 @@ if __name__ == "__main__":
|
||||
"stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer"
|
||||
)
|
||||
|
||||
scheduler = EulerDiscreteScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
if args.use_compiled_scheduler:
|
||||
scheduler = SharkEulerDiscreteScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
scheduler.compile()
|
||||
cpu_scheduling = False
|
||||
else:
|
||||
scheduler = EulerDiscreteScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
|
||||
text_input = tokenizer(
|
||||
@@ -144,36 +157,42 @@ if __name__ == "__main__":
|
||||
print(f"i = {i} t = {t}", end="")
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||
latents_numpy = latent_model_input.detach().numpy()
|
||||
if cpu_scheduling:
|
||||
latent_model_input = latent_model_input.detach().numpy()
|
||||
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
|
||||
noise_pred = unet.forward(
|
||||
(
|
||||
latents_numpy,
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
)
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
end_profiling(profile_device)
|
||||
|
||||
noise_pred = torch.from_numpy(noise_pred)
|
||||
if cpu_scheduling:
|
||||
noise_pred = torch.from_numpy(noise_pred.to_host())
|
||||
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
||||
else:
|
||||
latents = scheduler.step(noise_pred, t, latents)
|
||||
step_time = time.time() - step_start
|
||||
avg_ms += step_time
|
||||
step_ms = int((step_time) * 1000)
|
||||
print(f" ({step_ms}ms)")
|
||||
|
||||
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
||||
|
||||
avg_ms = 1000 * avg_ms / args.steps
|
||||
print(f"Average step time: {avg_ms}ms/it")
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
# latents = latents.
|
||||
latents_numpy = latents.detach().numpy()
|
||||
latents_numpy = latents
|
||||
if cpu_scheduling:
|
||||
latents_numpy = latents.detach().numpy()
|
||||
profile_device = start_profiling(file_path="vae.rdc")
|
||||
vae_start = time.time()
|
||||
image = vae.forward((latents_numpy,))
|
||||
|
||||
131
shark/examples/shark_inference/stable_diffusion/schedulers.py
Normal file
131
shark/examples/shark_inference/stable_diffusion/schedulers.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from utils import compile_through_fx, get_shark_model
|
||||
from stable_args import args
|
||||
import torch
|
||||
|
||||
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
|
||||
|
||||
model_input = {
|
||||
"euler": {
|
||||
"latent": torch.randn(1, 4, 64, 64),
|
||||
"output": torch.randn(1, 4, 64, 64),
|
||||
"sigma": torch.tensor(1).to(torch.float32),
|
||||
"dt": torch.tensor(1).to(torch.float32),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
):
|
||||
super().__init__(
|
||||
num_train_timesteps,
|
||||
beta_start,
|
||||
beta_end,
|
||||
beta_schedule,
|
||||
trained_betas,
|
||||
prediction_type,
|
||||
)
|
||||
|
||||
def compile(self):
|
||||
example_latent = model_input["euler"]["latent"]
|
||||
example_output = model_input["euler"]["output"]
|
||||
if args.precision == "fp16":
|
||||
example_latent = example_latent.half()
|
||||
example_output = example_output.half()
|
||||
example_sigma = model_input["euler"]["sigma"]
|
||||
example_dt = model_input["euler"]["dt"]
|
||||
|
||||
class ScalingModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, latent, sigma):
|
||||
return latent / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
class SchedulerStepModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, noise_pred, sigma, latent, dt):
|
||||
pred_original_sample = latent - sigma * noise_pred
|
||||
derivative = (latent - pred_original_sample) / sigma
|
||||
return latent + derivative * dt
|
||||
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
|
||||
if args.import_mlir:
|
||||
scaling_model = ScalingModel()
|
||||
self.scaling_model = compile_through_fx(
|
||||
scaling_model,
|
||||
(example_latent, example_sigma),
|
||||
model_name="euler_scale_model_input_" + args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
step_model = SchedulerStepModel()
|
||||
self.step_model = compile_through_fx(
|
||||
step_model,
|
||||
(example_output, example_sigma, example_latent, example_dt),
|
||||
model_name="euler_step_" + args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
else:
|
||||
self.scaling_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_scale_model_input_" + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
self.step_model = get_shark_model(
|
||||
SCHEDULER_BUCKET, "euler_step_" + args.precision, iree_flags
|
||||
)
|
||||
|
||||
def scale_model_input(self, sample, timestep):
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
return self.scaling_model.forward(
|
||||
(
|
||||
sample,
|
||||
sigma,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
def step(self, noise_pred, timestep, latent):
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
dt = self.sigmas[step_index + 1] - sigma
|
||||
return self.step_model.forward(
|
||||
(
|
||||
noise_pred,
|
||||
sigma,
|
||||
latent,
|
||||
dt,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
@@ -132,6 +132,13 @@ p.add_argument(
|
||||
### Misc. Debug and Optimization flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--use_compiled_scheduler",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="use the default scheduler precompiled into the model if available",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--local_tank_cache",
|
||||
default="",
|
||||
|
||||
@@ -348,21 +348,31 @@ def export_module_to_mlir_file(module, frontend, directory: str):
|
||||
return filename
|
||||
|
||||
|
||||
def get_results(compiled_vm, input, config, frontend="torch"):
|
||||
def get_results(
|
||||
compiled_vm, input, config, frontend="torch", send_to_host=True
|
||||
):
|
||||
"""Runs a .vmfb file given inputs and config and returns output."""
|
||||
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
|
||||
result = compiled_vm(*device_inputs)
|
||||
result_tensors = []
|
||||
if isinstance(result, tuple):
|
||||
for val in result:
|
||||
result_tensors.append(np.copy(np.asarray(val, val.dtype)))
|
||||
if send_to_host:
|
||||
for val in result:
|
||||
result_tensors.append(np.asarray(val, val.dtype))
|
||||
else:
|
||||
for val in result:
|
||||
result_tensors.append(val)
|
||||
return result_tensors
|
||||
elif isinstance(result, dict):
|
||||
data = list(result.items())
|
||||
res = np.array(data, dtype=object)
|
||||
return np.copy(res)
|
||||
if send_to_host:
|
||||
res = np.array(data, dtype=object)
|
||||
return np.copy(res)
|
||||
return data
|
||||
else:
|
||||
return result.to_host()
|
||||
if send_to_host:
|
||||
return result.to_host()
|
||||
return result
|
||||
|
||||
|
||||
def get_iree_runtime_config(device):
|
||||
|
||||
@@ -138,8 +138,8 @@ class SharkInference:
|
||||
os.system(f"rm -rf {self.temp_dispatch_benchmarks_dir}")
|
||||
|
||||
# inputs are considered to be tuple of np.array.
|
||||
def forward(self, inputs: tuple):
|
||||
return self.shark_runner.run(inputs)
|
||||
def forward(self, inputs: tuple, send_to_host=True):
|
||||
return self.shark_runner.run(inputs, send_to_host)
|
||||
|
||||
# Captures the static input information from the mlir_module.
|
||||
# TODO(pashu123): Generate the input information for dynamic shapes.
|
||||
|
||||
@@ -91,10 +91,11 @@ class SharkRunner:
|
||||
extra_args=self.extra_args,
|
||||
)
|
||||
|
||||
def run(self, inputs: tuple):
|
||||
def run(self, inputs: tuple, send_to_host=False):
|
||||
return get_results(
|
||||
self.iree_compilation_module,
|
||||
inputs,
|
||||
self.iree_config,
|
||||
self.mlir_dialect,
|
||||
send_to_host,
|
||||
)
|
||||
|
||||
@@ -114,13 +114,14 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
|
||||
with gr.Row():
|
||||
scheduler_key = gr.Dropdown(
|
||||
label="Scheduler",
|
||||
value="EulerDiscrete",
|
||||
value="SharkEulerDiscrete",
|
||||
choices=[
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"LMSDiscrete",
|
||||
"DPMSolverMultistep",
|
||||
"EulerDiscrete",
|
||||
"SharkEulerDiscrete",
|
||||
],
|
||||
)
|
||||
with gr.Group():
|
||||
|
||||
@@ -9,6 +9,9 @@ from diffusers import (
|
||||
from models.stable_diffusion.opt_params import get_unet, get_vae, get_clip
|
||||
from models.stable_diffusion.utils import set_iree_runtime_flags
|
||||
from models.stable_diffusion.stable_args import args
|
||||
from models.stable_diffusion.schedulers import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
)
|
||||
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
|
||||
|
||||
|
||||
@@ -39,6 +42,11 @@ schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
|
||||
model_config[args.version],
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["SharkEulerDiscrete"] = SharkEulerDiscreteScheduler.from_pretrained(
|
||||
model_config[args.version],
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["SharkEulerDiscrete"].compile()
|
||||
|
||||
# use tuned unet model in case of rdna3 cards.
|
||||
if "rdna3" in get_vulkan_triple_flag():
|
||||
|
||||
@@ -56,6 +56,7 @@ def stable_diff_inf(
|
||||
cache_obj["tokenizer"],
|
||||
)
|
||||
scheduler = schedulers[scheduler_key]
|
||||
cpu_scheduling = not scheduler_key.startswith("Shark")
|
||||
|
||||
start = time.time()
|
||||
text_input = tokenizer(
|
||||
@@ -104,27 +105,35 @@ def stable_diff_inf(
|
||||
|
||||
step_start = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
latents_model_input = scheduler.scale_model_input(latents, t)
|
||||
latents_numpy = latents_model_input.detach().numpy()
|
||||
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||
if cpu_scheduling:
|
||||
latent_model_input = latent_model_input.detach().numpy()
|
||||
|
||||
noise_pred = unet.forward(
|
||||
(
|
||||
latents_numpy,
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
args.guidance_scale,
|
||||
)
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
noise_pred = torch.from_numpy(noise_pred)
|
||||
|
||||
if cpu_scheduling:
|
||||
noise_pred = torch.from_numpy(noise_pred.to_host())
|
||||
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
||||
else:
|
||||
latents = scheduler.step(noise_pred, t, latents)
|
||||
step_time = time.time() - step_start
|
||||
avg_ms += step_time
|
||||
step_ms = int((step_time) * 1000)
|
||||
print(f" \nIteration = {i}, Time = {step_ms}ms")
|
||||
latents = scheduler.step(noise_pred, t, latents)["prev_sample"]
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents_numpy = latents.detach().numpy()
|
||||
latents_numpy = latents
|
||||
if cpu_scheduling:
|
||||
latents_numpy = latents.detach().numpy()
|
||||
vae_start = time.time()
|
||||
image = vae.forward((latents_numpy,))
|
||||
vae_end = time.time()
|
||||
|
||||
131
web/models/stable_diffusion/schedulers.py
Normal file
131
web/models/stable_diffusion/schedulers.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from models.stable_diffusion.utils import compile_through_fx, get_shark_model
|
||||
from models.stable_diffusion.stable_args import args
|
||||
import torch
|
||||
|
||||
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
|
||||
|
||||
model_input = {
|
||||
"euler": {
|
||||
"latent": torch.randn(1, 4, 64, 64),
|
||||
"output": torch.randn(1, 4, 64, 64),
|
||||
"sigma": torch.tensor(1).to(torch.float32),
|
||||
"dt": torch.tensor(1).to(torch.float32),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
):
|
||||
super().__init__(
|
||||
num_train_timesteps,
|
||||
beta_start,
|
||||
beta_end,
|
||||
beta_schedule,
|
||||
trained_betas,
|
||||
prediction_type,
|
||||
)
|
||||
|
||||
def compile(self):
|
||||
example_latent = model_input["euler"]["latent"]
|
||||
example_output = model_input["euler"]["output"]
|
||||
if args.precision == "fp16":
|
||||
example_latent = example_latent.half()
|
||||
example_output = example_output.half()
|
||||
example_sigma = model_input["euler"]["sigma"]
|
||||
example_dt = model_input["euler"]["dt"]
|
||||
|
||||
class ScalingModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, latent, sigma):
|
||||
return latent / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
class SchedulerStepModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, noise_pred, sigma, latent, dt):
|
||||
pred_original_sample = latent - sigma * noise_pred
|
||||
derivative = (latent - pred_original_sample) / sigma
|
||||
return latent + derivative * dt
|
||||
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
|
||||
if args.import_mlir:
|
||||
scaling_model = ScalingModel()
|
||||
self.scaling_model = compile_through_fx(
|
||||
scaling_model,
|
||||
(example_latent, example_sigma),
|
||||
model_name="euler_scale_model_input_" + args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
step_model = SchedulerStepModel()
|
||||
self.step_model = compile_through_fx(
|
||||
step_model,
|
||||
(example_output, example_sigma, example_latent, example_dt),
|
||||
model_name="euler_step_" + args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
else:
|
||||
self.scaling_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_scale_model_input_" + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
self.step_model = get_shark_model(
|
||||
SCHEDULER_BUCKET, "euler_step_" + args.precision, iree_flags
|
||||
)
|
||||
|
||||
def scale_model_input(self, sample, timestep):
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
return self.scaling_model.forward(
|
||||
(
|
||||
sample,
|
||||
sigma,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
def step(self, noise_pred, timestep, latent):
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
dt = self.sigmas[step_index + 1] - sigma
|
||||
return self.step_model.forward(
|
||||
(
|
||||
noise_pred,
|
||||
sigma,
|
||||
latent,
|
||||
dt,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
Reference in New Issue
Block a user