OPT Refactor (#1516)

* Change script to 1.3b model and add pytorch comparison

* fix CLI command

* Match OPT transformers model updates + numerics against latest version

* Cleanup OPT sentence completion script.

* Fix formatting and add standalone validation scripts.

* Add minimal OPT wrapper and example with import_with_fx

* Rename OPT full model wrapper.

* Cleanup test scripts for OPT.
This commit is contained in:
Ean Garvey
2023-06-13 22:40:07 -05:00
committed by GitHub
parent 5562d1dfda
commit f53e3594c3
5 changed files with 156 additions and 104 deletions

View File

@@ -1,30 +1,27 @@
import unittest
import os
import pytest
import torch_mlir
import torch
import numpy as np
from shark_hf_opt import OPTForCausalLM
from shark.iree_utils._common import check_device_drivers, device_driver_info
from shark_opt_wrapper import OPTForCausalLMModel
from shark.iree_utils._common import (
check_device_drivers,
device_driver_info,
)
from shark.shark_inference import SharkInference
from tank.model_utils import compare_tensors
from transformers import AutoTokenizer
from shark.shark_importer import import_with_fx
from transformers import AutoTokenizer, OPTForCausalLM
OPT_MODEL = "opt-350m"
OPT_MODEL_66B = "facebook/opt-66b"
MAX_SEQUENCE_LENGTH = 256
MAX_NEW_TOKENS = 200
OPT_MODEL = "opt-1.3b"
OPT_FS_NAME = "opt-1_3b"
MAX_SEQUENCE_LENGTH = 30
MAX_NEW_TOKENS = 20
def create_module(model_name, tokenizer, device):
opt_model = OPTForCausalLM.from_pretrained(
"facebook/" + model_name, return_dict=False
)
opt_model.eval()
opt_base_model = OPTForCausalLM.from_pretrained("facebook/" + model_name)
opt_base_model.eval()
opt_model = OPTForCausalLMModel(opt_base_model)
encoded_inputs = tokenizer(
"This is a sample input for generating the model.",
"What is the meaning of life?",
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
@@ -34,54 +31,37 @@ def create_module(model_name, tokenizer, device):
encoded_inputs["input_ids"],
encoded_inputs["attention_mask"],
)
mlir_path = f"./{OPT_MODEL}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
# np.save("model_inputs_0.npy", inputs[0])
# np.save("model_inputs_1.npy", inputs[1])
mlir_path = f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
if os.path.isfile(mlir_path):
with open(mlir_path, "r") as f:
model_mlir = f.read()
print(f"Loaded .mlir from {mlir_path}")
else:
module = torch_mlir.compile(
opt_model,
inputs,
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=True,
(model_mlir, func_name) = import_with_fx(
model=opt_model,
inputs=inputs,
is_f16=False,
model_name=OPT_FS_NAME,
return_str=True,
)
model_mlir = module.operation.get_asm(
large_elements_limit=None, enable_debug_info=True
)
with open(mlir_path, "w") as f:
f.write(model_mlir)
print(f"Saved mlir at {mlir_path}")
func_name = "forward"
act_out = opt_model(inputs[0], attention_mask=inputs[1], return_dict=False)
shark_module = SharkInference(
model_mlir,
device=device,
mlir_dialect="tm_tensor",
is_benchmark=False,
)
vmfb_name = f"{OPT_MODEL}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
shark_module.save_module(module_name=vmfb_name)
shark_module.load_module(vmfb_name + ".vmfb")
results = shark_module("forward", inputs)
print(
"SHARK logits have shape: ",
str(results[0].shape) + " : " + str(results[0]),
)
print(
"PyTorch logits have shape: "
+ str(act_out[0].shape)
+ " : "
+ str(act_out[0])
)
# exp_out = tokenizer.decode(act_out[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
# shark_out = tokenizer.decode(results[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
return shark_module
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
shark_module.save_module(module_name=vmfb_name)
vmfb_path = vmfb_name + ".vmfb"
return vmfb_path
def shouldStop(tokens):
@@ -129,15 +109,19 @@ if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(
"facebook/" + OPT_MODEL, use_fast=False
)
vmfb_path = f"./{OPT_MODEL}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu.vmfb"
vmfb_path = (
f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu-sync.vmfb"
)
opt_shark_module = SharkInference(mlir_module=None, device="cpu-sync")
if os.path.isfile(vmfb_path):
opt_shark_module = SharkInference(mlir_module=None, device="cpu")
opt_shark_module.load_module(vmfb_path)
else:
opt_shark_module = create_module(OPT_MODEL, tokenizer, "cpu")
vmfb_path = create_module(OPT_MODEL, tokenizer, "cpu-sync")
opt_shark_module.load_module(vmfb_path)
while True:
try:
new_text = input("Give me a sentence to complete:")
new_text_init = new_text
words_list = []
for i in range(MAX_NEW_TOKENS):

View File

@@ -1,17 +1,16 @@
import unittest
import os
import pytest
import torch_mlir
import torch
import numpy as np
from shark_hf_opt import OPTForCausalLM
from shark_opt_wrapper import OPTForCausalLMModel
from shark.iree_utils._common import check_device_drivers, device_driver_info
from shark.shark_inference import SharkInference
from tank.model_utils import compare_tensors
from transformers import AutoTokenizer
from shark.shark_importer import import_with_fx
from transformers import AutoTokenizer, OPTForCausalLM
OPT_MODEL = "facebook/opt-1.3B"
OPT_MODEL = "facebook/opt-1.3b"
OPT_FS_NAME = "opt-1_3b"
OPT_MODEL_66B = "facebook/opt-66b"
@@ -24,60 +23,50 @@ class OPTModuleTester:
def create_and_check_module(self, dynamic, device, model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
# config = OPTConfig()
# opt_model = OPTModel(config)
opt_model = OPTForCausalLM.from_pretrained(
model_name, return_dict=False
)
opt_model.eval()
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
input_ids, attention_mask = (
inputs.data["input_ids"],
inputs.data["attention_mask"],
model_inputs = tokenizer(
"The meaning of life is",
padding="max_length",
max_length=30,
truncation=True,
return_tensors="pt",
)
np.save("opt_inputs.npy", input_ids.detach())
mlir_path = "./OPT1-3b_causallm_torch.mlir"
if os.path.isfile(mlir_path):
with open(mlir_path, "r") as f:
model_mlir = f.read()
print(f"Loaded .mlir from {mlir_path}")
else:
module = torch_mlir.compile(
opt_model,
input_ids,
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=True,
)
model_mlir = module.operation.get_asm(
large_elements_limit=None, enable_debug_info=True
)
with open(mlir_path, "w") as f:
f.write(model_mlir)
print(f"Saved mlir at {mlir_path}")
func_name = "forward"
act_out = opt_model(input_ids, return_dict=False)
# mlir_importer = SharkImporter(
# model,
# (input,),
# frontend="torch",
# )
# minilm_mlir, func_name = mlir_importer.import_mlir(
# is_dynamic=dynamic, tracing_required=True
# )
inputs = (
model_inputs.data["input_ids"],
model_inputs.data["attention_mask"],
)
act_out = opt_model(
inputs[0], attention_mask=inputs[1], return_dict=False
)[0]
(
mlir_module,
func_name,
) = import_with_fx(
model=opt_model,
inputs=inputs,
is_f16=False,
model_name=OPT_FS_NAME,
)
del opt_model
opt_filename = f"./{OPT_FS_NAME}_causallm_30_torch_{device}"
mlir_path = os.path.join(opt_filename, ".mlir")
with open(mlir_path, "w") as f:
f.write(mlir_module)
print(f"Saved mlir at {mlir_path}")
shark_module = SharkInference(
model_mlir,
mlir_module,
device=device,
mlir_dialect="tm_tensor",
is_benchmark=self.benchmark,
)
shark_module.compile()
results = shark_module("forward", (input_ids,))
results = shark_module("forward", inputs)
print(
"SHARK logits have shape: ",
str(results[0].shape) + " : " + str(results[0]),
@@ -90,11 +79,11 @@ class OPTModuleTester:
)
# exp_out = tokenizer.decode(act_out[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
# shark_out = tokenizer.decode(results[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
assert compare_tensors(act_out[0].detach(), results[0])
np.testing.assert_allclose(act_out[0].detach(), results[0])
if self.benchmark:
shark_module.shark_runner.benchmark_all_csv(
(input_ids, attention_mask),
inputs,
"opt",
dynamic,
device,

View File

@@ -0,0 +1,47 @@
import os
import torch
from transformers import AutoTokenizer, OPTForCausalLM
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark_opt_wrapper import OPTForCausalLMModel
model_name = "facebook/opt-1.3b"
base_model = OPTForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = OPTForCausalLMModel(base_model)
prompt = "What is the meaning of life?"
model_inputs = tokenizer(prompt, return_tensors="pt")
inputs = (
model_inputs["input_ids"],
model_inputs["attention_mask"],
)
(
mlir_module,
func_name,
) = import_with_fx(
model=model,
inputs=inputs,
is_f16=False,
debug=True,
model_name=model_name.split("/")[1],
save_dir=".",
)
shark_module = SharkInference(
mlir_module,
device="cpu-sync",
mlir_dialect="tm_tensor",
)
shark_module.compile()
# Generated logits.
logits = shark_module("forward", inputs=inputs)
print("SHARK module returns logits:")
print(logits[0])
hf_logits = base_model.forward(inputs[0], inputs[1], return_dict=False)[0]
print("PyTorch baseline returns logits:")
print(hf_logits)

View File

@@ -0,0 +1,15 @@
import torch
class OPTForCausalLMModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
combine_input_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
output = self.model(**combine_input_dict)
return output.logits

View File

@@ -313,6 +313,7 @@ class OPTDecoderLayer(nn.Module):
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
bias=config.enable_bias,
)
self.do_layer_norm_before = config.do_layer_norm_before
self.dropout = config.dropout
@@ -320,10 +321,16 @@ class OPTDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine,
)
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim)
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
self.final_layer_norm = nn.LayerNorm(
self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine,
)
def forward(
self,
@@ -449,7 +456,14 @@ class OPTDecoder(OPTPreTrainedModel):
else:
self.project_in = None
self.layer_norm = None
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm(
config.hidden_size,
elementwise_affine=config.layer_norm_elementwise_affine,
)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList(
[OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]
)
@@ -646,6 +660,9 @@ class OPTDecoder(OPTPreTrainedModel):
if output_attentions:
all_self_attns += (layer_outputs[1],)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)