mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
* Change script to 1.3b model and add pytorch comparison * fix CLI command * Match OPT transformers model updates + numerics against latest version * Cleanup OPT sentence completion script. * Fix formatting and add standalone validation scripts. * Add minimal OPT wrapper and example with import_with_fx * Rename OPT full model wrapper. * Cleanup test scripts for OPT.
16 lines
397 B
Python
16 lines
397 B
Python
import torch
|
|
|
|
|
|
class OPTForCausalLMModel(torch.nn.Module):
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
self.model = model
|
|
|
|
def forward(self, input_ids, attention_mask):
|
|
combine_input_dict = {
|
|
"input_ids": input_ids,
|
|
"attention_mask": attention_mask,
|
|
}
|
|
output = self.model(**combine_input_dict)
|
|
return output.logits
|