mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-12 23:38:12 -05:00
Compare commits
3 Commits
diffusers-
...
20240530.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
151d3009bc | ||
|
|
fe142c8a0b | ||
|
|
6dfded3759 |
@@ -181,12 +181,17 @@ class StableDiffusion:
|
||||
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
|
||||
gc.collect()
|
||||
|
||||
def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img):
|
||||
def prepare_pipe(
|
||||
self, custom_weights, adapters, embeddings, is_img2img, compiled_pipeline
|
||||
):
|
||||
print(f"\n[LOG] Preparing pipeline...")
|
||||
self.is_img2img = False
|
||||
mlirs = copy.deepcopy(self.model_map)
|
||||
vmfbs = copy.deepcopy(self.model_map)
|
||||
weights = copy.deepcopy(self.model_map)
|
||||
if not self.is_sdxl:
|
||||
compiled_pipeline = False
|
||||
self.compiled_pipeline = compiled_pipeline
|
||||
|
||||
if custom_weights:
|
||||
custom_weights = os.path.join(
|
||||
@@ -253,7 +258,6 @@ class StableDiffusion:
|
||||
guidance_scale,
|
||||
seed,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
control_mode,
|
||||
hints,
|
||||
@@ -272,7 +276,7 @@ class StableDiffusion:
|
||||
def shark_sd_fn_dict_input(
|
||||
sd_kwargs: dict,
|
||||
):
|
||||
print("[LOG] Submitting Request...")
|
||||
print("\n[LOG] Submitting Request...")
|
||||
|
||||
for key in sd_kwargs:
|
||||
if sd_kwargs[key] in [None, []]:
|
||||
@@ -282,9 +286,8 @@ def shark_sd_fn_dict_input(
|
||||
if key == "seed":
|
||||
sd_kwargs[key] = int(sd_kwargs[key])
|
||||
|
||||
for i in range(1):
|
||||
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
|
||||
yield generated_imgs
|
||||
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
|
||||
return generated_imgs
|
||||
|
||||
|
||||
def shark_sd_fn(
|
||||
@@ -307,7 +310,7 @@ def shark_sd_fn(
|
||||
device: str,
|
||||
target_triple: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
compiled_pipeline: bool,
|
||||
resample_type: str,
|
||||
controlnets: dict,
|
||||
embeddings: dict,
|
||||
@@ -370,6 +373,7 @@ def shark_sd_fn(
|
||||
"adapters": adapters,
|
||||
"embeddings": embeddings,
|
||||
"is_img2img": is_img2img,
|
||||
"compiled_pipeline": compiled_pipeline,
|
||||
}
|
||||
submit_run_kwargs = {
|
||||
"prompt": prompt,
|
||||
@@ -379,7 +383,6 @@ def shark_sd_fn(
|
||||
"guidance_scale": guidance_scale,
|
||||
"seed": seed,
|
||||
"ondemand": ondemand,
|
||||
"repeatable_seeds": repeatable_seeds,
|
||||
"resample_type": resample_type,
|
||||
"control_mode": control_mode,
|
||||
"hints": hints,
|
||||
@@ -412,22 +415,35 @@ def shark_sd_fn(
|
||||
for current_batch in range(batch_count):
|
||||
start_time = time.time()
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
|
||||
if not isinstance(out_imgs, list):
|
||||
out_imgs = [out_imgs]
|
||||
# total_time = time.time() - start_time
|
||||
# text_output = f"Total image(s) generation time: {total_time:.4f}sec"
|
||||
# print(f"\n[LOG] {text_output}")
|
||||
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
# break
|
||||
# else:
|
||||
save_output_img(
|
||||
out_imgs[current_batch],
|
||||
seed,
|
||||
sd_kwargs,
|
||||
)
|
||||
for batch in range(batch_size):
|
||||
save_output_img(
|
||||
out_imgs[batch],
|
||||
seed,
|
||||
sd_kwargs,
|
||||
)
|
||||
generated_imgs.extend(out_imgs)
|
||||
# TODO: make seed changes over batch counts more configurable.
|
||||
submit_run_kwargs["seed"] = submit_run_kwargs["seed"] + 1
|
||||
yield generated_imgs, status_label(
|
||||
"Stable Diffusion", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
return generated_imgs, ""
|
||||
return (generated_imgs, "")
|
||||
|
||||
|
||||
def unload_sd():
|
||||
print("Unloading models.")
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
|
||||
global_obj.clear_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
def cancel_sd():
|
||||
|
||||
@@ -19,6 +19,7 @@ from apps.shark_studio.web.utils.file_utils import (
|
||||
from apps.shark_studio.api.sd import (
|
||||
shark_sd_fn_dict_input,
|
||||
cancel_sd,
|
||||
unload_sd,
|
||||
)
|
||||
from apps.shark_studio.api.controlnet import (
|
||||
cnet_preview,
|
||||
@@ -119,7 +120,7 @@ def pull_sd_configs(
|
||||
device,
|
||||
target_triple,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
compiled_pipeline,
|
||||
resample_type,
|
||||
controlnets,
|
||||
embeddings,
|
||||
@@ -178,7 +179,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
|
||||
sd_json["device"],
|
||||
sd_json["target_triple"],
|
||||
sd_json["ondemand"],
|
||||
sd_json["repeatable_seeds"],
|
||||
sd_json["compiled_pipeline"],
|
||||
sd_json["resample_type"],
|
||||
sd_json["controlnets"],
|
||||
sd_json["embeddings"],
|
||||
@@ -587,21 +588,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
object_fit="fit",
|
||||
preview=True,
|
||||
)
|
||||
with gr.Row():
|
||||
std_output = gr.Textbox(
|
||||
value=f"{sd_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=2,
|
||||
elem_id="std_output",
|
||||
show_label=True,
|
||||
label="Log",
|
||||
show_copy_button=True,
|
||||
)
|
||||
sd_element.load(
|
||||
logger.read_sd_logs, None, std_output, every=1
|
||||
)
|
||||
sd_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
@@ -620,17 +606,15 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
interactive=True,
|
||||
visible=True,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
cmd_opts.repeatable_seeds,
|
||||
label="Use Repeatable Seeds for Batches",
|
||||
compiled_pipeline = gr.Checkbox(
|
||||
False,
|
||||
label="Faster txt2img (SDXL only)",
|
||||
)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Start")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
unload = gr.Button("Unload Models")
|
||||
unload.click(
|
||||
fn=unload_sd,
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
@@ -701,7 +685,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
device,
|
||||
target_triple,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
compiled_pipeline,
|
||||
resample_type,
|
||||
cnet_config,
|
||||
embeddings_config,
|
||||
@@ -718,6 +702,22 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
inputs=[sd_json, sd_config_name],
|
||||
outputs=[sd_config_name],
|
||||
)
|
||||
with gr.Tab(label="Log", id=103) as sd_tab_log:
|
||||
with gr.Row():
|
||||
std_output = gr.Textbox(
|
||||
value=f"{sd_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=2,
|
||||
elem_id="std_output",
|
||||
show_label=True,
|
||||
label="Log",
|
||||
show_copy_button=True,
|
||||
)
|
||||
sd_element.load(
|
||||
logger.read_sd_logs, None, std_output, every=1
|
||||
)
|
||||
sd_status = gr.Textbox(visible=False)
|
||||
|
||||
pull_kwargs = dict(
|
||||
fn=pull_sd_configs,
|
||||
@@ -741,7 +741,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
device,
|
||||
target_triple,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
compiled_pipeline,
|
||||
resample_type,
|
||||
cnet_config,
|
||||
embeddings_config,
|
||||
|
||||
Reference in New Issue
Block a user