diff --git a/.gitignore b/.gitignore index 9fc97820a1..b33e0ea231 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,4 @@ disassemblers/cuda_ioctl_sniffer datasets/cifar-10-python.tar.gz datasets/librispeech/ datasets/imagenet/ +datasets/squad/ diff --git a/datasets/squad.py b/datasets/squad.py new file mode 100644 index 0000000000..495b90c9f4 --- /dev/null +++ b/datasets/squad.py @@ -0,0 +1,148 @@ +import json +import os +from pathlib import Path +from transformers import BertTokenizer +import numpy as np +from extra.utils import download_file + +BASEDIR = Path(__file__).parent.parent / "datasets/squad" +def init_dataset(): + os.makedirs(BASEDIR, exist_ok=True) + download_file("https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json", BASEDIR / "dev-v1.1.json") + with open(BASEDIR / "dev-v1.1.json") as f: + data = json.load(f)["data"] + + examples = [] + for article in data: + for paragraph in article["paragraphs"]: + text = paragraph["context"] + doc_tokens = [] + prev_is_whitespace = True + for c in text: + if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: + prev_is_whitespace = True + else: + if prev_is_whitespace: + doc_tokens.append(c) + else: + doc_tokens[-1] += c + prev_is_whitespace = False + + for qa in paragraph["qas"]: + qa_id = qa["id"] + q_text = qa["question"] + + examples.append({ + "id": qa_id, + "question": q_text, + "context": doc_tokens, + "answers": list(map(lambda x: x["text"], qa["answers"])) + }) + return examples + +def _check_is_max_context(doc_spans, cur_span_index, position): + best_score, best_span_index = None, None + for di, (doc_start, doc_length) in enumerate(doc_spans): + end = doc_start + doc_length - 1 + if position < doc_start: + continue + if position > end: + continue + num_left_context = position - doc_start + num_right_context = end - position + score = min(num_left_context, num_right_context) + 0.01 * doc_length + if best_score is None or score > best_score: + best_score = score + best_span_index = di + return cur_span_index == best_span_index + +def convert_example_to_features(example, tokenizer): + query_tokens = tokenizer.tokenize(example["question"]) + + if len(query_tokens) > 64: + query_tokens = query_tokens[:64] + + tok_to_orig_index = [] + orig_to_tok_index = [] + all_doc_tokens = [] + for i, token in enumerate(example["context"]): + orig_to_tok_index.append(len(all_doc_tokens)) + sub_tokens = tokenizer.tokenize(token) + for sub_token in sub_tokens: + tok_to_orig_index.append(i) + all_doc_tokens.append(sub_token) + + max_tokens_for_doc = 384 - len(query_tokens) - 3 + + doc_spans = [] + start_offset = 0 + while start_offset < len(all_doc_tokens): + length = len(all_doc_tokens) - start_offset + length = min(length, max_tokens_for_doc) + doc_spans.append((start_offset, length)) + if start_offset + length == len(all_doc_tokens): + break + start_offset += min(length, 128) + + outputs = [] + for di, (doc_start, doc_length) in enumerate(doc_spans): + tokens = [] + token_to_orig_map = {} + token_is_max_context = {} + segment_ids = [] + tokens.append("[CLS]") + segment_ids.append(0) + for token in query_tokens: + tokens.append(token) + segment_ids.append(0) + tokens.append("[SEP]") + segment_ids.append(0) + + for i in range(doc_length): + split_token_index = doc_start + i + token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] + token_is_max_context[len(tokens)] = _check_is_max_context(doc_spans, di, split_token_index) + tokens.append(all_doc_tokens[split_token_index]) + segment_ids.append(1) + tokens.append("[SEP]") + segment_ids.append(1) + + input_ids = tokenizer.convert_tokens_to_ids(tokens) + input_mask = [1] * len(input_ids) + + while len(input_ids) < 384: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + + assert len(input_ids) == 384 + assert len(input_mask) == 384 + assert len(segment_ids) == 384 + + outputs.append({ + "input_ids": np.expand_dims(np.array(input_ids), 0).astype(np.float32), + "input_mask": np.expand_dims(np.array(input_mask), 0).astype(np.float32), + "segment_ids": np.expand_dims(np.array(segment_ids), 0).astype(np.float32), + "token_to_orig_map": token_to_orig_map, + "token_is_max_context": token_is_max_context, + "tokens": tokens, + }) + + return outputs + +def iterate(tokenizer, start=0): + examples = init_dataset() + print(f"there are {len(examples)} pairs in the dataset") + + for i in range(start, len(examples)): + example = examples[i] + features = convert_example_to_features(example, tokenizer) + # we need to yield all features here as the f1 score is the maximum over all features + yield features, example + +if __name__ == "__main__": + tokenizer = BertTokenizer(str(Path(__file__).parent.parent / "weights/bert_vocab.txt")) + + X, Y = next(iterate(tokenizer)) + print(" ".join(X[0]["tokens"])) + print(X[0]["input_ids"].shape, Y) diff --git a/examples/mlperf/helpers.py b/examples/mlperf/helpers.py new file mode 100644 index 0000000000..454edca64d --- /dev/null +++ b/examples/mlperf/helpers.py @@ -0,0 +1,134 @@ +from collections import OrderedDict +import unicodedata + +def _get_best_indices(logits, n_best_size): + index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) + return list(map(lambda x: x[0], index_and_score))[:n_best_size] + +def _is_punctuation(char): + if (cp := ord(char)) in range(33, 48) or cp in range(58, 65) or cp in range(91, 97) or cp in range(123, 127): + return True + return unicodedata.category(char).startswith("P") + +def _is_whitespace(char): + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + return unicodedata.category(char) == "Zs" + +def _is_control(char): + if char == "\t" or char == "\n" or char == "\r": + return False + return unicodedata.category(char).startswith("C") + +def _run_split_on_punc(text): + if text in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"): + return [text] + start_new_word = True + output = [] + for i in range(len(text)): + if _is_punctuation(char := text[i]): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + return ["".join(x) for x in output] + +def _run_strip_accents(text): + output = [] + for char in unicodedata.normalize("NFD", text): + if unicodedata.category(char) != "Mn": + output.append(char) + return "".join(output) + +def _clean_text(text): + output = [] + for char in text: + if not ((cp := ord(char)) == 0 or cp == 0xfffd or _is_control(char)): + output.append(" " if _is_whitespace(char) else char) + return "".join(output) + +def _get_final_text(pred_text, orig_text): + def _strip_spaces(text): + ns_text = "" + ns_to_s_map = OrderedDict() + for i, c in enumerate(text): + if c == " ": + continue + ns_to_s_map[len(ns_text)] = i + ns_text += c + return ns_text, ns_to_s_map + + orig_tokens = _clean_text(orig_text).strip().split() + split_tokens = [] + for token in orig_tokens: + if token not in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"): + token = token.lower() + token = _run_strip_accents(token) + split_tokens.extend(_run_split_on_punc(token)) + + tok_text = " ".join(" ".join(split_tokens).strip().split()) + start_position = tok_text.find(pred_text) + if start_position == -1: + return orig_text + end_position = start_position + len(pred_text) - 1 + + orig_ns_text, orig_ns_to_s_map = _strip_spaces(orig_text) + tok_ns_text, tok_ns_to_s_map = _strip_spaces(tok_text) + if len(orig_ns_text) != len(tok_ns_text): + return orig_text + tok_s_to_ns_map = {v: k for k, v in tok_ns_to_s_map.items()} + + orig_start_position = None + if start_position in tok_s_to_ns_map: + if (ns_start_position := tok_s_to_ns_map[start_position]) in orig_ns_to_s_map: + orig_start_position = orig_ns_to_s_map[ns_start_position] + if orig_start_position is None: + return orig_text + + orig_end_position = None + if end_position in tok_s_to_ns_map: + if (ns_end_position := tok_s_to_ns_map[end_position]) in orig_ns_to_s_map: + orig_end_position = orig_ns_to_s_map[ns_end_position] + if orig_end_position is None: + return orig_text + + output_text = orig_text[orig_start_position:(orig_end_position + 1)] + return output_text + +def get_bert_qa_prediction(features, example, start_end_logits): + prelim_predictions = [] + for i, feature in enumerate(features): + for start_index in _get_best_indices(start_end_logits[i][0], 20): + for end_index in _get_best_indices(start_end_logits[i][1], 20): + if start_index >= len(feature["tokens"]) or end_index >= len(feature["tokens"]): + continue + if start_index not in feature["token_to_orig_map"] or end_index not in feature["token_to_orig_map"]: + continue + if not feature["token_is_max_context"].get(start_index, False): + continue + if end_index < start_index or end_index - start_index + 1 > 30: + continue + + prelim_predictions.append({ + "feature_index": i, + "start_index": start_index, + "end_index": end_index, + "start_logit": start_end_logits[i][0, start_index], + "end_logit": start_end_logits[i][1, end_index] + }) + predictions = sorted(prelim_predictions, key=lambda x: (x["start_logit"] + x["end_logit"]), reverse=True) + + if len(predictions) > 0: + feature = features[predictions[0]["feature_index"]] + tok_tokens = feature["tokens"][predictions[0]["start_index"]:(predictions[0]["end_index"] + 1)] + orig_doc_start = feature["token_to_orig_map"][predictions[0]["start_index"]] + orig_doc_end = feature["token_to_orig_map"][predictions[0]["end_index"]] + orig_tokens = example["context"][orig_doc_start:(orig_doc_end + 1)] + tok_text = " ".join(tok_tokens).replace(" ##", "").replace("##", "") + tok_text = " ".join(tok_text.strip().split()) + orig_text = " ".join(orig_tokens) + return _get_final_text(tok_text, orig_text) + return "empty" diff --git a/examples/mlperf/metrics.py b/examples/mlperf/metrics.py index 4644ec12db..e4ac9f9f15 100644 --- a/examples/mlperf/metrics.py +++ b/examples/mlperf/metrics.py @@ -1,3 +1,7 @@ +import re +import string +from collections import Counter + def levenshtein(a, b): n, m = len(a), len(b) if n > m: @@ -23,3 +27,18 @@ def word_error_rate(x, y): words += len(r_list) scores += levenshtein(h_list, r_list) return float(scores) / words, float(scores), words + +def normalize_string(s): + s = "".join(c for c in s.lower() if c not in string.punctuation) + s = re.sub(r'\b(a|an|the)\b', ' ', s) + return " ".join(s.split()) + +def f1_score(x, y): + xt = normalize_string(x).split() + yt = normalize_string(y).split() + ct = Counter(xt) & Counter(yt) + if (ns := sum(ct.values())) == 0: + return 0.0 + p = ns / len(xt) + r = ns / len(yt) + return 2 * p * r / (p + r) diff --git a/examples/mlperf/model_eval.py b/examples/mlperf/model_eval.py index a9fc3c8ff9..ef66fd6834 100644 --- a/examples/mlperf/model_eval.py +++ b/examples/mlperf/model_eval.py @@ -1,6 +1,8 @@ import time +from pathlib import Path import numpy as np from tinygrad.tensor import Tensor +from tinygrad.jit import TinyJit from tinygrad.helpers import getenv def eval_resnet(): @@ -69,6 +71,42 @@ def eval_rnnt(): print(f"WER: {scores/words}, {words} words, raw scores: {scores}, c: {c}") st = time.perf_counter() +def eval_bert(): + # Bert-QA + from models.bert import BertForQuestionAnswering + mdl = BertForQuestionAnswering() + mdl.load_from_pretrained() + + @TinyJit + def run(input_ids, input_mask, segment_ids): + return mdl(input_ids, input_mask, segment_ids).realize() + + from datasets.squad import iterate + from examples.mlperf.helpers import get_bert_qa_prediction + from examples.mlperf.metrics import f1_score + from transformers import BertTokenizer + + tokenizer = BertTokenizer(str(Path(__file__).parent.parent.parent / "weights/bert_vocab.txt")) + + c = 0 + f1 = 0.0 + st = time.perf_counter() + for X, Y in iterate(tokenizer): + mt = time.perf_counter() + outs = [] + for x in X: + outs.append(run(Tensor(x["input_ids"]), Tensor(x["input_mask"]), Tensor(x["segment_ids"])).numpy()) + et = time.perf_counter() + print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model over {len(X)} features") + + pred = get_bert_qa_prediction(X, Y, outs) + print(f"pred: {pred}\nans: {Y['answers']}") + f1 += max([f1_score(pred, ans) for ans in Y["answers"]]) + c += 1 + print(f"f1: {f1/c}, raw: {f1}, c: {c}\n") + + st = time.perf_counter() + if __name__ == "__main__": # inference only Tensor.training = False diff --git a/examples/mlperf/model_spec.py b/examples/mlperf/model_spec.py index 3ab91937f6..6b7c37a9e4 100644 --- a/examples/mlperf/model_spec.py +++ b/examples/mlperf/model_spec.py @@ -1,6 +1,7 @@ # load each model here, quick benchmark from tinygrad.tensor import Tensor from tinygrad.helpers import GlobalCounters, getenv +import numpy as np def test_model(model, *inputs): GlobalCounters.reset() @@ -35,8 +36,13 @@ def spec_rnnt(): test_model(mdl, x, y) def spec_bert(): - # TODO: BERT-large - pass + from models.bert import BertForQuestionAnswering + mdl = BertForQuestionAnswering() + mdl.load_from_pretrained() + x = Tensor.randn(1, 384) + am = Tensor.randn(1, 384) + tt = Tensor(np.random.randint(0, 2, (1, 384)).astype(np.float32)) + test_model(mdl, x, am, tt) if __name__ == "__main__": # inference only for now diff --git a/models/bert.py b/models/bert.py new file mode 100644 index 0000000000..0a06c82cab --- /dev/null +++ b/models/bert.py @@ -0,0 +1,178 @@ +from tinygrad.tensor import Tensor +from tinygrad.jit import TinyJit +from tinygrad.nn import Linear, LayerNorm, Embedding +import numpy as np +from extra.utils import download_file, get_child +from pathlib import Path + + +class BertForQuestionAnswering: + def __init__(self, hidden_size=1024, intermediate_size=4096, max_position_embeddings=512, num_attention_heads=16, num_hidden_layers=24, type_vocab_size=2, vocab_size=30522, attention_probs_dropout_prob=0.1, hidden_dropout_prob=0.1): + self.bert = Bert(hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob) + self.qa_outputs = Linear(hidden_size, 2) + + def load_from_pretrained(self): + fn = Path(__file__).parent.parent / "weights/bert_for_qa.pt" + download_file("https://zenodo.org/record/3733896/files/model.pytorch?download=1", fn) + fn_vocab = Path(__file__).parent.parent / "weights/bert_vocab.txt" + download_file("https://zenodo.org/record/3733896/files/vocab.txt?download=1", fn_vocab) + + import torch + with open(fn, "rb") as f: + state_dict = torch.load(f, map_location="cpu") + + for k, v in state_dict.items(): + if "dropout" in k: continue # skip dropout + if "pooler" in k: continue # skip pooler + get_child(self, k).assign(v.numpy()).realize() + + def __call__(self, input_ids:Tensor, attention_mask:Tensor, token_type_ids:Tensor): + sequence_output = self.bert(input_ids, attention_mask, token_type_ids) + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.chunk(2, dim=-1) + start_logits = start_logits.reshape(-1, 1) + end_logits = end_logits.reshape(-1, 1) + + return Tensor.stack([start_logits, end_logits]) + +class Bert: + def __init__(self, hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob): + self.embeddings = BertEmbeddings(hidden_size, max_position_embeddings, type_vocab_size, vocab_size, hidden_dropout_prob) + self.encoder = BertEncoder(hidden_size, intermediate_size, num_attention_heads, num_hidden_layers, attention_probs_dropout_prob, hidden_dropout_prob) + + def __call__(self, input_ids, attention_mask, token_type_ids): + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings(input_ids, token_type_ids) + encoder_outputs = self.encoder(embedding_output, extended_attention_mask) + + return encoder_outputs + +class BertEmbeddings: + def __init__(self, hidden_size, max_position_embeddings, type_vocab_size, vocab_size, hidden_dropout_prob): + self.word_embeddings = Embedding(vocab_size, hidden_size) + self.position_embeddings = Embedding(max_position_embeddings, hidden_size) + self.token_type_embeddings = Embedding(type_vocab_size, hidden_size) + self.LayerNorm = LayerNorm(hidden_size, eps=1e-12) + self.dropout = hidden_dropout_prob + + def __call__(self, input_ids, token_type_ids): + input_shape = input_ids.shape + seq_length = input_shape[1] + + position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape) + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = embeddings.dropout(self.dropout) + return embeddings + +class BertEncoder: + def __init__(self, hidden_size, intermediate_size, num_attention_heads, num_hidden_layers, attention_probs_dropout_prob, hidden_dropout_prob): + self.layer = [BertLayer(hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob) for _ in range(num_hidden_layers)] + + def __call__(self, hidden_states, attention_mask): + for layer in self.layer: + hidden_states = layer(hidden_states, attention_mask) + return hidden_states + +class BertLayer: + def __init__(self, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): + self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob) + self.intermediate = BertIntermediate(hidden_size, intermediate_size) + self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) + + def __call__(self, hidden_states, attention_mask): + attention_output = self.attention(hidden_states, attention_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + +class BertOutput: + def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): + self.dense = Linear(intermediate_size, hidden_size) + self.LayerNorm = LayerNorm(hidden_size, eps=1e-12) + self.dropout = hidden_dropout_prob + + def __call__(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = hidden_states.dropout(self.dropout) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + +# approixmation of the error function +def erf(x): + t = (1 + 0.3275911 * x.abs()).reciprocal() + return x.sign() * (1 - ((((1.061405429 * t + -1.453152027) * t + 1.421413741) * t + -0.284496736) * t + 0.254829592) * t * (-(x.square())).exp()) + +class BertIntermediate: + def __init__(self, hidden_size, intermediate_size): + self.dense = Linear(hidden_size, intermediate_size) + + def __call__(self, hidden_states): + x = self.dense(hidden_states) + # tinygrad gelu is openai gelu but we need the original bert gelu + return x * 0.5 * (1.0 + erf(x / 1.41421)) + +class BertAttention: + def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): + self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) + self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) + + def __call__(self, hidden_states, attention_mask): + self_output = self.self(hidden_states, attention_mask) + attention_output = self.output(self_output, hidden_states) + return attention_output + +class BertSelfAttention: + def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = Linear(hidden_size, self.all_head_size) + self.key = Linear(hidden_size, self.all_head_size) + self.value = Linear(hidden_size, self.all_head_size) + + self.dropout = attention_probs_dropout_prob + + def __call__(self, hidden_states, attention_mask): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = query_layer @ key_layer.transpose(2, 3) + attention_scores = attention_scores / self.attention_head_size**0.5 + attention_scores = attention_scores + attention_mask + attention_probs = attention_scores.softmax() + attention_probs = attention_probs.dropout(self.dropout) + + context_layer = attention_probs @ value_layer + context_layer = context_layer.transpose(1, 2) + context_layer = context_layer.reshape(context_layer.shape[0], context_layer.shape[1], self.all_head_size) + + return context_layer + + def transpose_for_scores(self, x): + x = x.reshape(x.shape[0], x.shape[1], self.num_attention_heads, self.attention_head_size) + return x.transpose(1, 2) + +class BertSelfOutput: + def __init__(self, hidden_size, hidden_dropout_prob): + self.dense = Linear(hidden_size, hidden_size) + self.LayerNorm = LayerNorm(hidden_size, eps=1e-12) + self.dropout = hidden_dropout_prob + + def __call__(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = hidden_states.dropout(self.dropout) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states