Compare commits

...

1 Commits

Author SHA1 Message Date
dan
489a858af1 enforce fp32 accumulates for cpu 2023-10-29 18:59:00 +00:00
2 changed files with 22 additions and 17 deletions

View File

@@ -1407,8 +1407,8 @@ class UnshardedVicuna(VicunaBase):
elif "llama2_70b" in self.model_name:
pkv_tensor_shape = "tensor<1x8x?x128x"
else:
pkv_tensor_shape = "tensor<1x32x?x128x"
if self.precision in ["fp16", "int4", "int8"]:
pkv_tensor_shape = "tensor<1x?x32x128x"
if self.device!="cpu:" : #precision in ["fp16", "int4", "int8"]:
pkv_tensor_shape += "f16>"
else:
pkv_tensor_shape += "f32>"
@@ -1416,9 +1416,9 @@ class UnshardedVicuna(VicunaBase):
while module:
line = module.pop(0)
if "%c19_i64 = arith.constant 19 : i64" in line:
new_lines.append("%c2 = arith.constant 2 : index")
new_lines.append("%c2 = arith.constant 1 : index")
new_lines.append(
f"%dim_4_int = tensor.dim %arg1, %c2 : {pkv_tensor_shape}"
f"%dim_4_int = tensor.dim %arg1, %c1 : {pkv_tensor_shape}"
)
new_lines.append(
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
@@ -1480,6 +1480,8 @@ class UnshardedVicuna(VicunaBase):
mlir_generated = True
break
print(self.device)
print(self.device=="cpu")
if not mlir_generated:
print(f"[DEBUG] mlir not found")
@@ -1507,7 +1509,7 @@ class UnshardedVicuna(VicunaBase):
model = FirstVicuna(
self.hf_model_path,
self.precision,
"fp32" if self.device=="cpu" else "fp16",
"fp32",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
@@ -1516,13 +1518,13 @@ class UnshardedVicuna(VicunaBase):
model = FirstVicunaGPU(
self.hf_model_path,
self.precision,
"fp32" if self.device=="cpu" else "fp16",
"fp16",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.precision in ["fp16", "int4"]
is_f16 = self.device!="cpu"
ts_graph = import_with_fx(
model,
firstVicunaCompileInput,
@@ -1605,7 +1607,7 @@ class UnshardedVicuna(VicunaBase):
dim1 = 32
total_tuple = 64
pkv = tuple(
(torch.zeros([1, dim1, 19, 128], dtype=torch.float32))
(torch.zeros([1, 19, dim1, 128], dtype=torch.float32))
for _ in range(total_tuple)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
@@ -1666,7 +1668,7 @@ class UnshardedVicuna(VicunaBase):
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.precision in ["fp16", "int4"]
is_f16 = self.device!="cpu"
ts_graph = import_with_fx(
model,
secondVicunaCompileInput,
@@ -1676,7 +1678,7 @@ class UnshardedVicuna(VicunaBase):
mlir_type="torchscript",
)
del model
if self.precision in ["fp16", "int4"]:
if self.device != "cpu":
secondVicunaCompileInput = get_f16_inputs(
secondVicunaCompileInput,
True,
@@ -1686,7 +1688,7 @@ class UnshardedVicuna(VicunaBase):
for i in range(len(secondVicunaCompileInput)):
if i != 0:
secondVicunaCompileInput[i] = torch_mlir.TensorPlaceholder.like(
secondVicunaCompileInput[i], dynamic_axes=[2]
secondVicunaCompileInput[i], dynamic_axes=[1]
)
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
print(f"[DEBUG] generating torch mlir")
@@ -1739,9 +1741,9 @@ class UnshardedVicuna(VicunaBase):
)
combined_module = save_mlir(
combined_module,
model_name="combined_llama",
model_name="self.vicuna_mlir_path",
mlir_dialect="tm_tensor",
dir=self.vicuna_mlir_path,
dir=str(os.getcwd()),
)
del first_module, second_module

View File

@@ -52,8 +52,8 @@ class FirstVicuna(torch.nn.Module):
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return_vals.append(item[0].transpose(1,2))
return_vals.append(item[1].transpose(1,2))
return tuple(return_vals)
@@ -295,6 +295,9 @@ class SecondVicuna7B(torch.nn.Module):
i64,
),
)
past_key_values = [(x[0].transpose(1,2), x[0].transpose(1,2)) for x in past_key_values]
past_key_values = tuple(past_key_values)
op = self.model(
input_ids=token, use_cache=True, past_key_values=past_key_values
)
@@ -303,8 +306,8 @@ class SecondVicuna7B(torch.nn.Module):
return_vals.append(token)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return_vals.append(item[0].transpose(1,2))
return_vals.append(item[1].transpose(1,2))
return tuple(return_vals)