mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-06 20:43:50 -05:00
Revive SD downloads from shark_tank. (#1465)
This commit is contained in:
@@ -86,8 +86,10 @@ class StableDiffusionPipeline:
|
||||
self.text_encoder = self.sd_model.clip()
|
||||
else:
|
||||
try:
|
||||
breakpoint()
|
||||
self.text_encoder = get_clip()
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.text_encoder = self.sd_model.clip()
|
||||
|
||||
@@ -104,7 +106,8 @@ class StableDiffusionPipeline:
|
||||
else:
|
||||
try:
|
||||
self.unet = get_unet()
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.unet = self.sd_model.unet()
|
||||
|
||||
@@ -121,7 +124,8 @@ class StableDiffusionPipeline:
|
||||
else:
|
||||
try:
|
||||
self.vae = get_vae()
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.vae = self.sd_model.vae()
|
||||
|
||||
|
||||
@@ -83,7 +83,6 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
|
||||
# Set local shark_tank cache directory.
|
||||
shark_args.local_tank_cache = args.local_tank_cache
|
||||
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
if "cuda" in args.device:
|
||||
|
||||
@@ -61,6 +61,8 @@ def download_public_file(
|
||||
continue
|
||||
|
||||
destination_filename = os.path.join(destination_folder_name, blob_name)
|
||||
if os.path.isdir(destination_filename):
|
||||
continue
|
||||
with open(destination_filename, "wb") as f:
|
||||
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
|
||||
storage_client.download_blob_to_file(blob, file_obj)
|
||||
@@ -210,6 +212,9 @@ def download_model(
|
||||
+ "_BS"
|
||||
+ str(import_args["batch_size"])
|
||||
)
|
||||
elif any(model in model_name for model in ["clip", "unet", "vae"]):
|
||||
# TODO(Ean Garvey): rework extended naming such that device is only included in model_name after .vmfb compilation.
|
||||
model_dir_name = model_name
|
||||
else:
|
||||
model_dir_name = model_name + "_" + frontend
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
@@ -270,6 +275,9 @@ def download_model(
|
||||
tuned_str = "" if tuned is None else "_" + tuned
|
||||
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
|
||||
filename = os.path.join(model_dir, model_name + suffix)
|
||||
print(
|
||||
f"Verifying that model artifacts were downloaded successfully to {filename}..."
|
||||
)
|
||||
if not os.path.exists(filename):
|
||||
from tank.generate_sharktank import gen_shark_files
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
{
|
||||
"version": "2023-03-31_02d52bb"
|
||||
"version": "nightly"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user