Progress indicators

This commit is contained in:
Ean Garvey
2024-06-02 10:18:09 -05:00
parent 64e63e7130
commit 349e9f70fb
2 changed files with 8 additions and 8 deletions

View File

@@ -101,7 +101,7 @@ class StableDiffusion:
external_weights: str = "safetensors",
progress=gr.Progress(),
):
progress(None, desc="Initializing pipeline...")
progress(0, desc="Initializing pipeline...")
self.ui_device = device
self.precision = precision
self.compiled_pipeline = False
@@ -164,7 +164,7 @@ class StableDiffusion:
external_weights = None
elif target_backend == "llvm-cpu":
decomp_attn = False
progress(0.5, desc="Initializing pipeline...")
self.sd_pipe = self.turbine_pipe(
hf_model_name=base_model_id,
scheduler_id=scheduler,
@@ -184,7 +184,7 @@ class StableDiffusion:
external_weights=external_weights,
custom_vae=custom_vae,
)
progress(None, desc="Pipeline initialized!...")
progress(1, desc="Pipeline initialized!...")
gc.collect()
def prepare_pipe(
@@ -196,6 +196,7 @@ class StableDiffusion:
compiled_pipeline,
progress=gr.Progress(),
):
progress(0, desc="Preparing models...")
self.is_img2img = False
mlirs = copy.deepcopy(self.model_map)
@@ -248,18 +249,18 @@ class StableDiffusion:
"diffusion_pytorch_model.safetensors",
)
weights[key] = save_irpa(vae_weights_path, "vae.")
progress(None, desc=f"Preparing pipeline for {self.ui_device}...")
progress(0.25, desc=f"Preparing pipeline for {self.ui_device}...")
vmfbs, weights = self.sd_pipe.check_prepared(
mlirs, vmfbs, weights, interactive=False
)
progress(None, desc=f"Artifacts ready!")
progress(None, desc=f"Loading pipeline on device {self.ui_device}...")
progress(.5, desc=f"Artifacts ready!")
progress(0.75, desc=f"Loading models and weights...")
self.sd_pipe.load_pipeline(
vmfbs, weights, self.rt_device, self.compiled_pipeline
)
progress(None, desc="Pipeline loaded! Generating images...")
progress(1, desc="Pipeline loaded! Generating images...")
return
def generate_images(

View File

@@ -789,7 +789,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
outputs=[
sd_json,
],
show_progress=False,
)
status_kwargs = dict(