mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 21:38:04 -05:00
[WEB] Cache model parameters (#452)
This commit cache some of the model parameters to reduce the response time of shark web. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com> Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
@@ -35,6 +35,6 @@ pip install --pre torch-mlir torch torchvision --extra-index-url https://downloa
|
||||
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
|
||||
Write-Host "Building SHARK..."
|
||||
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
pip install diffusers transformers scipy
|
||||
pip install diffusers transformers scipy gradio
|
||||
Write-Host "Build and installation completed successfully"
|
||||
Write-Host "Source your venv with ./shark.venv/Scripts/activate"
|
||||
|
||||
35
web/index.py
35
web/index.py
@@ -16,12 +16,20 @@ with gr.Blocks() as shark_web:
|
||||
with gr.Row():
|
||||
with gr.Group():
|
||||
with gr.Column(scale=1):
|
||||
img = Image.open("./Nod_logo.png")
|
||||
gr.Image(value=img, show_label=False, interactive=False).style(
|
||||
height=80, width=150
|
||||
)
|
||||
nod_logo = Image.open("./logos/Nod_logo.png")
|
||||
gr.Image(
|
||||
value=nod_logo, show_label=False, interactive=False
|
||||
).style(height=80, width=150)
|
||||
with gr.Column(scale=1):
|
||||
gr.Label(value="Shark Models Demo.")
|
||||
logo2 = Image.open("./logos/other_logo.png")
|
||||
gr.Image(
|
||||
value=logo2,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
visible=False,
|
||||
).style(height=80, width=150)
|
||||
with gr.Column(scale=1):
|
||||
gr.Label(value="Ultra fast Stable Diffusion")
|
||||
|
||||
with gr.Tabs():
|
||||
# with gr.TabItem("ResNet50"):
|
||||
@@ -136,9 +144,7 @@ with gr.Blocks() as shark_web:
|
||||
) = (
|
||||
device
|
||||
) = (
|
||||
load_vmfb
|
||||
) = (
|
||||
save_vmfb
|
||||
cache
|
||||
) = (
|
||||
iree_vulkan_target_triple
|
||||
) = (
|
||||
@@ -160,6 +166,7 @@ with gr.Blocks() as shark_web:
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value="a photograph of an astronaut riding a horse",
|
||||
lines=5,
|
||||
)
|
||||
ex = gr.Examples(
|
||||
examples=examples,
|
||||
@@ -219,11 +226,10 @@ with gr.Blocks() as shark_web:
|
||||
value="42", max_lines=1, label="Seed"
|
||||
)
|
||||
with gr.Row():
|
||||
load_vmfb = gr.Checkbox(label="Load vmfb", value=True)
|
||||
save_vmfb = gr.Checkbox(label="Save vmfb", value=False)
|
||||
cache = gr.Checkbox(label="Cache", value=True)
|
||||
debug = gr.Checkbox(label="DEBUG", value=False)
|
||||
live_preview = gr.Checkbox(
|
||||
label="live preview", value=False
|
||||
label="Live Preview", value=False
|
||||
)
|
||||
iree_vulkan_target_triple = gr.Textbox(
|
||||
value="",
|
||||
@@ -231,7 +237,7 @@ with gr.Blocks() as shark_web:
|
||||
label="IREE VULKAN TARGET TRIPLE",
|
||||
visible=False,
|
||||
)
|
||||
stable_diffusion = gr.Button("Generate image from prompt")
|
||||
stable_diffusion = gr.Button("Generate Image")
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
generated_img = gr.Image(type="pil", shape=(100, 100))
|
||||
std_output = gr.Textbox(
|
||||
@@ -261,8 +267,7 @@ with gr.Blocks() as shark_web:
|
||||
seed,
|
||||
precision,
|
||||
device,
|
||||
load_vmfb,
|
||||
save_vmfb,
|
||||
cache,
|
||||
iree_vulkan_target_triple,
|
||||
live_preview,
|
||||
],
|
||||
@@ -270,4 +275,4 @@ with gr.Blocks() as shark_web:
|
||||
)
|
||||
|
||||
shark_web.queue()
|
||||
shark_web.launch(share=True, server_port=8080, enable_queue=True)
|
||||
shark_web.launch(server_port=8080, enable_queue=True)
|
||||
|
||||
|
Before Width: | Height: | Size: 33 KiB After Width: | Height: | Size: 33 KiB |
BIN
web/logos/other_logo.png
Normal file
BIN
web/logos/other_logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 33 KiB |
@@ -22,7 +22,6 @@ UNET_FP32 = "unet_fp32"
|
||||
|
||||
IREE_EXTRA_ARGS = []
|
||||
args = None
|
||||
DEBUG = False
|
||||
|
||||
|
||||
class Arguments:
|
||||
@@ -39,8 +38,7 @@ class Arguments:
|
||||
seed: int,
|
||||
precision: str,
|
||||
device: str,
|
||||
load_vmfb: bool,
|
||||
save_vmfb: bool,
|
||||
cache: bool,
|
||||
iree_vulkan_target_triple: str,
|
||||
live_preview: bool,
|
||||
import_mlir: bool = False,
|
||||
@@ -57,8 +55,7 @@ class Arguments:
|
||||
self.seed = seed
|
||||
self.precision = precision
|
||||
self.device = device
|
||||
self.load_vmfb = load_vmfb
|
||||
self.save_vmfb = save_vmfb
|
||||
self.cache = cache
|
||||
self.iree_vulkan_target_triple = iree_vulkan_target_triple
|
||||
self.live_preview = live_preview
|
||||
self.import_mlir = import_mlir
|
||||
@@ -101,6 +98,37 @@ def get_models():
|
||||
return None, None
|
||||
|
||||
|
||||
schedulers = dict()
|
||||
# set scheduler value
|
||||
schedulers["PNDM"] = PNDMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
)
|
||||
schedulers["LMS"] = LMSDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
)
|
||||
schedulers["DDIM"] = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
|
||||
cache_obj = dict()
|
||||
cache_obj["tokenizer"] = CLIPTokenizer.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
cache_obj["text_encoder"] = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
|
||||
|
||||
def stable_diff_inf(
|
||||
prompt: str,
|
||||
scheduler: str,
|
||||
@@ -113,21 +141,17 @@ def stable_diff_inf(
|
||||
seed: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
load_vmfb: bool,
|
||||
save_vmfb: bool,
|
||||
cache: bool,
|
||||
iree_vulkan_target_triple: str,
|
||||
live_preview: bool,
|
||||
):
|
||||
|
||||
global IREE_EXTRA_ARGS
|
||||
global args
|
||||
global DEBUG
|
||||
global schedulers
|
||||
global cache_obj
|
||||
|
||||
output_loc = f"stored_results/stable_diffusion/{prompt}_{int(steps)}_{precision}_{device}.jpg"
|
||||
DEBUG = False
|
||||
log_write = open(r"logs/stable_diffusion_log.txt", "w")
|
||||
if log_write:
|
||||
DEBUG = True
|
||||
output_loc = f"stored_results/stable_diffusion/{time.time()}_{int(steps)}_{precision}_{device}.jpg"
|
||||
|
||||
# set seed value
|
||||
if seed == "":
|
||||
@@ -138,34 +162,7 @@ def stable_diff_inf(
|
||||
except ValueError:
|
||||
seed = hash(seed)
|
||||
|
||||
# set scheduler value
|
||||
if scheduler == "PNDM":
|
||||
scheduler = PNDMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
)
|
||||
elif scheduler == "LMS":
|
||||
scheduler = LMSDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
)
|
||||
elif scheduler == "DDIM":
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Does not support scheduler with name {args.scheduler}."
|
||||
)
|
||||
|
||||
scheduler = schedulers[scheduler]
|
||||
args = Arguments(
|
||||
prompt,
|
||||
scheduler,
|
||||
@@ -178,8 +175,7 @@ def stable_diff_inf(
|
||||
seed,
|
||||
precision,
|
||||
device,
|
||||
load_vmfb,
|
||||
save_vmfb,
|
||||
cache,
|
||||
iree_vulkan_target_triple,
|
||||
live_preview,
|
||||
)
|
||||
@@ -194,11 +190,8 @@ def stable_diff_inf(
|
||||
) # Seed generator to create the inital latent noise
|
||||
|
||||
vae, unet = get_models()
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
|
||||
tokenizer = cache_obj["tokenizer"]
|
||||
text_encoder = cache_obj["text_encoder"]
|
||||
text_input = tokenizer(
|
||||
[args.prompt] * batch_size,
|
||||
padding="max_length",
|
||||
@@ -233,10 +226,10 @@ def stable_diff_inf(
|
||||
|
||||
avg_ms = 0
|
||||
out_img = None
|
||||
text_output = ""
|
||||
for i, t in tqdm(enumerate(scheduler.timesteps)):
|
||||
|
||||
if DEBUG:
|
||||
log_write.write(f"\ni = {i} t = {t} ")
|
||||
text_output = text_output + f"\ni = {i} t = {t} "
|
||||
step_start = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
latents_numpy = latents.detach().numpy()
|
||||
@@ -249,8 +242,7 @@ def stable_diff_inf(
|
||||
step_time = time.time() - step_start
|
||||
avg_ms += step_time
|
||||
step_ms = int((step_time) * 1000)
|
||||
if DEBUG:
|
||||
log_write.write(f"time={step_ms}ms")
|
||||
text_output = text_output + f"time={step_ms}ms"
|
||||
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
|
||||
|
||||
if live_preview:
|
||||
@@ -263,7 +255,7 @@ def stable_diff_inf(
|
||||
images = (image * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
out_img = pil_images[0]
|
||||
yield out_img, ""
|
||||
yield out_img, text_output
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
if not live_preview:
|
||||
@@ -277,14 +269,8 @@ def stable_diff_inf(
|
||||
out_img = pil_images[0]
|
||||
|
||||
avg_ms = 1000 * avg_ms / args.steps
|
||||
if DEBUG:
|
||||
log_write.write(f"\nAverage step time: {avg_ms}ms/it")
|
||||
text_output = text_output + f"\nAverage step time: {avg_ms}ms/it"
|
||||
|
||||
# save the output image with the prompt name.
|
||||
out_img.save(os.path.join(output_loc))
|
||||
log_write.close()
|
||||
|
||||
std_output = ""
|
||||
with open(r"logs/stable_diffusion_log.txt", "r") as log_read:
|
||||
std_output = log_read.read()
|
||||
yield out_img, std_output
|
||||
yield out_img, text_output
|
||||
|
||||
@@ -7,27 +7,16 @@ import os
|
||||
|
||||
|
||||
def _compile_module(args, shark_module, model_name, extra_args=[]):
|
||||
if args.load_vmfb or args.save_vmfb:
|
||||
extended_name = "{}_{}".format(model_name, args.device)
|
||||
extended_name = "{}_{}".format(model_name, args.device)
|
||||
if args.cache:
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
||||
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
||||
if os.path.isfile(vmfb_path):
|
||||
print("Loading flatbuffer from {}".format(vmfb_path))
|
||||
shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
if args.save_vmfb:
|
||||
print("Saving to {}".format(vmfb_path))
|
||||
else:
|
||||
print(
|
||||
"No vmfb found. Compiling and saving to {}".format(
|
||||
vmfb_path
|
||||
)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), extended_name, extra_args
|
||||
)
|
||||
shark_module.load_module(path)
|
||||
else:
|
||||
shark_module.compile(extra_args)
|
||||
return shark_module
|
||||
print("No vmfb found. Compiling and saving to {}".format(vmfb_path))
|
||||
path = shark_module.save_module(os.getcwd(), extended_name, extra_args)
|
||||
shark_module.load_module(path)
|
||||
return shark_module
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user