mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Split CPU/GPU definitions conditionally outside of torch contexts. (#1879)
This commit is contained in:
@@ -44,6 +44,12 @@ from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
SecondVicuna13B,
|
||||
SecondVicuna70B,
|
||||
)
|
||||
from apps.language_models.src.model_wrappers.vicuna_model_gpu import (
|
||||
FirstVicunaGPU,
|
||||
SecondVicuna7BGPU,
|
||||
SecondVicuna13BGPU,
|
||||
SecondVicuna70BGPU,
|
||||
)
|
||||
from apps.language_models.utils import (
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
@@ -442,9 +448,13 @@ class VicunaBase(SharkLLMBase):
|
||||
_logits = output["logits"]
|
||||
_past_key_values = output["past_key_values"]
|
||||
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
|
||||
else:
|
||||
else if "cpu" in self.device:
|
||||
_past_key_values = output[1:]
|
||||
_token = int(output[0].to_host())
|
||||
else:
|
||||
_logits = torch.tensor(output[0].to_host())
|
||||
_past_key_values = output[1:]
|
||||
_token = torch.argmax(_logits[:, -1, :], dim=1)
|
||||
|
||||
_detok = self.tokenizer.decode(_token, skip_special_tokens=False)
|
||||
ret_dict = {
|
||||
@@ -452,6 +462,8 @@ class VicunaBase(SharkLLMBase):
|
||||
"detok": _detok,
|
||||
"past_key_values": _past_key_values,
|
||||
}
|
||||
if "cpu" not in self.device:
|
||||
ret_dict["logits"] = _logits
|
||||
|
||||
if cli:
|
||||
print(f" token : {_token} | detok : {_detok}")
|
||||
@@ -1491,14 +1503,24 @@ class UnshardedVicuna(VicunaBase):
|
||||
compilation_input_ids
|
||||
).reshape([1, 19])
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
model = FirstVicuna(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp32" if self.device=="cpu" else "fp16",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
if "cpu" in self.device:
|
||||
model = FirstVicuna(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp32" if self.device=="cpu" else "fp16",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
else:
|
||||
model = FirstVicunaGPU(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp32" if self.device=="cpu" else "fp16",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
is_f16 = self.precision in ["fp16", "int4"]
|
||||
ts_graph = import_with_fx(
|
||||
@@ -1587,33 +1609,62 @@ class UnshardedVicuna(VicunaBase):
|
||||
for _ in range(total_tuple)
|
||||
)
|
||||
secondVicunaCompileInput = (compilation_input_ids,) + pkv
|
||||
if self.model_name == "llama2_13b":
|
||||
model = SecondVicuna13B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp32" if self.device=="cpu" else "fp16",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
elif self.model_name == "llama2_70b":
|
||||
model = SecondVicuna70B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp32" if self.device=="cpu" else "fp16",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
if "cpu" in self.device:
|
||||
if self.model_name == "llama2_13b":
|
||||
model = SecondVicuna13B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp32",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
elif self.model_name == "llama2_70b":
|
||||
model = SecondVicuna70B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp32",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
else:
|
||||
model = SecondVicuna7B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp32",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
else:
|
||||
model = SecondVicuna7B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp32" if self.device=="cpu" else "fp16",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
if self.model_name == "llama2_13b":
|
||||
model = SecondVicuna13BGPU(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp16",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
elif self.model_name == "llama2_70b":
|
||||
model = SecondVicuna70BGPU(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp16",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
else:
|
||||
model = SecondVicuna7BGPU(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp16",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
is_f16 = self.precision in ["fp16", "int4"]
|
||||
ts_graph = import_with_fx(
|
||||
@@ -1742,6 +1793,8 @@ class UnshardedVicuna(VicunaBase):
|
||||
prefill_time = time.time() - prefill_st_time
|
||||
|
||||
token = generated_token_op["token"]
|
||||
if "cpu" not in self.device:
|
||||
logits = generated_token_op["logits"]
|
||||
pkv = generated_token_op["past_key_values"]
|
||||
detok = generated_token_op["detok"]
|
||||
yield detok, None, prefill_time
|
||||
@@ -1757,6 +1810,8 @@ class UnshardedVicuna(VicunaBase):
|
||||
"past_key_values": pkv,
|
||||
"sv": self.shark_model,
|
||||
}
|
||||
if "cpu" not in self.device:
|
||||
params["logits"] = logits
|
||||
|
||||
decode_st_time = time.time()
|
||||
generated_token_op = self.generate_new_token(
|
||||
@@ -1765,6 +1820,8 @@ class UnshardedVicuna(VicunaBase):
|
||||
decode_time_ms = (time.time() - decode_st_time)*1000
|
||||
|
||||
token = generated_token_op["token"]
|
||||
if "cpu" not in self.device:
|
||||
logits = generated_token_op["logits"]
|
||||
pkv = generated_token_op["past_key_values"]
|
||||
detok = generated_token_op["detok"]
|
||||
|
||||
|
||||
1165
apps/language_models/src/model_wrappers/vicuna_model_gpu.py
Normal file
1165
apps/language_models/src/model_wrappers/vicuna_model_gpu.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user