mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
Fixes to llama2 cpu compilation and studio UI, schedulers (#2013)
* Fix some issues with defaults Fixes to llama2 cpu compilation (turns off data tiling for old argmax mode) --------- Co-authored-by: Max Dawkins <max.dawkins@gmail.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import (
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from diffusers import (
|
||||
LCMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDPMScheduler,
|
||||
@@ -25,7 +26,6 @@ def get_schedulers(model_id):
|
||||
# set batch_size here, the SHARK schedulers will
|
||||
# compile with batch size = 1 regardless of whether the model
|
||||
# outputs latents of a larger batch size, e.g. SDXL.
|
||||
# This also goes towards enabling batch size cfg for SD in general.
|
||||
# However, obviously, searching for whether the base model ID
|
||||
# contains "xl" is not very robust.
|
||||
|
||||
@@ -52,6 +52,10 @@ def get_schedulers(model_id):
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverMultistep"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
|
||||
@@ -467,6 +467,13 @@ p.add_argument(
|
||||
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
|
||||
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--autogen",
|
||||
type=bool,
|
||||
default="False",
|
||||
help="Only used for a gradio workaround.",
|
||||
)
|
||||
##############################################################################
|
||||
# IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
@@ -97,8 +97,6 @@ if __name__ == "__main__":
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
|
||||
from apps.stable_diffusion.web.ui import (
|
||||
txt2img_web,
|
||||
txt2img_custom_model,
|
||||
@@ -210,6 +208,8 @@ if __name__ == "__main__":
|
||||
outputs,
|
||||
)
|
||||
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
|
||||
with gr.Blocks(
|
||||
css=dark_theme, analytics_enabled=False, title="SHARK AI Studio"
|
||||
) as sd_web:
|
||||
@@ -255,10 +255,10 @@ if __name__ == "__main__":
|
||||
# lora_train_web.render()
|
||||
with gr.TabItem(label="Chat Bot", id=8):
|
||||
stablelm_chat.render()
|
||||
with gr.TabItem(
|
||||
label="Generate Sharding Config (Experimental)", id=9
|
||||
):
|
||||
model_config_web.render()
|
||||
# with gr.TabItem(
|
||||
# label="Generate Sharding Config (Experimental)", id=9
|
||||
# ):
|
||||
# model_config_web.render()
|
||||
# with gr.TabItem(label="MultiModal (Experimental)", id=10):
|
||||
# minigpt4_web.render()
|
||||
# with gr.TabItem(label="DocuChat Upload", id=11):
|
||||
@@ -405,6 +405,12 @@ if __name__ == "__main__":
|
||||
[outputgallery_filename],
|
||||
[upscaler_init_image, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_txt2img_sdxl,
|
||||
0,
|
||||
[outputgallery_filename],
|
||||
[txt2img_sdxl_png_info_img, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_txt2img,
|
||||
0,
|
||||
|
||||
@@ -173,6 +173,13 @@ def chat(
|
||||
get_vulkan_target_triple,
|
||||
)
|
||||
|
||||
_extra_args = _extra_args + [
|
||||
"--iree-global-opt-enable-quantized-matmul-reassociation",
|
||||
"--iree-llvmcpu-enable-quantized-matmul-reassociation",
|
||||
"--iree-opt-const-eval=false",
|
||||
"--iree-opt-data-tiling=false",
|
||||
]
|
||||
|
||||
if device == "vulkan":
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
if vulkan_target_triple == "":
|
||||
@@ -250,10 +257,11 @@ def chat(
|
||||
total_time_ms = 0.001 # In order to avoid divide by zero error
|
||||
prefill_time = 0
|
||||
is_first = True
|
||||
for text, msg, exec_time in progress.tqdm(
|
||||
vicuna_model.generate(prompt, cli=cli),
|
||||
desc="generating response",
|
||||
):
|
||||
# for text, msg, exec_time in progress.tqdm(
|
||||
# vicuna_model.generate(prompt, cli=cli),
|
||||
# desc="generating response",
|
||||
# ):
|
||||
for text, msg, exec_time in vicuna_model.generate(prompt, cli=cli):
|
||||
if msg is None:
|
||||
if is_first:
|
||||
prefill_time = exec_time
|
||||
|
||||
@@ -226,7 +226,12 @@ def txt2img_sdxl_inf(
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
theme = gr.themes.Glass(
|
||||
primary_hue="slate",
|
||||
secondary_hue="gray",
|
||||
)
|
||||
|
||||
with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
@@ -288,17 +293,24 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
txt2img_sdxl_autogen = gr.Checkbox(
|
||||
label="Auto-Generate Images",
|
||||
value=False,
|
||||
visible=False,
|
||||
)
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=2,
|
||||
elem_id="prompt_box",
|
||||
show_copy_button=True,
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
show_copy_button=True,
|
||||
)
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
@@ -340,6 +352,7 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
"DDIM",
|
||||
"EulerAncestralDiscrete",
|
||||
"EulerDiscrete",
|
||||
"LCMScheduler",
|
||||
],
|
||||
allow_custom_value=False,
|
||||
visible=True,
|
||||
@@ -360,7 +373,7 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
512,
|
||||
1024,
|
||||
value=1024,
|
||||
step=512,
|
||||
step=256,
|
||||
label="Height",
|
||||
visible=True,
|
||||
interactive=True,
|
||||
@@ -369,7 +382,7 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
512,
|
||||
1024,
|
||||
value=1024,
|
||||
step=512,
|
||||
step=256,
|
||||
label="Width",
|
||||
visible=True,
|
||||
interactive=True,
|
||||
@@ -379,7 +392,6 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
value="fp16",
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
@@ -427,12 +439,14 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
args.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
seed = gr.Textbox(
|
||||
value=args.seed,
|
||||
@@ -484,16 +498,20 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
txt2img_sdxl_sendto_img2img = gr.Button(
|
||||
value="Send To Img2Img"
|
||||
value="Send To Img2Img",
|
||||
visible=False,
|
||||
)
|
||||
txt2img_sdxl_sendto_inpaint = gr.Button(
|
||||
value="Send To Inpaint"
|
||||
value="Send To Inpaint",
|
||||
visible=False,
|
||||
)
|
||||
txt2img_sdxl_sendto_outpaint = gr.Button(
|
||||
value="Send To Outpaint"
|
||||
value="Send To Outpaint",
|
||||
visible=False,
|
||||
)
|
||||
txt2img_sdxl_sendto_upscaler = gr.Button(
|
||||
value="Send To Upscaler"
|
||||
value="Send To Upscaler",
|
||||
visible=False,
|
||||
)
|
||||
|
||||
kwargs = dict(
|
||||
@@ -523,14 +541,42 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
],
|
||||
outputs=[txt2img_sdxl_gallery, std_output, txt2img_sdxl_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
queue=True,
|
||||
)
|
||||
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Text-to-Image-SDXL", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=txt2img_sdxl_status,
|
||||
concurrency_limit=1,
|
||||
)
|
||||
|
||||
def autogen_changed(checked):
|
||||
if checked:
|
||||
args.autogen = True
|
||||
else:
|
||||
args.autogen = False
|
||||
|
||||
def check_last_input(prompt):
|
||||
if not prompt.endswith(" "):
|
||||
return True
|
||||
elif not args.autogen:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
auto_gen_kwargs = dict(
|
||||
fn=check_last_input,
|
||||
inputs=[negative_prompt],
|
||||
outputs=[txt2img_sdxl_status],
|
||||
concurrency_limit=1,
|
||||
)
|
||||
|
||||
txt2img_sdxl_autogen.change(
|
||||
fn=autogen_changed,
|
||||
inputs=[txt2img_sdxl_autogen],
|
||||
outputs=None,
|
||||
)
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
@@ -538,7 +584,11 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
cancels=[
|
||||
prompt_submit,
|
||||
neg_prompt_submit,
|
||||
generate_click,
|
||||
],
|
||||
)
|
||||
|
||||
txt2img_sdxl_png_info_img.change(
|
||||
@@ -588,6 +638,7 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
width,
|
||||
height,
|
||||
custom_vae,
|
||||
txt2img_sdxl_autogen,
|
||||
],
|
||||
)
|
||||
lora_weights.change(
|
||||
|
||||
@@ -282,6 +282,12 @@ def set_model_default_configs(model_ckpt_or_id, jsonconfig=None):
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.Checkbox(
|
||||
label="Auto-Generate",
|
||||
visible=False,
|
||||
interactive=False,
|
||||
value=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -317,7 +323,7 @@ default_configs = {
|
||||
gr.Textbox(label="", interactive=False, value=None, visible=False),
|
||||
gr.Textbox(
|
||||
label="Prompt",
|
||||
value="role-playing game (RPG) style fantasy, An enchanting image featuring an adorable kitten mage wearing intricate ancient robes, holding an ancient staff, hard at work in her fantastical workshop, magic runes floating in the air",
|
||||
value="masterpiece, a graceful shark leaping out of the water to catch a fish, eclipsing the sunset, epic, rays of light, silhouette",
|
||||
),
|
||||
gr.Slider(0, 10, value=2),
|
||||
gr.Dropdown(value="EulerAncestralDiscrete"),
|
||||
@@ -325,6 +331,9 @@ default_configs = {
|
||||
512,
|
||||
512,
|
||||
"madebyollin/sdxl-vae-fp16-fix",
|
||||
gr.Checkbox(
|
||||
label="Auto-Generate", visible=False, interactive=True, value=False
|
||||
),
|
||||
],
|
||||
"stabilityai/stable-diffusion-xl-base-1.0": [
|
||||
gr.Textbox(label="Prompt", interactive=True, visible=True),
|
||||
@@ -332,9 +341,15 @@ default_configs = {
|
||||
40,
|
||||
"EulerDiscrete",
|
||||
7.5,
|
||||
gr.Slider(value=1024, interactive=False),
|
||||
gr.Slider(value=1024, interactive=False),
|
||||
gr.Slider(value=768, interactive=True),
|
||||
gr.Slider(value=768, interactive=True),
|
||||
"madebyollin/sdxl-vae-fp16-fix",
|
||||
gr.Checkbox(
|
||||
label="Auto-Generate",
|
||||
visible=False,
|
||||
interactive=False,
|
||||
value=False,
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ def clean_device_info(raw_device):
|
||||
device_id = int(device_id)
|
||||
|
||||
if device not in ["rocm", "vulkan"]:
|
||||
device_id = ""
|
||||
device_id = None
|
||||
if device in ["rocm", "vulkan"] and device_id == None:
|
||||
device_id = 0
|
||||
return device, device_id
|
||||
|
||||
Reference in New Issue
Block a user