mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
16 lines
397 B
Python
16 lines
397 B
Python
import torch
|
|
|
|
|
|
class OPTForCausalLMModel(torch.nn.Module):
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
self.model = model
|
|
|
|
def forward(self, input_ids, attention_mask):
|
|
combine_input_dict = {
|
|
"input_ids": input_ids,
|
|
"attention_mask": attention_mask,
|
|
}
|
|
output = self.model(**combine_input_dict)
|
|
return output.logits
|