mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Modify Falcon-180b-GPTQ sharded pipeline
This commit is contained in:
@@ -169,8 +169,9 @@ class ShardedFalcon(SharkLLMBase):
|
||||
if args.use_precompiled_model:
|
||||
if not self.falcon_vmfb_path.exists():
|
||||
# Downloading VMFB from shark_tank
|
||||
print(f"[DEBUG] Trying to download vmfb from shark_tank")
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/sharded/vmfb/"
|
||||
f"gs://shark_tank/falcon/sharded/falcon_{args.falcon_variant_to_use}/vmfb/"
|
||||
+ str(self.falcon_vmfb_path),
|
||||
self.falcon_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
@@ -195,7 +196,7 @@ class ShardedFalcon(SharkLLMBase):
|
||||
# Downloading MLIR from shark_tank
|
||||
print(f"[DEBUG] Trying to download mlir from shark_tank")
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/sharded/mlir/"
|
||||
f"gs://shark_tank/falcon/sharded/falcon_{args.falcon_variant_to_use}/mlir/"
|
||||
+ str(self.falcon_mlir_path),
|
||||
self.falcon_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
@@ -286,7 +287,7 @@ class ShardedFalcon(SharkLLMBase):
|
||||
num_in_features = 4544
|
||||
else:
|
||||
num_in_features = 14848
|
||||
sample_attention_mask.to(dtype=torch.bool)
|
||||
sample_attention_mask = sample_attention_mask.to(dtype=torch.bool)
|
||||
|
||||
sample_hidden_states = torch.zeros(
|
||||
[1, 100, num_in_features], dtype=torch.float32
|
||||
|
||||
Reference in New Issue
Block a user