mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 22:07:55 -05:00
382 lines
11 KiB
Python
382 lines
11 KiB
Python
import sys
|
|
import os
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BloomConfig
|
|
import re
|
|
from shark.shark_inference import SharkInference
|
|
import torch
|
|
import torch.nn as nn
|
|
from collections import OrderedDict
|
|
from transformers.models.bloom.modeling_bloom import (
|
|
BloomBlock,
|
|
build_alibi_tensor,
|
|
)
|
|
import time
|
|
import json
|
|
|
|
|
|
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None):
|
|
"""
|
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
|
"""
|
|
batch_size, source_length = mask.size()
|
|
tgt_len = tgt_len if tgt_len is not None else source_length
|
|
|
|
expanded_mask = (
|
|
mask[:, None, None, :]
|
|
.expand(batch_size, 1, tgt_len, source_length)
|
|
.to(dtype)
|
|
)
|
|
|
|
inverted_mask = 1.0 - expanded_mask
|
|
|
|
return inverted_mask.masked_fill(
|
|
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
|
)
|
|
|
|
|
|
def _prepare_attn_mask(
|
|
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
|
):
|
|
# create causal mask
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
combined_attention_mask = None
|
|
if input_shape[-1] > 1:
|
|
combined_attention_mask = _make_causal_mask(
|
|
input_shape,
|
|
inputs_embeds.dtype,
|
|
past_key_values_length=past_key_values_length,
|
|
).to(attention_mask.device)
|
|
|
|
if attention_mask is not None:
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
expanded_attn_mask = _expand_mask(
|
|
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
|
)
|
|
combined_attention_mask = (
|
|
expanded_attn_mask
|
|
if combined_attention_mask is None
|
|
else expanded_attn_mask + combined_attention_mask
|
|
)
|
|
|
|
return combined_attention_mask
|
|
|
|
|
|
def _make_causal_mask(
|
|
input_ids_shape: torch.Size,
|
|
dtype: torch.dtype,
|
|
past_key_values_length: int = 0,
|
|
):
|
|
"""
|
|
Make causal mask used for bi-directional self-attention.
|
|
"""
|
|
batch_size, target_length = input_ids_shape
|
|
mask = torch.full((target_length, target_length), torch.finfo(dtype).min)
|
|
mask_cond = torch.arange(mask.size(-1))
|
|
intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1)
|
|
mask.masked_fill_(intermediate_mask, 0)
|
|
mask = mask.to(dtype)
|
|
|
|
if past_key_values_length > 0:
|
|
mask = torch.cat(
|
|
[
|
|
torch.zeros(
|
|
target_length, past_key_values_length, dtype=dtype
|
|
),
|
|
mask,
|
|
],
|
|
dim=-1,
|
|
)
|
|
expanded_mask = mask[None, None, :, :].expand(
|
|
batch_size, 1, target_length, target_length + past_key_values_length
|
|
)
|
|
return expanded_mask
|
|
|
|
|
|
if __name__ == "__main__":
|
|
working_dir = sys.argv[1]
|
|
layer_name = sys.argv[2]
|
|
will_compile = sys.argv[3]
|
|
device = sys.argv[4]
|
|
device_idx = sys.argv[5]
|
|
prompt = sys.argv[6]
|
|
|
|
if device_idx.lower().strip() == "none":
|
|
device_idx = None
|
|
else:
|
|
device_idx = int(device_idx)
|
|
|
|
if will_compile.lower().strip() == "true":
|
|
will_compile = True
|
|
else:
|
|
will_compile = False
|
|
|
|
f = open(f"{working_dir}/config.json")
|
|
config = json.load(f)
|
|
f.close()
|
|
|
|
layers_initialized = False
|
|
try:
|
|
n_embed = config["n_embed"]
|
|
except KeyError:
|
|
n_embed = config["hidden_size"]
|
|
vocab_size = config["vocab_size"]
|
|
n_layer = config["n_layer"]
|
|
try:
|
|
n_head = config["num_attention_heads"]
|
|
except KeyError:
|
|
n_head = config["n_head"]
|
|
|
|
if not os.path.isdir(working_dir):
|
|
os.mkdir(working_dir)
|
|
|
|
if layer_name == "start":
|
|
tokenizer = AutoTokenizer.from_pretrained(working_dir)
|
|
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
|
|
|
mlir_str = ""
|
|
|
|
if will_compile:
|
|
f = open(f"{working_dir}/word_embeddings.mlir", encoding="utf-8")
|
|
mlir_str = f.read()
|
|
f.close()
|
|
|
|
mlir_str = bytes(mlir_str, "utf-8")
|
|
|
|
shark_module = SharkInference(
|
|
mlir_str,
|
|
device="cpu",
|
|
mlir_dialect="tm_tensor",
|
|
device_idx=None,
|
|
)
|
|
|
|
if will_compile:
|
|
shark_module.save_module(
|
|
module_name=f"{working_dir}/word_embeddings",
|
|
extra_args=[
|
|
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
|
"--iree-stream-resource-max-allocation-size=1000000000",
|
|
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
|
],
|
|
)
|
|
|
|
shark_module.load_module(f"{working_dir}/word_embeddings.vmfb")
|
|
input_embeds = shark_module(
|
|
inputs=(input_ids,), function_name="forward"
|
|
)
|
|
input_embeds = torch.tensor(input_embeds).float()
|
|
|
|
mlir_str = ""
|
|
|
|
if will_compile:
|
|
f = open(
|
|
f"{working_dir}/word_embeddings_layernorm.mlir",
|
|
encoding="utf-8",
|
|
)
|
|
mlir_str = f.read()
|
|
f.close()
|
|
|
|
shark_module = SharkInference(
|
|
mlir_str,
|
|
device="cpu",
|
|
mlir_dialect="tm_tensor",
|
|
device_idx=None,
|
|
)
|
|
|
|
if will_compile:
|
|
shark_module.save_module(
|
|
module_name=f"{working_dir}/word_embeddings_layernorm",
|
|
extra_args=[
|
|
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
|
"--iree-stream-resource-max-allocation-size=1000000000",
|
|
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
|
],
|
|
)
|
|
|
|
shark_module.load_module(
|
|
f"{working_dir}/word_embeddings_layernorm.vmfb"
|
|
)
|
|
hidden_states = shark_module(
|
|
inputs=(input_embeds,), function_name="forward"
|
|
)
|
|
hidden_states = torch.tensor(hidden_states).float()
|
|
|
|
torch.save(hidden_states, f"{working_dir}/hidden_states_0.pt")
|
|
|
|
attention_mask = torch.ones(
|
|
[hidden_states.shape[0], len(input_ids[0])]
|
|
)
|
|
|
|
attention_mask = torch.tensor(attention_mask).float()
|
|
|
|
alibi = build_alibi_tensor(
|
|
attention_mask,
|
|
n_head,
|
|
hidden_states.dtype,
|
|
device="cpu",
|
|
)
|
|
|
|
torch.save(alibi, f"{working_dir}/alibi.pt")
|
|
|
|
causal_mask = _prepare_attn_mask(
|
|
attention_mask, input_ids.size(), input_embeds, 0
|
|
)
|
|
causal_mask = torch.tensor(causal_mask).float()
|
|
|
|
torch.save(causal_mask, f"{working_dir}/causal_mask.pt")
|
|
|
|
elif layer_name in [str(x) for x in range(n_layer)]:
|
|
hidden_states = torch.load(
|
|
f"{working_dir}/hidden_states_{layer_name}.pt"
|
|
)
|
|
alibi = torch.load(f"{working_dir}/alibi.pt")
|
|
causal_mask = torch.load(f"{working_dir}/causal_mask.pt")
|
|
|
|
mlir_str = ""
|
|
|
|
if will_compile:
|
|
f = open(
|
|
f"{working_dir}/bloom_block_{layer_name}.mlir",
|
|
encoding="utf-8",
|
|
)
|
|
mlir_str = f.read()
|
|
f.close()
|
|
|
|
mlir_str = bytes(mlir_str, "utf-8")
|
|
|
|
shark_module = SharkInference(
|
|
mlir_str,
|
|
device=device,
|
|
mlir_dialect="tm_tensor",
|
|
device_idx=device_idx,
|
|
)
|
|
|
|
if will_compile:
|
|
shark_module.save_module(
|
|
module_name=f"{working_dir}/bloom_block_{layer_name}",
|
|
extra_args=[
|
|
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
|
"--iree-stream-resource-max-allocation-size=1000000000",
|
|
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
|
],
|
|
)
|
|
|
|
shark_module.load_module(
|
|
f"{working_dir}/bloom_block_{layer_name}.vmfb"
|
|
)
|
|
|
|
output = shark_module(
|
|
inputs=(
|
|
hidden_states.detach().numpy(),
|
|
alibi.detach().numpy(),
|
|
causal_mask.detach().numpy(),
|
|
),
|
|
function_name="forward",
|
|
)
|
|
|
|
hidden_states = torch.tensor(output[0]).float()
|
|
|
|
torch.save(
|
|
hidden_states,
|
|
f"{working_dir}/hidden_states_{int(layer_name) + 1}.pt",
|
|
)
|
|
|
|
elif layer_name == "end":
|
|
mlir_str = ""
|
|
|
|
if will_compile:
|
|
f = open(f"{working_dir}/ln_f.mlir", encoding="utf-8")
|
|
mlir_str = f.read()
|
|
f.close()
|
|
|
|
mlir_str = bytes(mlir_str, "utf-8")
|
|
|
|
shark_module = SharkInference(
|
|
mlir_str,
|
|
device="cpu",
|
|
mlir_dialect="tm_tensor",
|
|
device_idx=None,
|
|
)
|
|
|
|
if will_compile:
|
|
shark_module.save_module(
|
|
module_name=f"{working_dir}/ln_f",
|
|
extra_args=[
|
|
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
|
"--iree-stream-resource-max-allocation-size=1000000000",
|
|
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
|
],
|
|
)
|
|
|
|
shark_module.load_module(f"{working_dir}/ln_f.vmfb")
|
|
|
|
hidden_states = torch.load(f"{working_dir}/hidden_states_{n_layer}.pt")
|
|
|
|
hidden_states = shark_module(
|
|
inputs=(hidden_states,), function_name="forward"
|
|
)
|
|
|
|
mlir_str = ""
|
|
|
|
if will_compile:
|
|
f = open(f"{working_dir}/lm_head.mlir", encoding="utf-8")
|
|
mlir_str = f.read()
|
|
f.close()
|
|
|
|
mlir_str = bytes(mlir_str, "utf-8")
|
|
|
|
if "n_embed" in config.keys() and config["n_embed"] == 14336:
|
|
|
|
def get_state_dict():
|
|
d = torch.load(
|
|
f"{working_dir}/pytorch_model_00001-of-00072.bin"
|
|
)
|
|
return OrderedDict(
|
|
(k.replace("word_embeddings.", ""), v)
|
|
for k, v in d.items()
|
|
)
|
|
|
|
def load_causal_lm_head():
|
|
linear = nn.utils.skip_init(
|
|
nn.Linear, 14336, 250880, bias=False, dtype=torch.float
|
|
)
|
|
linear.load_state_dict(get_state_dict(), strict=False)
|
|
return linear.float()
|
|
|
|
lm_head = load_causal_lm_head()
|
|
|
|
logits = lm_head(torch.tensor(hidden_states).float())
|
|
|
|
else:
|
|
shark_module = SharkInference(
|
|
mlir_str,
|
|
device="cpu",
|
|
mlir_dialect="tm_tensor",
|
|
device_idx=None,
|
|
)
|
|
|
|
if will_compile:
|
|
shark_module.save_module(
|
|
module_name=f"{working_dir}/lm_head",
|
|
extra_args=[
|
|
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
|
"--iree-stream-resource-max-allocation-size=1000000000",
|
|
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
|
],
|
|
)
|
|
|
|
shark_module.load_module(f"{working_dir}/lm_head.vmfb")
|
|
|
|
logits = shark_module(
|
|
inputs=(hidden_states,), function_name="forward"
|
|
)
|
|
|
|
logits = torch.tensor(logits).float()
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(working_dir)
|
|
|
|
next_token = tokenizer.decode(torch.argmax(logits[:, -1, :], dim=-1))
|
|
|
|
f = open(f"{working_dir}/prompt.txt", "w+")
|
|
f.write(prompt + next_token)
|
|
f.close()
|