Split CPU/GPU definitions conditionally outside of torch contexts. (#1879)

This commit is contained in:
Ean Garvey
2023-10-09 18:46:41 -05:00
committed by GitHub
parent 3b825579a7
commit 66f6e79d68
2 changed files with 1257 additions and 35 deletions

View File

@@ -44,6 +44,12 @@ from apps.language_models.src.model_wrappers.vicuna_model import (
SecondVicuna13B,
SecondVicuna70B,
)
from apps.language_models.src.model_wrappers.vicuna_model_gpu import (
FirstVicunaGPU,
SecondVicuna7BGPU,
SecondVicuna13BGPU,
SecondVicuna70BGPU,
)
from apps.language_models.utils import (
get_vmfb_from_path,
)
@@ -442,9 +448,13 @@ class VicunaBase(SharkLLMBase):
_logits = output["logits"]
_past_key_values = output["past_key_values"]
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
else:
else if "cpu" in self.device:
_past_key_values = output[1:]
_token = int(output[0].to_host())
else:
_logits = torch.tensor(output[0].to_host())
_past_key_values = output[1:]
_token = torch.argmax(_logits[:, -1, :], dim=1)
_detok = self.tokenizer.decode(_token, skip_special_tokens=False)
ret_dict = {
@@ -452,6 +462,8 @@ class VicunaBase(SharkLLMBase):
"detok": _detok,
"past_key_values": _past_key_values,
}
if "cpu" not in self.device:
ret_dict["logits"] = _logits
if cli:
print(f" token : {_token} | detok : {_detok}")
@@ -1491,14 +1503,24 @@ class UnshardedVicuna(VicunaBase):
compilation_input_ids
).reshape([1, 19])
firstVicunaCompileInput = (compilation_input_ids,)
model = FirstVicuna(
self.hf_model_path,
self.precision,
"fp32" if self.device=="cpu" else "fp16",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
if "cpu" in self.device:
model = FirstVicuna(
self.hf_model_path,
self.precision,
"fp32" if self.device=="cpu" else "fp16",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
else:
model = FirstVicunaGPU(
self.hf_model_path,
self.precision,
"fp32" if self.device=="cpu" else "fp16",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.precision in ["fp16", "int4"]
ts_graph = import_with_fx(
@@ -1587,33 +1609,62 @@ class UnshardedVicuna(VicunaBase):
for _ in range(total_tuple)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
if self.model_name == "llama2_13b":
model = SecondVicuna13B(
self.hf_model_path,
self.precision,
"fp32" if self.device=="cpu" else "fp16",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
elif self.model_name == "llama2_70b":
model = SecondVicuna70B(
self.hf_model_path,
self.precision,
"fp32" if self.device=="cpu" else "fp16",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
if "cpu" in self.device:
if self.model_name == "llama2_13b":
model = SecondVicuna13B(
self.hf_model_path,
self.precision,
"fp32",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
elif self.model_name == "llama2_70b":
model = SecondVicuna70B(
self.hf_model_path,
self.precision,
"fp32",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
else:
model = SecondVicuna7B(
self.hf_model_path,
self.precision,
"fp32",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
else:
model = SecondVicuna7B(
self.hf_model_path,
self.precision,
"fp32" if self.device=="cpu" else "fp16",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
if self.model_name == "llama2_13b":
model = SecondVicuna13BGPU(
self.hf_model_path,
self.precision,
"fp16",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
elif self.model_name == "llama2_70b":
model = SecondVicuna70BGPU(
self.hf_model_path,
self.precision,
"fp16",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
else:
model = SecondVicuna7BGPU(
self.hf_model_path,
self.precision,
"fp16",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.precision in ["fp16", "int4"]
ts_graph = import_with_fx(
@@ -1742,6 +1793,8 @@ class UnshardedVicuna(VicunaBase):
prefill_time = time.time() - prefill_st_time
token = generated_token_op["token"]
if "cpu" not in self.device:
logits = generated_token_op["logits"]
pkv = generated_token_op["past_key_values"]
detok = generated_token_op["detok"]
yield detok, None, prefill_time
@@ -1757,6 +1810,8 @@ class UnshardedVicuna(VicunaBase):
"past_key_values": pkv,
"sv": self.shark_model,
}
if "cpu" not in self.device:
params["logits"] = logits
decode_st_time = time.time()
generated_token_op = self.generate_new_token(
@@ -1765,6 +1820,8 @@ class UnshardedVicuna(VicunaBase):
decode_time_ms = (time.time() - decode_st_time)*1000
token = generated_token_op["token"]
if "cpu" not in self.device:
logits = generated_token_op["logits"]
pkv = generated_token_op["past_key_values"]
detok = generated_token_op["detok"]

File diff suppressed because it is too large Load Diff