Add a model config generator (#1511)

Model config generator takes a PyTorch model as input and generates a JSON file with model layers and other propperties that define sharding on a particular hardware.
This commit is contained in:
Nithin Meganathan
2023-06-09 15:32:00 -07:00
committed by GitHub
parent 1980d7b2c3
commit 34f1295349
2 changed files with 60 additions and 0 deletions

View File

@@ -237,3 +237,25 @@ class SecondVicuna(torch.nn.Module):
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
class CombinedModel(torch.nn.Module):
def __init__(
self,
first_vicuna_model_path="TheBloke/vicuna-7B-1.1-HF",
second_vicuna_model_path="TheBloke/vicuna-7B-1.1-HF",
):
super().__init__()
self.first_vicuna = FirstVicuna(first_vicuna_model_path)
self.second_vicuna = SecondVicuna(second_vicuna_model_path)
def forward(self, input_ids):
first_output = self.first_vicuna(input_ids=input_ids, use_cache=True)
logits = first_output[0]
pkv = first_output[1:]
token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1)
token = token.to(torch.int64).reshape([1, 1])
secondVicunaInput = (token,) + tuple(pkv)
second_output = self.second_vicuna(secondVicunaInput)
return second_output