mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-11 14:58:11 -05:00
Compare commits
1 Commits
llm-rest-a
...
fp16cpu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
489a858af1 |
@@ -1407,8 +1407,8 @@ class UnshardedVicuna(VicunaBase):
|
||||
elif "llama2_70b" in self.model_name:
|
||||
pkv_tensor_shape = "tensor<1x8x?x128x"
|
||||
else:
|
||||
pkv_tensor_shape = "tensor<1x32x?x128x"
|
||||
if self.precision in ["fp16", "int4", "int8"]:
|
||||
pkv_tensor_shape = "tensor<1x?x32x128x"
|
||||
if self.device!="cpu:" : #precision in ["fp16", "int4", "int8"]:
|
||||
pkv_tensor_shape += "f16>"
|
||||
else:
|
||||
pkv_tensor_shape += "f32>"
|
||||
@@ -1416,9 +1416,9 @@ class UnshardedVicuna(VicunaBase):
|
||||
while module:
|
||||
line = module.pop(0)
|
||||
if "%c19_i64 = arith.constant 19 : i64" in line:
|
||||
new_lines.append("%c2 = arith.constant 2 : index")
|
||||
new_lines.append("%c2 = arith.constant 1 : index")
|
||||
new_lines.append(
|
||||
f"%dim_4_int = tensor.dim %arg1, %c2 : {pkv_tensor_shape}"
|
||||
f"%dim_4_int = tensor.dim %arg1, %c1 : {pkv_tensor_shape}"
|
||||
)
|
||||
new_lines.append(
|
||||
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
|
||||
@@ -1480,6 +1480,8 @@ class UnshardedVicuna(VicunaBase):
|
||||
mlir_generated = True
|
||||
break
|
||||
|
||||
print(self.device)
|
||||
print(self.device=="cpu")
|
||||
if not mlir_generated:
|
||||
print(f"[DEBUG] mlir not found")
|
||||
|
||||
@@ -1507,7 +1509,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
model = FirstVicuna(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp32" if self.device=="cpu" else "fp16",
|
||||
"fp32",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
@@ -1516,13 +1518,13 @@ class UnshardedVicuna(VicunaBase):
|
||||
model = FirstVicunaGPU(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp32" if self.device=="cpu" else "fp16",
|
||||
"fp16",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
is_f16 = self.precision in ["fp16", "int4"]
|
||||
is_f16 = self.device!="cpu"
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
firstVicunaCompileInput,
|
||||
@@ -1605,7 +1607,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
dim1 = 32
|
||||
total_tuple = 64
|
||||
pkv = tuple(
|
||||
(torch.zeros([1, dim1, 19, 128], dtype=torch.float32))
|
||||
(torch.zeros([1, 19, dim1, 128], dtype=torch.float32))
|
||||
for _ in range(total_tuple)
|
||||
)
|
||||
secondVicunaCompileInput = (compilation_input_ids,) + pkv
|
||||
@@ -1666,7 +1668,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
self.hf_auth_token,
|
||||
)
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
is_f16 = self.precision in ["fp16", "int4"]
|
||||
is_f16 = self.device!="cpu"
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
secondVicunaCompileInput,
|
||||
@@ -1676,7 +1678,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
del model
|
||||
if self.precision in ["fp16", "int4"]:
|
||||
if self.device != "cpu":
|
||||
secondVicunaCompileInput = get_f16_inputs(
|
||||
secondVicunaCompileInput,
|
||||
True,
|
||||
@@ -1686,7 +1688,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
for i in range(len(secondVicunaCompileInput)):
|
||||
if i != 0:
|
||||
secondVicunaCompileInput[i] = torch_mlir.TensorPlaceholder.like(
|
||||
secondVicunaCompileInput[i], dynamic_axes=[2]
|
||||
secondVicunaCompileInput[i], dynamic_axes=[1]
|
||||
)
|
||||
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
@@ -1739,9 +1741,9 @@ class UnshardedVicuna(VicunaBase):
|
||||
)
|
||||
combined_module = save_mlir(
|
||||
combined_module,
|
||||
model_name="combined_llama",
|
||||
model_name="self.vicuna_mlir_path",
|
||||
mlir_dialect="tm_tensor",
|
||||
dir=self.vicuna_mlir_path,
|
||||
dir=str(os.getcwd()),
|
||||
)
|
||||
del first_module, second_module
|
||||
|
||||
|
||||
@@ -52,8 +52,8 @@ class FirstVicuna(torch.nn.Module):
|
||||
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return_vals.append(item[0].transpose(1,2))
|
||||
return_vals.append(item[1].transpose(1,2))
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
@@ -295,6 +295,9 @@ class SecondVicuna7B(torch.nn.Module):
|
||||
i64,
|
||||
),
|
||||
)
|
||||
|
||||
past_key_values = [(x[0].transpose(1,2), x[0].transpose(1,2)) for x in past_key_values]
|
||||
past_key_values = tuple(past_key_values)
|
||||
op = self.model(
|
||||
input_ids=token, use_cache=True, past_key_values=past_key_values
|
||||
)
|
||||
@@ -303,8 +306,8 @@ class SecondVicuna7B(torch.nn.Module):
|
||||
return_vals.append(token)
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return_vals.append(item[0].transpose(1,2))
|
||||
return_vals.append(item[1].transpose(1,2))
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user