From 69341144ba4ccdf19335a563df22b5385fbec6ae Mon Sep 17 00:00:00 2001 From: Elias Wahl <82230675+Eliulm@users.noreply.github.com> Date: Tue, 23 Apr 2024 16:28:01 +0200 Subject: [PATCH] Wikipedia preprocessing script (#4229) * Preprocessing script * short seq prob * comments + env vars * Add preprocessing reference. Add test * lint fix + add eval test support * whitespaces * point to commit * comment * rename * better comments --- extra/datasets/wikipedia.py | 411 +++++++++++++++++ .../preprocessing/create_pretraining_data.py | 435 ++++++++++++++++++ .../external_test_preprocessing_part.py | 82 ++++ .../preprocessing/pick_eval_samples.py | 127 +++++ .../mlperf_bert/preprocessing/tokenization.py | 415 +++++++++++++++++ 5 files changed, 1470 insertions(+) create mode 100644 extra/datasets/wikipedia.py create mode 100644 test/external/mlperf_bert/preprocessing/create_pretraining_data.py create mode 100644 test/external/mlperf_bert/preprocessing/external_test_preprocessing_part.py create mode 100644 test/external/mlperf_bert/preprocessing/pick_eval_samples.py create mode 100644 test/external/mlperf_bert/preprocessing/tokenization.py diff --git a/extra/datasets/wikipedia.py b/extra/datasets/wikipedia.py new file mode 100644 index 0000000000..32e3513339 --- /dev/null +++ b/extra/datasets/wikipedia.py @@ -0,0 +1,411 @@ +# Preprocessing of downloaded text from Wikipedia for MLPerf BERT training +# This is a modified version of the original script: +# https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/cleanup_scripts/create_pretraining_data.py +# ENV VARS: +# MAX_SEQ_LENGTH - Maximum sequence length +# MAX_PREDICTIONS_PER_SEQ - Maximum number of masked LM predictions per sequence +# RANDOM_SEED - Random seed +# DUPE_FACTOR - Number of times to duplicate the input data with different masks +# MASKED_LM_PROB - Probability of masking a token +# SHORT_SEQ_PROB - Probability of picking a sequence shorter than MAX_SEQ_LENGTH + +import os, sys, pickle, random, unicodedata +from pathlib import Path +import numpy as np +from tqdm import tqdm +from tqdm.contrib.concurrent import process_map + +from tinygrad.helpers import diskcache, getenv + +BASEDIR = getenv('BASEDIR', Path(__file__).parent / "wiki") + +################### Tokenization ##################### + +def _is_whitespace(char:str) -> bool: + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + return unicodedata.category(char) == "Zs" + +def _is_control(char:str) -> bool: + if char == "\t" or char == "\n" or char == "\r": + return False + return unicodedata.category(char).startswith("C") + +def _is_punctuation(char:str) -> bool: + # range(33, 48) -> ! " # $ % & ' ( ) * + , - . / + # range(58, 65) -> : ; < = > ? @ + # range(91, 97) -> [ \ ] ^ _ + # range(123, 127) -> { | } ~ + if (cp := ord(char)) in range(33, 48) or cp in range(58, 65) or cp in range(91, 97) or cp in range(123, 127): + return True + return unicodedata.category(char).startswith("P") + +def _is_chinese_char(cp:int) -> bool: + if ((cp >= 0x4E00 and cp <= 0x9FFF) or + (cp >= 0x3400 and cp <= 0x4DBF) or + (cp >= 0x20000 and cp <= 0x2A6DF) or + (cp >= 0x2A700 and cp <= 0x2B73F) or + (cp >= 0x2B740 and cp <= 0x2B81F) or + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or + (cp >= 0x2F800 and cp <= 0x2FA1F)): + return True + return False + +def _run_split_on_punc(text:str) -> list[str]: + if text in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"): + return [text] + start_new_word = True + output = [] + for i in range(len(text)): + if _is_punctuation(char := text[i]): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + return ["".join(x) for x in output] + +def _run_strip_accents(text:str) -> str: + output = [] + for char in unicodedata.normalize("NFD", text): + if unicodedata.category(char) != "Mn": + output.append(char) + return "".join(output) + +def _clean_text(text:str) -> str: + output = [] + for char in text: + if not ((cp := ord(char)) == 0 or cp == 0xfffd or _is_control(char)): + output.append(" " if _is_whitespace(char) else char) + return "".join(output) + +def _tokenize_chinese_chars(text:str) -> str: + output = [] + for char in text: + cp = ord(char) + if _is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + +def whitespace_tokenize(text): + if not (text := text.strip()): return [] + return text.split() + +def _wordpiece_tokenize(text:str, vocab:dict[str, int]) -> list[str]: + text = text.decode("utf-8", "ignore") if isinstance(text, bytes) else text + output_tokens = [] + for token in text.strip().split(): + chars = list(token) + if len(chars) > 200: + output_tokens.append("[UNK]") + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: substr = "##" + substr + if substr in vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: output_tokens.append("[UNK]") + else: output_tokens.extend(sub_tokens) + return output_tokens + +class Tokenizer: + def __init__(self, vocab_file): + self.vocab = {} + with open(vocab_file) as f: + for line in f: + line = line.decode("utf-8", "ignore") if isinstance(line, bytes) else line + if (token := line.strip()) and token not in self.vocab: self.vocab[token] = len(self.vocab) + self.inv_vocab = {v: k for k, v in self.vocab.items()} + + def tokenize(self, text:str) -> list[str]: + # BasicTokenizer + split_tokens = [] + for token in whitespace_tokenize(_tokenize_chinese_chars(_clean_text(text.decode("utf-8", "ignore") if isinstance(text, bytes) else text))): + split_tokens.extend(_run_split_on_punc(_run_strip_accents(token.lower()))) + split_tokens = " ".join(split_tokens).strip().split() + # WordpieceTokenizer + tokens = [] + for token in split_tokens: + tokens.extend(_wordpiece_tokenize(token, self.vocab)) + return tokens + + def convert_tokens_to_ids(self, tokens:list[str]) -> list[int]: return [self.vocab[token] for token in tokens] + def convert_ids_to_tokens(self, ids:list[int]) -> list[str]: return [self.inv_vocab[id] for id in ids] + +##################### Feature transformation ##################### + +def truncate_seq_pair(tokens_a:list[str], tokens_b:list[str], max_num_tokens:int, rng:random.Random) -> None: + while True: + total_length = len(tokens_a) + len(tokens_b) + if total_length <= max_num_tokens: + break + + trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b + assert len(trunc_tokens) >= 1 + + if rng.random() < 0.5: + del trunc_tokens[0] + else: + trunc_tokens.pop() + +def create_masked_lm_predictions(tokens:list[str], tokenizer:Tokenizer, rng:random.Random, vocab_words:list[str]) -> tuple[list[str], list[int], list[str]]: + cand_indices = [] + for i, token in enumerate(tokens): + if token == "[CLS]" or token == "[SEP]": + continue + cand_indices.append(i) + + rng.shuffle(cand_indices) + output_tokens = list(tokens) + num_to_predict = min(getenv('MAX_PREDICTIONS_PER_SEQ', 76), max(1, int(round(len(tokens) * getenv("MASKED_LM_PROB", 0.15))))) + + masked_lms = [] + covered_indices = set() + for index in cand_indices: + if len(masked_lms) >= num_to_predict: + break + if index in covered_indices: + continue + covered_indices.add(index) + + masked_token = None + if rng.random() < 0.8: + masked_token = "[MASK]" + else: + if rng.random() < 0.5: + masked_token = tokens[index] + else: + masked_token = vocab_words[rng.randint(0, len(tokenizer.vocab) - 1)] + + output_tokens[index] = masked_token + masked_lms.append((index, tokens[index])) + masked_lms = sorted(masked_lms, key=lambda x: x[0]) + + masked_lm_positions = [] + masked_lm_labels = [] + for p in masked_lms: + masked_lm_positions.append(p[0]) + masked_lm_labels.append(p[1]) + + return output_tokens, masked_lm_positions, masked_lm_labels + +def create_instances_from_document(rng:random.Random, tokenizer:Tokenizer, doc:list[str], di:int, documents:list[list[str]]) -> list[dict]: + max_num_tokens = getenv('MAX_SEQ_LENGTH', 512) - 3 # [CLS] + 2 * [SEP] + + target_seq_length = max_num_tokens + if rng.random() < getenv("SHORT_SEQ_PROB", 0.1): + target_seq_length = rng.randint(2, max_num_tokens) + + instances = [] + current_chunk = [] + current_length = 0 + i = 0 + while i < len(doc): + segment = doc[i] + current_chunk.append(segment) + current_length += len(segment) + if i == len(doc) - 1 or current_length >= target_seq_length: + if current_chunk: + a_end = 1 + if len(current_chunk) >= 2: + a_end = rng.randint(1, len(current_chunk) - 1) + + tokens_a = [] + for j in range(a_end): + tokens_a.extend(current_chunk[j]) + + tokens_b = [] + is_random_next = False + if len(current_chunk) == 1 or rng.random() < 0.5: + is_random_next = True + target_b_length = target_seq_length - len(tokens_a) + + for _ in range(10): + random_document_index = rng.randint(0, len(documents) - 1) + if random_document_index != di: + break + + random_document = documents[random_document_index] + random_start = rng.randint(0, len(random_document) - 1) + for j in range(random_start, len(random_document)): + tokens_b.extend(random_document[j]) + if len(tokens_b) >= target_b_length: + break + + num_unused_segments = len(current_chunk) - a_end + i -= num_unused_segments + else: + is_random_next = False + for j in range(a_end, len(current_chunk)): + tokens_b.extend(current_chunk[j]) + truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng) + + assert len(tokens_a) >= 1 + assert len(tokens_b) >= 1 + + tokens = [] + segment_ids = [] + tokens.append("[CLS]") + segment_ids.append(0) + for token in tokens_a: + tokens.append(token) + segment_ids.append(0) + tokens.append("[SEP]") + segment_ids.append(0) + for token in tokens_b: + tokens.append(token) + segment_ids.append(1) + tokens.append("[SEP]") + segment_ids.append(1) + + tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions(tokens, tokenizer, rng, list(tokenizer.vocab.keys())) + instances.append({ + "tokens": tokens, + "segment_ids": segment_ids, + "masked_lm_positions": masked_lm_positions, + "masked_lm_labels": masked_lm_labels, + "is_random_next": is_random_next + }) + current_chunk = [] + current_length = 0 + i += 1 + return instances + +def get_documents(rng:random.Random, tokenizer:Tokenizer, fn:str) -> list[list[str]]: + documents = [[]] + with open(BASEDIR / fn) as f: + for line in f.readlines(): + if not (line := line.decode("utf-8", "ignore") if isinstance(line, bytes) else line): break + if not (line := line.strip()): documents.append([]) + if (tokens := tokenizer.tokenize(line)): documents[-1].append(tokens) + documents = [x for x in documents if x] + rng.shuffle(documents) + return documents + +def get_instances(rng:random.Random, tokenizer:Tokenizer, documents:list[list[str]]) -> list[dict]: + instances = [] + for _ in range(getenv('DUPE_FACTOR', 10)): + for di, doc in enumerate(documents): + instances.extend(create_instances_from_document(rng, tokenizer, doc, di, documents)) + rng.shuffle(instances) + return instances + +def instance_to_features(instance:dict, tokenizer:Tokenizer) -> dict: + input_ids = tokenizer.convert_tokens_to_ids(instance["tokens"]) + input_mask = [1] * len(input_ids) + segment_ids = instance["segment_ids"] + + max_seq_length = getenv('MAX_SEQ_LENGTH', 512) + + assert len(input_ids) <= max_seq_length + while len(input_ids) < max_seq_length: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + assert len(input_ids) == max_seq_length + assert len(input_mask) == max_seq_length + assert len(segment_ids) == max_seq_length + + masked_lm_positions = instance["masked_lm_positions"] + masked_lm_ids = tokenizer.convert_tokens_to_ids(instance["masked_lm_labels"]) + masked_lm_weights = [1.0] * len(masked_lm_ids) + + while len(masked_lm_positions) < getenv("MAX_PREDICTIONS_PER_SEQ", 76): + masked_lm_positions.append(0) + masked_lm_ids.append(0) + masked_lm_weights.append(0.0) + + next_sentence_label = 1 if instance["is_random_next"] else 0 + + return { + "input_ids": np.expand_dims(np.array(input_ids, dtype=np.int32), 0), + "input_mask": np.expand_dims(np.array(input_mask, dtype=np.int32), 0), + "segment_ids": np.expand_dims(np.array(segment_ids, dtype=np.int32), 0), + "masked_lm_positions": np.expand_dims(np.array(masked_lm_positions, dtype=np.int32), 0), + "masked_lm_ids": np.expand_dims(np.array(masked_lm_ids, dtype=np.int32), 0), + "masked_lm_weights": np.expand_dims(np.array(masked_lm_weights, dtype=np.float32), 0), + "next_sentence_labels": np.expand_dims(np.array([next_sentence_label], dtype=np.int32), 0), + } + +def process_part(part:int): + tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt") + os.makedirs(BASEDIR / "train" / str(part), exist_ok=True) + for i, feature_batch in enumerate(process_iterate(tokenizer, val=False, part=part)): + with open(BASEDIR / f"train/{str(part)}/{part}_{i}.pkl", "wb") as f: + pickle.dump(feature_batch, f) + +def process_iterate(tokenizer:Tokenizer, val:bool=False, part:int=0) -> list[dict]: # Convert raw text to masked NSP samples + rng = random.Random(getenv('RANDOM_SEED', 12345)) + + if val: + tqdm.write("Getting samples from dataset") + documents = get_documents(rng, tokenizer, "results4/eval.txt") + instances = get_instances(rng, tokenizer, documents) + + tqdm.write(f"There are {len(instances)} samples in the dataset") + tqdm.write(f"Picking 10000 samples") + + pick_ratio = len(instances) / 10000 + picks = [instance_to_features(instances[int(inst*pick_ratio)], tokenizer) for inst in range(10000)] + for batch in range(10): + yield picks[batch*1000:(batch+1)*1000] + else: + documents = get_documents(rng, tokenizer, f"results4/part-{part:05d}-of-00500") + instances = get_instances(rng, tokenizer, documents) + + while len(instances) > 0: + batch_size = min(1000, len(instances)) # We batch 1000 samples to one file + batch = instances[:batch_size] + del instances[:batch_size] + yield [instance_to_features(instance, tokenizer) for instance in batch] + +##################### Load files ##################### + +def get_wiki_val_files(): return sorted(list((BASEDIR / "eval/").glob("*.pkl"))) + +@diskcache +def get_wiki_train_files(): return sorted(list((BASEDIR / "train/").glob("*/*.pkl"))) + +if __name__ == "__main__": + tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt") + + assert len(sys.argv) > 1, "Usage: python wikipedia.py pre-eval|pre-train [part]|all" + + if sys.argv[1] == "pre-eval": # Generate 10000 eval samples + os.makedirs(BASEDIR / "eval", exist_ok=True) + + for i, feature_batch in tqdm(enumerate(process_iterate(tokenizer, val=True)), total=10): + with open(BASEDIR / f"eval/{i}.pkl", "wb") as f: + pickle.dump(feature_batch, f) + elif sys.argv[1] == "pre-train": + os.makedirs(BASEDIR / "train", exist_ok=True) + if sys.argv[2] == "all": # Use all 500 parts for training generation + process_map(process_part, [part for part in range(500)], max_workers=getenv('NUM_WORKERS', os.cpu_count()), chunksize=1) + else: # Use a specific part for training generation + part = int(sys.argv[2]) + os.makedirs(BASEDIR / "train" / str(part), exist_ok=True) + for i, feature_batch in tqdm(enumerate(process_iterate(tokenizer, val=False, part=part))): + with open(BASEDIR / f"train/{str(part)}/{part}_{i}.pkl", "wb") as f: + pickle.dump(feature_batch, f) diff --git a/test/external/mlperf_bert/preprocessing/create_pretraining_data.py b/test/external/mlperf_bert/preprocessing/create_pretraining_data.py new file mode 100644 index 0000000000..54041dfdd4 --- /dev/null +++ b/test/external/mlperf_bert/preprocessing/create_pretraining_data.py @@ -0,0 +1,435 @@ +# https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/cleanup_scripts/create_pretraining_data.py +# NOTE: This is a direct copy of the original script +# NOTE: With python 3.7.12, pip install tensorflow=1.15.5 +"""Create masked LM/next sentence masked_lm TF examples for BERT.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' # NOTE: This is a workaround for protobuf issue + +import collections +import random +import tokenization +import tensorflow as tf + + +flags = tf.flags + +FLAGS = flags.FLAGS + +flags.DEFINE_string("input_file", None, + "Input raw text file (or comma-separated list of files).") + +flags.DEFINE_string( + "output_file", None, + "Output TF example file (or comma-separated list of files).") + +flags.DEFINE_string("vocab_file", None, + "The vocabulary file that the BERT model was trained on.") + +flags.DEFINE_bool( + "do_lower_case", True, + "Whether to lower case the input text. Should be True for uncased " + "models and False for cased models.") + +flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.") + +flags.DEFINE_integer("max_predictions_per_seq", 20, + "Maximum number of masked LM predictions per sequence.") + +flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.") + +flags.DEFINE_integer( + "dupe_factor", 10, + "Number of times to duplicate the input data (with different masks).") + +flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.") + +flags.DEFINE_float( + "short_seq_prob", 0.1, + "Probability of creating sequences which are shorter than the " + "maximum length.") + + +class TrainingInstance(object): + """A single training instance (sentence pair).""" + + def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, + is_random_next): + self.tokens = tokens + self.segment_ids = segment_ids + self.is_random_next = is_random_next + self.masked_lm_positions = masked_lm_positions + self.masked_lm_labels = masked_lm_labels + + def __str__(self): + s = "" + s += "tokens: %s\n" % (" ".join( + [tokenization.printable_text(x) for x in self.tokens])) + s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids])) + s += "is_random_next: %s\n" % self.is_random_next + s += "masked_lm_positions: %s\n" % (" ".join( + [str(x) for x in self.masked_lm_positions])) + s += "masked_lm_labels: %s\n" % (" ".join( + [tokenization.printable_text(x) for x in self.masked_lm_labels])) + s += "\n" + return s + + def __repr__(self): + return self.__str__() + + +def write_instance_to_example_files(instances, tokenizer, max_seq_length, + max_predictions_per_seq, output_files): + """Create TF example files from `TrainingInstance`s.""" + writers = [] + for output_file in output_files: + writers.append(tf.python_io.TFRecordWriter(output_file)) + + writer_index = 0 + + total_written = 0 + for (inst_index, instance) in enumerate(instances): + input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) + input_mask = [1] * len(input_ids) + segment_ids = list(instance.segment_ids) + assert len(input_ids) <= max_seq_length + + while len(input_ids) < max_seq_length: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + + assert len(input_ids) == max_seq_length + assert len(input_mask) == max_seq_length + assert len(segment_ids) == max_seq_length + + masked_lm_positions = list(instance.masked_lm_positions) + masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) + masked_lm_weights = [1.0] * len(masked_lm_ids) + + while len(masked_lm_positions) < max_predictions_per_seq: + masked_lm_positions.append(0) + masked_lm_ids.append(0) + masked_lm_weights.append(0.0) + + next_sentence_label = 1 if instance.is_random_next else 0 + + features = collections.OrderedDict() + features["input_ids"] = create_int_feature(input_ids) + features["input_mask"] = create_int_feature(input_mask) + features["segment_ids"] = create_int_feature(segment_ids) + features["masked_lm_positions"] = create_int_feature(masked_lm_positions) + features["masked_lm_ids"] = create_int_feature(masked_lm_ids) + features["masked_lm_weights"] = create_float_feature(masked_lm_weights) + features["next_sentence_labels"] = create_int_feature([next_sentence_label]) + + tf_example = tf.train.Example(features=tf.train.Features(feature=features)) + + writers[writer_index].write(tf_example.SerializeToString()) + writer_index = (writer_index + 1) % len(writers) + + total_written += 1 + + if inst_index < 20: + tf.logging.info("*** Example ***") + tf.logging.info("tokens: %s" % " ".join( + [tokenization.printable_text(x) for x in instance.tokens])) + + for feature_name in features.keys(): + feature = features[feature_name] + values = [] + if feature.int64_list.value: + values = feature.int64_list.value + elif feature.float_list.value: + values = feature.float_list.value + tf.logging.info( + "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) + + for writer in writers: + writer.close() + + tf.logging.info("Wrote %d total instances", total_written) + + +def create_int_feature(values): + feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) + return feature + + +def create_float_feature(values): + feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) + return feature + + +def create_training_instances(input_files, tokenizer, max_seq_length, + dupe_factor, short_seq_prob, masked_lm_prob, + max_predictions_per_seq, rng): + """Create `TrainingInstance`s from raw text.""" + all_documents = [[]] + + # Input file format: + # (1) One sentence per line. These should ideally be actual sentences, not + # entire paragraphs or arbitrary spans of text. (Because we use the + # sentence boundaries for the "next sentence prediction" task). + # (2) Blank lines between documents. Document boundaries are needed so + # that the "next sentence prediction" task doesn't span between documents. + for input_file in input_files: + with tf.gfile.GFile(input_file, "r") as reader: + while True: + line = tokenization.convert_to_unicode(reader.readline()) + if not line: + break + line = line.strip() + + # Empty lines are used as document delimiters + if not line: + all_documents.append([]) + tokens = tokenizer.tokenize(line) + if tokens: + all_documents[-1].append(tokens) + + # Remove empty documents + all_documents = [x for x in all_documents if x] + rng.shuffle(all_documents) + + vocab_words = list(tokenizer.vocab.keys()) + instances = [] + for _ in range(dupe_factor): + for document_index in range(len(all_documents)): + instances.extend( + create_instances_from_document( + all_documents, document_index, max_seq_length, short_seq_prob, + masked_lm_prob, max_predictions_per_seq, vocab_words, rng)) + + rng.shuffle(instances) + return instances + + +def create_instances_from_document( + all_documents, document_index, max_seq_length, short_seq_prob, + masked_lm_prob, max_predictions_per_seq, vocab_words, rng): + """Creates `TrainingInstance`s for a single document.""" + document = all_documents[document_index] + + # Account for [CLS], [SEP], [SEP] + max_num_tokens = max_seq_length - 3 + + # We *usually* want to fill up the entire sequence since we are padding + # to `max_seq_length` anyways, so short sequences are generally wasted + # computation. However, we *sometimes* + # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter + # sequences to minimize the mismatch between pre-training and fine-tuning. + # The `target_seq_length` is just a rough target however, whereas + # `max_seq_length` is a hard limit. + target_seq_length = max_num_tokens + if rng.random() < short_seq_prob: + target_seq_length = rng.randint(2, max_num_tokens) + + # We DON'T just concatenate all of the tokens from a document into a long + # sequence and choose an arbitrary split point because this would make the + # next sentence prediction task too easy. Instead, we split the input into + # segments "A" and "B" based on the actual "sentences" provided by the user + # input. + instances = [] + current_chunk = [] + current_length = 0 + i = 0 + while i < len(document): + segment = document[i] + current_chunk.append(segment) + current_length += len(segment) + if i == len(document) - 1 or current_length >= target_seq_length: + if current_chunk: + # `a_end` is how many segments from `current_chunk` go into the `A` + # (first) sentence. + a_end = 1 + if len(current_chunk) >= 2: + a_end = rng.randint(1, len(current_chunk) - 1) + + tokens_a = [] + for j in range(a_end): + tokens_a.extend(current_chunk[j]) + + tokens_b = [] + # Random next + is_random_next = False + if len(current_chunk) == 1 or rng.random() < 0.5: + is_random_next = True + target_b_length = target_seq_length - len(tokens_a) + + # This should rarely go for more than one iteration for large + # corpora. However, just to be careful, we try to make sure that + # the random document is not the same as the document + # we're processing. + for _ in range(10): + random_document_index = rng.randint(0, len(all_documents) - 1) + if random_document_index != document_index: + break + + random_document = all_documents[random_document_index] + random_start = rng.randint(0, len(random_document) - 1) + for j in range(random_start, len(random_document)): + tokens_b.extend(random_document[j]) + if len(tokens_b) >= target_b_length: + break + # We didn't actually use these segments so we "put them back" so + # they don't go to waste. + num_unused_segments = len(current_chunk) - a_end + i -= num_unused_segments + # Actual next + else: + is_random_next = False + for j in range(a_end, len(current_chunk)): + tokens_b.extend(current_chunk[j]) + truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng) + + assert len(tokens_a) >= 1 + assert len(tokens_b) >= 1 + + tokens = [] + segment_ids = [] + tokens.append("[CLS]") + segment_ids.append(0) + for token in tokens_a: + tokens.append(token) + segment_ids.append(0) + + tokens.append("[SEP]") + segment_ids.append(0) + + for token in tokens_b: + tokens.append(token) + segment_ids.append(1) + tokens.append("[SEP]") + segment_ids.append(1) + + (tokens, masked_lm_positions, + masked_lm_labels) = create_masked_lm_predictions( + tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng) + instance = TrainingInstance( + tokens=tokens, + segment_ids=segment_ids, + is_random_next=is_random_next, + masked_lm_positions=masked_lm_positions, + masked_lm_labels=masked_lm_labels) + instances.append(instance) + current_chunk = [] + current_length = 0 + i += 1 + + return instances + + +MaskedLmInstance = collections.namedtuple("MaskedLmInstance", + ["index", "label"]) + + +def create_masked_lm_predictions(tokens, masked_lm_prob, + max_predictions_per_seq, vocab_words, rng): + """Creates the predictions for the masked LM objective.""" + + cand_indexes = [] + for (i, token) in enumerate(tokens): + if token == "[CLS]" or token == "[SEP]": + continue + cand_indexes.append(i) + + rng.shuffle(cand_indexes) + + output_tokens = list(tokens) + + num_to_predict = min(max_predictions_per_seq, + max(1, int(round(len(tokens) * masked_lm_prob)))) + + masked_lms = [] + covered_indexes = set() + for index in cand_indexes: + if len(masked_lms) >= num_to_predict: + break + if index in covered_indexes: + continue + covered_indexes.add(index) + + masked_token = None + # 80% of the time, replace with [MASK] + if rng.random() < 0.8: + masked_token = "[MASK]" + else: + # 10% of the time, keep original + if rng.random() < 0.5: + masked_token = tokens[index] + # 10% of the time, replace with random word + else: + masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] + + output_tokens[index] = masked_token + + masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) + + masked_lms = sorted(masked_lms, key=lambda x: x.index) + + masked_lm_positions = [] + masked_lm_labels = [] + for p in masked_lms: + masked_lm_positions.append(p.index) + masked_lm_labels.append(p.label) + + return (output_tokens, masked_lm_positions, masked_lm_labels) + + +def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng): + """Truncates a pair of sequences to a maximum sequence length.""" + while True: + total_length = len(tokens_a) + len(tokens_b) + if total_length <= max_num_tokens: + break + + trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b + assert len(trunc_tokens) >= 1 + + # We want to sometimes truncate from the front and sometimes from the + # back to add more randomness and avoid biases. + if rng.random() < 0.5: + del trunc_tokens[0] + else: + trunc_tokens.pop() + + +def main(_): + tf.logging.set_verbosity(tf.logging.INFO) + + tokenizer = tokenization.FullTokenizer( + vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) + + input_files = [] + for input_pattern in FLAGS.input_file.split(","): + input_files.extend(tf.gfile.Glob(input_pattern)) + + tf.logging.info("*** Reading from input files ***") + for input_file in input_files: + tf.logging.info(" %s", input_file) + + rng = random.Random(FLAGS.random_seed) + instances = create_training_instances( + input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, + FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, + rng) + + output_files = FLAGS.output_file.split(",") + tf.logging.info("*** Writing to output files ***") + for output_file in output_files: + tf.logging.info(" %s", output_file) + + write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length, + FLAGS.max_predictions_per_seq, output_files) + + +if __name__ == "__main__": + flags.mark_flag_as_required("input_file") + flags.mark_flag_as_required("output_file") + flags.mark_flag_as_required("vocab_file") + tf.app.run() diff --git a/test/external/mlperf_bert/preprocessing/external_test_preprocessing_part.py b/test/external/mlperf_bert/preprocessing/external_test_preprocessing_part.py new file mode 100644 index 0000000000..2869543fcb --- /dev/null +++ b/test/external/mlperf_bert/preprocessing/external_test_preprocessing_part.py @@ -0,0 +1,82 @@ +# USAGE: +# 1. Download raw text data with `wikipedia_download.py` + +# 2. Install python==3.7.12 and tensorflow==1.15.5 +# Run `create_pretraining_data.py` to create TFRecords on specific part (This will take some time) +# Command: python3 create_pretraining_data.py --input_file=/path/to/part-00XXX-of-00500 --vocab_file=/path/to/vocab.txt \ +# --output_file=/path/to/output.tfrecord --max_seq_length=512 --max_predictions_per_seq=76 +# +# 2.1 For eval: --input_file=/path/to/eval.txt and +# Command: python3 pick_eval_samples.py --input_tfrecord=/path/to/eval.tfrecord --output_tfrecord=/path/to/output_eval.tfrecord + +# 3. Run `wikipedia.py` to preprocess the data with tinygrad (Use python > 3.7) +# Command: BASEDIR=/path/to/basedir python3 wikipedia.py pre-train XXX (NOTE: part number needs to match part of step 2) +# This will output to /path/to/basedir/train/XXX +# +# 3.1 For eval: +# Command: BASEDIR=/path/to/basedir python3 wikipedia.py pre-eval +# This will output to /path/to/basedir/eval + +# 4. Run this script to verify the correctness of the preprocessing script for specific part +# Command: python3 external_test_preprocessing_part.py --preprocessed_part_dir=/path/to/basedir/part --tf_records=/path/to/output.tfrecord +import os, argparse, pickle +from tqdm import tqdm + +# This is a workaround for protobuf issue +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + +import tensorflow as tf + +tf.compat.v1.enable_eager_execution() +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + +def _parse_function(proto, max_seq_length, max_predictions_per_seq): + feature_description = { + 'input_ids': tf.io.FixedLenFeature([max_seq_length], tf.int64), + 'input_mask': tf.io.FixedLenFeature([max_seq_length], tf.int64), + 'segment_ids': tf.io.FixedLenFeature([max_seq_length], tf.int64), + 'masked_lm_positions': tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64), + 'masked_lm_ids': tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64), + 'masked_lm_weights': tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32), + 'next_sentence_labels': tf.io.FixedLenFeature([1], tf.int64), + } + return tf.io.parse_single_example(proto, feature_description) + +def load_dataset(file_path, max_seq_length=512, max_predictions_per_seq=76): + dataset = tf.data.TFRecordDataset(file_path) + parse_function = lambda proto: _parse_function(proto, max_seq_length, max_predictions_per_seq) # noqa: E731 + return dataset.map(parse_function) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Verify the correctness of the preprocessing script for specific part", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--preprocessed_part_dir", type=str, default=None, + help="Path to dir with preprocessed samples from `wikipedia.py`") + parser.add_argument("--tf_records", type=str, default=None, + help="Path to TFRecords file from `create_pretraining_data.py` (Reference implementation)") + parser.add_argument("--max_seq_length", type=int, default=512, help="Max sequence length. For MLPerf keep it as 512") + parser.add_argument("--max_predictions_per_seq", type=int, default=76, help="Max predictions per sequence. For MLPerf keep it as 76") + parser.add_argument("--is_eval", type=bool, default=False, help="Whether to run eval or train preprocessing") + args = parser.parse_args() + + assert os.path.isdir(args.preprocessed_part_dir), f"The specified directory {args.preprocessed_part_dir} does not exist." + assert os.path.isfile(args.tf_records), f"The specified TFRecords file {args.tf_records} does not exist." + + preprocessed_samples = [] + for file_name in sorted(os.listdir(args.preprocessed_part_dir), key=lambda x: int(x.split("_")[1].split(".")[0]) if not args.is_eval else int(x.split(".")[0])): # 0_3.pkl -> 3 # noqa: E501 + with open(os.path.join(args.preprocessed_part_dir, file_name), 'rb') as f: + samples = pickle.load(f) + preprocessed_samples.extend(samples) + + dataset = load_dataset(args.tf_records, args.max_seq_length, args.max_predictions_per_seq) + tf_record_count = sum(1 for _ in dataset) + assert tf_record_count == len(preprocessed_samples), f"Samples in reference: {tf_record_count} != Preprocessed samples: {len(preprocessed_samples)}" + print(f"Total samples in the part: {tf_record_count}") + + for i, (reference_example, preprocessed_sample) in tqdm(enumerate(zip(dataset, preprocessed_samples)), desc="Checking samples", total=len(preprocessed_samples)): # noqa: E501 + feature_keys = ["input_ids", "input_mask", "segment_ids", "masked_lm_positions", "masked_lm_ids", "masked_lm_weights", "next_sentence_labels"] + for key in feature_keys: + reference_example_feature = reference_example[key].numpy() + assert (reference_example_feature == preprocessed_sample[key]).all(), \ + f"{key} are not equal at index {i}\nReference: {reference_example_feature}\nPreprocessed: {preprocessed_sample[key]}" diff --git a/test/external/mlperf_bert/preprocessing/pick_eval_samples.py b/test/external/mlperf_bert/preprocessing/pick_eval_samples.py new file mode 100644 index 0000000000..d9f0b64405 --- /dev/null +++ b/test/external/mlperf_bert/preprocessing/pick_eval_samples.py @@ -0,0 +1,127 @@ +# https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/cleanup_scripts/pick_eval_samples.py +# NOTE: This is a direct copy of the original script +"""Script for picking certain number of sampels. +""" + +import argparse +import time +import logging +import collections +import tensorflow as tf + +parser = argparse.ArgumentParser( + description="Eval sample picker for BERT.") +parser.add_argument( + '--input_tfrecord', + type=str, + default='', + help='Input tfrecord path') +parser.add_argument( + '--output_tfrecord', + type=str, + default='', + help='Output tfrecord path') +parser.add_argument( + '--num_examples_to_pick', + type=int, + default=10000, + help='Number of examples to pick') +parser.add_argument( + '--max_seq_length', + type=int, + default=512, + help='The maximum number of tokens within a sequence.') +parser.add_argument( + '--max_predictions_per_seq', + type=int, + default=76, + help='The maximum number of predictions within a sequence.') +args = parser.parse_args() + +max_seq_length = args.max_seq_length +max_predictions_per_seq = args.max_predictions_per_seq +logging.basicConfig(level=logging.INFO) + +def decode_record(record): + """Decodes a record to a TensorFlow example.""" + name_to_features = { + "input_ids": + tf.FixedLenFeature([max_seq_length], tf.int64), + "input_mask": + tf.FixedLenFeature([max_seq_length], tf.int64), + "segment_ids": + tf.FixedLenFeature([max_seq_length], tf.int64), + "masked_lm_positions": + tf.FixedLenFeature([max_predictions_per_seq], tf.int64), + "masked_lm_ids": + tf.FixedLenFeature([max_predictions_per_seq], tf.int64), + "masked_lm_weights": + tf.FixedLenFeature([max_predictions_per_seq], tf.float32), + "next_sentence_labels": + tf.FixedLenFeature([1], tf.int64), + } + + example = tf.parse_single_example(record, name_to_features) + + # tf.Example only supports tf.int64, but the TPU only supports tf.int32. + # So cast all int64 to int32. + for name in list(example.keys()): + t = example[name] + if t.dtype == tf.int64: + t = tf.to_int32(t) + example[name] = t + + return example + + +def create_int_feature(values): + feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) + return feature + + +def create_float_feature(values): + feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) + return feature + + +if __name__ == '__main__': + tic = time.time() + tf.enable_eager_execution() + + d = tf.data.TFRecordDataset(args.input_tfrecord) + num_examples = 0 + records = [] + for record in d: + records.append(record) + num_examples += 1 + + writer = tf.python_io.TFRecordWriter(args.output_tfrecord) + i = 0 + pick_ratio = num_examples / args.num_examples_to_pick + num_examples_picked = 0 + for i in range(args.num_examples_to_pick): + example = decode_record(records[int(i * pick_ratio)]) + features = collections.OrderedDict() + features["input_ids"] = create_int_feature( + example["input_ids"].numpy()) + features["input_mask"] = create_int_feature( + example["input_mask"].numpy()) + features["segment_ids"] = create_int_feature( + example["segment_ids"].numpy()) + features["masked_lm_positions"] = create_int_feature( + example["masked_lm_positions"].numpy()) + features["masked_lm_ids"] = create_int_feature( + example["masked_lm_ids"].numpy()) + features["masked_lm_weights"] = create_float_feature( + example["masked_lm_weights"].numpy()) + features["next_sentence_labels"] = create_int_feature( + example["next_sentence_labels"].numpy()) + + tf_example = tf.train.Example(features=tf.train.Features(feature=features)) + writer.write(tf_example.SerializeToString()) + num_examples_picked += 1 + + writer.close() + toc = time.time() + logging.info("Picked %d examples out of %d samples in %.2f sec", + num_examples_picked, num_examples, toc - tic) diff --git a/test/external/mlperf_bert/preprocessing/tokenization.py b/test/external/mlperf_bert/preprocessing/tokenization.py new file mode 100644 index 0000000000..2d7fc8f159 --- /dev/null +++ b/test/external/mlperf_bert/preprocessing/tokenization.py @@ -0,0 +1,415 @@ +# https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/cleanup_scripts/tokenization.py +# NOTE: This is a direct copy of the original script +"""Tokenization classes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import re +import unicodedata + +from absl import flags +import six +import tensorflow.compat.v1 as tf + +FLAGS = flags.FLAGS + +flags.DEFINE_bool( + "preserve_unused_tokens", False, + "If True, Wordpiece tokenization will not be applied to words in the vocab." +) + +_UNUSED_TOKEN_RE = re.compile("^\\[unused\\d+\\]$") + + +def preserve_token(token, vocab): + """Returns True if the token should forgo tokenization and be preserved.""" + if not FLAGS.preserve_unused_tokens: + return False + if token not in vocab: + return False + return bool(_UNUSED_TOKEN_RE.search(token)) + + +def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): + """Checks whether the casing config is consistent with the checkpoint name.""" + + # The casing has to be passed in by the user and there is no explicit check + # as to whether it matches the checkpoint. The casing information probably + # should have been stored in the bert_config.json file, but it's not, so + # we have to heuristically detect it to validate. + + if not init_checkpoint: + return + + m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) + if m is None: + return + + model_name = m.group(1) + + lower_models = [ + "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", + "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" + ] + + cased_models = [ + "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", + "multi_cased_L-12_H-768_A-12" + ] + + is_bad_config = False + if model_name in lower_models and not do_lower_case: + is_bad_config = True + actual_flag = "False" + case_name = "lowercased" + opposite_flag = "True" + + if model_name in cased_models and do_lower_case: + is_bad_config = True + actual_flag = "True" + case_name = "cased" + opposite_flag = "False" + + if is_bad_config: + raise ValueError( + "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " + "However, `%s` seems to be a %s model, so you " + "should pass in `--do_lower_case=%s` so that the fine-tuning matches " + "how the model was pre-training. If this error is wrong, please " + "just comment out this check." % (actual_flag, init_checkpoint, + model_name, case_name, opposite_flag)) + + +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode("utf-8", "ignore") + elif isinstance(text, unicode): # noqa: F821 + return text + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") + + +def printable_text(text): + """Returns text encoded in a way suitable for print or `tf.logging`.""" + + # These functions want `str` for both Python2 and Python3, but in one case + # it's a Unicode string and in the other it's a byte string. + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text + elif isinstance(text, unicode): # noqa: F821 + return text.encode("utf-8") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with tf.gfile.GFile(vocab_file, "r") as reader: + while True: + token = convert_to_unicode(reader.readline()) + if not token: + break + token = token.strip() + if token not in vocab: + vocab[token] = len(vocab) + return vocab + + +def convert_by_vocab(vocab, items): + """Converts a sequence of [tokens|ids] using the vocab.""" + output = [] + for item in items: + output.append(vocab[item]) + return output + + +def convert_tokens_to_ids(vocab, tokens): + return convert_by_vocab(vocab, tokens) + + +def convert_ids_to_tokens(inv_vocab, ids): + return convert_by_vocab(inv_vocab, ids) + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class FullTokenizer(object): + """Runs end-to-end tokenziation.""" + + def __init__(self, vocab_file, do_lower_case=True): + self.vocab = load_vocab(vocab_file) + self.inv_vocab = {v: k for k, v in self.vocab.items()} + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, vocab=self.vocab) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + + def tokenize(self, text): + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + if preserve_token(token, self.vocab): + split_tokens.append(token) + continue + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + + return split_tokens + + def convert_tokens_to_ids(self, tokens): + return convert_by_vocab(self.vocab, tokens) + + def convert_ids_to_tokens(self, ids): + return convert_by_vocab(self.inv_vocab, ids) + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, do_lower_case=True, vocab=tuple()): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + vocab: A container of tokens to not mutate during tokenization. + """ + self.do_lower_case = do_lower_case + self.vocab = vocab + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = convert_to_unicode(text) + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if preserve_token(token, self.vocab): + split_tokens.append(token) + continue + if self.do_lower_case: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenziation.""" + + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer. + + Returns: + A list of wordpiece tokens. + """ + + text = convert_to_unicode(text) + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically control characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat in ("Cc", "Cf"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or + (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False