mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 05:24:00 -05:00
Revive SD downloads from shark_tank. (#1465)
This commit is contained in:
@@ -86,8 +86,10 @@ class StableDiffusionPipeline:
|
|||||||
self.text_encoder = self.sd_model.clip()
|
self.text_encoder = self.sd_model.clip()
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
breakpoint()
|
||||||
self.text_encoder = get_clip()
|
self.text_encoder = get_clip()
|
||||||
except:
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
print("download pipeline failed, falling back to import_mlir")
|
print("download pipeline failed, falling back to import_mlir")
|
||||||
self.text_encoder = self.sd_model.clip()
|
self.text_encoder = self.sd_model.clip()
|
||||||
|
|
||||||
@@ -104,7 +106,8 @@ class StableDiffusionPipeline:
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
self.unet = get_unet()
|
self.unet = get_unet()
|
||||||
except:
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
print("download pipeline failed, falling back to import_mlir")
|
print("download pipeline failed, falling back to import_mlir")
|
||||||
self.unet = self.sd_model.unet()
|
self.unet = self.sd_model.unet()
|
||||||
|
|
||||||
@@ -121,7 +124,8 @@ class StableDiffusionPipeline:
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
self.vae = get_vae()
|
self.vae = get_vae()
|
||||||
except:
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
print("download pipeline failed, falling back to import_mlir")
|
print("download pipeline failed, falling back to import_mlir")
|
||||||
self.vae = self.sd_model.vae()
|
self.vae = self.sd_model.vae()
|
||||||
|
|
||||||
|
|||||||
@@ -83,7 +83,6 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
|
|||||||
|
|
||||||
# Set local shark_tank cache directory.
|
# Set local shark_tank cache directory.
|
||||||
shark_args.local_tank_cache = args.local_tank_cache
|
shark_args.local_tank_cache = args.local_tank_cache
|
||||||
|
|
||||||
from shark.shark_downloader import download_model
|
from shark.shark_downloader import download_model
|
||||||
|
|
||||||
if "cuda" in args.device:
|
if "cuda" in args.device:
|
||||||
|
|||||||
@@ -61,6 +61,8 @@ def download_public_file(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
destination_filename = os.path.join(destination_folder_name, blob_name)
|
destination_filename = os.path.join(destination_folder_name, blob_name)
|
||||||
|
if os.path.isdir(destination_filename):
|
||||||
|
continue
|
||||||
with open(destination_filename, "wb") as f:
|
with open(destination_filename, "wb") as f:
|
||||||
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
|
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
|
||||||
storage_client.download_blob_to_file(blob, file_obj)
|
storage_client.download_blob_to_file(blob, file_obj)
|
||||||
@@ -210,6 +212,9 @@ def download_model(
|
|||||||
+ "_BS"
|
+ "_BS"
|
||||||
+ str(import_args["batch_size"])
|
+ str(import_args["batch_size"])
|
||||||
)
|
)
|
||||||
|
elif any(model in model_name for model in ["clip", "unet", "vae"]):
|
||||||
|
# TODO(Ean Garvey): rework extended naming such that device is only included in model_name after .vmfb compilation.
|
||||||
|
model_dir_name = model_name
|
||||||
else:
|
else:
|
||||||
model_dir_name = model_name + "_" + frontend
|
model_dir_name = model_name + "_" + frontend
|
||||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||||
@@ -270,6 +275,9 @@ def download_model(
|
|||||||
tuned_str = "" if tuned is None else "_" + tuned
|
tuned_str = "" if tuned is None else "_" + tuned
|
||||||
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
|
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
|
||||||
filename = os.path.join(model_dir, model_name + suffix)
|
filename = os.path.join(model_dir, model_name + suffix)
|
||||||
|
print(
|
||||||
|
f"Verifying that model artifacts were downloaded successfully to {filename}..."
|
||||||
|
)
|
||||||
if not os.path.exists(filename):
|
if not os.path.exists(filename):
|
||||||
from tank.generate_sharktank import gen_shark_files
|
from tank.generate_sharktank import gen_shark_files
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
{
|
{
|
||||||
"version": "2023-03-31_02d52bb"
|
"version": "nightly"
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user