Compare commits

...

19 Commits

Author SHA1 Message Date
Ean Garvey
5f1b5e58d6 igpu dont error on device parse fail 2024-06-02 12:56:44 -05:00
Ean Garvey
6adae49d9b igpu restrictions 2024-06-02 12:51:53 -05:00
Ean Garvey
6abd9ff5cf Reduce available step options for turbo. 2024-06-02 11:41:23 -05:00
Ean Garvey
9957c96014 More noticeable seed changes 2024-06-02 11:39:21 -05:00
Ean Garvey
36b8c2fd6d disable pndm 2024-06-02 11:30:18 -05:00
Ean Garvey
9163c1fc50 small fixes 2024-06-02 11:28:37 -05:00
Ean Garvey
349e9f70fb Progress indicators 2024-06-02 10:18:09 -05:00
Ean Garvey
64e63e7130 znver4 device handling 2024-06-02 10:08:00 -05:00
Ean Garvey
ea8738fb1a Update SRT links 2024-06-02 09:50:09 -05:00
Ean Garvey
2a5bec3c4f Fixes for seed. 2024-06-02 09:46:22 -05:00
Ean Garvey
bb58b01d75 Switch to fixed steps, tweak config loading to prevent race condition 2024-06-01 20:15:53 -05:00
Ean Garvey
02285b33a4 More fixes for demo. 2024-06-01 19:46:52 -05:00
Ean Garvey
f9a1d35b59 Hide chatbot. 2024-06-01 14:24:37 -05:00
Ean Garvey
b1ca19a6e6 Cleanup for demo. 2024-06-01 13:42:51 -05:00
Ean Garvey
b5dea85808 Reduce UI for demos. 2024-06-01 12:00:22 -05:00
Ean Garvey
e75f96f2d7 fixup conditional 2024-06-01 12:00:11 -05:00
Ean Garvey
bf67e2aa3b Formatting 2024-06-01 11:59:10 -05:00
Ean Garvey
c088247aa1 Fix default configs, config loading, and add warnings/early returns for bad configs. 2024-06-01 11:58:51 -05:00
Ean Garvey
42abc6787d Small tweaks to ckpt processing, add tool to prefix params keys 2024-06-01 11:53:40 -05:00
13 changed files with 229 additions and 105 deletions

View File

@@ -25,6 +25,14 @@ def imports():
)
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torch")
warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
warnings.filterwarnings(action="ignore", category=FutureWarning, module="diffusers")
warnings.filterwarnings(
action="ignore", category=FutureWarning, module="huggingface-hub"
)
warnings.filterwarnings(
action="ignore", category=UserWarning, module="huggingface-hub"
)
import gradio # noqa: F401

View File

@@ -99,7 +99,10 @@ class StableDiffusion:
import_ir: bool = True,
is_controlled: bool = False,
external_weights: str = "safetensors",
progress=gr.Progress(),
):
progress(0, desc="Initializing pipeline...")
self.ui_device = device
self.precision = precision
self.compiled_pipeline = False
self.base_model_id = base_model_id
@@ -112,12 +115,15 @@ class StableDiffusion:
"custom_pipeline",
)
self.turbine_pipe = custom_module.StudioPipeline
self.dynamic_steps = custom_module.dynamic_steps
self.model_map = custom_module.MODEL_MAP
elif self.is_sdxl:
self.turbine_pipe = SharkSDXLPipeline
self.dynamic_steps = False
self.model_map = EMPTY_SDXL_MAP
else:
self.turbine_pipe = SharkSDPipeline
self.dynamic_steps = True
self.model_map = EMPTY_SD_MAP
max_length = 64
target_backend, self.rt_device, triple = parse_device(device, target_triple)
@@ -158,7 +164,7 @@ class StableDiffusion:
external_weights = None
elif target_backend == "llvm-cpu":
decomp_attn = False
progress(0.5, desc="Initializing pipeline...")
self.sd_pipe = self.turbine_pipe(
hf_model_name=base_model_id,
scheduler_id=scheduler,
@@ -178,13 +184,20 @@ class StableDiffusion:
external_weights=external_weights,
custom_vae=custom_vae,
)
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
progress(1, desc="Pipeline initialized!...")
gc.collect()
def prepare_pipe(
self, custom_weights, adapters, embeddings, is_img2img, compiled_pipeline
self,
custom_weights,
adapters,
embeddings,
is_img2img,
compiled_pipeline,
progress=gr.Progress(),
):
print(f"\n[LOG] Preparing pipeline...")
progress(0, desc="Preparing models...")
self.is_img2img = False
mlirs = copy.deepcopy(self.model_map)
vmfbs = copy.deepcopy(self.model_map)
@@ -236,17 +249,18 @@ class StableDiffusion:
"diffusion_pytorch_model.safetensors",
)
weights[key] = save_irpa(vae_weights_path, "vae.")
progress(0.25, desc=f"Preparing pipeline for {self.ui_device}...")
vmfbs, weights = self.sd_pipe.check_prepared(
mlirs, vmfbs, weights, interactive=False
)
print(f"\n[LOG] Loading pipeline to device {self.rt_device}.")
progress(0.5, desc=f"Artifacts ready!")
progress(0.75, desc=f"Loading models and weights...")
self.sd_pipe.load_pipeline(
vmfbs, weights, self.rt_device, self.compiled_pipeline
)
print(
"\n[LOG] Pipeline successfully prepared for runtime. Generating images..."
)
progress(1, desc="Pipeline loaded! Generating images...")
return
def generate_images(
@@ -261,7 +275,9 @@ class StableDiffusion:
resample_type,
control_mode,
hints,
progress=gr.Progress(),
):
img = self.sd_pipe.generate_images(
prompt,
negative_prompt,
@@ -273,9 +289,7 @@ class StableDiffusion:
return img
def shark_sd_fn_dict_input(
sd_kwargs: dict,
):
def shark_sd_fn_dict_input(sd_kwargs: dict, *, progress=gr.Progress()):
print("\n[LOG] Submitting Request...")
for key in sd_kwargs:
@@ -283,6 +297,8 @@ def shark_sd_fn_dict_input(
sd_kwargs[key] = None
if sd_kwargs[key] in ["None"]:
sd_kwargs[key] = ""
if key in ["steps", "height", "width", "batch_count", "batch_size"]:
sd_kwargs[key] = int(sd_kwargs[key])
if key == "seed":
sd_kwargs[key] = int(sd_kwargs[key])
@@ -306,7 +322,7 @@ def shark_sd_fn_dict_input(
)
return None, ""
if sd_kwargs["target_triple"] == "":
if parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2] == "":
if not parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2]:
gr.Warning(
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
)
@@ -340,6 +356,8 @@ def shark_sd_fn(
resample_type: str,
controlnets: dict,
embeddings: dict,
seed_increment: str | int = 1,
progress=gr.Progress(),
):
sd_kwargs = locals()
if not isinstance(sd_init_image, list):
@@ -413,6 +431,9 @@ def shark_sd_fn(
"control_mode": control_mode,
"hints": hints,
}
if global_obj.get_sd_obj() and global_obj.get_sd_obj().dynamic_steps:
submit_run_kwargs["steps"] = submit_pipe_kwargs["steps"]
submit_pipe_kwargs.pop("steps")
if (
not global_obj.get_sd_obj()
or global_obj.get_pipe_kwargs() != submit_pipe_kwargs
@@ -438,6 +459,12 @@ def shark_sd_fn(
global_obj.get_sd_obj().prepare_pipe(**submit_prep_kwargs)
generated_imgs = []
if seed in [-1, "-1"]:
seed = randint(0, 4294967295)
seed_increment = "random"
print(f"\n[LOG] Random seed: {seed}")
progress(None, desc=f"Generating...")
for current_batch in range(batch_count):
start_time = time.time()
out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
@@ -456,14 +483,23 @@ def shark_sd_fn(
sd_kwargs,
)
generated_imgs.extend(out_imgs)
# TODO: make seed changes over batch counts more configurable.
submit_run_kwargs["seed"] = submit_run_kwargs["seed"] + 1
seed = get_next_seed(seed, seed_increment)
yield generated_imgs, status_label(
"Stable Diffusion", current_batch + 1, batch_count, batch_size
)
return (generated_imgs, "")
def get_next_seed(seed, seed_increment: str | int = 10):
if isinstance(seed_increment, int):
print(f"\n[LOG] Seed after batch increment: {seed + seed_increment}")
return int(seed + seed_increment)
elif seed_increment == "random":
seed = randint(0, 4294967295)
print(f"\n[LOG] Random seed: {seed}")
return seed
def unload_sd():
print("Unloading models.")
import apps.shark_studio.web.utils.globals as global_obj

View File

@@ -54,31 +54,31 @@ def get_available_devices():
available_devices = []
rocm_devices = get_devices_by_name("rocm")
available_devices.extend(rocm_devices)
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
# cpu_device = get_devices_by_name("cpu-sync")
# available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
available_devices.extend(cpu_device)
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
)
# from shark.iree_utils.vulkan_utils import (
# get_all_vulkan_devices,
# )
vulkaninfo_list = get_all_vulkan_devices()
vulkan_devices = []
id = 0
for device in vulkaninfo_list:
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
id += 1
if id != 0:
print(f"vulkan devices are available.")
# vulkaninfo_list = get_all_vulkan_devices()
# vulkan_devices = []
# id = 0
# for device in vulkaninfo_list:
# vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
# id += 1
# if id != 0:
# print(f"vulkan devices are available.")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
hip_devices = get_devices_by_name("hip")
available_devices.extend(hip_devices)
# available_devices.extend(vulkan_devices)
# metal_devices = get_devices_by_name("metal")
# available_devices.extend(metal_devices)
# cuda_devices = get_devices_by_name("cuda")
# available_devices.extend(cuda_devices)
# hip_devices = get_devices_by_name("hip")
# available_devices.extend(hip_devices)
for idx, device_str in enumerate(available_devices):
if "AMD Radeon(TM) Graphics =>" in device_str:
@@ -160,6 +160,8 @@ def parse_device(device_str, target_override=""):
rt_device = rt_driver
if target_override:
if "cpu" in device_str:
rt_device = "local-task"
return target_backend, rt_device, target_override
match target_backend:
case "vulkan-spirv":
@@ -169,7 +171,10 @@ def parse_device(device_str, target_override=""):
triple = get_rocm_target_chip(device_str)
return target_backend, rt_device, triple
case "llvm-cpu":
return "llvm-cpu", "local-task", "x86_64-linux-gnu"
if "Ryzen 9" in device_str:
return target_backend, "local-task", "znver4"
else:
return "llvm-cpu", "local-task", "x86_64-linux-gnu"
def get_rocm_target_chip(device_str):
@@ -191,9 +196,10 @@ def get_rocm_target_chip(device_str):
for key in rocm_chip_map:
if key in device_str:
return rocm_chip_map[key]
raise AssertionError(
f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/SHARK/issues."
)
return None
# raise AssertionError(
# f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/SHARK/issues."
# )
def get_all_devices(driver_name):

View File

@@ -71,7 +71,14 @@ def save_irpa(weights_path, prepend_str):
new_key = prepend_str + key
archive.add_tensor(new_key, weights[key])
irpa_file = weights_path.replace(".safetensors", ".irpa")
if "safetensors" in weights_path:
irpa_file = weights_path.replace(".safetensors", ".irpa")
elif "irpa" in weights_path:
irpa_file = weights_path
else:
return Exception(
"Invalid file format. Please provide a .safetensors or .irpa file."
)
archive.save(irpa_file)
return irpa_file

View File

@@ -33,6 +33,8 @@ def save_output_img(output_img, img_seed, extra_info=None):
if extra_info is None:
extra_info = {}
elif "progress" in extra_info.keys():
extra_info.pop("progress")
generated_imgs_path = Path(
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
)

View File

@@ -101,7 +101,7 @@ def export_scheduler_model(model):
scheduler_model_map = {
"PNDM": export_scheduler_model("PNDMScheduler"),
# "PNDM": export_scheduler_model("PNDMScheduler"),
# "DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"),
"EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"),
"EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"),

View File

@@ -35,10 +35,7 @@ p.add_argument(
"--prompt",
nargs="+",
default=[
"a photo taken of the front of a super-car drifting on a road near "
"mountains at high speeds with smoke coming off the tires, front "
"angle, front point of view, trees in the mountains of the "
"background, ((sharp focus))"
"A hi-res photo of a red street racer drifting around a curve on a mountain, high altitude, at night, tokyo in the background, 8k"
],
help="Text of which images to be generated.",
)
@@ -62,7 +59,7 @@ p.add_argument(
p.add_argument(
"--steps",
type=int,
default=50,
default=2,
help="The number of steps to do the sampling.",
)
@@ -100,7 +97,7 @@ p.add_argument(
p.add_argument(
"--guidance_scale",
type=float,
default=7.5,
default=0,
help="The value to be used for guidance scaling.",
)
@@ -346,7 +343,7 @@ p.add_argument(
p.add_argument(
"--batch_count",
type=int,
default=1,
default=4,
help="Number of batches to be generated with random seeds in " "single execution.",
)

View File

@@ -0,0 +1,20 @@
from apps.shark_studio.modules.ckpt_processing import save_irpa
import argparse
import safetensors
parser = argparse.ArgumentParser()
parser.add_argument(
"--input",
type=str,
default="",
help="input safetensors/irpa",
)
parser.add_argument(
"--prefix",
type=str,
default="",
help="prefix to add to all the keys in the irpa",
)
args = parser.parse_args()
output_file = save_irpa(args.input, args.prefix)
print("saved irpa to", output_file, "with prefix", args.prefix)

View File

@@ -194,7 +194,7 @@ def webui():
sd_element.render()
with gr.TabItem(label="Output Gallery", id=1):
outputgallery_element.render()
with gr.TabItem(label="Chat Bot", id=2):
with gr.TabItem(label="Chat Bot", id=2, visible=False):
chat_element.render()
studio_web.queue()

View File

@@ -44,13 +44,15 @@ from apps.shark_studio.web.ui.common_events import lora_changed
from apps.shark_studio.modules import logger
import apps.shark_studio.web.utils.globals as global_obj
# Disabled some models for demo purposes
sd_default_models = [
"runwayml/stable-diffusion-v1-5",
"stabilityai/stable-diffusion-2-1-base",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-xl-base-1.0",
# "runwayml/stable-diffusion-v1-5",
# "stabilityai/stable-diffusion-2-1-base",
# "stabilityai/stable-diffusion-2-1",
# "stabilityai/stable-diffusion-xl-base-1.0",
"stabilityai/sdxl-turbo",
]
sd_default_models.extend(get_checkpoints(model_type="scripts"))
def view_json_file(file_path):
@@ -158,6 +160,8 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
sd_image = [Image.open(i, mode="r")]
else:
sd_image = None
if not sd_json["device"]:
sd_json["device"] = gr.update()
return [
sd_json["prompt"][0],
@@ -165,7 +169,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
sd_image,
sd_json["height"],
sd_json["width"],
sd_json["steps"],
gr.update(),
sd_json["strength"],
sd_json["guidance_scale"],
sd_json["seed"],
@@ -198,7 +202,7 @@ def save_sd_cfg(config: dict, save_name: str):
filepath += ".json"
with open(filepath, mode="w") as f:
f.write(json.dumps(config))
return "..."
return save_name
def create_canvas(width, height):
@@ -235,11 +239,42 @@ def base_model_changed(base_model_id):
new_choices = get_checkpoints(
os.path.join("checkpoints", os.path.basename(str(base_model_id)))
) + get_checkpoints(model_type="checkpoints")
if "turbo" in base_model_id:
new_steps = gr.Dropdown(
value=cmd_opts.steps,
choices=[1, 2, 3],
label="\U0001F3C3\U0000FE0F Steps",
allow_custom_value=False,
)
if "stable-diffusion-xl-base-1.0" in base_model_id:
new_steps = gr.Dropdown(
value=40,
choices=[20, 25, 30, 35, 40, 45, 50],
label="\U0001F3C3\U0000FE0F Steps",
allow_custom_value=False,
)
elif ".py" in base_model_id:
new_steps = gr.Dropdown(
value=20,
choices=[10, 15, 20, 28],
label="\U0001F3C3\U0000FE0F Steps",
allow_custom_value=True,
)
else:
new_steps = gr.Dropdown(
value=cmd_opts.steps,
choices=[10, 20, 30, 40, 50],
label="\U0001F3C3\U0000FE0F Steps",
allow_custom_value=True,
)
return gr.Dropdown(
value=new_choices[0] if len(new_choices) > 0 else "None",
choices=["None"] + new_choices,
)
return [
gr.Dropdown(
value=new_choices[0] if len(new_choices) > 0 else "None",
choices=["None"] + new_choices,
),
new_steps,
]
with gr.Blocks(title="Stable Diffusion") as sd_element:
@@ -256,16 +291,17 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
choices=global_obj.get_device_list(),
allow_custom_value=False,
)
target_triple = gr.Textbox(
elem_id="target_triple",
label="Architecture",
value="",
)
with gr.Row():
ondemand = gr.Checkbox(
value=cmd_opts.lowvram,
label="Low VRAM",
interactive=True,
visible=False,
)
target_triple = gr.Textbox(
elem_id="target_triple",
label="Architecture",
value="",
)
precision = gr.Radio(
label="Precision",
@@ -281,27 +317,33 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
label="\U000026F0\U0000FE0F Base Model",
info="Select or enter HF model ID",
elem_id="custom_model",
value="stabilityai/stable-diffusion-2-1-base",
value="stabilityai/sdxl-turbo",
choices=sd_default_models,
allow_custom_value=True,
) # base_model_id
with gr.Row():
height = gr.Slider(
384,
512,
1024,
value=cmd_opts.height,
step=8,
value=512,
step=512,
label="\U00002195\U0000FE0F Height",
interactive=False, # DEMO
visible=False, # DEMO
)
width = gr.Slider(
384,
512,
1024,
value=cmd_opts.width,
step=8,
value=512,
step=512,
label="\U00002194\U0000FE0F Width",
interactive=False, # DEMO
visible=False, # DEMO
)
with gr.Accordion(
label="\U00002696\U0000FE0F Model Weights", open=False
label="\U00002696\U0000FE0F Model Weights",
open=False,
visible=False, # DEMO
):
with gr.Column():
custom_weights = gr.Dropdown(
@@ -313,11 +355,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
choices=["None"]
+ get_checkpoints(os.path.basename(str(base_model_id))),
) # custom_weights
base_model_id.change(
fn=base_model_changed,
inputs=[base_model_id],
outputs=[custom_weights],
)
sd_vae_info = (str(get_checkpoints_path("vae"))).replace(
"\\", "\n\\"
)
@@ -369,7 +406,9 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
show_progress=False,
)
with gr.Accordion(
label="\U0001F9EA\U0000FE0F Input Image Processing", open=False
label="\U0001F9EA\U0000FE0F Input Image Processing",
open=False,
visible=False,
):
strength = gr.Slider(
0,
@@ -403,24 +442,23 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
seed = gr.Textbox(
value=cmd_opts.seed,
label="\U0001F331\U0000FE0F Seed",
info="An integer or a JSON list of integers, -1 for random",
info="An integer, -1 for random",
show_copy_button=True,
)
scheduler = gr.Dropdown(
elem_id="scheduler",
label="\U0001F4C5\U0000FE0F Scheduler",
info="\U000E0020", # forces same height as seed
value="EulerDiscrete",
value="EulerAncestralDiscrete",
choices=scheduler_model_map.keys(),
allow_custom_value=False,
)
with gr.Row():
steps = gr.Slider(
1,
100,
steps = gr.Dropdown(
value=cmd_opts.steps,
step=1,
choices=[1, 2, 3, 4],
label="\U0001F3C3\U0000FE0F Steps",
allow_custom_value=True,
)
guidance_scale = gr.Slider(
0,
@@ -478,17 +516,17 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
with gr.Row():
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
minimum=512,
maximum=1024,
value=512,
step=8,
step=512,
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
minimum=512,
maximum=1024,
value=512,
step=8,
step=512,
)
make_canvas = gr.Button(
value="Make Canvas!",
@@ -558,7 +596,9 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
lambda: gr.Tabs(selected=101),
outputs=[sd_tabs],
)
with gr.Tab(label="Input Image", id=100) as sd_tab_init_image:
with gr.Tab(
label="Input Image", id=100, visible=False
) as sd_tab_init_image: # DEMO
with gr.Column(elem_classes=["sd-right-panel"]):
with gr.Row(elem_classes=["fill"]):
# TODO: make this import image prompt info if it exists
@@ -604,10 +644,10 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
step=1,
label="Batch Size",
interactive=True,
visible=True,
visible=False, # DEMO
)
compiled_pipeline = gr.Checkbox(
False,
True,
label="Faster txt2img (SDXL only)",
)
with gr.Row():
@@ -618,18 +658,18 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
queue=False,
show_progress=False,
)
stop_batch = gr.Button("Stop")
stop_batch = gr.Button("Stop", visible=False)
with gr.Tab(label="Config", id=102) as sd_tab_config:
with gr.Column(elem_classes=["sd-right-panel"]):
with gr.Row(elem_classes=["fill"]):
Path(get_configs_path()).mkdir(
parents=True, exist_ok=True
)
write_default_sd_configs(get_configs_path())
default_config_file = os.path.join(
get_configs_path(),
"default_sd_config.json",
"sdxl-turbo.json",
)
write_default_sd_configs(get_configs_path())
sd_json = gr.JSON(
elem_classes=["fill"],
value=view_json_file(default_config_file),
@@ -644,7 +684,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
if cmd_opts.configs_path
else get_configs_path()
),
height=75,
height=200,
)
with gr.Column(scale=1):
save_sd_config = gr.Button(
@@ -655,13 +695,13 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
size="sm",
components=sd_json,
)
with gr.Row():
sd_config_name = gr.Textbox(
value="Config Name",
info="Name of the file this config will be saved to.",
interactive=True,
show_label=False,
)
# with gr.Row():
sd_config_name = gr.Textbox(
value="Config Name",
info="Name of the file this config will be saved to.",
interactive=True,
show_label=False,
)
load_sd_config.change(
fn=load_sd_cfg,
inputs=[sd_json, load_sd_config],
@@ -718,6 +758,11 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
logger.read_sd_logs, None, std_output, every=1
)
sd_status = gr.Textbox(visible=False)
base_model_id.change(
fn=base_model_changed,
inputs=[base_model_id],
outputs=[custom_weights, steps],
)
pull_kwargs = dict(
fn=pull_sd_configs,

View File

@@ -89,7 +89,7 @@ sdxl_turbo = r"""{
}"""
default_sd_configs = {
"default_sd_config.json": default_sd_config,
# "default_sd_config.json": sdxl_turbo,
"sdxl-30steps.json": sdxl_30steps,
"sdxl-turbo.json": sdxl_turbo,
}

View File

@@ -17,8 +17,9 @@ from apps.shark_studio.web.utils.default_configs import default_sd_configs
def write_default_sd_configs(path):
for key in default_sd_configs.keys():
config_fpath = os.path.join(path, key)
with open(config_fpath, "w") as f:
f.write(default_sd_configs[key])
if not os.path.exists(config_fpath):
with open(config_fpath, "w") as f:
f.write(default_sd_configs[key])
def safe_name(name):
@@ -87,6 +88,8 @@ def get_checkpoints_path(model_type=""):
def get_checkpoints(model_type="checkpoints"):
ckpt_files = []
file_types = checkpoints_filetypes
if model_type == "scripts":
file_types = ["shark_*.py"]
if model_type == "lora":
file_types = file_types + ("*.pt", "*.bin")
for extn in file_types:

View File

@@ -89,7 +89,7 @@ else {python -m venv .\shark.venv\}
python -m pip install --upgrade pip
pip install wheel
pip install --pre -r requirements.txt
pip install --force-reinstall https://github.com/nod-ai/SRT/releases/download/candidate-20240528.279/iree_compiler-20240528.279-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240528.279/iree_runtime-20240528.279-cp311-cp311-win_amd64.whl
pip install https://github.com/nod-ai/SRT/releases/download/candidate-20240602.283/iree_compiler-20240602.283-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240602.283/iree_runtime-20240602.283-cp311-cp311-win_amd64.whl
pip install -e .
Write-Host "Source your venv with ./shark.venv/Scripts/activate"