mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add compiled pipeline option
This commit is contained in:
@@ -181,12 +181,17 @@ class StableDiffusion:
|
||||
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
|
||||
gc.collect()
|
||||
|
||||
def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img):
|
||||
def prepare_pipe(
|
||||
self, custom_weights, adapters, embeddings, is_img2img, compiled_pipeline
|
||||
):
|
||||
print(f"\n[LOG] Preparing pipeline...")
|
||||
self.is_img2img = False
|
||||
mlirs = copy.deepcopy(self.model_map)
|
||||
vmfbs = copy.deepcopy(self.model_map)
|
||||
weights = copy.deepcopy(self.model_map)
|
||||
if not self.is_sdxl:
|
||||
compiled_pipeline = False
|
||||
self.compiled_pipeline = compiled_pipeline
|
||||
|
||||
if custom_weights:
|
||||
custom_weights = os.path.join(
|
||||
@@ -253,7 +258,6 @@ class StableDiffusion:
|
||||
guidance_scale,
|
||||
seed,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
control_mode,
|
||||
hints,
|
||||
@@ -306,7 +310,7 @@ def shark_sd_fn(
|
||||
device: str,
|
||||
target_triple: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
compiled_pipeline: bool,
|
||||
resample_type: str,
|
||||
controlnets: dict,
|
||||
embeddings: dict,
|
||||
@@ -369,6 +373,7 @@ def shark_sd_fn(
|
||||
"adapters": adapters,
|
||||
"embeddings": embeddings,
|
||||
"is_img2img": is_img2img,
|
||||
"compiled_pipeline": compiled_pipeline,
|
||||
}
|
||||
submit_run_kwargs = {
|
||||
"prompt": prompt,
|
||||
@@ -378,7 +383,6 @@ def shark_sd_fn(
|
||||
"guidance_scale": guidance_scale,
|
||||
"seed": seed,
|
||||
"ondemand": ondemand,
|
||||
"repeatable_seeds": repeatable_seeds,
|
||||
"resample_type": resample_type,
|
||||
"control_mode": control_mode,
|
||||
"hints": hints,
|
||||
|
||||
@@ -120,7 +120,7 @@ def pull_sd_configs(
|
||||
device,
|
||||
target_triple,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
compiled_pipeline,
|
||||
resample_type,
|
||||
controlnets,
|
||||
embeddings,
|
||||
@@ -179,7 +179,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
|
||||
sd_json["device"],
|
||||
sd_json["target_triple"],
|
||||
sd_json["ondemand"],
|
||||
sd_json["repeatable_seeds"],
|
||||
sd_json["compiled_pipeline"],
|
||||
sd_json["resample_type"],
|
||||
sd_json["controlnets"],
|
||||
sd_json["embeddings"],
|
||||
@@ -606,9 +606,9 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
interactive=True,
|
||||
visible=True,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
cmd_opts.repeatable_seeds,
|
||||
label="Use Repeatable Seeds for Batches",
|
||||
compiled_pipeline = gr.Checkbox(
|
||||
False,
|
||||
label="Faster txt2img (SDXL only)",
|
||||
)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Start")
|
||||
@@ -685,7 +685,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
device,
|
||||
target_triple,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
compiled_pipeline,
|
||||
resample_type,
|
||||
cnet_config,
|
||||
embeddings_config,
|
||||
@@ -741,7 +741,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
device,
|
||||
target_triple,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
compiled_pipeline,
|
||||
resample_type,
|
||||
cnet_config,
|
||||
embeddings_config,
|
||||
|
||||
Reference in New Issue
Block a user