Modify Falcon-180b-GPTQ sharded pipeline

This commit is contained in:
Vivek Khandelwal
2023-10-17 14:31:45 +00:00
parent 2866d665ee
commit 205e57683a

View File

@@ -169,8 +169,9 @@ class ShardedFalcon(SharkLLMBase):
if args.use_precompiled_model:
if not self.falcon_vmfb_path.exists():
# Downloading VMFB from shark_tank
print(f"[DEBUG] Trying to download vmfb from shark_tank")
download_public_file(
"gs://shark_tank/falcon/sharded/vmfb/"
f"gs://shark_tank/falcon/sharded/falcon_{args.falcon_variant_to_use}/vmfb/"
+ str(self.falcon_vmfb_path),
self.falcon_vmfb_path.absolute(),
single_file=True,
@@ -195,7 +196,7 @@ class ShardedFalcon(SharkLLMBase):
# Downloading MLIR from shark_tank
print(f"[DEBUG] Trying to download mlir from shark_tank")
download_public_file(
"gs://shark_tank/falcon/sharded/mlir/"
f"gs://shark_tank/falcon/sharded/falcon_{args.falcon_variant_to_use}/mlir/"
+ str(self.falcon_mlir_path),
self.falcon_mlir_path.absolute(),
single_file=True,
@@ -286,7 +287,7 @@ class ShardedFalcon(SharkLLMBase):
num_in_features = 4544
else:
num_in_features = 14848
sample_attention_mask.to(dtype=torch.bool)
sample_attention_mask = sample_attention_mask.to(dtype=torch.bool)
sample_hidden_states = torch.zeros(
[1, 100, num_in_features], dtype=torch.float32