Add compiled pipeline option

This commit is contained in:
Ean Garvey
2024-05-30 15:03:58 -04:00
parent fe142c8a0b
commit 151d3009bc
2 changed files with 15 additions and 11 deletions

View File

@@ -181,12 +181,17 @@ class StableDiffusion:
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
gc.collect()
def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img):
def prepare_pipe(
self, custom_weights, adapters, embeddings, is_img2img, compiled_pipeline
):
print(f"\n[LOG] Preparing pipeline...")
self.is_img2img = False
mlirs = copy.deepcopy(self.model_map)
vmfbs = copy.deepcopy(self.model_map)
weights = copy.deepcopy(self.model_map)
if not self.is_sdxl:
compiled_pipeline = False
self.compiled_pipeline = compiled_pipeline
if custom_weights:
custom_weights = os.path.join(
@@ -253,7 +258,6 @@ class StableDiffusion:
guidance_scale,
seed,
ondemand,
repeatable_seeds,
resample_type,
control_mode,
hints,
@@ -306,7 +310,7 @@ def shark_sd_fn(
device: str,
target_triple: str,
ondemand: bool,
repeatable_seeds: bool,
compiled_pipeline: bool,
resample_type: str,
controlnets: dict,
embeddings: dict,
@@ -369,6 +373,7 @@ def shark_sd_fn(
"adapters": adapters,
"embeddings": embeddings,
"is_img2img": is_img2img,
"compiled_pipeline": compiled_pipeline,
}
submit_run_kwargs = {
"prompt": prompt,
@@ -378,7 +383,6 @@ def shark_sd_fn(
"guidance_scale": guidance_scale,
"seed": seed,
"ondemand": ondemand,
"repeatable_seeds": repeatable_seeds,
"resample_type": resample_type,
"control_mode": control_mode,
"hints": hints,

View File

@@ -120,7 +120,7 @@ def pull_sd_configs(
device,
target_triple,
ondemand,
repeatable_seeds,
compiled_pipeline,
resample_type,
controlnets,
embeddings,
@@ -179,7 +179,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
sd_json["device"],
sd_json["target_triple"],
sd_json["ondemand"],
sd_json["repeatable_seeds"],
sd_json["compiled_pipeline"],
sd_json["resample_type"],
sd_json["controlnets"],
sd_json["embeddings"],
@@ -606,9 +606,9 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
interactive=True,
visible=True,
)
repeatable_seeds = gr.Checkbox(
cmd_opts.repeatable_seeds,
label="Use Repeatable Seeds for Batches",
compiled_pipeline = gr.Checkbox(
False,
label="Faster txt2img (SDXL only)",
)
with gr.Row():
stable_diffusion = gr.Button("Start")
@@ -685,7 +685,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
device,
target_triple,
ondemand,
repeatable_seeds,
compiled_pipeline,
resample_type,
cnet_config,
embeddings_config,
@@ -741,7 +741,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
device,
target_triple,
ondemand,
repeatable_seeds,
compiled_pipeline,
resample_type,
cnet_config,
embeddings_config,