Modify falcon pipeline to add fp16 support (#1551)

This commit is contained in:
Vivek Khandelwal
2023-06-19 22:27:13 +05:30
committed by GitHub
parent 20ddd96ef7
commit 015d0132c3
3 changed files with 76 additions and 43 deletions

View File

@@ -28,6 +28,7 @@ parser = argparse.ArgumentParser(
description="runs a falcon model",
)
parser.add_argument("--falcon_variant_to_use", default="7b", help="7b, 40b")
parser.add_argument(
"--precision", "-p", default="fp32", help="fp32, fp16, int8, int4"
)
@@ -40,7 +41,12 @@ parser.add_argument(
default=None,
help="path to falcon's mlir file",
)
parser.add_argument(
"--use_precompiled_model",
default=True,
action=argparse.BooleanOptionalAction,
help="use the precompiled vmfb",
)
parser.add_argument(
"--load_mlir_from_shark_tank",
default=False,
@@ -59,12 +65,12 @@ class Falcon(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="tiiuae/falcon-7b-instruct",
hf_model_path,
max_num_tokens=150,
device="cuda",
precision="fp32",
falcon_mlir_path=Path("falcon.mlir"),
falcon_vmfb_path=Path("falcon.vmfb"),
falcon_mlir_path=None,
falcon_vmfb_path=None,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_padding_length = 100
@@ -85,7 +91,7 @@ class Falcon(SharkLLMBase):
return tokenizer
def get_src_model(self):
print("Loading src model")
print("Loading src model: ", self.model_name)
kwargs = {"torch_dtype": torch.float, "trust_remote_code": True}
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
@@ -93,9 +99,26 @@ class Falcon(SharkLLMBase):
return falcon_model
def compile_falcon(self):
vmfb = get_vmfb_from_path(self.falcon_vmfb_path, self.device, "linalg")
if vmfb is not None:
return vmfb
if args.use_precompiled_vmfb:
if not self.falcon_vmfb_path.exists():
# Downloading VMFB from shark_tank
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ "_"
+ self.device
+ ".vmfb",
self.falcon_vmfb_path.absolute(),
single_file=True,
)
vmfb = get_vmfb_from_path(
self.falcon_vmfb_path, self.device, "linalg"
)
if vmfb is not None:
return vmfb
print(
f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}. Trying to work with"
@@ -106,27 +129,26 @@ class Falcon(SharkLLMBase):
bytecode = f.read()
else:
mlir_generated = False
if args.load_mlir_from_shark_tank:
if self.precision == "fp32":
# download MLIR from shark_tank for fp32
download_public_file(
"gs://shark_tank/falcon/7b/cuda/falcon.mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
)
if self.falcon_mlir_path.exists():
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
else:
print(
"Only fp32 mlir added to tank, generating mlir on device."
)
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ ".mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
)
if self.falcon_mlir_path.exists():
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
if not mlir_generated:
compilation_input_ids = torch.randint(
@@ -184,6 +206,7 @@ class Falcon(SharkLLMBase):
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-spirv-index-bits=64",
],
)
print("Saved falcon vmfb at ", str(path))
@@ -192,17 +215,6 @@ class Falcon(SharkLLMBase):
return shark_module
def compile(self):
if (
not self.falcon_vmfb_path.exists()
and self.device == "cuda"
and self.precision == "fp32"
):
download_public_file(
"gs://shark_tank/falcon/7b/cuda/falcon.vmfb",
self.falcon_vmfb_path.absolute(),
single_file=True,
)
falcon_shark_model = self.compile_falcon()
return falcon_shark_model
@@ -375,6 +387,8 @@ class Falcon(SharkLLMBase):
(model_inputs["input_ids"], model_inputs["attention_mask"]),
)
)
if self.precision == "fp16":
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
# pre-process distribution
@@ -428,18 +442,35 @@ if __name__ == "__main__":
args = parser.parse_args()
falcon_mlir_path = (
Path("falcon.mlir")
Path(
"falcon_"
+ args.falcon_variant_to_use
+ "_"
+ args.precision
+ ".mlir"
)
if args.falcon_mlir_path is None
else Path(args.falcon_mlir_path)
)
falcon_vmfb_path = (
Path("falcon.vmfb")
Path(
"falcon_"
+ args.falcon_variant_to_use
+ "_"
+ args.precision
+ "_"
+ args.device
+ ".vmfb"
)
if args.falcon_vmfb_path is None
else Path(args.falcon_vmfb_path)
)
falcon = Falcon(
"falcon",
"falcon_" + args.falcon_variant_to_use,
hf_model_path="tiiuae/falcon-"
+ args.falcon_variant_to_use
+ "-instruct",
device=args.device,
precision=args.precision,
falcon_mlir_path=falcon_mlir_path,

View File

@@ -37,4 +37,5 @@ from apps.stable_diffusion.src.utils.utils import (
get_generation_text_info,
update_lora_weight,
resize_stencil,
_compile_module,
)

View File

@@ -370,6 +370,7 @@ def transform_fx(fx_g):
torch.ops.aten.arange,
torch.ops.aten.empty,
torch.ops.aten.zeros,
torch.ops.aten.zeros_like,
]:
if node.kwargs.get("dtype") == torch.float32:
node.kwargs = kwargs_dict