mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
OPT Refactor (#1516)
* Change script to 1.3b model and add pytorch comparison * fix CLI command * Match OPT transformers model updates + numerics against latest version * Cleanup OPT sentence completion script. * Fix formatting and add standalone validation scripts. * Add minimal OPT wrapper and example with import_with_fx * Rename OPT full model wrapper. * Cleanup test scripts for OPT.
This commit is contained in:
@@ -1,30 +1,27 @@
|
||||
import unittest
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import torch_mlir
|
||||
import torch
|
||||
import numpy as np
|
||||
from shark_hf_opt import OPTForCausalLM
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark_opt_wrapper import OPTForCausalLMModel
|
||||
from shark.iree_utils._common import (
|
||||
check_device_drivers,
|
||||
device_driver_info,
|
||||
)
|
||||
from shark.shark_inference import SharkInference
|
||||
from tank.model_utils import compare_tensors
|
||||
from transformers import AutoTokenizer
|
||||
from shark.shark_importer import import_with_fx
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
|
||||
OPT_MODEL = "opt-350m"
|
||||
OPT_MODEL_66B = "facebook/opt-66b"
|
||||
MAX_SEQUENCE_LENGTH = 256
|
||||
MAX_NEW_TOKENS = 200
|
||||
OPT_MODEL = "opt-1.3b"
|
||||
OPT_FS_NAME = "opt-1_3b"
|
||||
MAX_SEQUENCE_LENGTH = 30
|
||||
MAX_NEW_TOKENS = 20
|
||||
|
||||
|
||||
def create_module(model_name, tokenizer, device):
|
||||
opt_model = OPTForCausalLM.from_pretrained(
|
||||
"facebook/" + model_name, return_dict=False
|
||||
)
|
||||
opt_model.eval()
|
||||
|
||||
opt_base_model = OPTForCausalLM.from_pretrained("facebook/" + model_name)
|
||||
opt_base_model.eval()
|
||||
opt_model = OPTForCausalLMModel(opt_base_model)
|
||||
encoded_inputs = tokenizer(
|
||||
"This is a sample input for generating the model.",
|
||||
"What is the meaning of life?",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
@@ -34,54 +31,37 @@ def create_module(model_name, tokenizer, device):
|
||||
encoded_inputs["input_ids"],
|
||||
encoded_inputs["attention_mask"],
|
||||
)
|
||||
mlir_path = f"./{OPT_MODEL}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
|
||||
# np.save("model_inputs_0.npy", inputs[0])
|
||||
# np.save("model_inputs_1.npy", inputs[1])
|
||||
|
||||
mlir_path = f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
|
||||
if os.path.isfile(mlir_path):
|
||||
with open(mlir_path, "r") as f:
|
||||
model_mlir = f.read()
|
||||
print(f"Loaded .mlir from {mlir_path}")
|
||||
else:
|
||||
module = torch_mlir.compile(
|
||||
opt_model,
|
||||
inputs,
|
||||
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=True,
|
||||
(model_mlir, func_name) = import_with_fx(
|
||||
model=opt_model,
|
||||
inputs=inputs,
|
||||
is_f16=False,
|
||||
model_name=OPT_FS_NAME,
|
||||
return_str=True,
|
||||
)
|
||||
|
||||
model_mlir = module.operation.get_asm(
|
||||
large_elements_limit=None, enable_debug_info=True
|
||||
)
|
||||
|
||||
with open(mlir_path, "w") as f:
|
||||
f.write(model_mlir)
|
||||
print(f"Saved mlir at {mlir_path}")
|
||||
|
||||
func_name = "forward"
|
||||
act_out = opt_model(inputs[0], attention_mask=inputs[1], return_dict=False)
|
||||
|
||||
shark_module = SharkInference(
|
||||
model_mlir,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
is_benchmark=False,
|
||||
)
|
||||
vmfb_name = f"{OPT_MODEL}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
|
||||
shark_module.save_module(module_name=vmfb_name)
|
||||
shark_module.load_module(vmfb_name + ".vmfb")
|
||||
|
||||
results = shark_module("forward", inputs)
|
||||
print(
|
||||
"SHARK logits have shape: ",
|
||||
str(results[0].shape) + " : " + str(results[0]),
|
||||
)
|
||||
print(
|
||||
"PyTorch logits have shape: "
|
||||
+ str(act_out[0].shape)
|
||||
+ " : "
|
||||
+ str(act_out[0])
|
||||
)
|
||||
# exp_out = tokenizer.decode(act_out[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
# shark_out = tokenizer.decode(results[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
return shark_module
|
||||
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
|
||||
shark_module.save_module(module_name=vmfb_name)
|
||||
vmfb_path = vmfb_name + ".vmfb"
|
||||
return vmfb_path
|
||||
|
||||
|
||||
def shouldStop(tokens):
|
||||
@@ -129,15 +109,19 @@ if __name__ == "__main__":
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"facebook/" + OPT_MODEL, use_fast=False
|
||||
)
|
||||
vmfb_path = f"./{OPT_MODEL}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu.vmfb"
|
||||
vmfb_path = (
|
||||
f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu-sync.vmfb"
|
||||
)
|
||||
opt_shark_module = SharkInference(mlir_module=None, device="cpu-sync")
|
||||
if os.path.isfile(vmfb_path):
|
||||
opt_shark_module = SharkInference(mlir_module=None, device="cpu")
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
opt_shark_module = create_module(OPT_MODEL, tokenizer, "cpu")
|
||||
vmfb_path = create_module(OPT_MODEL, tokenizer, "cpu-sync")
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
while True:
|
||||
try:
|
||||
new_text = input("Give me a sentence to complete:")
|
||||
new_text_init = new_text
|
||||
words_list = []
|
||||
|
||||
for i in range(MAX_NEW_TOKENS):
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
import unittest
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import torch_mlir
|
||||
import torch
|
||||
import numpy as np
|
||||
from shark_hf_opt import OPTForCausalLM
|
||||
from shark_opt_wrapper import OPTForCausalLMModel
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from tank.model_utils import compare_tensors
|
||||
from transformers import AutoTokenizer
|
||||
from shark.shark_importer import import_with_fx
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
|
||||
OPT_MODEL = "facebook/opt-1.3B"
|
||||
OPT_MODEL = "facebook/opt-1.3b"
|
||||
OPT_FS_NAME = "opt-1_3b"
|
||||
OPT_MODEL_66B = "facebook/opt-66b"
|
||||
|
||||
|
||||
@@ -24,60 +23,50 @@ class OPTModuleTester:
|
||||
|
||||
def create_and_check_module(self, dynamic, device, model_name):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
||||
# config = OPTConfig()
|
||||
# opt_model = OPTModel(config)
|
||||
opt_model = OPTForCausalLM.from_pretrained(
|
||||
model_name, return_dict=False
|
||||
)
|
||||
opt_model.eval()
|
||||
|
||||
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
input_ids, attention_mask = (
|
||||
inputs.data["input_ids"],
|
||||
inputs.data["attention_mask"],
|
||||
model_inputs = tokenizer(
|
||||
"The meaning of life is",
|
||||
padding="max_length",
|
||||
max_length=30,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
np.save("opt_inputs.npy", input_ids.detach())
|
||||
mlir_path = "./OPT1-3b_causallm_torch.mlir"
|
||||
if os.path.isfile(mlir_path):
|
||||
with open(mlir_path, "r") as f:
|
||||
model_mlir = f.read()
|
||||
print(f"Loaded .mlir from {mlir_path}")
|
||||
else:
|
||||
module = torch_mlir.compile(
|
||||
opt_model,
|
||||
input_ids,
|
||||
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=True,
|
||||
)
|
||||
|
||||
model_mlir = module.operation.get_asm(
|
||||
large_elements_limit=None, enable_debug_info=True
|
||||
)
|
||||
|
||||
with open(mlir_path, "w") as f:
|
||||
f.write(model_mlir)
|
||||
print(f"Saved mlir at {mlir_path}")
|
||||
|
||||
func_name = "forward"
|
||||
act_out = opt_model(input_ids, return_dict=False)
|
||||
|
||||
# mlir_importer = SharkImporter(
|
||||
# model,
|
||||
# (input,),
|
||||
# frontend="torch",
|
||||
# )
|
||||
# minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
# is_dynamic=dynamic, tracing_required=True
|
||||
# )
|
||||
inputs = (
|
||||
model_inputs.data["input_ids"],
|
||||
model_inputs.data["attention_mask"],
|
||||
)
|
||||
act_out = opt_model(
|
||||
inputs[0], attention_mask=inputs[1], return_dict=False
|
||||
)[0]
|
||||
(
|
||||
mlir_module,
|
||||
func_name,
|
||||
) = import_with_fx(
|
||||
model=opt_model,
|
||||
inputs=inputs,
|
||||
is_f16=False,
|
||||
model_name=OPT_FS_NAME,
|
||||
)
|
||||
del opt_model
|
||||
opt_filename = f"./{OPT_FS_NAME}_causallm_30_torch_{device}"
|
||||
mlir_path = os.path.join(opt_filename, ".mlir")
|
||||
with open(mlir_path, "w") as f:
|
||||
f.write(mlir_module)
|
||||
print(f"Saved mlir at {mlir_path}")
|
||||
|
||||
shark_module = SharkInference(
|
||||
model_mlir,
|
||||
mlir_module,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
is_benchmark=self.benchmark,
|
||||
)
|
||||
|
||||
shark_module.compile()
|
||||
results = shark_module("forward", (input_ids,))
|
||||
results = shark_module("forward", inputs)
|
||||
print(
|
||||
"SHARK logits have shape: ",
|
||||
str(results[0].shape) + " : " + str(results[0]),
|
||||
@@ -90,11 +79,11 @@ class OPTModuleTester:
|
||||
)
|
||||
# exp_out = tokenizer.decode(act_out[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
# shark_out = tokenizer.decode(results[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
assert compare_tensors(act_out[0].detach(), results[0])
|
||||
np.testing.assert_allclose(act_out[0].detach(), results[0])
|
||||
|
||||
if self.benchmark:
|
||||
shark_module.shark_runner.benchmark_all_csv(
|
||||
(input_ids, attention_mask),
|
||||
inputs,
|
||||
"opt",
|
||||
dynamic,
|
||||
device,
|
||||
|
||||
47
tank/examples/opt/shark_hf_base_opt.py
Normal file
47
tank/examples/opt/shark_hf_base_opt.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import os
|
||||
import torch
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark_opt_wrapper import OPTForCausalLMModel
|
||||
|
||||
model_name = "facebook/opt-1.3b"
|
||||
base_model = OPTForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
||||
|
||||
model = OPTForCausalLMModel(base_model)
|
||||
|
||||
prompt = "What is the meaning of life?"
|
||||
model_inputs = tokenizer(prompt, return_tensors="pt")
|
||||
inputs = (
|
||||
model_inputs["input_ids"],
|
||||
model_inputs["attention_mask"],
|
||||
)
|
||||
|
||||
(
|
||||
mlir_module,
|
||||
func_name,
|
||||
) = import_with_fx(
|
||||
model=model,
|
||||
inputs=inputs,
|
||||
is_f16=False,
|
||||
debug=True,
|
||||
model_name=model_name.split("/")[1],
|
||||
save_dir=".",
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device="cpu-sync",
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
shark_module.compile()
|
||||
# Generated logits.
|
||||
logits = shark_module("forward", inputs=inputs)
|
||||
print("SHARK module returns logits:")
|
||||
print(logits[0])
|
||||
|
||||
hf_logits = base_model.forward(inputs[0], inputs[1], return_dict=False)[0]
|
||||
|
||||
print("PyTorch baseline returns logits:")
|
||||
print(hf_logits)
|
||||
15
tank/examples/opt/shark_opt_wrapper.py
Normal file
15
tank/examples/opt/shark_opt_wrapper.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch
|
||||
|
||||
|
||||
class OPTForCausalLMModel(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
combine_input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
output = self.model(**combine_input_dict)
|
||||
return output.logits
|
||||
@@ -313,6 +313,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
num_heads=config.num_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=True,
|
||||
bias=config.enable_bias,
|
||||
)
|
||||
self.do_layer_norm_before = config.do_layer_norm_before
|
||||
self.dropout = config.dropout
|
||||
@@ -320,10 +321,16 @@ class OPTDecoderLayer(nn.Module):
|
||||
|
||||
self.activation_dropout = config.activation_dropout
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim,
|
||||
elementwise_affine=config.layer_norm_elementwise_affine,
|
||||
)
|
||||
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim)
|
||||
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim)
|
||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim,
|
||||
elementwise_affine=config.layer_norm_elementwise_affine,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -449,7 +456,14 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
else:
|
||||
self.project_in = None
|
||||
|
||||
self.layer_norm = None
|
||||
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
config.hidden_size,
|
||||
elementwise_affine=config.layer_norm_elementwise_affine,
|
||||
)
|
||||
else:
|
||||
self.final_layer_norm = None
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]
|
||||
)
|
||||
@@ -646,6 +660,9 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if self.final_layer_norm is not None:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
if self.project_out is not None:
|
||||
hidden_states = self.project_out(hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user