mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
Wikipedia preprocessing script (#4229)
* Preprocessing script * short seq prob * comments + env vars * Add preprocessing reference. Add test * lint fix + add eval test support * whitespaces * point to commit * comment * rename * better comments
This commit is contained in:
411
extra/datasets/wikipedia.py
Normal file
411
extra/datasets/wikipedia.py
Normal file
@@ -0,0 +1,411 @@
|
||||
# Preprocessing of downloaded text from Wikipedia for MLPerf BERT training
|
||||
# This is a modified version of the original script:
|
||||
# https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/cleanup_scripts/create_pretraining_data.py
|
||||
# ENV VARS:
|
||||
# MAX_SEQ_LENGTH - Maximum sequence length
|
||||
# MAX_PREDICTIONS_PER_SEQ - Maximum number of masked LM predictions per sequence
|
||||
# RANDOM_SEED - Random seed
|
||||
# DUPE_FACTOR - Number of times to duplicate the input data with different masks
|
||||
# MASKED_LM_PROB - Probability of masking a token
|
||||
# SHORT_SEQ_PROB - Probability of picking a sequence shorter than MAX_SEQ_LENGTH
|
||||
|
||||
import os, sys, pickle, random, unicodedata
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from tqdm.contrib.concurrent import process_map
|
||||
|
||||
from tinygrad.helpers import diskcache, getenv
|
||||
|
||||
BASEDIR = getenv('BASEDIR', Path(__file__).parent / "wiki")
|
||||
|
||||
################### Tokenization #####################
|
||||
|
||||
def _is_whitespace(char:str) -> bool:
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
return unicodedata.category(char) == "Zs"
|
||||
|
||||
def _is_control(char:str) -> bool:
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
return unicodedata.category(char).startswith("C")
|
||||
|
||||
def _is_punctuation(char:str) -> bool:
|
||||
# range(33, 48) -> ! " # $ % & ' ( ) * + , - . /
|
||||
# range(58, 65) -> : ; < = > ? @
|
||||
# range(91, 97) -> [ \ ] ^ _
|
||||
# range(123, 127) -> { | } ~
|
||||
if (cp := ord(char)) in range(33, 48) or cp in range(58, 65) or cp in range(91, 97) or cp in range(123, 127):
|
||||
return True
|
||||
return unicodedata.category(char).startswith("P")
|
||||
|
||||
def _is_chinese_char(cp:int) -> bool:
|
||||
if ((cp >= 0x4E00 and cp <= 0x9FFF) or
|
||||
(cp >= 0x3400 and cp <= 0x4DBF) or
|
||||
(cp >= 0x20000 and cp <= 0x2A6DF) or
|
||||
(cp >= 0x2A700 and cp <= 0x2B73F) or
|
||||
(cp >= 0x2B740 and cp <= 0x2B81F) or
|
||||
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
||||
(cp >= 0xF900 and cp <= 0xFAFF) or
|
||||
(cp >= 0x2F800 and cp <= 0x2FA1F)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _run_split_on_punc(text:str) -> list[str]:
|
||||
if text in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"):
|
||||
return [text]
|
||||
start_new_word = True
|
||||
output = []
|
||||
for i in range(len(text)):
|
||||
if _is_punctuation(char := text[i]):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
def _run_strip_accents(text:str) -> str:
|
||||
output = []
|
||||
for char in unicodedata.normalize("NFD", text):
|
||||
if unicodedata.category(char) != "Mn":
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _clean_text(text:str) -> str:
|
||||
output = []
|
||||
for char in text:
|
||||
if not ((cp := ord(char)) == 0 or cp == 0xfffd or _is_control(char)):
|
||||
output.append(" " if _is_whitespace(char) else char)
|
||||
return "".join(output)
|
||||
|
||||
def _tokenize_chinese_chars(text:str) -> str:
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if _is_chinese_char(cp):
|
||||
output.append(" ")
|
||||
output.append(char)
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def whitespace_tokenize(text):
|
||||
if not (text := text.strip()): return []
|
||||
return text.split()
|
||||
|
||||
def _wordpiece_tokenize(text:str, vocab:dict[str, int]) -> list[str]:
|
||||
text = text.decode("utf-8", "ignore") if isinstance(text, bytes) else text
|
||||
output_tokens = []
|
||||
for token in text.strip().split():
|
||||
chars = list(token)
|
||||
if len(chars) > 200:
|
||||
output_tokens.append("[UNK]")
|
||||
continue
|
||||
|
||||
is_bad = False
|
||||
start = 0
|
||||
sub_tokens = []
|
||||
while start < len(chars):
|
||||
end = len(chars)
|
||||
cur_substr = None
|
||||
while start < end:
|
||||
substr = "".join(chars[start:end])
|
||||
if start > 0: substr = "##" + substr
|
||||
if substr in vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
end -= 1
|
||||
if cur_substr is None:
|
||||
is_bad = True
|
||||
break
|
||||
sub_tokens.append(cur_substr)
|
||||
start = end
|
||||
|
||||
if is_bad: output_tokens.append("[UNK]")
|
||||
else: output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, vocab_file):
|
||||
self.vocab = {}
|
||||
with open(vocab_file) as f:
|
||||
for line in f:
|
||||
line = line.decode("utf-8", "ignore") if isinstance(line, bytes) else line
|
||||
if (token := line.strip()) and token not in self.vocab: self.vocab[token] = len(self.vocab)
|
||||
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
||||
|
||||
def tokenize(self, text:str) -> list[str]:
|
||||
# BasicTokenizer
|
||||
split_tokens = []
|
||||
for token in whitespace_tokenize(_tokenize_chinese_chars(_clean_text(text.decode("utf-8", "ignore") if isinstance(text, bytes) else text))):
|
||||
split_tokens.extend(_run_split_on_punc(_run_strip_accents(token.lower())))
|
||||
split_tokens = " ".join(split_tokens).strip().split()
|
||||
# WordpieceTokenizer
|
||||
tokens = []
|
||||
for token in split_tokens:
|
||||
tokens.extend(_wordpiece_tokenize(token, self.vocab))
|
||||
return tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens:list[str]) -> list[int]: return [self.vocab[token] for token in tokens]
|
||||
def convert_ids_to_tokens(self, ids:list[int]) -> list[str]: return [self.inv_vocab[id] for id in ids]
|
||||
|
||||
##################### Feature transformation #####################
|
||||
|
||||
def truncate_seq_pair(tokens_a:list[str], tokens_b:list[str], max_num_tokens:int, rng:random.Random) -> None:
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_num_tokens:
|
||||
break
|
||||
|
||||
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
|
||||
assert len(trunc_tokens) >= 1
|
||||
|
||||
if rng.random() < 0.5:
|
||||
del trunc_tokens[0]
|
||||
else:
|
||||
trunc_tokens.pop()
|
||||
|
||||
def create_masked_lm_predictions(tokens:list[str], tokenizer:Tokenizer, rng:random.Random, vocab_words:list[str]) -> tuple[list[str], list[int], list[str]]:
|
||||
cand_indices = []
|
||||
for i, token in enumerate(tokens):
|
||||
if token == "[CLS]" or token == "[SEP]":
|
||||
continue
|
||||
cand_indices.append(i)
|
||||
|
||||
rng.shuffle(cand_indices)
|
||||
output_tokens = list(tokens)
|
||||
num_to_predict = min(getenv('MAX_PREDICTIONS_PER_SEQ', 76), max(1, int(round(len(tokens) * getenv("MASKED_LM_PROB", 0.15)))))
|
||||
|
||||
masked_lms = []
|
||||
covered_indices = set()
|
||||
for index in cand_indices:
|
||||
if len(masked_lms) >= num_to_predict:
|
||||
break
|
||||
if index in covered_indices:
|
||||
continue
|
||||
covered_indices.add(index)
|
||||
|
||||
masked_token = None
|
||||
if rng.random() < 0.8:
|
||||
masked_token = "[MASK]"
|
||||
else:
|
||||
if rng.random() < 0.5:
|
||||
masked_token = tokens[index]
|
||||
else:
|
||||
masked_token = vocab_words[rng.randint(0, len(tokenizer.vocab) - 1)]
|
||||
|
||||
output_tokens[index] = masked_token
|
||||
masked_lms.append((index, tokens[index]))
|
||||
masked_lms = sorted(masked_lms, key=lambda x: x[0])
|
||||
|
||||
masked_lm_positions = []
|
||||
masked_lm_labels = []
|
||||
for p in masked_lms:
|
||||
masked_lm_positions.append(p[0])
|
||||
masked_lm_labels.append(p[1])
|
||||
|
||||
return output_tokens, masked_lm_positions, masked_lm_labels
|
||||
|
||||
def create_instances_from_document(rng:random.Random, tokenizer:Tokenizer, doc:list[str], di:int, documents:list[list[str]]) -> list[dict]:
|
||||
max_num_tokens = getenv('MAX_SEQ_LENGTH', 512) - 3 # [CLS] + 2 * [SEP]
|
||||
|
||||
target_seq_length = max_num_tokens
|
||||
if rng.random() < getenv("SHORT_SEQ_PROB", 0.1):
|
||||
target_seq_length = rng.randint(2, max_num_tokens)
|
||||
|
||||
instances = []
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
i = 0
|
||||
while i < len(doc):
|
||||
segment = doc[i]
|
||||
current_chunk.append(segment)
|
||||
current_length += len(segment)
|
||||
if i == len(doc) - 1 or current_length >= target_seq_length:
|
||||
if current_chunk:
|
||||
a_end = 1
|
||||
if len(current_chunk) >= 2:
|
||||
a_end = rng.randint(1, len(current_chunk) - 1)
|
||||
|
||||
tokens_a = []
|
||||
for j in range(a_end):
|
||||
tokens_a.extend(current_chunk[j])
|
||||
|
||||
tokens_b = []
|
||||
is_random_next = False
|
||||
if len(current_chunk) == 1 or rng.random() < 0.5:
|
||||
is_random_next = True
|
||||
target_b_length = target_seq_length - len(tokens_a)
|
||||
|
||||
for _ in range(10):
|
||||
random_document_index = rng.randint(0, len(documents) - 1)
|
||||
if random_document_index != di:
|
||||
break
|
||||
|
||||
random_document = documents[random_document_index]
|
||||
random_start = rng.randint(0, len(random_document) - 1)
|
||||
for j in range(random_start, len(random_document)):
|
||||
tokens_b.extend(random_document[j])
|
||||
if len(tokens_b) >= target_b_length:
|
||||
break
|
||||
|
||||
num_unused_segments = len(current_chunk) - a_end
|
||||
i -= num_unused_segments
|
||||
else:
|
||||
is_random_next = False
|
||||
for j in range(a_end, len(current_chunk)):
|
||||
tokens_b.extend(current_chunk[j])
|
||||
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
|
||||
|
||||
assert len(tokens_a) >= 1
|
||||
assert len(tokens_b) >= 1
|
||||
|
||||
tokens = []
|
||||
segment_ids = []
|
||||
tokens.append("[CLS]")
|
||||
segment_ids.append(0)
|
||||
for token in tokens_a:
|
||||
tokens.append(token)
|
||||
segment_ids.append(0)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
for token in tokens_b:
|
||||
tokens.append(token)
|
||||
segment_ids.append(1)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(1)
|
||||
|
||||
tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions(tokens, tokenizer, rng, list(tokenizer.vocab.keys()))
|
||||
instances.append({
|
||||
"tokens": tokens,
|
||||
"segment_ids": segment_ids,
|
||||
"masked_lm_positions": masked_lm_positions,
|
||||
"masked_lm_labels": masked_lm_labels,
|
||||
"is_random_next": is_random_next
|
||||
})
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
i += 1
|
||||
return instances
|
||||
|
||||
def get_documents(rng:random.Random, tokenizer:Tokenizer, fn:str) -> list[list[str]]:
|
||||
documents = [[]]
|
||||
with open(BASEDIR / fn) as f:
|
||||
for line in f.readlines():
|
||||
if not (line := line.decode("utf-8", "ignore") if isinstance(line, bytes) else line): break
|
||||
if not (line := line.strip()): documents.append([])
|
||||
if (tokens := tokenizer.tokenize(line)): documents[-1].append(tokens)
|
||||
documents = [x for x in documents if x]
|
||||
rng.shuffle(documents)
|
||||
return documents
|
||||
|
||||
def get_instances(rng:random.Random, tokenizer:Tokenizer, documents:list[list[str]]) -> list[dict]:
|
||||
instances = []
|
||||
for _ in range(getenv('DUPE_FACTOR', 10)):
|
||||
for di, doc in enumerate(documents):
|
||||
instances.extend(create_instances_from_document(rng, tokenizer, doc, di, documents))
|
||||
rng.shuffle(instances)
|
||||
return instances
|
||||
|
||||
def instance_to_features(instance:dict, tokenizer:Tokenizer) -> dict:
|
||||
input_ids = tokenizer.convert_tokens_to_ids(instance["tokens"])
|
||||
input_mask = [1] * len(input_ids)
|
||||
segment_ids = instance["segment_ids"]
|
||||
|
||||
max_seq_length = getenv('MAX_SEQ_LENGTH', 512)
|
||||
|
||||
assert len(input_ids) <= max_seq_length
|
||||
while len(input_ids) < max_seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
|
||||
masked_lm_positions = instance["masked_lm_positions"]
|
||||
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance["masked_lm_labels"])
|
||||
masked_lm_weights = [1.0] * len(masked_lm_ids)
|
||||
|
||||
while len(masked_lm_positions) < getenv("MAX_PREDICTIONS_PER_SEQ", 76):
|
||||
masked_lm_positions.append(0)
|
||||
masked_lm_ids.append(0)
|
||||
masked_lm_weights.append(0.0)
|
||||
|
||||
next_sentence_label = 1 if instance["is_random_next"] else 0
|
||||
|
||||
return {
|
||||
"input_ids": np.expand_dims(np.array(input_ids, dtype=np.int32), 0),
|
||||
"input_mask": np.expand_dims(np.array(input_mask, dtype=np.int32), 0),
|
||||
"segment_ids": np.expand_dims(np.array(segment_ids, dtype=np.int32), 0),
|
||||
"masked_lm_positions": np.expand_dims(np.array(masked_lm_positions, dtype=np.int32), 0),
|
||||
"masked_lm_ids": np.expand_dims(np.array(masked_lm_ids, dtype=np.int32), 0),
|
||||
"masked_lm_weights": np.expand_dims(np.array(masked_lm_weights, dtype=np.float32), 0),
|
||||
"next_sentence_labels": np.expand_dims(np.array([next_sentence_label], dtype=np.int32), 0),
|
||||
}
|
||||
|
||||
def process_part(part:int):
|
||||
tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt")
|
||||
os.makedirs(BASEDIR / "train" / str(part), exist_ok=True)
|
||||
for i, feature_batch in enumerate(process_iterate(tokenizer, val=False, part=part)):
|
||||
with open(BASEDIR / f"train/{str(part)}/{part}_{i}.pkl", "wb") as f:
|
||||
pickle.dump(feature_batch, f)
|
||||
|
||||
def process_iterate(tokenizer:Tokenizer, val:bool=False, part:int=0) -> list[dict]: # Convert raw text to masked NSP samples
|
||||
rng = random.Random(getenv('RANDOM_SEED', 12345))
|
||||
|
||||
if val:
|
||||
tqdm.write("Getting samples from dataset")
|
||||
documents = get_documents(rng, tokenizer, "results4/eval.txt")
|
||||
instances = get_instances(rng, tokenizer, documents)
|
||||
|
||||
tqdm.write(f"There are {len(instances)} samples in the dataset")
|
||||
tqdm.write(f"Picking 10000 samples")
|
||||
|
||||
pick_ratio = len(instances) / 10000
|
||||
picks = [instance_to_features(instances[int(inst*pick_ratio)], tokenizer) for inst in range(10000)]
|
||||
for batch in range(10):
|
||||
yield picks[batch*1000:(batch+1)*1000]
|
||||
else:
|
||||
documents = get_documents(rng, tokenizer, f"results4/part-{part:05d}-of-00500")
|
||||
instances = get_instances(rng, tokenizer, documents)
|
||||
|
||||
while len(instances) > 0:
|
||||
batch_size = min(1000, len(instances)) # We batch 1000 samples to one file
|
||||
batch = instances[:batch_size]
|
||||
del instances[:batch_size]
|
||||
yield [instance_to_features(instance, tokenizer) for instance in batch]
|
||||
|
||||
##################### Load files #####################
|
||||
|
||||
def get_wiki_val_files(): return sorted(list((BASEDIR / "eval/").glob("*.pkl")))
|
||||
|
||||
@diskcache
|
||||
def get_wiki_train_files(): return sorted(list((BASEDIR / "train/").glob("*/*.pkl")))
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt")
|
||||
|
||||
assert len(sys.argv) > 1, "Usage: python wikipedia.py pre-eval|pre-train [part]|all"
|
||||
|
||||
if sys.argv[1] == "pre-eval": # Generate 10000 eval samples
|
||||
os.makedirs(BASEDIR / "eval", exist_ok=True)
|
||||
|
||||
for i, feature_batch in tqdm(enumerate(process_iterate(tokenizer, val=True)), total=10):
|
||||
with open(BASEDIR / f"eval/{i}.pkl", "wb") as f:
|
||||
pickle.dump(feature_batch, f)
|
||||
elif sys.argv[1] == "pre-train":
|
||||
os.makedirs(BASEDIR / "train", exist_ok=True)
|
||||
if sys.argv[2] == "all": # Use all 500 parts for training generation
|
||||
process_map(process_part, [part for part in range(500)], max_workers=getenv('NUM_WORKERS', os.cpu_count()), chunksize=1)
|
||||
else: # Use a specific part for training generation
|
||||
part = int(sys.argv[2])
|
||||
os.makedirs(BASEDIR / "train" / str(part), exist_ok=True)
|
||||
for i, feature_batch in tqdm(enumerate(process_iterate(tokenizer, val=False, part=part))):
|
||||
with open(BASEDIR / f"train/{str(part)}/{part}_{i}.pkl", "wb") as f:
|
||||
pickle.dump(feature_batch, f)
|
||||
435
test/external/mlperf_bert/preprocessing/create_pretraining_data.py
vendored
Normal file
435
test/external/mlperf_bert/preprocessing/create_pretraining_data.py
vendored
Normal file
@@ -0,0 +1,435 @@
|
||||
# https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/cleanup_scripts/create_pretraining_data.py
|
||||
# NOTE: This is a direct copy of the original script
|
||||
# NOTE: With python 3.7.12, pip install tensorflow=1.15.5
|
||||
"""Create masked LM/next sentence masked_lm TF examples for BERT."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' # NOTE: This is a workaround for protobuf issue
|
||||
|
||||
import collections
|
||||
import random
|
||||
import tokenization
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
flags = tf.flags
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string("input_file", None,
|
||||
"Input raw text file (or comma-separated list of files).")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"output_file", None,
|
||||
"Output TF example file (or comma-separated list of files).")
|
||||
|
||||
flags.DEFINE_string("vocab_file", None,
|
||||
"The vocabulary file that the BERT model was trained on.")
|
||||
|
||||
flags.DEFINE_bool(
|
||||
"do_lower_case", True,
|
||||
"Whether to lower case the input text. Should be True for uncased "
|
||||
"models and False for cased models.")
|
||||
|
||||
flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
|
||||
|
||||
flags.DEFINE_integer("max_predictions_per_seq", 20,
|
||||
"Maximum number of masked LM predictions per sequence.")
|
||||
|
||||
flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"dupe_factor", 10,
|
||||
"Number of times to duplicate the input data (with different masks).")
|
||||
|
||||
flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
|
||||
|
||||
flags.DEFINE_float(
|
||||
"short_seq_prob", 0.1,
|
||||
"Probability of creating sequences which are shorter than the "
|
||||
"maximum length.")
|
||||
|
||||
|
||||
class TrainingInstance(object):
|
||||
"""A single training instance (sentence pair)."""
|
||||
|
||||
def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
|
||||
is_random_next):
|
||||
self.tokens = tokens
|
||||
self.segment_ids = segment_ids
|
||||
self.is_random_next = is_random_next
|
||||
self.masked_lm_positions = masked_lm_positions
|
||||
self.masked_lm_labels = masked_lm_labels
|
||||
|
||||
def __str__(self):
|
||||
s = ""
|
||||
s += "tokens: %s\n" % (" ".join(
|
||||
[tokenization.printable_text(x) for x in self.tokens]))
|
||||
s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
|
||||
s += "is_random_next: %s\n" % self.is_random_next
|
||||
s += "masked_lm_positions: %s\n" % (" ".join(
|
||||
[str(x) for x in self.masked_lm_positions]))
|
||||
s += "masked_lm_labels: %s\n" % (" ".join(
|
||||
[tokenization.printable_text(x) for x in self.masked_lm_labels]))
|
||||
s += "\n"
|
||||
return s
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def write_instance_to_example_files(instances, tokenizer, max_seq_length,
|
||||
max_predictions_per_seq, output_files):
|
||||
"""Create TF example files from `TrainingInstance`s."""
|
||||
writers = []
|
||||
for output_file in output_files:
|
||||
writers.append(tf.python_io.TFRecordWriter(output_file))
|
||||
|
||||
writer_index = 0
|
||||
|
||||
total_written = 0
|
||||
for (inst_index, instance) in enumerate(instances):
|
||||
input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
segment_ids = list(instance.segment_ids)
|
||||
assert len(input_ids) <= max_seq_length
|
||||
|
||||
while len(input_ids) < max_seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
|
||||
masked_lm_positions = list(instance.masked_lm_positions)
|
||||
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
|
||||
masked_lm_weights = [1.0] * len(masked_lm_ids)
|
||||
|
||||
while len(masked_lm_positions) < max_predictions_per_seq:
|
||||
masked_lm_positions.append(0)
|
||||
masked_lm_ids.append(0)
|
||||
masked_lm_weights.append(0.0)
|
||||
|
||||
next_sentence_label = 1 if instance.is_random_next else 0
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["input_ids"] = create_int_feature(input_ids)
|
||||
features["input_mask"] = create_int_feature(input_mask)
|
||||
features["segment_ids"] = create_int_feature(segment_ids)
|
||||
features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
|
||||
features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
|
||||
features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
|
||||
features["next_sentence_labels"] = create_int_feature([next_sentence_label])
|
||||
|
||||
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
||||
|
||||
writers[writer_index].write(tf_example.SerializeToString())
|
||||
writer_index = (writer_index + 1) % len(writers)
|
||||
|
||||
total_written += 1
|
||||
|
||||
if inst_index < 20:
|
||||
tf.logging.info("*** Example ***")
|
||||
tf.logging.info("tokens: %s" % " ".join(
|
||||
[tokenization.printable_text(x) for x in instance.tokens]))
|
||||
|
||||
for feature_name in features.keys():
|
||||
feature = features[feature_name]
|
||||
values = []
|
||||
if feature.int64_list.value:
|
||||
values = feature.int64_list.value
|
||||
elif feature.float_list.value:
|
||||
values = feature.float_list.value
|
||||
tf.logging.info(
|
||||
"%s: %s" % (feature_name, " ".join([str(x) for x in values])))
|
||||
|
||||
for writer in writers:
|
||||
writer.close()
|
||||
|
||||
tf.logging.info("Wrote %d total instances", total_written)
|
||||
|
||||
|
||||
def create_int_feature(values):
|
||||
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
||||
return feature
|
||||
|
||||
|
||||
def create_float_feature(values):
|
||||
feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
|
||||
return feature
|
||||
|
||||
|
||||
def create_training_instances(input_files, tokenizer, max_seq_length,
|
||||
dupe_factor, short_seq_prob, masked_lm_prob,
|
||||
max_predictions_per_seq, rng):
|
||||
"""Create `TrainingInstance`s from raw text."""
|
||||
all_documents = [[]]
|
||||
|
||||
# Input file format:
|
||||
# (1) One sentence per line. These should ideally be actual sentences, not
|
||||
# entire paragraphs or arbitrary spans of text. (Because we use the
|
||||
# sentence boundaries for the "next sentence prediction" task).
|
||||
# (2) Blank lines between documents. Document boundaries are needed so
|
||||
# that the "next sentence prediction" task doesn't span between documents.
|
||||
for input_file in input_files:
|
||||
with tf.gfile.GFile(input_file, "r") as reader:
|
||||
while True:
|
||||
line = tokenization.convert_to_unicode(reader.readline())
|
||||
if not line:
|
||||
break
|
||||
line = line.strip()
|
||||
|
||||
# Empty lines are used as document delimiters
|
||||
if not line:
|
||||
all_documents.append([])
|
||||
tokens = tokenizer.tokenize(line)
|
||||
if tokens:
|
||||
all_documents[-1].append(tokens)
|
||||
|
||||
# Remove empty documents
|
||||
all_documents = [x for x in all_documents if x]
|
||||
rng.shuffle(all_documents)
|
||||
|
||||
vocab_words = list(tokenizer.vocab.keys())
|
||||
instances = []
|
||||
for _ in range(dupe_factor):
|
||||
for document_index in range(len(all_documents)):
|
||||
instances.extend(
|
||||
create_instances_from_document(
|
||||
all_documents, document_index, max_seq_length, short_seq_prob,
|
||||
masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
|
||||
|
||||
rng.shuffle(instances)
|
||||
return instances
|
||||
|
||||
|
||||
def create_instances_from_document(
|
||||
all_documents, document_index, max_seq_length, short_seq_prob,
|
||||
masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
|
||||
"""Creates `TrainingInstance`s for a single document."""
|
||||
document = all_documents[document_index]
|
||||
|
||||
# Account for [CLS], [SEP], [SEP]
|
||||
max_num_tokens = max_seq_length - 3
|
||||
|
||||
# We *usually* want to fill up the entire sequence since we are padding
|
||||
# to `max_seq_length` anyways, so short sequences are generally wasted
|
||||
# computation. However, we *sometimes*
|
||||
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
|
||||
# sequences to minimize the mismatch between pre-training and fine-tuning.
|
||||
# The `target_seq_length` is just a rough target however, whereas
|
||||
# `max_seq_length` is a hard limit.
|
||||
target_seq_length = max_num_tokens
|
||||
if rng.random() < short_seq_prob:
|
||||
target_seq_length = rng.randint(2, max_num_tokens)
|
||||
|
||||
# We DON'T just concatenate all of the tokens from a document into a long
|
||||
# sequence and choose an arbitrary split point because this would make the
|
||||
# next sentence prediction task too easy. Instead, we split the input into
|
||||
# segments "A" and "B" based on the actual "sentences" provided by the user
|
||||
# input.
|
||||
instances = []
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
i = 0
|
||||
while i < len(document):
|
||||
segment = document[i]
|
||||
current_chunk.append(segment)
|
||||
current_length += len(segment)
|
||||
if i == len(document) - 1 or current_length >= target_seq_length:
|
||||
if current_chunk:
|
||||
# `a_end` is how many segments from `current_chunk` go into the `A`
|
||||
# (first) sentence.
|
||||
a_end = 1
|
||||
if len(current_chunk) >= 2:
|
||||
a_end = rng.randint(1, len(current_chunk) - 1)
|
||||
|
||||
tokens_a = []
|
||||
for j in range(a_end):
|
||||
tokens_a.extend(current_chunk[j])
|
||||
|
||||
tokens_b = []
|
||||
# Random next
|
||||
is_random_next = False
|
||||
if len(current_chunk) == 1 or rng.random() < 0.5:
|
||||
is_random_next = True
|
||||
target_b_length = target_seq_length - len(tokens_a)
|
||||
|
||||
# This should rarely go for more than one iteration for large
|
||||
# corpora. However, just to be careful, we try to make sure that
|
||||
# the random document is not the same as the document
|
||||
# we're processing.
|
||||
for _ in range(10):
|
||||
random_document_index = rng.randint(0, len(all_documents) - 1)
|
||||
if random_document_index != document_index:
|
||||
break
|
||||
|
||||
random_document = all_documents[random_document_index]
|
||||
random_start = rng.randint(0, len(random_document) - 1)
|
||||
for j in range(random_start, len(random_document)):
|
||||
tokens_b.extend(random_document[j])
|
||||
if len(tokens_b) >= target_b_length:
|
||||
break
|
||||
# We didn't actually use these segments so we "put them back" so
|
||||
# they don't go to waste.
|
||||
num_unused_segments = len(current_chunk) - a_end
|
||||
i -= num_unused_segments
|
||||
# Actual next
|
||||
else:
|
||||
is_random_next = False
|
||||
for j in range(a_end, len(current_chunk)):
|
||||
tokens_b.extend(current_chunk[j])
|
||||
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
|
||||
|
||||
assert len(tokens_a) >= 1
|
||||
assert len(tokens_b) >= 1
|
||||
|
||||
tokens = []
|
||||
segment_ids = []
|
||||
tokens.append("[CLS]")
|
||||
segment_ids.append(0)
|
||||
for token in tokens_a:
|
||||
tokens.append(token)
|
||||
segment_ids.append(0)
|
||||
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
|
||||
for token in tokens_b:
|
||||
tokens.append(token)
|
||||
segment_ids.append(1)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(1)
|
||||
|
||||
(tokens, masked_lm_positions,
|
||||
masked_lm_labels) = create_masked_lm_predictions(
|
||||
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
|
||||
instance = TrainingInstance(
|
||||
tokens=tokens,
|
||||
segment_ids=segment_ids,
|
||||
is_random_next=is_random_next,
|
||||
masked_lm_positions=masked_lm_positions,
|
||||
masked_lm_labels=masked_lm_labels)
|
||||
instances.append(instance)
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
i += 1
|
||||
|
||||
return instances
|
||||
|
||||
|
||||
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
|
||||
["index", "label"])
|
||||
|
||||
|
||||
def create_masked_lm_predictions(tokens, masked_lm_prob,
|
||||
max_predictions_per_seq, vocab_words, rng):
|
||||
"""Creates the predictions for the masked LM objective."""
|
||||
|
||||
cand_indexes = []
|
||||
for (i, token) in enumerate(tokens):
|
||||
if token == "[CLS]" or token == "[SEP]":
|
||||
continue
|
||||
cand_indexes.append(i)
|
||||
|
||||
rng.shuffle(cand_indexes)
|
||||
|
||||
output_tokens = list(tokens)
|
||||
|
||||
num_to_predict = min(max_predictions_per_seq,
|
||||
max(1, int(round(len(tokens) * masked_lm_prob))))
|
||||
|
||||
masked_lms = []
|
||||
covered_indexes = set()
|
||||
for index in cand_indexes:
|
||||
if len(masked_lms) >= num_to_predict:
|
||||
break
|
||||
if index in covered_indexes:
|
||||
continue
|
||||
covered_indexes.add(index)
|
||||
|
||||
masked_token = None
|
||||
# 80% of the time, replace with [MASK]
|
||||
if rng.random() < 0.8:
|
||||
masked_token = "[MASK]"
|
||||
else:
|
||||
# 10% of the time, keep original
|
||||
if rng.random() < 0.5:
|
||||
masked_token = tokens[index]
|
||||
# 10% of the time, replace with random word
|
||||
else:
|
||||
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
|
||||
|
||||
output_tokens[index] = masked_token
|
||||
|
||||
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
|
||||
|
||||
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
||||
|
||||
masked_lm_positions = []
|
||||
masked_lm_labels = []
|
||||
for p in masked_lms:
|
||||
masked_lm_positions.append(p.index)
|
||||
masked_lm_labels.append(p.label)
|
||||
|
||||
return (output_tokens, masked_lm_positions, masked_lm_labels)
|
||||
|
||||
|
||||
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
|
||||
"""Truncates a pair of sequences to a maximum sequence length."""
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_num_tokens:
|
||||
break
|
||||
|
||||
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
|
||||
assert len(trunc_tokens) >= 1
|
||||
|
||||
# We want to sometimes truncate from the front and sometimes from the
|
||||
# back to add more randomness and avoid biases.
|
||||
if rng.random() < 0.5:
|
||||
del trunc_tokens[0]
|
||||
else:
|
||||
trunc_tokens.pop()
|
||||
|
||||
|
||||
def main(_):
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
tokenizer = tokenization.FullTokenizer(
|
||||
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
||||
|
||||
input_files = []
|
||||
for input_pattern in FLAGS.input_file.split(","):
|
||||
input_files.extend(tf.gfile.Glob(input_pattern))
|
||||
|
||||
tf.logging.info("*** Reading from input files ***")
|
||||
for input_file in input_files:
|
||||
tf.logging.info(" %s", input_file)
|
||||
|
||||
rng = random.Random(FLAGS.random_seed)
|
||||
instances = create_training_instances(
|
||||
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
|
||||
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
|
||||
rng)
|
||||
|
||||
output_files = FLAGS.output_file.split(",")
|
||||
tf.logging.info("*** Writing to output files ***")
|
||||
for output_file in output_files:
|
||||
tf.logging.info(" %s", output_file)
|
||||
|
||||
write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
|
||||
FLAGS.max_predictions_per_seq, output_files)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
flags.mark_flag_as_required("input_file")
|
||||
flags.mark_flag_as_required("output_file")
|
||||
flags.mark_flag_as_required("vocab_file")
|
||||
tf.app.run()
|
||||
82
test/external/mlperf_bert/preprocessing/external_test_preprocessing_part.py
vendored
Normal file
82
test/external/mlperf_bert/preprocessing/external_test_preprocessing_part.py
vendored
Normal file
@@ -0,0 +1,82 @@
|
||||
# USAGE:
|
||||
# 1. Download raw text data with `wikipedia_download.py`
|
||||
|
||||
# 2. Install python==3.7.12 and tensorflow==1.15.5
|
||||
# Run `create_pretraining_data.py` to create TFRecords on specific part (This will take some time)
|
||||
# Command: python3 create_pretraining_data.py --input_file=/path/to/part-00XXX-of-00500 --vocab_file=/path/to/vocab.txt \
|
||||
# --output_file=/path/to/output.tfrecord --max_seq_length=512 --max_predictions_per_seq=76
|
||||
#
|
||||
# 2.1 For eval: --input_file=/path/to/eval.txt and
|
||||
# Command: python3 pick_eval_samples.py --input_tfrecord=/path/to/eval.tfrecord --output_tfrecord=/path/to/output_eval.tfrecord
|
||||
|
||||
# 3. Run `wikipedia.py` to preprocess the data with tinygrad (Use python > 3.7)
|
||||
# Command: BASEDIR=/path/to/basedir python3 wikipedia.py pre-train XXX (NOTE: part number needs to match part of step 2)
|
||||
# This will output to /path/to/basedir/train/XXX
|
||||
#
|
||||
# 3.1 For eval:
|
||||
# Command: BASEDIR=/path/to/basedir python3 wikipedia.py pre-eval
|
||||
# This will output to /path/to/basedir/eval
|
||||
|
||||
# 4. Run this script to verify the correctness of the preprocessing script for specific part
|
||||
# Command: python3 external_test_preprocessing_part.py --preprocessed_part_dir=/path/to/basedir/part --tf_records=/path/to/output.tfrecord
|
||||
import os, argparse, pickle
|
||||
from tqdm import tqdm
|
||||
|
||||
# This is a workaround for protobuf issue
|
||||
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
tf.compat.v1.enable_eager_execution()
|
||||
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
||||
|
||||
def _parse_function(proto, max_seq_length, max_predictions_per_seq):
|
||||
feature_description = {
|
||||
'input_ids': tf.io.FixedLenFeature([max_seq_length], tf.int64),
|
||||
'input_mask': tf.io.FixedLenFeature([max_seq_length], tf.int64),
|
||||
'segment_ids': tf.io.FixedLenFeature([max_seq_length], tf.int64),
|
||||
'masked_lm_positions': tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
|
||||
'masked_lm_ids': tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
|
||||
'masked_lm_weights': tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32),
|
||||
'next_sentence_labels': tf.io.FixedLenFeature([1], tf.int64),
|
||||
}
|
||||
return tf.io.parse_single_example(proto, feature_description)
|
||||
|
||||
def load_dataset(file_path, max_seq_length=512, max_predictions_per_seq=76):
|
||||
dataset = tf.data.TFRecordDataset(file_path)
|
||||
parse_function = lambda proto: _parse_function(proto, max_seq_length, max_predictions_per_seq) # noqa: E731
|
||||
return dataset.map(parse_function)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Verify the correctness of the preprocessing script for specific part",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("--preprocessed_part_dir", type=str, default=None,
|
||||
help="Path to dir with preprocessed samples from `wikipedia.py`")
|
||||
parser.add_argument("--tf_records", type=str, default=None,
|
||||
help="Path to TFRecords file from `create_pretraining_data.py` (Reference implementation)")
|
||||
parser.add_argument("--max_seq_length", type=int, default=512, help="Max sequence length. For MLPerf keep it as 512")
|
||||
parser.add_argument("--max_predictions_per_seq", type=int, default=76, help="Max predictions per sequence. For MLPerf keep it as 76")
|
||||
parser.add_argument("--is_eval", type=bool, default=False, help="Whether to run eval or train preprocessing")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert os.path.isdir(args.preprocessed_part_dir), f"The specified directory {args.preprocessed_part_dir} does not exist."
|
||||
assert os.path.isfile(args.tf_records), f"The specified TFRecords file {args.tf_records} does not exist."
|
||||
|
||||
preprocessed_samples = []
|
||||
for file_name in sorted(os.listdir(args.preprocessed_part_dir), key=lambda x: int(x.split("_")[1].split(".")[0]) if not args.is_eval else int(x.split(".")[0])): # 0_3.pkl -> 3 # noqa: E501
|
||||
with open(os.path.join(args.preprocessed_part_dir, file_name), 'rb') as f:
|
||||
samples = pickle.load(f)
|
||||
preprocessed_samples.extend(samples)
|
||||
|
||||
dataset = load_dataset(args.tf_records, args.max_seq_length, args.max_predictions_per_seq)
|
||||
tf_record_count = sum(1 for _ in dataset)
|
||||
assert tf_record_count == len(preprocessed_samples), f"Samples in reference: {tf_record_count} != Preprocessed samples: {len(preprocessed_samples)}"
|
||||
print(f"Total samples in the part: {tf_record_count}")
|
||||
|
||||
for i, (reference_example, preprocessed_sample) in tqdm(enumerate(zip(dataset, preprocessed_samples)), desc="Checking samples", total=len(preprocessed_samples)): # noqa: E501
|
||||
feature_keys = ["input_ids", "input_mask", "segment_ids", "masked_lm_positions", "masked_lm_ids", "masked_lm_weights", "next_sentence_labels"]
|
||||
for key in feature_keys:
|
||||
reference_example_feature = reference_example[key].numpy()
|
||||
assert (reference_example_feature == preprocessed_sample[key]).all(), \
|
||||
f"{key} are not equal at index {i}\nReference: {reference_example_feature}\nPreprocessed: {preprocessed_sample[key]}"
|
||||
127
test/external/mlperf_bert/preprocessing/pick_eval_samples.py
vendored
Normal file
127
test/external/mlperf_bert/preprocessing/pick_eval_samples.py
vendored
Normal file
@@ -0,0 +1,127 @@
|
||||
# https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/cleanup_scripts/pick_eval_samples.py
|
||||
# NOTE: This is a direct copy of the original script
|
||||
"""Script for picking certain number of sampels.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import logging
|
||||
import collections
|
||||
import tensorflow as tf
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Eval sample picker for BERT.")
|
||||
parser.add_argument(
|
||||
'--input_tfrecord',
|
||||
type=str,
|
||||
default='',
|
||||
help='Input tfrecord path')
|
||||
parser.add_argument(
|
||||
'--output_tfrecord',
|
||||
type=str,
|
||||
default='',
|
||||
help='Output tfrecord path')
|
||||
parser.add_argument(
|
||||
'--num_examples_to_pick',
|
||||
type=int,
|
||||
default=10000,
|
||||
help='Number of examples to pick')
|
||||
parser.add_argument(
|
||||
'--max_seq_length',
|
||||
type=int,
|
||||
default=512,
|
||||
help='The maximum number of tokens within a sequence.')
|
||||
parser.add_argument(
|
||||
'--max_predictions_per_seq',
|
||||
type=int,
|
||||
default=76,
|
||||
help='The maximum number of predictions within a sequence.')
|
||||
args = parser.parse_args()
|
||||
|
||||
max_seq_length = args.max_seq_length
|
||||
max_predictions_per_seq = args.max_predictions_per_seq
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
def decode_record(record):
|
||||
"""Decodes a record to a TensorFlow example."""
|
||||
name_to_features = {
|
||||
"input_ids":
|
||||
tf.FixedLenFeature([max_seq_length], tf.int64),
|
||||
"input_mask":
|
||||
tf.FixedLenFeature([max_seq_length], tf.int64),
|
||||
"segment_ids":
|
||||
tf.FixedLenFeature([max_seq_length], tf.int64),
|
||||
"masked_lm_positions":
|
||||
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
|
||||
"masked_lm_ids":
|
||||
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
|
||||
"masked_lm_weights":
|
||||
tf.FixedLenFeature([max_predictions_per_seq], tf.float32),
|
||||
"next_sentence_labels":
|
||||
tf.FixedLenFeature([1], tf.int64),
|
||||
}
|
||||
|
||||
example = tf.parse_single_example(record, name_to_features)
|
||||
|
||||
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
||||
# So cast all int64 to int32.
|
||||
for name in list(example.keys()):
|
||||
t = example[name]
|
||||
if t.dtype == tf.int64:
|
||||
t = tf.to_int32(t)
|
||||
example[name] = t
|
||||
|
||||
return example
|
||||
|
||||
|
||||
def create_int_feature(values):
|
||||
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
||||
return feature
|
||||
|
||||
|
||||
def create_float_feature(values):
|
||||
feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
|
||||
return feature
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tic = time.time()
|
||||
tf.enable_eager_execution()
|
||||
|
||||
d = tf.data.TFRecordDataset(args.input_tfrecord)
|
||||
num_examples = 0
|
||||
records = []
|
||||
for record in d:
|
||||
records.append(record)
|
||||
num_examples += 1
|
||||
|
||||
writer = tf.python_io.TFRecordWriter(args.output_tfrecord)
|
||||
i = 0
|
||||
pick_ratio = num_examples / args.num_examples_to_pick
|
||||
num_examples_picked = 0
|
||||
for i in range(args.num_examples_to_pick):
|
||||
example = decode_record(records[int(i * pick_ratio)])
|
||||
features = collections.OrderedDict()
|
||||
features["input_ids"] = create_int_feature(
|
||||
example["input_ids"].numpy())
|
||||
features["input_mask"] = create_int_feature(
|
||||
example["input_mask"].numpy())
|
||||
features["segment_ids"] = create_int_feature(
|
||||
example["segment_ids"].numpy())
|
||||
features["masked_lm_positions"] = create_int_feature(
|
||||
example["masked_lm_positions"].numpy())
|
||||
features["masked_lm_ids"] = create_int_feature(
|
||||
example["masked_lm_ids"].numpy())
|
||||
features["masked_lm_weights"] = create_float_feature(
|
||||
example["masked_lm_weights"].numpy())
|
||||
features["next_sentence_labels"] = create_int_feature(
|
||||
example["next_sentence_labels"].numpy())
|
||||
|
||||
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
||||
writer.write(tf_example.SerializeToString())
|
||||
num_examples_picked += 1
|
||||
|
||||
writer.close()
|
||||
toc = time.time()
|
||||
logging.info("Picked %d examples out of %d samples in %.2f sec",
|
||||
num_examples_picked, num_examples, toc - tic)
|
||||
415
test/external/mlperf_bert/preprocessing/tokenization.py
vendored
Normal file
415
test/external/mlperf_bert/preprocessing/tokenization.py
vendored
Normal file
@@ -0,0 +1,415 @@
|
||||
# https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/cleanup_scripts/tokenization.py
|
||||
# NOTE: This is a direct copy of the original script
|
||||
"""Tokenization classes."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
from absl import flags
|
||||
import six
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_bool(
|
||||
"preserve_unused_tokens", False,
|
||||
"If True, Wordpiece tokenization will not be applied to words in the vocab."
|
||||
)
|
||||
|
||||
_UNUSED_TOKEN_RE = re.compile("^\\[unused\\d+\\]$")
|
||||
|
||||
|
||||
def preserve_token(token, vocab):
|
||||
"""Returns True if the token should forgo tokenization and be preserved."""
|
||||
if not FLAGS.preserve_unused_tokens:
|
||||
return False
|
||||
if token not in vocab:
|
||||
return False
|
||||
return bool(_UNUSED_TOKEN_RE.search(token))
|
||||
|
||||
|
||||
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
|
||||
"""Checks whether the casing config is consistent with the checkpoint name."""
|
||||
|
||||
# The casing has to be passed in by the user and there is no explicit check
|
||||
# as to whether it matches the checkpoint. The casing information probably
|
||||
# should have been stored in the bert_config.json file, but it's not, so
|
||||
# we have to heuristically detect it to validate.
|
||||
|
||||
if not init_checkpoint:
|
||||
return
|
||||
|
||||
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
|
||||
if m is None:
|
||||
return
|
||||
|
||||
model_name = m.group(1)
|
||||
|
||||
lower_models = [
|
||||
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
|
||||
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
|
||||
]
|
||||
|
||||
cased_models = [
|
||||
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
|
||||
"multi_cased_L-12_H-768_A-12"
|
||||
]
|
||||
|
||||
is_bad_config = False
|
||||
if model_name in lower_models and not do_lower_case:
|
||||
is_bad_config = True
|
||||
actual_flag = "False"
|
||||
case_name = "lowercased"
|
||||
opposite_flag = "True"
|
||||
|
||||
if model_name in cased_models and do_lower_case:
|
||||
is_bad_config = True
|
||||
actual_flag = "True"
|
||||
case_name = "cased"
|
||||
opposite_flag = "False"
|
||||
|
||||
if is_bad_config:
|
||||
raise ValueError(
|
||||
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
|
||||
"However, `%s` seems to be a %s model, so you "
|
||||
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
|
||||
"how the model was pre-training. If this error is wrong, please "
|
||||
"just comment out this check." % (actual_flag, init_checkpoint,
|
||||
model_name, case_name, opposite_flag))
|
||||
|
||||
|
||||
def convert_to_unicode(text):
|
||||
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
elif six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text.decode("utf-8", "ignore")
|
||||
elif isinstance(text, unicode): # noqa: F821
|
||||
return text
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
else:
|
||||
raise ValueError("Not running on Python2 or Python 3?")
|
||||
|
||||
|
||||
def printable_text(text):
|
||||
"""Returns text encoded in a way suitable for print or `tf.logging`."""
|
||||
|
||||
# These functions want `str` for both Python2 and Python3, but in one case
|
||||
# it's a Unicode string and in the other it's a byte string.
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
elif six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, unicode): # noqa: F821
|
||||
return text.encode("utf-8")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
else:
|
||||
raise ValueError("Not running on Python2 or Python 3?")
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
with tf.gfile.GFile(vocab_file, "r") as reader:
|
||||
while True:
|
||||
token = convert_to_unicode(reader.readline())
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
if token not in vocab:
|
||||
vocab[token] = len(vocab)
|
||||
return vocab
|
||||
|
||||
|
||||
def convert_by_vocab(vocab, items):
|
||||
"""Converts a sequence of [tokens|ids] using the vocab."""
|
||||
output = []
|
||||
for item in items:
|
||||
output.append(vocab[item])
|
||||
return output
|
||||
|
||||
|
||||
def convert_tokens_to_ids(vocab, tokens):
|
||||
return convert_by_vocab(vocab, tokens)
|
||||
|
||||
|
||||
def convert_ids_to_tokens(inv_vocab, ids):
|
||||
return convert_by_vocab(inv_vocab, ids)
|
||||
|
||||
|
||||
def whitespace_tokenize(text):
|
||||
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
|
||||
class FullTokenizer(object):
|
||||
"""Runs end-to-end tokenziation."""
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True):
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
||||
self.basic_tokenizer = BasicTokenizer(
|
||||
do_lower_case=do_lower_case, vocab=self.vocab)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||
|
||||
def tokenize(self, text):
|
||||
split_tokens = []
|
||||
for token in self.basic_tokenizer.tokenize(text):
|
||||
if preserve_token(token, self.vocab):
|
||||
split_tokens.append(token)
|
||||
continue
|
||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||
split_tokens.append(sub_token)
|
||||
|
||||
return split_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return convert_by_vocab(self.vocab, tokens)
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
return convert_by_vocab(self.inv_vocab, ids)
|
||||
|
||||
|
||||
class BasicTokenizer(object):
|
||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||
|
||||
def __init__(self, do_lower_case=True, vocab=tuple()):
|
||||
"""Constructs a BasicTokenizer.
|
||||
|
||||
Args:
|
||||
do_lower_case: Whether to lower case the input.
|
||||
vocab: A container of tokens to not mutate during tokenization.
|
||||
"""
|
||||
self.do_lower_case = do_lower_case
|
||||
self.vocab = vocab
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text."""
|
||||
text = convert_to_unicode(text)
|
||||
text = self._clean_text(text)
|
||||
|
||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
# models. This is also applied to the English models now, but it doesn't
|
||||
# matter since the English models were not trained on any Chinese data
|
||||
# and generally don't have any Chinese data in them (there are Chinese
|
||||
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||
# words in the English Wikipedia.).
|
||||
text = self._tokenize_chinese_chars(text)
|
||||
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if preserve_token(token, self.vocab):
|
||||
split_tokens.append(token)
|
||||
continue
|
||||
if self.do_lower_case:
|
||||
token = token.lower()
|
||||
token = self._run_strip_accents(token)
|
||||
split_tokens.extend(self._run_split_on_punc(token))
|
||||
|
||||
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||
return output_tokens
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _run_split_on_punc(self, text):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
def _tokenize_chinese_chars(self, text):
|
||||
"""Adds whitespace around any CJK character."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if self._is_chinese_char(cp):
|
||||
output.append(" ")
|
||||
output.append(char)
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _is_chinese_char(self, cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like all of the other languages.
|
||||
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
||||
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
||||
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
||||
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
||||
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
||||
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
||||
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
||||
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
class WordpieceTokenizer(object):
|
||||
"""Runs WordPiece tokenziation."""
|
||||
|
||||
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.max_input_chars_per_word = max_input_chars_per_word
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text into its word pieces.
|
||||
|
||||
This uses a greedy longest-match-first algorithm to perform tokenization
|
||||
using the given vocabulary.
|
||||
|
||||
For example:
|
||||
input = "unaffable"
|
||||
output = ["un", "##aff", "##able"]
|
||||
|
||||
Args:
|
||||
text: A single token or whitespace separated tokens. This should have
|
||||
already been passed through `BasicTokenizer.
|
||||
|
||||
Returns:
|
||||
A list of wordpiece tokens.
|
||||
"""
|
||||
|
||||
text = convert_to_unicode(text)
|
||||
|
||||
output_tokens = []
|
||||
for token in whitespace_tokenize(text):
|
||||
chars = list(token)
|
||||
if len(chars) > self.max_input_chars_per_word:
|
||||
output_tokens.append(self.unk_token)
|
||||
continue
|
||||
|
||||
is_bad = False
|
||||
start = 0
|
||||
sub_tokens = []
|
||||
while start < len(chars):
|
||||
end = len(chars)
|
||||
cur_substr = None
|
||||
while start < end:
|
||||
substr = "".join(chars[start:end])
|
||||
if start > 0:
|
||||
substr = "##" + substr
|
||||
if substr in self.vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
end -= 1
|
||||
if cur_substr is None:
|
||||
is_bad = True
|
||||
break
|
||||
sub_tokens.append(cur_substr)
|
||||
start = end
|
||||
|
||||
if is_bad:
|
||||
output_tokens.append(self.unk_token)
|
||||
else:
|
||||
output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically control characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat in ("Cc", "Cf"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
||||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
||||
Reference in New Issue
Block a user