mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-12 07:18:27 -05:00
Compare commits
12 Commits
debug
...
20240602.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
64e63e7130 | ||
|
|
ea8738fb1a | ||
|
|
2a5bec3c4f | ||
|
|
bb58b01d75 | ||
|
|
02285b33a4 | ||
|
|
f9a1d35b59 | ||
|
|
b1ca19a6e6 | ||
|
|
b5dea85808 | ||
|
|
e75f96f2d7 | ||
|
|
bf67e2aa3b | ||
|
|
c088247aa1 | ||
|
|
42abc6787d |
@@ -25,6 +25,14 @@ def imports():
|
||||
)
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torch")
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
|
||||
warnings.filterwarnings(action="ignore", category=FutureWarning, module="diffusers")
|
||||
warnings.filterwarnings(
|
||||
action="ignore", category=FutureWarning, module="huggingface-hub"
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
action="ignore", category=UserWarning, module="huggingface-hub"
|
||||
)
|
||||
|
||||
import gradio # noqa: F401
|
||||
|
||||
|
||||
@@ -99,7 +99,10 @@ class StableDiffusion:
|
||||
import_ir: bool = True,
|
||||
is_controlled: bool = False,
|
||||
external_weights: str = "safetensors",
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
progress(None, desc="Initializing pipeline...")
|
||||
self.ui_device = device
|
||||
self.precision = precision
|
||||
self.compiled_pipeline = False
|
||||
self.base_model_id = base_model_id
|
||||
@@ -112,12 +115,15 @@ class StableDiffusion:
|
||||
"custom_pipeline",
|
||||
)
|
||||
self.turbine_pipe = custom_module.StudioPipeline
|
||||
self.dynamic_steps = custom_module.dynamic_steps
|
||||
self.model_map = custom_module.MODEL_MAP
|
||||
elif self.is_sdxl:
|
||||
self.turbine_pipe = SharkSDXLPipeline
|
||||
self.dynamic_steps = False
|
||||
self.model_map = EMPTY_SDXL_MAP
|
||||
else:
|
||||
self.turbine_pipe = SharkSDPipeline
|
||||
self.dynamic_steps = True
|
||||
self.model_map = EMPTY_SD_MAP
|
||||
max_length = 64
|
||||
target_backend, self.rt_device, triple = parse_device(device, target_triple)
|
||||
@@ -178,13 +184,19 @@ class StableDiffusion:
|
||||
external_weights=external_weights,
|
||||
custom_vae=custom_vae,
|
||||
)
|
||||
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
|
||||
progress(None, desc="Pipeline initialized!...")
|
||||
gc.collect()
|
||||
|
||||
def prepare_pipe(
|
||||
self, custom_weights, adapters, embeddings, is_img2img, compiled_pipeline
|
||||
self,
|
||||
custom_weights,
|
||||
adapters,
|
||||
embeddings,
|
||||
is_img2img,
|
||||
compiled_pipeline,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
print(f"\n[LOG] Preparing pipeline...")
|
||||
|
||||
self.is_img2img = False
|
||||
mlirs = copy.deepcopy(self.model_map)
|
||||
vmfbs = copy.deepcopy(self.model_map)
|
||||
@@ -236,17 +248,18 @@ class StableDiffusion:
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
)
|
||||
weights[key] = save_irpa(vae_weights_path, "vae.")
|
||||
progress(None, desc=f"Preparing pipeline for {self.ui_device}...")
|
||||
|
||||
vmfbs, weights = self.sd_pipe.check_prepared(
|
||||
mlirs, vmfbs, weights, interactive=False
|
||||
)
|
||||
print(f"\n[LOG] Loading pipeline to device {self.rt_device}.")
|
||||
progress(None, desc=f"Artifacts ready!")
|
||||
progress(None, desc=f"Loading pipeline on device {self.ui_device}...")
|
||||
|
||||
self.sd_pipe.load_pipeline(
|
||||
vmfbs, weights, self.rt_device, self.compiled_pipeline
|
||||
)
|
||||
print(
|
||||
"\n[LOG] Pipeline successfully prepared for runtime. Generating images..."
|
||||
)
|
||||
progress(None, desc="Pipeline loaded! Generating images...")
|
||||
return
|
||||
|
||||
def generate_images(
|
||||
@@ -261,7 +274,9 @@ class StableDiffusion:
|
||||
resample_type,
|
||||
control_mode,
|
||||
hints,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
|
||||
img = self.sd_pipe.generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
@@ -273,9 +288,7 @@ class StableDiffusion:
|
||||
return img
|
||||
|
||||
|
||||
def shark_sd_fn_dict_input(
|
||||
sd_kwargs: dict,
|
||||
):
|
||||
def shark_sd_fn_dict_input(sd_kwargs: dict, *, progress=gr.Progress()):
|
||||
print("\n[LOG] Submitting Request...")
|
||||
|
||||
for key in sd_kwargs:
|
||||
@@ -306,7 +319,7 @@ def shark_sd_fn_dict_input(
|
||||
)
|
||||
return None, ""
|
||||
if sd_kwargs["target_triple"] == "":
|
||||
if parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2] == "":
|
||||
if not parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2]:
|
||||
gr.Warning(
|
||||
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
|
||||
)
|
||||
@@ -340,6 +353,8 @@ def shark_sd_fn(
|
||||
resample_type: str,
|
||||
controlnets: dict,
|
||||
embeddings: dict,
|
||||
seed_increment: str | int = 1,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
sd_kwargs = locals()
|
||||
if not isinstance(sd_init_image, list):
|
||||
@@ -413,6 +428,9 @@ def shark_sd_fn(
|
||||
"control_mode": control_mode,
|
||||
"hints": hints,
|
||||
}
|
||||
if global_obj.get_sd_obj() and global_obj.get_sd_obj().dynamic_steps:
|
||||
submit_run_kwargs["steps"] = submit_pipe_kwargs["steps"]
|
||||
submit_pipe_kwargs.pop("steps")
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_pipe_kwargs() != submit_pipe_kwargs
|
||||
@@ -438,6 +456,11 @@ def shark_sd_fn(
|
||||
global_obj.get_sd_obj().prepare_pipe(**submit_prep_kwargs)
|
||||
|
||||
generated_imgs = []
|
||||
if seed in [-1, "-1"]:
|
||||
seed = randint(0, 4294967295)
|
||||
print(f"\n[LOG] Random seed: {seed}")
|
||||
progress(None, desc=f"Generating...")
|
||||
|
||||
for current_batch in range(batch_count):
|
||||
start_time = time.time()
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
|
||||
@@ -456,14 +479,20 @@ def shark_sd_fn(
|
||||
sd_kwargs,
|
||||
)
|
||||
generated_imgs.extend(out_imgs)
|
||||
# TODO: make seed changes over batch counts more configurable.
|
||||
submit_run_kwargs["seed"] = submit_run_kwargs["seed"] + 1
|
||||
seed = get_next_seed(seed, seed_increment)
|
||||
yield generated_imgs, status_label(
|
||||
"Stable Diffusion", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
return (generated_imgs, "")
|
||||
|
||||
|
||||
def get_next_seed(seed, seed_increment):
|
||||
if isinstance(seed_increment, int):
|
||||
return int(seed + seed_increment)
|
||||
elif seed_increment == "random":
|
||||
return randint(0, sys.maxsize)
|
||||
|
||||
|
||||
def unload_sd():
|
||||
print("Unloading models.")
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
|
||||
@@ -54,31 +54,31 @@ def get_available_devices():
|
||||
available_devices = []
|
||||
rocm_devices = get_devices_by_name("rocm")
|
||||
available_devices.extend(rocm_devices)
|
||||
cpu_device = get_devices_by_name("cpu-sync")
|
||||
available_devices.extend(cpu_device)
|
||||
# cpu_device = get_devices_by_name("cpu-sync")
|
||||
# available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
available_devices.extend(cpu_device)
|
||||
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
)
|
||||
# from shark.iree_utils.vulkan_utils import (
|
||||
# get_all_vulkan_devices,
|
||||
# )
|
||||
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
vulkan_devices = []
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
|
||||
id += 1
|
||||
if id != 0:
|
||||
print(f"vulkan devices are available.")
|
||||
# vulkaninfo_list = get_all_vulkan_devices()
|
||||
# vulkan_devices = []
|
||||
# id = 0
|
||||
# for device in vulkaninfo_list:
|
||||
# vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
|
||||
# id += 1
|
||||
# if id != 0:
|
||||
# print(f"vulkan devices are available.")
|
||||
|
||||
available_devices.extend(vulkan_devices)
|
||||
metal_devices = get_devices_by_name("metal")
|
||||
available_devices.extend(metal_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
hip_devices = get_devices_by_name("hip")
|
||||
available_devices.extend(hip_devices)
|
||||
# available_devices.extend(vulkan_devices)
|
||||
# metal_devices = get_devices_by_name("metal")
|
||||
# available_devices.extend(metal_devices)
|
||||
# cuda_devices = get_devices_by_name("cuda")
|
||||
# available_devices.extend(cuda_devices)
|
||||
# hip_devices = get_devices_by_name("hip")
|
||||
# available_devices.extend(hip_devices)
|
||||
|
||||
for idx, device_str in enumerate(available_devices):
|
||||
if "AMD Radeon(TM) Graphics =>" in device_str:
|
||||
@@ -160,6 +160,8 @@ def parse_device(device_str, target_override=""):
|
||||
rt_device = rt_driver
|
||||
|
||||
if target_override:
|
||||
if "cpu" in device_str:
|
||||
rt_device = "local-task"
|
||||
return target_backend, rt_device, target_override
|
||||
match target_backend:
|
||||
case "vulkan-spirv":
|
||||
@@ -169,7 +171,10 @@ def parse_device(device_str, target_override=""):
|
||||
triple = get_rocm_target_chip(device_str)
|
||||
return target_backend, rt_device, triple
|
||||
case "llvm-cpu":
|
||||
return "llvm-cpu", "local-task", "x86_64-linux-gnu"
|
||||
if "Ryzen 9" in device_str:
|
||||
return target_backend, "local-task", "znver4"
|
||||
else:
|
||||
return "llvm-cpu", "local-task", "x86_64-linux-gnu"
|
||||
|
||||
|
||||
def get_rocm_target_chip(device_str):
|
||||
|
||||
@@ -71,7 +71,14 @@ def save_irpa(weights_path, prepend_str):
|
||||
new_key = prepend_str + key
|
||||
archive.add_tensor(new_key, weights[key])
|
||||
|
||||
irpa_file = weights_path.replace(".safetensors", ".irpa")
|
||||
if "safetensors" in weights_path:
|
||||
irpa_file = weights_path.replace(".safetensors", ".irpa")
|
||||
elif "irpa" in weights_path:
|
||||
irpa_file = weights_path
|
||||
else:
|
||||
return Exception(
|
||||
"Invalid file format. Please provide a .safetensors or .irpa file."
|
||||
)
|
||||
archive.save(irpa_file)
|
||||
return irpa_file
|
||||
|
||||
|
||||
@@ -33,6 +33,8 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
|
||||
if extra_info is None:
|
||||
extra_info = {}
|
||||
elif "progress" in extra_info.keys():
|
||||
extra_info.pop("progress")
|
||||
generated_imgs_path = Path(
|
||||
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
|
||||
)
|
||||
|
||||
@@ -35,10 +35,7 @@ p.add_argument(
|
||||
"--prompt",
|
||||
nargs="+",
|
||||
default=[
|
||||
"a photo taken of the front of a super-car drifting on a road near "
|
||||
"mountains at high speeds with smoke coming off the tires, front "
|
||||
"angle, front point of view, trees in the mountains of the "
|
||||
"background, ((sharp focus))"
|
||||
"A hi-res photo of a red street racer drifting around a curve on a mountain, high altitude, at night, tokyo in the background, 8k"
|
||||
],
|
||||
help="Text of which images to be generated.",
|
||||
)
|
||||
@@ -62,7 +59,7 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=50,
|
||||
default=2,
|
||||
help="The number of steps to do the sampling.",
|
||||
)
|
||||
|
||||
@@ -100,7 +97,7 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
default=7.5,
|
||||
default=0,
|
||||
help="The value to be used for guidance scaling.",
|
||||
)
|
||||
|
||||
@@ -346,7 +343,7 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--batch_count",
|
||||
type=int,
|
||||
default=1,
|
||||
default=4,
|
||||
help="Number of batches to be generated with random seeds in " "single execution.",
|
||||
)
|
||||
|
||||
|
||||
20
apps/shark_studio/tools/params_prefixer.py
Normal file
20
apps/shark_studio/tools/params_prefixer.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from apps.shark_studio.modules.ckpt_processing import save_irpa
|
||||
import argparse
|
||||
import safetensors
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
default="",
|
||||
help="input safetensors/irpa",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix",
|
||||
type=str,
|
||||
default="",
|
||||
help="prefix to add to all the keys in the irpa",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
output_file = save_irpa(args.input, args.prefix)
|
||||
print("saved irpa to", output_file, "with prefix", args.prefix)
|
||||
@@ -194,7 +194,7 @@ def webui():
|
||||
sd_element.render()
|
||||
with gr.TabItem(label="Output Gallery", id=1):
|
||||
outputgallery_element.render()
|
||||
with gr.TabItem(label="Chat Bot", id=2):
|
||||
with gr.TabItem(label="Chat Bot", id=2, visible=False):
|
||||
chat_element.render()
|
||||
|
||||
studio_web.queue()
|
||||
|
||||
@@ -45,12 +45,13 @@ from apps.shark_studio.modules import logger
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
|
||||
sd_default_models = [
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
# "runwayml/stable-diffusion-v1-5",
|
||||
# "stabilityai/stable-diffusion-2-1-base",
|
||||
# "stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"stabilityai/sdxl-turbo",
|
||||
]
|
||||
sd_default_models.extend(get_checkpoints(model_type="scripts"))
|
||||
|
||||
|
||||
def view_json_file(file_path):
|
||||
@@ -158,6 +159,8 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
|
||||
sd_image = [Image.open(i, mode="r")]
|
||||
else:
|
||||
sd_image = None
|
||||
if not sd_json["device"]:
|
||||
sd_json["device"] = gr.update()
|
||||
|
||||
return [
|
||||
sd_json["prompt"][0],
|
||||
@@ -165,7 +168,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
|
||||
sd_image,
|
||||
sd_json["height"],
|
||||
sd_json["width"],
|
||||
sd_json["steps"],
|
||||
gr.update(),
|
||||
sd_json["strength"],
|
||||
sd_json["guidance_scale"],
|
||||
sd_json["seed"],
|
||||
@@ -198,7 +201,7 @@ def save_sd_cfg(config: dict, save_name: str):
|
||||
filepath += ".json"
|
||||
with open(filepath, mode="w") as f:
|
||||
f.write(json.dumps(config))
|
||||
return "..."
|
||||
return save_name
|
||||
|
||||
|
||||
def create_canvas(width, height):
|
||||
@@ -235,11 +238,42 @@ def base_model_changed(base_model_id):
|
||||
new_choices = get_checkpoints(
|
||||
os.path.join("checkpoints", os.path.basename(str(base_model_id)))
|
||||
) + get_checkpoints(model_type="checkpoints")
|
||||
if "turbo" in base_model_id:
|
||||
new_steps = gr.Dropdown(
|
||||
value=cmd_opts.steps,
|
||||
choices=[1, 2, 3, 4],
|
||||
label="\U0001F3C3\U0000FE0F Steps",
|
||||
allow_custom_value=False,
|
||||
)
|
||||
if "stable-diffusion-xl-base-1.0" in base_model_id:
|
||||
new_steps = gr.Dropdown(
|
||||
value=40,
|
||||
choices=[20, 25, 30, 35, 40, 45, 50],
|
||||
label="\U0001F3C3\U0000FE0F Steps",
|
||||
allow_custom_value=False,
|
||||
)
|
||||
elif ".py" in base_model_id:
|
||||
new_steps = gr.Dropdown(
|
||||
value=20,
|
||||
choices=[10, 15, 20, 28],
|
||||
label="\U0001F3C3\U0000FE0F Steps",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
else:
|
||||
new_steps = gr.Dropdown(
|
||||
value=cmd_opts.steps,
|
||||
choices=[10, 20, 30, 40, 50],
|
||||
label="\U0001F3C3\U0000FE0F Steps",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
return gr.Dropdown(
|
||||
value=new_choices[0] if len(new_choices) > 0 else "None",
|
||||
choices=["None"] + new_choices,
|
||||
)
|
||||
return [
|
||||
gr.Dropdown(
|
||||
value=new_choices[0] if len(new_choices) > 0 else "None",
|
||||
choices=["None"] + new_choices,
|
||||
),
|
||||
new_steps,
|
||||
]
|
||||
|
||||
|
||||
with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
@@ -256,16 +290,17 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
choices=global_obj.get_device_list(),
|
||||
allow_custom_value=False,
|
||||
)
|
||||
target_triple = gr.Textbox(
|
||||
elem_id="target_triple",
|
||||
label="Architecture",
|
||||
value="",
|
||||
)
|
||||
with gr.Row():
|
||||
ondemand = gr.Checkbox(
|
||||
value=cmd_opts.lowvram,
|
||||
label="Low VRAM",
|
||||
interactive=True,
|
||||
visible=False,
|
||||
)
|
||||
target_triple = gr.Textbox(
|
||||
elem_id="target_triple",
|
||||
label="Architecture",
|
||||
value="",
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
@@ -281,27 +316,29 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
label="\U000026F0\U0000FE0F Base Model",
|
||||
info="Select or enter HF model ID",
|
||||
elem_id="custom_model",
|
||||
value="stabilityai/stable-diffusion-2-1-base",
|
||||
value="stabilityai/sdxl-turbo",
|
||||
choices=sd_default_models,
|
||||
allow_custom_value=True,
|
||||
) # base_model_id
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384,
|
||||
512,
|
||||
1024,
|
||||
value=cmd_opts.height,
|
||||
step=8,
|
||||
value=512,
|
||||
step=512,
|
||||
label="\U00002195\U0000FE0F Height",
|
||||
)
|
||||
width = gr.Slider(
|
||||
384,
|
||||
512,
|
||||
1024,
|
||||
value=cmd_opts.width,
|
||||
step=8,
|
||||
value=512,
|
||||
step=512,
|
||||
label="\U00002194\U0000FE0F Width",
|
||||
)
|
||||
with gr.Accordion(
|
||||
label="\U00002696\U0000FE0F Model Weights", open=False
|
||||
label="\U00002696\U0000FE0F Model Weights",
|
||||
open=False,
|
||||
visible=False, # DEMO
|
||||
):
|
||||
with gr.Column():
|
||||
custom_weights = gr.Dropdown(
|
||||
@@ -313,11 +350,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
choices=["None"]
|
||||
+ get_checkpoints(os.path.basename(str(base_model_id))),
|
||||
) # custom_weights
|
||||
base_model_id.change(
|
||||
fn=base_model_changed,
|
||||
inputs=[base_model_id],
|
||||
outputs=[custom_weights],
|
||||
)
|
||||
sd_vae_info = (str(get_checkpoints_path("vae"))).replace(
|
||||
"\\", "\n\\"
|
||||
)
|
||||
@@ -369,7 +401,9 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
show_progress=False,
|
||||
)
|
||||
with gr.Accordion(
|
||||
label="\U0001F9EA\U0000FE0F Input Image Processing", open=False
|
||||
label="\U0001F9EA\U0000FE0F Input Image Processing",
|
||||
open=False,
|
||||
visible=False,
|
||||
):
|
||||
strength = gr.Slider(
|
||||
0,
|
||||
@@ -403,24 +437,23 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
seed = gr.Textbox(
|
||||
value=cmd_opts.seed,
|
||||
label="\U0001F331\U0000FE0F Seed",
|
||||
info="An integer or a JSON list of integers, -1 for random",
|
||||
info="An integer, -1 for random",
|
||||
show_copy_button=True,
|
||||
)
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="\U0001F4C5\U0000FE0F Scheduler",
|
||||
info="\U000E0020", # forces same height as seed
|
||||
value="EulerDiscrete",
|
||||
value="EulerAncestralDiscrete",
|
||||
choices=scheduler_model_map.keys(),
|
||||
allow_custom_value=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
steps = gr.Dropdown(
|
||||
value=cmd_opts.steps,
|
||||
step=1,
|
||||
choices=[1, 2, 3, 4],
|
||||
label="\U0001F3C3\U0000FE0F Steps",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
@@ -478,17 +511,17 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
with gr.Row():
|
||||
canvas_width = gr.Slider(
|
||||
label="Canvas Width",
|
||||
minimum=256,
|
||||
minimum=512,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=8,
|
||||
step=512,
|
||||
)
|
||||
canvas_height = gr.Slider(
|
||||
label="Canvas Height",
|
||||
minimum=256,
|
||||
minimum=512,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=8,
|
||||
step=512,
|
||||
)
|
||||
make_canvas = gr.Button(
|
||||
value="Make Canvas!",
|
||||
@@ -558,7 +591,9 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
lambda: gr.Tabs(selected=101),
|
||||
outputs=[sd_tabs],
|
||||
)
|
||||
with gr.Tab(label="Input Image", id=100) as sd_tab_init_image:
|
||||
with gr.Tab(
|
||||
label="Input Image", id=100, visible=False
|
||||
) as sd_tab_init_image: # DEMO
|
||||
with gr.Column(elem_classes=["sd-right-panel"]):
|
||||
with gr.Row(elem_classes=["fill"]):
|
||||
# TODO: make this import image prompt info if it exists
|
||||
@@ -604,10 +639,10 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
visible=True,
|
||||
visible=False, # DEMO
|
||||
)
|
||||
compiled_pipeline = gr.Checkbox(
|
||||
False,
|
||||
True,
|
||||
label="Faster txt2img (SDXL only)",
|
||||
)
|
||||
with gr.Row():
|
||||
@@ -618,18 +653,18 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop")
|
||||
stop_batch = gr.Button("Stop", visible=False)
|
||||
with gr.Tab(label="Config", id=102) as sd_tab_config:
|
||||
with gr.Column(elem_classes=["sd-right-panel"]):
|
||||
with gr.Row(elem_classes=["fill"]):
|
||||
Path(get_configs_path()).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
write_default_sd_configs(get_configs_path())
|
||||
default_config_file = os.path.join(
|
||||
get_configs_path(),
|
||||
"default_sd_config.json",
|
||||
"sdxl-turbo.json",
|
||||
)
|
||||
write_default_sd_configs(get_configs_path())
|
||||
sd_json = gr.JSON(
|
||||
elem_classes=["fill"],
|
||||
value=view_json_file(default_config_file),
|
||||
@@ -644,7 +679,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
if cmd_opts.configs_path
|
||||
else get_configs_path()
|
||||
),
|
||||
height=75,
|
||||
height=200,
|
||||
)
|
||||
with gr.Column(scale=1):
|
||||
save_sd_config = gr.Button(
|
||||
@@ -655,13 +690,13 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
size="sm",
|
||||
components=sd_json,
|
||||
)
|
||||
with gr.Row():
|
||||
sd_config_name = gr.Textbox(
|
||||
value="Config Name",
|
||||
info="Name of the file this config will be saved to.",
|
||||
interactive=True,
|
||||
show_label=False,
|
||||
)
|
||||
# with gr.Row():
|
||||
sd_config_name = gr.Textbox(
|
||||
value="Config Name",
|
||||
info="Name of the file this config will be saved to.",
|
||||
interactive=True,
|
||||
show_label=False,
|
||||
)
|
||||
load_sd_config.change(
|
||||
fn=load_sd_cfg,
|
||||
inputs=[sd_json, load_sd_config],
|
||||
@@ -718,6 +753,11 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
logger.read_sd_logs, None, std_output, every=1
|
||||
)
|
||||
sd_status = gr.Textbox(visible=False)
|
||||
base_model_id.change(
|
||||
fn=base_model_changed,
|
||||
inputs=[base_model_id],
|
||||
outputs=[custom_weights, steps],
|
||||
)
|
||||
|
||||
pull_kwargs = dict(
|
||||
fn=pull_sd_configs,
|
||||
@@ -749,6 +789,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
outputs=[
|
||||
sd_json,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
status_kwargs = dict(
|
||||
|
||||
@@ -89,7 +89,7 @@ sdxl_turbo = r"""{
|
||||
}"""
|
||||
|
||||
default_sd_configs = {
|
||||
"default_sd_config.json": default_sd_config,
|
||||
# "default_sd_config.json": sdxl_turbo,
|
||||
"sdxl-30steps.json": sdxl_30steps,
|
||||
"sdxl-turbo.json": sdxl_turbo,
|
||||
}
|
||||
|
||||
@@ -17,8 +17,9 @@ from apps.shark_studio.web.utils.default_configs import default_sd_configs
|
||||
def write_default_sd_configs(path):
|
||||
for key in default_sd_configs.keys():
|
||||
config_fpath = os.path.join(path, key)
|
||||
with open(config_fpath, "w") as f:
|
||||
f.write(default_sd_configs[key])
|
||||
if not os.path.exists(config_fpath):
|
||||
with open(config_fpath, "w") as f:
|
||||
f.write(default_sd_configs[key])
|
||||
|
||||
|
||||
def safe_name(name):
|
||||
@@ -87,6 +88,8 @@ def get_checkpoints_path(model_type=""):
|
||||
def get_checkpoints(model_type="checkpoints"):
|
||||
ckpt_files = []
|
||||
file_types = checkpoints_filetypes
|
||||
if model_type == "scripts":
|
||||
file_types = ["shark_*.py"]
|
||||
if model_type == "lora":
|
||||
file_types = file_types + ("*.pt", "*.bin")
|
||||
for extn in file_types:
|
||||
|
||||
@@ -8,8 +8,8 @@ wheel
|
||||
|
||||
torch==2.3.0
|
||||
shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main
|
||||
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@deprecated-constraints#subdirectory=models
|
||||
diffusers @ git+https://github.com/nod-ai/diffusers@0.29.0.dev0-shark
|
||||
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-unify-sd#subdirectory=models
|
||||
diffusers @ git+https://github.com/nod-ai/diffusers@v0.24.0-release
|
||||
brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b
|
||||
|
||||
# SHARK Runner
|
||||
|
||||
@@ -89,7 +89,7 @@ else {python -m venv .\shark.venv\}
|
||||
python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install --pre -r requirements.txt
|
||||
pip install --force-reinstall https://github.com/nod-ai/SRT/releases/download/candidate-20240528.279/iree_compiler-20240528.279-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240528.279/iree_runtime-20240528.279-cp311-cp311-win_amd64.whl
|
||||
pip install https://github.com/nod-ai/SRT/releases/download/candidate-20240602.283/iree_compiler-20240602.283-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240602.283/iree_runtime-20240602.283-cp311-cp311-win_amd64.whl
|
||||
pip install -e .
|
||||
|
||||
Write-Host "Source your venv with ./shark.venv/Scripts/activate"
|
||||
|
||||
Reference in New Issue
Block a user