Add Llama2 13B int4 fp16 support (#1784)

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
Abhishek Varma
2023-08-23 22:30:32 +05:30
committed by GitHub
parent 7ee3e4ba5d
commit db990826d3
5 changed files with 387 additions and 42 deletions

View File

@@ -37,7 +37,8 @@ from apps.language_models.src.model_wrappers.vicuna4 import (
)
from apps.language_models.src.model_wrappers.vicuna_model import (
FirstVicuna,
SecondVicuna,
SecondVicuna7B,
SecondVicuna13B,
)
from apps.language_models.utils import (
get_vmfb_from_path,
@@ -50,7 +51,6 @@ from shark.shark_inference import SharkInference
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
parser = argparse.ArgumentParser(
prog="vicuna runner",
description="runs a vicuna model",
@@ -108,7 +108,7 @@ parser.add_argument(
"--model_name",
type=str,
default="vicuna",
choices=["vicuna", "llama2_7b", "llama2_70b"],
choices=["vicuna", "llama2_7b", "llama2_13b", "llama2_70b"],
help="Specify which model to run.",
)
parser.add_argument(
@@ -189,7 +189,7 @@ class VicunaBase(SharkLLMBase):
return vicuna_model
def combine_mlir_scripts(
self, first_vicuna_mlir, second_vicuna_mlir, output_name
self, first_vicuna_mlir, second_vicuna_mlir, output_name, model_name=None
):
print(f"[DEBUG] combining first and second mlir")
print(f"[DEBIG] output_name = {output_name}")
@@ -363,7 +363,8 @@ class VicunaBase(SharkLLMBase):
f_.writelines(line + "\n" for line in global_vars)
f_.writelines(line + "\n" for line in f1)
f_.writelines(line + "\n" for line in f2)
f_.writelines(line + "\n" for line in [module_end])
if not (model_name and "llama2_13b" in model_name):
f_.writelines(line + "\n" for line in [module_end])
del maps1
del maps2
@@ -1254,6 +1255,8 @@ class UnshardedVicuna(VicunaBase):
self.hf_auth_token = hf_auth_token
if self.model_name == "llama2_7b":
self.hf_model_path = "meta-llama/Llama-2-7b-chat-hf"
elif self.model_name == "llama2_13b":
self.hf_model_path = "meta-llama/Llama-2-13b-chat-hf"
elif self.model_name == "llama2_70b":
self.hf_model_path = "meta-llama/Llama-2-70b-chat-hf"
print(f"[DEBUG] hf model name: {self.hf_model_path}")
@@ -1336,7 +1339,7 @@ class UnshardedVicuna(VicunaBase):
new_lines.append(line)
return "\n".join(new_lines)
def write_in_dynamic_inputs1(self, module):
def write_in_dynamic_inputs1(self, module, model_name):
print("[DEBUG] writing dynamic inputs to second vicuna")
def remove_constant_dim(line):
@@ -1355,7 +1358,7 @@ class UnshardedVicuna(VicunaBase):
line = re.sub("c19", "dim", line)
if " 19," in line:
line = re.sub(" 19,", " %dim,", line)
if "20x" in line:
if "x20x" in line or "<20x" in line:
line = re.sub("20x", "?x", line)
line = re.sub("tensor.empty\(\)", "tensor.empty(%dimp1)", line)
if " 20," in line:
@@ -1365,12 +1368,21 @@ class UnshardedVicuna(VicunaBase):
module = module.splitlines()
new_lines = []
# Using a while loop and the pop method to avoid creating a copy of module
if "llama2_13b" in model_name:
pkv_tensor_shape = "tensor<1x40x?x128x"
else:
pkv_tensor_shape = "tensor<1x32x?x128x"
if self.precision in ["fp16", "int4", "int8"]:
pkv_tensor_shape += "f16>"
else:
pkv_tensor_shape += "f32>"
while module:
line = module.pop(0)
if "%c19_i64 = arith.constant 19 : i64" in line:
new_lines.append("%c2 = arith.constant 2 : index")
new_lines.append(
f"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128x{'f16' if self.precision == 'fp16' else 'f32'}>"
f"%dim_4_int = tensor.dim %arg1, %c2 : {pkv_tensor_shape}"
)
new_lines.append(
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
@@ -1395,6 +1407,7 @@ class UnshardedVicuna(VicunaBase):
def compile(self, download_vmfb=False):
# Testing : DO NOT Download Vmfbs if not found. Modify later
# download vmfbs for A100
print(f"Looking into gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.vicuna_vmfb_path.name}")
if not self.vicuna_vmfb_path.exists() and download_vmfb:
download_public_file(
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.vicuna_vmfb_path.name}",
@@ -1420,16 +1433,20 @@ class UnshardedVicuna(VicunaBase):
mlir_generated = False
if self.load_mlir_from_shark_tank:
# download MLIR from shark tank
download_public_file(
f"gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}",
self.vicuna_mlir_path.absolute(),
single_file=True,
)
if self.vicuna_mlir_path.exists():
with open(self.vicuna_mlir_path, "rb") as f:
combined_module = f.read()
mlir_generated = True
else:
for suffix in ["mlir", "mlirbc"]:
self.vicuna_mlir_path = self.get_model_path(suffix)
download_public_file(
f"gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}",
self.vicuna_mlir_path.absolute(),
single_file=True,
)
if self.vicuna_mlir_path.exists():
with open(self.vicuna_mlir_path, "rb") as f:
combined_module = f.read()
mlir_generated = True
break
self.vicuna_mlir_path = self.get_model_path("mlir")
if not mlir_generated:
print(
f"[DEBUG] failed to download {self.vicuna_mlir_path.name} from shark tank"
)
@@ -1465,12 +1482,11 @@ class UnshardedVicuna(VicunaBase):
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.precision in ["fp16", "int4"]
ts_graph = import_with_fx(
model,
firstVicunaCompileInput,
is_f16=True
if self.precision in ["fp16", "int4"]
else False,
is_f16=is_f16,
precision=self.precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
@@ -1523,6 +1539,7 @@ class UnshardedVicuna(VicunaBase):
if self.cache_vicunas:
with open(f"first_{self.precision}.mlir", "w+") as f:
f.write(first_module)
print("Finished writing IR after dynamic")
if Path(f"second_{self.precision}.mlir").exists():
print(f"loading second_{self.precision}.mlir")
@@ -1533,27 +1550,41 @@ class UnshardedVicuna(VicunaBase):
compilation_input_ids = torch.zeros(
[1, 1], dtype=torch.int64
)
if self.model_name == "llama2_13b":
dim1 = 40
total_tuple = 80
else:
dim1 = 32
total_tuple = 64
pkv = tuple(
(torch.zeros([1, 32, 19, 128], dtype=torch.float32))
for _ in range(64)
(torch.zeros([1, dim1, 19, 128], dtype=torch.float32))
for _ in range(total_tuple)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
model = SecondVicuna(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
if self.model_name == "llama2_13b":
model = SecondVicuna13B(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
else:
model = SecondVicuna7B(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.precision in ["fp16", "int4"]
ts_graph = import_with_fx(
model,
secondVicunaCompileInput,
is_f16=True
if self.precision in ["fp16", "int4"]
else False,
is_f16=is_f16,
precision=self.precision,
f16_input_mask=[False] + [True] * 64,
f16_input_mask=[False] + [True] * total_tuple,
mlir_type="torchscript",
)
del model
@@ -1561,7 +1592,7 @@ class UnshardedVicuna(VicunaBase):
secondVicunaCompileInput = get_f16_inputs(
secondVicunaCompileInput,
True,
f16_input_mask=[False] + [True] * 64,
f16_input_mask=[False] + [True] * total_tuple,
)
secondVicunaCompileInput = list(secondVicunaCompileInput)
for i in range(len(secondVicunaCompileInput)):
@@ -1583,7 +1614,6 @@ class UnshardedVicuna(VicunaBase):
use_tracing=False,
verbose=False,
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
second_module,
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
@@ -1607,11 +1637,13 @@ class UnshardedVicuna(VicunaBase):
str(second_module)
)
if self.cache_vicunas:
with open(f"second_{self.precision}.mlir", "w+") as f:
with open(f"second_{self.precision}.mlir", 'w') as f:
f.write(second_module)
print("Finished writing IR after dynamic")
combined_module = self.combine_mlir_scripts(
first_module, second_module, self.vicuna_mlir_path
first_module, second_module, self.vicuna_mlir_path, self.model_name
)
del first_module, second_module
@@ -1715,6 +1747,15 @@ start_message = {
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
),
"llama2_13b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
),
"llama2_70b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
@@ -1835,6 +1876,7 @@ if __name__ == "__main__":
model_list = {
"vicuna": "vicuna=>TheBloke/vicuna-7B-1.1-HF",
"llama2_7b": "llama2_7b=>meta-llama/Llama-2-7b-chat-hf",
"llama2_13b": "llama2_13b=>meta-llama/Llama-2-13b-chat-hf",
"llama2_70b": "llama2_70b=>meta-llama/Llama-2-70b-chat-hf",
}
while True:

View File

@@ -47,7 +47,7 @@ from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
)
from apps.language_models.src.model_wrappers.vicuna_model import (
FirstVicuna,
SecondVicuna,
SecondVicuna7B,
)
from apps.language_models.utils import (
get_vmfb_from_path,

View File

@@ -48,7 +48,7 @@ class FirstVicuna(torch.nn.Module):
return tuple(return_vals)
class SecondVicuna(torch.nn.Module):
class SecondVicuna7B(torch.nn.Module):
def __init__(
self,
model_path,
@@ -290,6 +290,296 @@ class SecondVicuna(torch.nn.Module):
return tuple(return_vals)
class SecondVicuna13B(torch.nn.Module):
def __init__(
self,
model_path,
precision="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
if precision in ["int4", "int8"]:
print("Second Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float16 if precision == "int4" else torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(
self,
i0,
i1,
i2,
i3,
i4,
i5,
i6,
i7,
i8,
i9,
i10,
i11,
i12,
i13,
i14,
i15,
i16,
i17,
i18,
i19,
i20,
i21,
i22,
i23,
i24,
i25,
i26,
i27,
i28,
i29,
i30,
i31,
i32,
i33,
i34,
i35,
i36,
i37,
i38,
i39,
i40,
i41,
i42,
i43,
i44,
i45,
i46,
i47,
i48,
i49,
i50,
i51,
i52,
i53,
i54,
i55,
i56,
i57,
i58,
i59,
i60,
i61,
i62,
i63,
i64,
i65,
i66,
i67,
i68,
i69,
i70,
i71,
i72,
i73,
i74,
i75,
i76,
i77,
i78,
i79,
i80,
):
# input_ids = input_tuple[0]
# input_tuple = torch.unbind(pkv, dim=0)
token = i0
past_key_values = (
(i1, i2),
(
i3,
i4,
),
(
i5,
i6,
),
(
i7,
i8,
),
(
i9,
i10,
),
(
i11,
i12,
),
(
i13,
i14,
),
(
i15,
i16,
),
(
i17,
i18,
),
(
i19,
i20,
),
(
i21,
i22,
),
(
i23,
i24,
),
(
i25,
i26,
),
(
i27,
i28,
),
(
i29,
i30,
),
(
i31,
i32,
),
(
i33,
i34,
),
(
i35,
i36,
),
(
i37,
i38,
),
(
i39,
i40,
),
(
i41,
i42,
),
(
i43,
i44,
),
(
i45,
i46,
),
(
i47,
i48,
),
(
i49,
i50,
),
(
i51,
i52,
),
(
i53,
i54,
),
(
i55,
i56,
),
(
i57,
i58,
),
(
i59,
i60,
),
(
i61,
i62,
),
(
i63,
i64,
),
(
i65,
i66,
),
(
i67,
i68,
),
(
i69,
i70,
),
(
i71,
i72,
),
(
i73,
i74,
),
(
i75,
i76,
),
(
i77,
i78,
),
(
i79,
i80,
),
)
op = self.model(
input_ids=token, use_cache=True, past_key_values=past_key_values
)
return_vals = []
return_vals.append(op.logits)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
class CombinedModel(torch.nn.Module):
def __init__(
self,
@@ -298,7 +588,8 @@ class CombinedModel(torch.nn.Module):
):
super().__init__()
self.first_vicuna = FirstVicuna(first_vicuna_model_path)
self.second_vicuna = SecondVicuna(second_vicuna_model_path)
# NOT using this path for 13B currently, hence using `SecondVicuna7B`.
self.second_vicuna = SecondVicuna7B(second_vicuna_model_path)
def forward(self, input_ids):
first_output = self.first_vicuna(input_ids=input_ids)

View File

@@ -24,6 +24,7 @@ past_key_values = None
model_map = {
"llama2_7b": "meta-llama/Llama-2-7b-chat-hf",
"llama2_13b": "meta-llama/Llama-2-13b-chat-hf",
"llama2_70b": "meta-llama/Llama-2-70b-chat-hf",
"codegen": "Salesforce/codegen25-7b-multi",
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
@@ -43,6 +44,15 @@ start_message = {
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
),
"llama2_13b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
),
"llama2_70b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
@@ -91,6 +101,7 @@ def create_prompt(model_name, history):
"vicuna4",
"vicuna1p3",
"llama2_7b",
"llama2_13b",
"llama2_70b",
]:
conversation = "".join(
@@ -176,6 +187,7 @@ def chat(
"vicuna1p3",
"codegen",
"llama2_7b",
"llama2_13b",
"llama2_70b",
]:
from apps.language_models.scripts.vicuna import ShardedVicuna