From 54e57f7771de236e94d21792019fc925f4beb34b Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Thu, 25 May 2023 12:03:21 -0500 Subject: [PATCH] Revive SD downloads from shark_tank. (#1465) --- .../pipelines/pipeline_shark_stable_diffusion_utils.py | 10 +++++++--- apps/stable_diffusion/src/utils/utils.py | 1 - shark/shark_downloader.py | 8 ++++++++ tank_version.json | 2 +- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py index e08cd9e2..fcd0af2f 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py @@ -86,8 +86,10 @@ class StableDiffusionPipeline: self.text_encoder = self.sd_model.clip() else: try: + breakpoint() self.text_encoder = get_clip() - except: + except Exception as e: + print(e) print("download pipeline failed, falling back to import_mlir") self.text_encoder = self.sd_model.clip() @@ -104,7 +106,8 @@ class StableDiffusionPipeline: else: try: self.unet = get_unet() - except: + except Exception as e: + print(e) print("download pipeline failed, falling back to import_mlir") self.unet = self.sd_model.unet() @@ -121,7 +124,8 @@ class StableDiffusionPipeline: else: try: self.vae = get_vae() - except: + except Exception as e: + print(e) print("download pipeline failed, falling back to import_mlir") self.vae = self.sd_model.vae() diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index ef773354..b7726a51 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -83,7 +83,6 @@ def get_shark_model(tank_url, model_name, extra_args=[]): # Set local shark_tank cache directory. shark_args.local_tank_cache = args.local_tank_cache - from shark.shark_downloader import download_model if "cuda" in args.device: diff --git a/shark/shark_downloader.py b/shark/shark_downloader.py index 657f271c..8005ecc1 100644 --- a/shark/shark_downloader.py +++ b/shark/shark_downloader.py @@ -61,6 +61,8 @@ def download_public_file( continue destination_filename = os.path.join(destination_folder_name, blob_name) + if os.path.isdir(destination_filename): + continue with open(destination_filename, "wb") as f: with tqdm.wrapattr(f, "write", total=blob.size) as file_obj: storage_client.download_blob_to_file(blob, file_obj) @@ -210,6 +212,9 @@ def download_model( + "_BS" + str(import_args["batch_size"]) ) + elif any(model in model_name for model in ["clip", "unet", "vae"]): + # TODO(Ean Garvey): rework extended naming such that device is only included in model_name after .vmfb compilation. + model_dir_name = model_name else: model_dir_name = model_name + "_" + frontend model_dir = os.path.join(WORKDIR, model_dir_name) @@ -270,6 +275,9 @@ def download_model( tuned_str = "" if tuned is None else "_" + tuned suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir" filename = os.path.join(model_dir, model_name + suffix) + print( + f"Verifying that model artifacts were downloaded successfully to {filename}..." + ) if not os.path.exists(filename): from tank.generate_sharktank import gen_shark_files diff --git a/tank_version.json b/tank_version.json index 6826ad37..8e3c5b59 100644 --- a/tank_version.json +++ b/tank_version.json @@ -1,3 +1,3 @@ { - "version": "2023-03-31_02d52bb" + "version": "nightly" }