Fixes to llama2 cpu compilation and studio UI, schedulers (#2013)

* Fix some issues with defaults

Fixes to llama2 cpu compilation (turns off data tiling for old argmax
mode)

---------

Co-authored-by: Max Dawkins <max.dawkins@gmail.com>
This commit is contained in:
Ean Garvey
2023-12-05 10:19:19 -06:00
committed by GitHub
parent db0c53ae59
commit 6384780d16
8 changed files with 117 additions and 26 deletions

View File

@@ -1,7 +1,7 @@
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import (
SharkEulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers

View File

@@ -1,4 +1,5 @@
from diffusers import (
LCMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
DDPMScheduler,
@@ -25,7 +26,6 @@ def get_schedulers(model_id):
# set batch_size here, the SHARK schedulers will
# compile with batch size = 1 regardless of whether the model
# outputs latents of a larger batch size, e.g. SDXL.
# This also goes towards enabling batch size cfg for SD in general.
# However, obviously, searching for whether the base model ID
# contains "xl" is not very robust.
@@ -52,6 +52,10 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverMultistep"
] = DPMSolverMultistepScheduler.from_pretrained(

View File

@@ -467,6 +467,13 @@ p.add_argument(
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
)
p.add_argument(
"--autogen",
type=bool,
default="False",
help="Only used for a gradio workaround.",
)
##############################################################################
# IREE - Vulkan supported flags
##############################################################################

View File

@@ -97,8 +97,6 @@ if __name__ == "__main__":
)
return os.path.join(base_path, relative_path)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
from apps.stable_diffusion.web.ui import (
txt2img_web,
txt2img_custom_model,
@@ -210,6 +208,8 @@ if __name__ == "__main__":
outputs,
)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="SHARK AI Studio"
) as sd_web:
@@ -255,10 +255,10 @@ if __name__ == "__main__":
# lora_train_web.render()
with gr.TabItem(label="Chat Bot", id=8):
stablelm_chat.render()
with gr.TabItem(
label="Generate Sharding Config (Experimental)", id=9
):
model_config_web.render()
# with gr.TabItem(
# label="Generate Sharding Config (Experimental)", id=9
# ):
# model_config_web.render()
# with gr.TabItem(label="MultiModal (Experimental)", id=10):
# minigpt4_web.render()
# with gr.TabItem(label="DocuChat Upload", id=11):
@@ -405,6 +405,12 @@ if __name__ == "__main__":
[outputgallery_filename],
[upscaler_init_image, tabs],
)
register_outputgallery_button(
outputgallery_sendto_txt2img_sdxl,
0,
[outputgallery_filename],
[txt2img_sdxl_png_info_img, tabs],
)
register_modelmanager_button(
modelmanager_sendto_txt2img,
0,

View File

@@ -173,6 +173,13 @@ def chat(
get_vulkan_target_triple,
)
_extra_args = _extra_args + [
"--iree-global-opt-enable-quantized-matmul-reassociation",
"--iree-llvmcpu-enable-quantized-matmul-reassociation",
"--iree-opt-const-eval=false",
"--iree-opt-data-tiling=false",
]
if device == "vulkan":
vulkaninfo_list = get_all_vulkan_devices()
if vulkan_target_triple == "":
@@ -250,10 +257,11 @@ def chat(
total_time_ms = 0.001 # In order to avoid divide by zero error
prefill_time = 0
is_first = True
for text, msg, exec_time in progress.tqdm(
vicuna_model.generate(prompt, cli=cli),
desc="generating response",
):
# for text, msg, exec_time in progress.tqdm(
# vicuna_model.generate(prompt, cli=cli),
# desc="generating response",
# ):
for text, msg, exec_time in vicuna_model.generate(prompt, cli=cli):
if msg is None:
if is_first:
prefill_time = exec_time

View File

@@ -226,7 +226,12 @@ def txt2img_sdxl_inf(
return generated_imgs, text_output, ""
with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
theme = gr.themes.Glass(
primary_hue="slate",
secondary_hue="gray",
)
with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
@@ -288,17 +293,24 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
)
with gr.Group(elem_id="prompt_box_outer"):
txt2img_sdxl_autogen = gr.Checkbox(
label="Auto-Generate Images",
value=False,
visible=False,
)
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=2,
elem_id="prompt_box",
show_copy_button=True,
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=2,
elem_id="negative_prompt_box",
show_copy_button=True,
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
@@ -340,6 +352,7 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
"DDIM",
"EulerAncestralDiscrete",
"EulerDiscrete",
"LCMScheduler",
],
allow_custom_value=False,
visible=True,
@@ -360,7 +373,7 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
512,
1024,
value=1024,
step=512,
step=256,
label="Height",
visible=True,
interactive=True,
@@ -369,7 +382,7 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
512,
1024,
value=1024,
step=512,
step=256,
label="Width",
visible=True,
interactive=True,
@@ -379,7 +392,6 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
value="fp16",
choices=[
"fp16",
"fp32",
],
visible=False,
)
@@ -427,12 +439,14 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
value=args.batch_size,
step=1,
label="Batch Size",
interactive=True,
interactive=False,
visible=False,
)
repeatable_seeds = gr.Checkbox(
args.repeatable_seeds,
label="Repeatable Seeds",
)
with gr.Row():
seed = gr.Textbox(
value=args.seed,
@@ -484,16 +498,20 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
stop_batch = gr.Button("Stop Batch")
with gr.Row():
txt2img_sdxl_sendto_img2img = gr.Button(
value="Send To Img2Img"
value="Send To Img2Img",
visible=False,
)
txt2img_sdxl_sendto_inpaint = gr.Button(
value="Send To Inpaint"
value="Send To Inpaint",
visible=False,
)
txt2img_sdxl_sendto_outpaint = gr.Button(
value="Send To Outpaint"
value="Send To Outpaint",
visible=False,
)
txt2img_sdxl_sendto_upscaler = gr.Button(
value="Send To Upscaler"
value="Send To Upscaler",
visible=False,
)
kwargs = dict(
@@ -523,14 +541,42 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
],
outputs=[txt2img_sdxl_gallery, std_output, txt2img_sdxl_status],
show_progress="minimal" if args.progress_bar else "none",
queue=True,
)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Text-to-Image-SDXL", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=txt2img_sdxl_status,
concurrency_limit=1,
)
def autogen_changed(checked):
if checked:
args.autogen = True
else:
args.autogen = False
def check_last_input(prompt):
if not prompt.endswith(" "):
return True
elif not args.autogen:
return True
else:
return False
auto_gen_kwargs = dict(
fn=check_last_input,
inputs=[negative_prompt],
outputs=[txt2img_sdxl_status],
concurrency_limit=1,
)
txt2img_sdxl_autogen.change(
fn=autogen_changed,
inputs=[txt2img_sdxl_autogen],
outputs=None,
)
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
**kwargs
@@ -538,7 +584,11 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
cancels=[
prompt_submit,
neg_prompt_submit,
generate_click,
],
)
txt2img_sdxl_png_info_img.change(
@@ -588,6 +638,7 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
width,
height,
custom_vae,
txt2img_sdxl_autogen,
],
)
lora_weights.change(

View File

@@ -282,6 +282,12 @@ def set_model_default_configs(model_ckpt_or_id, jsonconfig=None):
gr.update(),
gr.update(),
gr.update(),
gr.Checkbox(
label="Auto-Generate",
visible=False,
interactive=False,
value=False,
),
]
@@ -317,7 +323,7 @@ default_configs = {
gr.Textbox(label="", interactive=False, value=None, visible=False),
gr.Textbox(
label="Prompt",
value="role-playing game (RPG) style fantasy, An enchanting image featuring an adorable kitten mage wearing intricate ancient robes, holding an ancient staff, hard at work in her fantastical workshop, magic runes floating in the air",
value="masterpiece, a graceful shark leaping out of the water to catch a fish, eclipsing the sunset, epic, rays of light, silhouette",
),
gr.Slider(0, 10, value=2),
gr.Dropdown(value="EulerAncestralDiscrete"),
@@ -325,6 +331,9 @@ default_configs = {
512,
512,
"madebyollin/sdxl-vae-fp16-fix",
gr.Checkbox(
label="Auto-Generate", visible=False, interactive=True, value=False
),
],
"stabilityai/stable-diffusion-xl-base-1.0": [
gr.Textbox(label="Prompt", interactive=True, visible=True),
@@ -332,9 +341,15 @@ default_configs = {
40,
"EulerDiscrete",
7.5,
gr.Slider(value=1024, interactive=False),
gr.Slider(value=1024, interactive=False),
gr.Slider(value=768, interactive=True),
gr.Slider(value=768, interactive=True),
"madebyollin/sdxl-vae-fp16-fix",
gr.Checkbox(
label="Auto-Generate",
visible=False,
interactive=False,
value=False,
),
],
}

View File

@@ -83,7 +83,7 @@ def clean_device_info(raw_device):
device_id = int(device_id)
if device not in ["rocm", "vulkan"]:
device_id = ""
device_id = None
if device in ["rocm", "vulkan"] and device_id == None:
device_id = 0
return device, device_id