fix device_idx for non-layer vmfbs

This commit is contained in:
PhaneeshB
2023-12-06 01:03:11 +05:30
parent e5ed167f03
commit 93f583f0be

View File

@@ -1384,6 +1384,8 @@ class ShardedVicuna(VicunaBase):
device_idx = self.get_device_index(
r"vicuna\.model\.model\.norm(?:\.|\s|$)"
)
# HC device_idx for non-layer vmfbs
device_idx = 0
norm = self.compile_norm(
norm,
torch.zeros([1, SAMPLE_INPUT_LEN, self.hidden_state_size_dict[self.model_name]]),
@@ -1395,6 +1397,8 @@ class ShardedVicuna(VicunaBase):
device_idx = self.get_device_index(
r"vicuna\.model\.model\.embed_tokens(?:\.|\s|$)"
)
# HC device_idx for non-layer vmfbs
device_idx = 0
embeddings = self.compile_embedding(
embeddings,
(torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64)),
@@ -1406,6 +1410,8 @@ class ShardedVicuna(VicunaBase):
device_idx = self.get_device_index(
r"vicuna\.model\.lm_head(?:\.|\s|$)"
)
# HC device_idx for non-layer vmfbs
device_idx = 0
lmhead = self.compile_lmhead(
lmhead,
torch.zeros([1, SAMPLE_INPUT_LEN, self.hidden_state_size_dict[self.model_name]]),