[WEB] Cache model parameters (#452)

This commit cache some of the model parameters to reduce the response
time of shark web.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2022-11-01 00:25:10 +05:30
committed by GitHub
parent 25931d48a3
commit 1939376d72
6 changed files with 75 additions and 95 deletions

View File

@@ -35,6 +35,6 @@ pip install --pre torch-mlir torch torchvision --extra-index-url https://downloa
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
Write-Host "Building SHARK..."
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
pip install diffusers transformers scipy
pip install diffusers transformers scipy gradio
Write-Host "Build and installation completed successfully"
Write-Host "Source your venv with ./shark.venv/Scripts/activate"

View File

@@ -16,12 +16,20 @@ with gr.Blocks() as shark_web:
with gr.Row():
with gr.Group():
with gr.Column(scale=1):
img = Image.open("./Nod_logo.png")
gr.Image(value=img, show_label=False, interactive=False).style(
height=80, width=150
)
nod_logo = Image.open("./logos/Nod_logo.png")
gr.Image(
value=nod_logo, show_label=False, interactive=False
).style(height=80, width=150)
with gr.Column(scale=1):
gr.Label(value="Shark Models Demo.")
logo2 = Image.open("./logos/other_logo.png")
gr.Image(
value=logo2,
show_label=False,
interactive=False,
visible=False,
).style(height=80, width=150)
with gr.Column(scale=1):
gr.Label(value="Ultra fast Stable Diffusion")
with gr.Tabs():
# with gr.TabItem("ResNet50"):
@@ -136,9 +144,7 @@ with gr.Blocks() as shark_web:
) = (
device
) = (
load_vmfb
) = (
save_vmfb
cache
) = (
iree_vulkan_target_triple
) = (
@@ -160,6 +166,7 @@ with gr.Blocks() as shark_web:
prompt = gr.Textbox(
label="Prompt",
value="a photograph of an astronaut riding a horse",
lines=5,
)
ex = gr.Examples(
examples=examples,
@@ -219,11 +226,10 @@ with gr.Blocks() as shark_web:
value="42", max_lines=1, label="Seed"
)
with gr.Row():
load_vmfb = gr.Checkbox(label="Load vmfb", value=True)
save_vmfb = gr.Checkbox(label="Save vmfb", value=False)
cache = gr.Checkbox(label="Cache", value=True)
debug = gr.Checkbox(label="DEBUG", value=False)
live_preview = gr.Checkbox(
label="live preview", value=False
label="Live Preview", value=False
)
iree_vulkan_target_triple = gr.Textbox(
value="",
@@ -231,7 +237,7 @@ with gr.Blocks() as shark_web:
label="IREE VULKAN TARGET TRIPLE",
visible=False,
)
stable_diffusion = gr.Button("Generate image from prompt")
stable_diffusion = gr.Button("Generate Image")
with gr.Column(scale=1, min_width=600):
generated_img = gr.Image(type="pil", shape=(100, 100))
std_output = gr.Textbox(
@@ -261,8 +267,7 @@ with gr.Blocks() as shark_web:
seed,
precision,
device,
load_vmfb,
save_vmfb,
cache,
iree_vulkan_target_triple,
live_preview,
],
@@ -270,4 +275,4 @@ with gr.Blocks() as shark_web:
)
shark_web.queue()
shark_web.launch(share=True, server_port=8080, enable_queue=True)
shark_web.launch(server_port=8080, enable_queue=True)

View File

Before

Width:  |  Height:  |  Size: 33 KiB

After

Width:  |  Height:  |  Size: 33 KiB

BIN
web/logos/other_logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

View File

@@ -22,7 +22,6 @@ UNET_FP32 = "unet_fp32"
IREE_EXTRA_ARGS = []
args = None
DEBUG = False
class Arguments:
@@ -39,8 +38,7 @@ class Arguments:
seed: int,
precision: str,
device: str,
load_vmfb: bool,
save_vmfb: bool,
cache: bool,
iree_vulkan_target_triple: str,
live_preview: bool,
import_mlir: bool = False,
@@ -57,8 +55,7 @@ class Arguments:
self.seed = seed
self.precision = precision
self.device = device
self.load_vmfb = load_vmfb
self.save_vmfb = save_vmfb
self.cache = cache
self.iree_vulkan_target_triple = iree_vulkan_target_triple
self.live_preview = live_preview
self.import_mlir = import_mlir
@@ -101,6 +98,37 @@ def get_models():
return None, None
schedulers = dict()
# set scheduler value
schedulers["PNDM"] = PNDMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
)
schedulers["LMS"] = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
)
schedulers["DDIM"] = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
cache_obj = dict()
cache_obj["tokenizer"] = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14"
)
cache_obj["text_encoder"] = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14"
)
def stable_diff_inf(
prompt: str,
scheduler: str,
@@ -113,21 +141,17 @@ def stable_diff_inf(
seed: str,
precision: str,
device: str,
load_vmfb: bool,
save_vmfb: bool,
cache: bool,
iree_vulkan_target_triple: str,
live_preview: bool,
):
global IREE_EXTRA_ARGS
global args
global DEBUG
global schedulers
global cache_obj
output_loc = f"stored_results/stable_diffusion/{prompt}_{int(steps)}_{precision}_{device}.jpg"
DEBUG = False
log_write = open(r"logs/stable_diffusion_log.txt", "w")
if log_write:
DEBUG = True
output_loc = f"stored_results/stable_diffusion/{time.time()}_{int(steps)}_{precision}_{device}.jpg"
# set seed value
if seed == "":
@@ -138,34 +162,7 @@ def stable_diff_inf(
except ValueError:
seed = hash(seed)
# set scheduler value
if scheduler == "PNDM":
scheduler = PNDMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
)
elif scheduler == "LMS":
scheduler = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
)
elif scheduler == "DDIM":
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
else:
raise Exception(
f"Does not support scheduler with name {args.scheduler}."
)
scheduler = schedulers[scheduler]
args = Arguments(
prompt,
scheduler,
@@ -178,8 +175,7 @@ def stable_diff_inf(
seed,
precision,
device,
load_vmfb,
save_vmfb,
cache,
iree_vulkan_target_triple,
live_preview,
)
@@ -194,11 +190,8 @@ def stable_diff_inf(
) # Seed generator to create the inital latent noise
vae, unet = get_models()
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14"
)
tokenizer = cache_obj["tokenizer"]
text_encoder = cache_obj["text_encoder"]
text_input = tokenizer(
[args.prompt] * batch_size,
padding="max_length",
@@ -233,10 +226,10 @@ def stable_diff_inf(
avg_ms = 0
out_img = None
text_output = ""
for i, t in tqdm(enumerate(scheduler.timesteps)):
if DEBUG:
log_write.write(f"\ni = {i} t = {t} ")
text_output = text_output + f"\ni = {i} t = {t} "
step_start = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
latents_numpy = latents.detach().numpy()
@@ -249,8 +242,7 @@ def stable_diff_inf(
step_time = time.time() - step_start
avg_ms += step_time
step_ms = int((step_time) * 1000)
if DEBUG:
log_write.write(f"time={step_ms}ms")
text_output = text_output + f"time={step_ms}ms"
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
if live_preview:
@@ -263,7 +255,7 @@ def stable_diff_inf(
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
out_img = pil_images[0]
yield out_img, ""
yield out_img, text_output
# scale and decode the image latents with vae
if not live_preview:
@@ -277,14 +269,8 @@ def stable_diff_inf(
out_img = pil_images[0]
avg_ms = 1000 * avg_ms / args.steps
if DEBUG:
log_write.write(f"\nAverage step time: {avg_ms}ms/it")
text_output = text_output + f"\nAverage step time: {avg_ms}ms/it"
# save the output image with the prompt name.
out_img.save(os.path.join(output_loc))
log_write.close()
std_output = ""
with open(r"logs/stable_diffusion_log.txt", "r") as log_read:
std_output = log_read.read()
yield out_img, std_output
yield out_img, text_output

View File

@@ -7,27 +7,16 @@ import os
def _compile_module(args, shark_module, model_name, extra_args=[]):
if args.load_vmfb or args.save_vmfb:
extended_name = "{}_{}".format(model_name, args.device)
extended_name = "{}_{}".format(model_name, args.device)
if args.cache:
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
if os.path.isfile(vmfb_path):
print("Loading flatbuffer from {}".format(vmfb_path))
shark_module.load_module(vmfb_path)
else:
if args.save_vmfb:
print("Saving to {}".format(vmfb_path))
else:
print(
"No vmfb found. Compiling and saving to {}".format(
vmfb_path
)
)
path = shark_module.save_module(
os.getcwd(), extended_name, extra_args
)
shark_module.load_module(path)
else:
shark_module.compile(extra_args)
return shark_module
print("No vmfb found. Compiling and saving to {}".format(vmfb_path))
path = shark_module.save_module(os.getcwd(), extended_name, extra_args)
shark_module.load_module(path)
return shark_module