mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-11 23:08:19 -05:00
Compare commits
3 Commits
debug
...
AMD-Shark-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c5705cf09 | ||
|
|
2f891a6c23 | ||
|
|
f1fb363403 |
@@ -310,9 +310,12 @@ def _prepare_attn_mask(
|
||||
|
||||
|
||||
def download_model(destination_folder, model_name):
|
||||
download_public_file(
|
||||
f"gs://shark_tank/sharded_bloom/{model_name}/", destination_folder
|
||||
)
|
||||
if model_name == "bloom":
|
||||
subprocess.run(["gsutil", "cp", "-r", "gs://shark_tank/sharded_bloom/bloom/", f"{destination_folder}"])
|
||||
else:
|
||||
download_public_file(
|
||||
f"gs://shark_tank/sharded_bloom/{model_name}/", destination_folder
|
||||
)
|
||||
|
||||
|
||||
def compile_embeddings(embeddings_layer, input_ids, path):
|
||||
@@ -726,11 +729,6 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.create_mlirs and args.large_model_memory_efficient:
|
||||
print(
|
||||
"Warning: If you need to use memory efficient mode, you probably want to use 'download' instead"
|
||||
)
|
||||
|
||||
if not os.path.isdir(args.model_path):
|
||||
os.mkdir(args.model_path)
|
||||
|
||||
@@ -744,6 +742,9 @@ if __name__ == "__main__":
|
||||
print(
|
||||
"WARNING: It is not advised to turn on both download and create_mlirs"
|
||||
)
|
||||
if args.model_name == "bloom" and (args.create_mlirs or args.download):
|
||||
|
||||
urllib.request.urlretrieve("https://huggingface.co/bigscience/bloom/resolve/main/pytorch_model_00001-of-00072.bin", f"{args.model_path}/pytorch_model_00001-of-00072.bin")
|
||||
if args.download:
|
||||
download_model(args.model_path, args.model_name)
|
||||
if args.create_mlirs:
|
||||
|
||||
@@ -324,7 +324,7 @@ if __name__ == "__main__":
|
||||
|
||||
mlir_str = bytes(mlir_str, "utf-8")
|
||||
|
||||
if config["n_embed"] == 14336:
|
||||
if "n_embed" in config.keys() and config["n_embed"] == 14336:
|
||||
|
||||
def get_state_dict():
|
||||
d = torch.load(
|
||||
|
||||
@@ -294,7 +294,7 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
|
||||
|
||||
haldevice = haldriver.create_device(
|
||||
haldriver.query_available_devices()[device_idx]["device_id"],
|
||||
allocators=shark_args.device_allocator,
|
||||
#allocators=shark_args.device_allocator,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
else:
|
||||
@@ -406,7 +406,7 @@ def get_iree_runtime_config(device):
|
||||
haldriver = ireert.get_driver(device)
|
||||
haldevice = haldriver.create_device_by_uri(
|
||||
device,
|
||||
allocators=shark_args.device_allocator,
|
||||
#allocators=shark_args.device_allocator,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
return config
|
||||
|
||||
@@ -44,4 +44,4 @@ def get_iree_cpu_args():
|
||||
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
|
||||
raise Exception(error_message)
|
||||
print(f"Target triple found:{target_triple}")
|
||||
return [f"--iree-llvmcpu-target-triple={target_triple}"]
|
||||
return [f"--iree-llvm-target-triple={target_triple}"]
|
||||
|
||||
Reference in New Issue
Block a user