Compare commits

...

3 Commits

4 changed files with 13 additions and 12 deletions

View File

@@ -310,9 +310,12 @@ def _prepare_attn_mask(
def download_model(destination_folder, model_name):
download_public_file(
f"gs://shark_tank/sharded_bloom/{model_name}/", destination_folder
)
if model_name == "bloom":
subprocess.run(["gsutil", "cp", "-r", "gs://shark_tank/sharded_bloom/bloom/", f"{destination_folder}"])
else:
download_public_file(
f"gs://shark_tank/sharded_bloom/{model_name}/", destination_folder
)
def compile_embeddings(embeddings_layer, input_ids, path):
@@ -726,11 +729,6 @@ if __name__ == "__main__":
)
args = parser.parse_args()
if args.create_mlirs and args.large_model_memory_efficient:
print(
"Warning: If you need to use memory efficient mode, you probably want to use 'download' instead"
)
if not os.path.isdir(args.model_path):
os.mkdir(args.model_path)
@@ -744,6 +742,9 @@ if __name__ == "__main__":
print(
"WARNING: It is not advised to turn on both download and create_mlirs"
)
if args.model_name == "bloom" and (args.create_mlirs or args.download):
urllib.request.urlretrieve("https://huggingface.co/bigscience/bloom/resolve/main/pytorch_model_00001-of-00072.bin", f"{args.model_path}/pytorch_model_00001-of-00072.bin")
if args.download:
download_model(args.model_path, args.model_name)
if args.create_mlirs:

View File

@@ -324,7 +324,7 @@ if __name__ == "__main__":
mlir_str = bytes(mlir_str, "utf-8")
if config["n_embed"] == 14336:
if "n_embed" in config.keys() and config["n_embed"] == 14336:
def get_state_dict():
d = torch.load(

View File

@@ -294,7 +294,7 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
allocators=shark_args.device_allocator,
#allocators=shark_args.device_allocator,
)
config = ireert.Config(device=haldevice)
else:
@@ -406,7 +406,7 @@ def get_iree_runtime_config(device):
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device_by_uri(
device,
allocators=shark_args.device_allocator,
#allocators=shark_args.device_allocator,
)
config = ireert.Config(device=haldevice)
return config

View File

@@ -44,4 +44,4 @@ def get_iree_cpu_args():
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
raise Exception(error_message)
print(f"Target triple found:{target_triple}")
return [f"--iree-llvmcpu-target-triple={target_triple}"]
return [f"--iree-llvm-target-triple={target_triple}"]