mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
fix device_idx for non-layer vmfbs
This commit is contained in:
@@ -1384,6 +1384,8 @@ class ShardedVicuna(VicunaBase):
|
||||
device_idx = self.get_device_index(
|
||||
r"vicuna\.model\.model\.norm(?:\.|\s|$)"
|
||||
)
|
||||
# HC device_idx for non-layer vmfbs
|
||||
device_idx = 0
|
||||
norm = self.compile_norm(
|
||||
norm,
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, self.hidden_state_size_dict[self.model_name]]),
|
||||
@@ -1395,6 +1397,8 @@ class ShardedVicuna(VicunaBase):
|
||||
device_idx = self.get_device_index(
|
||||
r"vicuna\.model\.model\.embed_tokens(?:\.|\s|$)"
|
||||
)
|
||||
# HC device_idx for non-layer vmfbs
|
||||
device_idx = 0
|
||||
embeddings = self.compile_embedding(
|
||||
embeddings,
|
||||
(torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64)),
|
||||
@@ -1406,6 +1410,8 @@ class ShardedVicuna(VicunaBase):
|
||||
device_idx = self.get_device_index(
|
||||
r"vicuna\.model\.lm_head(?:\.|\s|$)"
|
||||
)
|
||||
# HC device_idx for non-layer vmfbs
|
||||
device_idx = 0
|
||||
lmhead = self.compile_lmhead(
|
||||
lmhead,
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, self.hidden_state_size_dict[self.model_name]]),
|
||||
|
||||
Reference in New Issue
Block a user