From 6384780d16580a55d677784c668dcd7bc19ba60b Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Tue, 5 Dec 2023 10:19:19 -0600 Subject: [PATCH] Fixes to llama2 cpu compilation and studio UI, schedulers (#2013) * Fix some issues with defaults Fixes to llama2 cpu compilation (turns off data tiling for old argmax mode) --------- Co-authored-by: Max Dawkins --- .../src/schedulers/__init__.py | 2 +- .../src/schedulers/sd_schedulers.py | 6 +- .../stable_diffusion/src/utils/stable_args.py | 7 ++ apps/stable_diffusion/web/index.py | 18 +++-- apps/stable_diffusion/web/ui/stablelm_ui.py | 16 +++-- .../web/ui/txt2img_sdxl_ui.py | 71 ++++++++++++++++--- apps/stable_diffusion/web/ui/utils.py | 21 +++++- shark/iree_utils/compile_utils.py | 2 +- 8 files changed, 117 insertions(+), 26 deletions(-) diff --git a/apps/stable_diffusion/src/schedulers/__init__.py b/apps/stable_diffusion/src/schedulers/__init__.py index 4e6d8db9..e7864e2d 100644 --- a/apps/stable_diffusion/src/schedulers/__init__.py +++ b/apps/stable_diffusion/src/schedulers/__init__.py @@ -1,7 +1,7 @@ -from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import ( SharkEulerDiscreteScheduler, ) from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import ( SharkEulerAncestralDiscreteScheduler, ) +from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers diff --git a/apps/stable_diffusion/src/schedulers/sd_schedulers.py b/apps/stable_diffusion/src/schedulers/sd_schedulers.py index 544fa1ef..913b15c9 100644 --- a/apps/stable_diffusion/src/schedulers/sd_schedulers.py +++ b/apps/stable_diffusion/src/schedulers/sd_schedulers.py @@ -1,4 +1,5 @@ from diffusers import ( + LCMScheduler, LMSDiscreteScheduler, PNDMScheduler, DDPMScheduler, @@ -25,7 +26,6 @@ def get_schedulers(model_id): # set batch_size here, the SHARK schedulers will # compile with batch size = 1 regardless of whether the model # outputs latents of a larger batch size, e.g. SDXL. - # This also goes towards enabling batch size cfg for SD in general. # However, obviously, searching for whether the base model ID # contains "xl" is not very robust. @@ -52,6 +52,10 @@ def get_schedulers(model_id): model_id, subfolder="scheduler", ) + schedulers["LCMScheduler"] = LCMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) schedulers[ "DPMSolverMultistep" ] = DPMSolverMultistepScheduler.from_pretrained( diff --git a/apps/stable_diffusion/src/utils/stable_args.py b/apps/stable_diffusion/src/utils/stable_args.py index b251b64f..88434ff5 100644 --- a/apps/stable_diffusion/src/utils/stable_args.py +++ b/apps/stable_diffusion/src/utils/stable_args.py @@ -467,6 +467,13 @@ p.add_argument( "Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count" "Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)", ) + +p.add_argument( + "--autogen", + type=bool, + default="False", + help="Only used for a gradio workaround.", +) ############################################################################## # IREE - Vulkan supported flags ############################################################################## diff --git a/apps/stable_diffusion/web/index.py b/apps/stable_diffusion/web/index.py index 7bb7be9c..f301d6d2 100644 --- a/apps/stable_diffusion/web/index.py +++ b/apps/stable_diffusion/web/index.py @@ -97,8 +97,6 @@ if __name__ == "__main__": ) return os.path.join(base_path, relative_path) - dark_theme = resource_path("ui/css/sd_dark_theme.css") - from apps.stable_diffusion.web.ui import ( txt2img_web, txt2img_custom_model, @@ -210,6 +208,8 @@ if __name__ == "__main__": outputs, ) + dark_theme = resource_path("ui/css/sd_dark_theme.css") + with gr.Blocks( css=dark_theme, analytics_enabled=False, title="SHARK AI Studio" ) as sd_web: @@ -255,10 +255,10 @@ if __name__ == "__main__": # lora_train_web.render() with gr.TabItem(label="Chat Bot", id=8): stablelm_chat.render() - with gr.TabItem( - label="Generate Sharding Config (Experimental)", id=9 - ): - model_config_web.render() + # with gr.TabItem( + # label="Generate Sharding Config (Experimental)", id=9 + # ): + # model_config_web.render() # with gr.TabItem(label="MultiModal (Experimental)", id=10): # minigpt4_web.render() # with gr.TabItem(label="DocuChat Upload", id=11): @@ -405,6 +405,12 @@ if __name__ == "__main__": [outputgallery_filename], [upscaler_init_image, tabs], ) + register_outputgallery_button( + outputgallery_sendto_txt2img_sdxl, + 0, + [outputgallery_filename], + [txt2img_sdxl_png_info_img, tabs], + ) register_modelmanager_button( modelmanager_sendto_txt2img, 0, diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py index f3baa3c5..706ec27a 100644 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ b/apps/stable_diffusion/web/ui/stablelm_ui.py @@ -173,6 +173,13 @@ def chat( get_vulkan_target_triple, ) + _extra_args = _extra_args + [ + "--iree-global-opt-enable-quantized-matmul-reassociation", + "--iree-llvmcpu-enable-quantized-matmul-reassociation", + "--iree-opt-const-eval=false", + "--iree-opt-data-tiling=false", + ] + if device == "vulkan": vulkaninfo_list = get_all_vulkan_devices() if vulkan_target_triple == "": @@ -250,10 +257,11 @@ def chat( total_time_ms = 0.001 # In order to avoid divide by zero error prefill_time = 0 is_first = True - for text, msg, exec_time in progress.tqdm( - vicuna_model.generate(prompt, cli=cli), - desc="generating response", - ): + # for text, msg, exec_time in progress.tqdm( + # vicuna_model.generate(prompt, cli=cli), + # desc="generating response", + # ): + for text, msg, exec_time in vicuna_model.generate(prompt, cli=cli): if msg is None: if is_first: prefill_time = exec_time diff --git a/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py b/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py index c3a653bd..746428f1 100644 --- a/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py @@ -226,7 +226,12 @@ def txt2img_sdxl_inf( return generated_imgs, text_output, "" -with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web: +theme = gr.themes.Glass( + primary_hue="slate", + secondary_hue="gray", +) + +with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web: with gr.Row(elem_id="ui_title"): nod_logo = Image.open(nodlogo_loc) with gr.Row(): @@ -288,17 +293,24 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web: ) with gr.Group(elem_id="prompt_box_outer"): + txt2img_sdxl_autogen = gr.Checkbox( + label="Auto-Generate Images", + value=False, + visible=False, + ) prompt = gr.Textbox( label="Prompt", value=args.prompts[0], lines=2, elem_id="prompt_box", + show_copy_button=True, ) negative_prompt = gr.Textbox( label="Negative Prompt", value=args.negative_prompts[0], lines=2, elem_id="negative_prompt_box", + show_copy_button=True, ) with gr.Accordion(label="LoRA Options", open=False): with gr.Row(): @@ -340,6 +352,7 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web: "DDIM", "EulerAncestralDiscrete", "EulerDiscrete", + "LCMScheduler", ], allow_custom_value=False, visible=True, @@ -360,7 +373,7 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web: 512, 1024, value=1024, - step=512, + step=256, label="Height", visible=True, interactive=True, @@ -369,7 +382,7 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web: 512, 1024, value=1024, - step=512, + step=256, label="Width", visible=True, interactive=True, @@ -379,7 +392,6 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web: value="fp16", choices=[ "fp16", - "fp32", ], visible=False, ) @@ -427,12 +439,14 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web: value=args.batch_size, step=1, label="Batch Size", - interactive=True, + interactive=False, + visible=False, ) repeatable_seeds = gr.Checkbox( args.repeatable_seeds, label="Repeatable Seeds", ) + with gr.Row(): seed = gr.Textbox( value=args.seed, @@ -484,16 +498,20 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web: stop_batch = gr.Button("Stop Batch") with gr.Row(): txt2img_sdxl_sendto_img2img = gr.Button( - value="Send To Img2Img" + value="Send To Img2Img", + visible=False, ) txt2img_sdxl_sendto_inpaint = gr.Button( - value="Send To Inpaint" + value="Send To Inpaint", + visible=False, ) txt2img_sdxl_sendto_outpaint = gr.Button( - value="Send To Outpaint" + value="Send To Outpaint", + visible=False, ) txt2img_sdxl_sendto_upscaler = gr.Button( - value="Send To Upscaler" + value="Send To Upscaler", + visible=False, ) kwargs = dict( @@ -523,14 +541,42 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web: ], outputs=[txt2img_sdxl_gallery, std_output, txt2img_sdxl_status], show_progress="minimal" if args.progress_bar else "none", + queue=True, ) status_kwargs = dict( fn=lambda bc, bs: status_label("Text-to-Image-SDXL", 0, bc, bs), inputs=[batch_count, batch_size], outputs=txt2img_sdxl_status, + concurrency_limit=1, ) + def autogen_changed(checked): + if checked: + args.autogen = True + else: + args.autogen = False + + def check_last_input(prompt): + if not prompt.endswith(" "): + return True + elif not args.autogen: + return True + else: + return False + + auto_gen_kwargs = dict( + fn=check_last_input, + inputs=[negative_prompt], + outputs=[txt2img_sdxl_status], + concurrency_limit=1, + ) + + txt2img_sdxl_autogen.change( + fn=autogen_changed, + inputs=[txt2img_sdxl_autogen], + outputs=None, + ) prompt_submit = prompt.submit(**status_kwargs).then(**kwargs) neg_prompt_submit = negative_prompt.submit(**status_kwargs).then( **kwargs @@ -538,7 +584,11 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web: generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs) stop_batch.click( fn=cancel_sd, - cancels=[prompt_submit, neg_prompt_submit, generate_click], + cancels=[ + prompt_submit, + neg_prompt_submit, + generate_click, + ], ) txt2img_sdxl_png_info_img.change( @@ -588,6 +638,7 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web: width, height, custom_vae, + txt2img_sdxl_autogen, ], ) lora_weights.change( diff --git a/apps/stable_diffusion/web/ui/utils.py b/apps/stable_diffusion/web/ui/utils.py index 3770f982..aecf9366 100644 --- a/apps/stable_diffusion/web/ui/utils.py +++ b/apps/stable_diffusion/web/ui/utils.py @@ -282,6 +282,12 @@ def set_model_default_configs(model_ckpt_or_id, jsonconfig=None): gr.update(), gr.update(), gr.update(), + gr.Checkbox( + label="Auto-Generate", + visible=False, + interactive=False, + value=False, + ), ] @@ -317,7 +323,7 @@ default_configs = { gr.Textbox(label="", interactive=False, value=None, visible=False), gr.Textbox( label="Prompt", - value="role-playing game (RPG) style fantasy, An enchanting image featuring an adorable kitten mage wearing intricate ancient robes, holding an ancient staff, hard at work in her fantastical workshop, magic runes floating in the air", + value="masterpiece, a graceful shark leaping out of the water to catch a fish, eclipsing the sunset, epic, rays of light, silhouette", ), gr.Slider(0, 10, value=2), gr.Dropdown(value="EulerAncestralDiscrete"), @@ -325,6 +331,9 @@ default_configs = { 512, 512, "madebyollin/sdxl-vae-fp16-fix", + gr.Checkbox( + label="Auto-Generate", visible=False, interactive=True, value=False + ), ], "stabilityai/stable-diffusion-xl-base-1.0": [ gr.Textbox(label="Prompt", interactive=True, visible=True), @@ -332,9 +341,15 @@ default_configs = { 40, "EulerDiscrete", 7.5, - gr.Slider(value=1024, interactive=False), - gr.Slider(value=1024, interactive=False), + gr.Slider(value=768, interactive=True), + gr.Slider(value=768, interactive=True), "madebyollin/sdxl-vae-fp16-fix", + gr.Checkbox( + label="Auto-Generate", + visible=False, + interactive=False, + value=False, + ), ], } diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index f30a3095..b5f87378 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -83,7 +83,7 @@ def clean_device_info(raw_device): device_id = int(device_id) if device not in ["rocm", "vulkan"]: - device_id = "" + device_id = None if device in ["rocm", "vulkan"] and device_id == None: device_id = 0 return device, device_id