Wikipedia preprocessing script (#4229)

* Preprocessing script

* short seq prob

* comments + env vars

* Add preprocessing reference. Add test

* lint fix + add eval test support

* whitespaces

* point to commit

* comment

* rename

* better comments
This commit is contained in:
Elias Wahl
2024-04-23 16:28:01 +02:00
committed by GitHub
parent 759b4f41c3
commit 69341144ba
5 changed files with 1470 additions and 0 deletions

411
extra/datasets/wikipedia.py Normal file
View File

@@ -0,0 +1,411 @@
# Preprocessing of downloaded text from Wikipedia for MLPerf BERT training
# This is a modified version of the original script:
# https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/cleanup_scripts/create_pretraining_data.py
# ENV VARS:
# MAX_SEQ_LENGTH - Maximum sequence length
# MAX_PREDICTIONS_PER_SEQ - Maximum number of masked LM predictions per sequence
# RANDOM_SEED - Random seed
# DUPE_FACTOR - Number of times to duplicate the input data with different masks
# MASKED_LM_PROB - Probability of masking a token
# SHORT_SEQ_PROB - Probability of picking a sequence shorter than MAX_SEQ_LENGTH
import os, sys, pickle, random, unicodedata
from pathlib import Path
import numpy as np
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map
from tinygrad.helpers import diskcache, getenv
BASEDIR = getenv('BASEDIR', Path(__file__).parent / "wiki")
################### Tokenization #####################
def _is_whitespace(char:str) -> bool:
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
return unicodedata.category(char) == "Zs"
def _is_control(char:str) -> bool:
if char == "\t" or char == "\n" or char == "\r":
return False
return unicodedata.category(char).startswith("C")
def _is_punctuation(char:str) -> bool:
# range(33, 48) -> ! " # $ % & ' ( ) * + , - . /
# range(58, 65) -> : ; < = > ? @
# range(91, 97) -> [ \ ] ^ _
# range(123, 127) -> { | } ~
if (cp := ord(char)) in range(33, 48) or cp in range(58, 65) or cp in range(91, 97) or cp in range(123, 127):
return True
return unicodedata.category(char).startswith("P")
def _is_chinese_char(cp:int) -> bool:
if ((cp >= 0x4E00 and cp <= 0x9FFF) or
(cp >= 0x3400 and cp <= 0x4DBF) or
(cp >= 0x20000 and cp <= 0x2A6DF) or
(cp >= 0x2A700 and cp <= 0x2B73F) or
(cp >= 0x2B740 and cp <= 0x2B81F) or
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or
(cp >= 0x2F800 and cp <= 0x2FA1F)):
return True
return False
def _run_split_on_punc(text:str) -> list[str]:
if text in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"):
return [text]
start_new_word = True
output = []
for i in range(len(text)):
if _is_punctuation(char := text[i]):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
return ["".join(x) for x in output]
def _run_strip_accents(text:str) -> str:
output = []
for char in unicodedata.normalize("NFD", text):
if unicodedata.category(char) != "Mn":
output.append(char)
return "".join(output)
def _clean_text(text:str) -> str:
output = []
for char in text:
if not ((cp := ord(char)) == 0 or cp == 0xfffd or _is_control(char)):
output.append(" " if _is_whitespace(char) else char)
return "".join(output)
def _tokenize_chinese_chars(text:str) -> str:
output = []
for char in text:
cp = ord(char)
if _is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def whitespace_tokenize(text):
if not (text := text.strip()): return []
return text.split()
def _wordpiece_tokenize(text:str, vocab:dict[str, int]) -> list[str]:
text = text.decode("utf-8", "ignore") if isinstance(text, bytes) else text
output_tokens = []
for token in text.strip().split():
chars = list(token)
if len(chars) > 200:
output_tokens.append("[UNK]")
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0: substr = "##" + substr
if substr in vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad: output_tokens.append("[UNK]")
else: output_tokens.extend(sub_tokens)
return output_tokens
class Tokenizer:
def __init__(self, vocab_file):
self.vocab = {}
with open(vocab_file) as f:
for line in f:
line = line.decode("utf-8", "ignore") if isinstance(line, bytes) else line
if (token := line.strip()) and token not in self.vocab: self.vocab[token] = len(self.vocab)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
def tokenize(self, text:str) -> list[str]:
# BasicTokenizer
split_tokens = []
for token in whitespace_tokenize(_tokenize_chinese_chars(_clean_text(text.decode("utf-8", "ignore") if isinstance(text, bytes) else text))):
split_tokens.extend(_run_split_on_punc(_run_strip_accents(token.lower())))
split_tokens = " ".join(split_tokens).strip().split()
# WordpieceTokenizer
tokens = []
for token in split_tokens:
tokens.extend(_wordpiece_tokenize(token, self.vocab))
return tokens
def convert_tokens_to_ids(self, tokens:list[str]) -> list[int]: return [self.vocab[token] for token in tokens]
def convert_ids_to_tokens(self, ids:list[int]) -> list[str]: return [self.inv_vocab[id] for id in ids]
##################### Feature transformation #####################
def truncate_seq_pair(tokens_a:list[str], tokens_b:list[str], max_num_tokens:int, rng:random.Random) -> None:
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_num_tokens:
break
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
assert len(trunc_tokens) >= 1
if rng.random() < 0.5:
del trunc_tokens[0]
else:
trunc_tokens.pop()
def create_masked_lm_predictions(tokens:list[str], tokenizer:Tokenizer, rng:random.Random, vocab_words:list[str]) -> tuple[list[str], list[int], list[str]]:
cand_indices = []
for i, token in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
cand_indices.append(i)
rng.shuffle(cand_indices)
output_tokens = list(tokens)
num_to_predict = min(getenv('MAX_PREDICTIONS_PER_SEQ', 76), max(1, int(round(len(tokens) * getenv("MASKED_LM_PROB", 0.15)))))
masked_lms = []
covered_indices = set()
for index in cand_indices:
if len(masked_lms) >= num_to_predict:
break
if index in covered_indices:
continue
covered_indices.add(index)
masked_token = None
if rng.random() < 0.8:
masked_token = "[MASK]"
else:
if rng.random() < 0.5:
masked_token = tokens[index]
else:
masked_token = vocab_words[rng.randint(0, len(tokenizer.vocab) - 1)]
output_tokens[index] = masked_token
masked_lms.append((index, tokens[index]))
masked_lms = sorted(masked_lms, key=lambda x: x[0])
masked_lm_positions = []
masked_lm_labels = []
for p in masked_lms:
masked_lm_positions.append(p[0])
masked_lm_labels.append(p[1])
return output_tokens, masked_lm_positions, masked_lm_labels
def create_instances_from_document(rng:random.Random, tokenizer:Tokenizer, doc:list[str], di:int, documents:list[list[str]]) -> list[dict]:
max_num_tokens = getenv('MAX_SEQ_LENGTH', 512) - 3 # [CLS] + 2 * [SEP]
target_seq_length = max_num_tokens
if rng.random() < getenv("SHORT_SEQ_PROB", 0.1):
target_seq_length = rng.randint(2, max_num_tokens)
instances = []
current_chunk = []
current_length = 0
i = 0
while i < len(doc):
segment = doc[i]
current_chunk.append(segment)
current_length += len(segment)
if i == len(doc) - 1 or current_length >= target_seq_length:
if current_chunk:
a_end = 1
if len(current_chunk) >= 2:
a_end = rng.randint(1, len(current_chunk) - 1)
tokens_a = []
for j in range(a_end):
tokens_a.extend(current_chunk[j])
tokens_b = []
is_random_next = False
if len(current_chunk) == 1 or rng.random() < 0.5:
is_random_next = True
target_b_length = target_seq_length - len(tokens_a)
for _ in range(10):
random_document_index = rng.randint(0, len(documents) - 1)
if random_document_index != di:
break
random_document = documents[random_document_index]
random_start = rng.randint(0, len(random_document) - 1)
for j in range(random_start, len(random_document)):
tokens_b.extend(random_document[j])
if len(tokens_b) >= target_b_length:
break
num_unused_segments = len(current_chunk) - a_end
i -= num_unused_segments
else:
is_random_next = False
for j in range(a_end, len(current_chunk)):
tokens_b.extend(current_chunk[j])
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
assert len(tokens_a) >= 1
assert len(tokens_b) >= 1
tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions(tokens, tokenizer, rng, list(tokenizer.vocab.keys()))
instances.append({
"tokens": tokens,
"segment_ids": segment_ids,
"masked_lm_positions": masked_lm_positions,
"masked_lm_labels": masked_lm_labels,
"is_random_next": is_random_next
})
current_chunk = []
current_length = 0
i += 1
return instances
def get_documents(rng:random.Random, tokenizer:Tokenizer, fn:str) -> list[list[str]]:
documents = [[]]
with open(BASEDIR / fn) as f:
for line in f.readlines():
if not (line := line.decode("utf-8", "ignore") if isinstance(line, bytes) else line): break
if not (line := line.strip()): documents.append([])
if (tokens := tokenizer.tokenize(line)): documents[-1].append(tokens)
documents = [x for x in documents if x]
rng.shuffle(documents)
return documents
def get_instances(rng:random.Random, tokenizer:Tokenizer, documents:list[list[str]]) -> list[dict]:
instances = []
for _ in range(getenv('DUPE_FACTOR', 10)):
for di, doc in enumerate(documents):
instances.extend(create_instances_from_document(rng, tokenizer, doc, di, documents))
rng.shuffle(instances)
return instances
def instance_to_features(instance:dict, tokenizer:Tokenizer) -> dict:
input_ids = tokenizer.convert_tokens_to_ids(instance["tokens"])
input_mask = [1] * len(input_ids)
segment_ids = instance["segment_ids"]
max_seq_length = getenv('MAX_SEQ_LENGTH', 512)
assert len(input_ids) <= max_seq_length
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
masked_lm_positions = instance["masked_lm_positions"]
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance["masked_lm_labels"])
masked_lm_weights = [1.0] * len(masked_lm_ids)
while len(masked_lm_positions) < getenv("MAX_PREDICTIONS_PER_SEQ", 76):
masked_lm_positions.append(0)
masked_lm_ids.append(0)
masked_lm_weights.append(0.0)
next_sentence_label = 1 if instance["is_random_next"] else 0
return {
"input_ids": np.expand_dims(np.array(input_ids, dtype=np.int32), 0),
"input_mask": np.expand_dims(np.array(input_mask, dtype=np.int32), 0),
"segment_ids": np.expand_dims(np.array(segment_ids, dtype=np.int32), 0),
"masked_lm_positions": np.expand_dims(np.array(masked_lm_positions, dtype=np.int32), 0),
"masked_lm_ids": np.expand_dims(np.array(masked_lm_ids, dtype=np.int32), 0),
"masked_lm_weights": np.expand_dims(np.array(masked_lm_weights, dtype=np.float32), 0),
"next_sentence_labels": np.expand_dims(np.array([next_sentence_label], dtype=np.int32), 0),
}
def process_part(part:int):
tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt")
os.makedirs(BASEDIR / "train" / str(part), exist_ok=True)
for i, feature_batch in enumerate(process_iterate(tokenizer, val=False, part=part)):
with open(BASEDIR / f"train/{str(part)}/{part}_{i}.pkl", "wb") as f:
pickle.dump(feature_batch, f)
def process_iterate(tokenizer:Tokenizer, val:bool=False, part:int=0) -> list[dict]: # Convert raw text to masked NSP samples
rng = random.Random(getenv('RANDOM_SEED', 12345))
if val:
tqdm.write("Getting samples from dataset")
documents = get_documents(rng, tokenizer, "results4/eval.txt")
instances = get_instances(rng, tokenizer, documents)
tqdm.write(f"There are {len(instances)} samples in the dataset")
tqdm.write(f"Picking 10000 samples")
pick_ratio = len(instances) / 10000
picks = [instance_to_features(instances[int(inst*pick_ratio)], tokenizer) for inst in range(10000)]
for batch in range(10):
yield picks[batch*1000:(batch+1)*1000]
else:
documents = get_documents(rng, tokenizer, f"results4/part-{part:05d}-of-00500")
instances = get_instances(rng, tokenizer, documents)
while len(instances) > 0:
batch_size = min(1000, len(instances)) # We batch 1000 samples to one file
batch = instances[:batch_size]
del instances[:batch_size]
yield [instance_to_features(instance, tokenizer) for instance in batch]
##################### Load files #####################
def get_wiki_val_files(): return sorted(list((BASEDIR / "eval/").glob("*.pkl")))
@diskcache
def get_wiki_train_files(): return sorted(list((BASEDIR / "train/").glob("*/*.pkl")))
if __name__ == "__main__":
tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt")
assert len(sys.argv) > 1, "Usage: python wikipedia.py pre-eval|pre-train [part]|all"
if sys.argv[1] == "pre-eval": # Generate 10000 eval samples
os.makedirs(BASEDIR / "eval", exist_ok=True)
for i, feature_batch in tqdm(enumerate(process_iterate(tokenizer, val=True)), total=10):
with open(BASEDIR / f"eval/{i}.pkl", "wb") as f:
pickle.dump(feature_batch, f)
elif sys.argv[1] == "pre-train":
os.makedirs(BASEDIR / "train", exist_ok=True)
if sys.argv[2] == "all": # Use all 500 parts for training generation
process_map(process_part, [part for part in range(500)], max_workers=getenv('NUM_WORKERS', os.cpu_count()), chunksize=1)
else: # Use a specific part for training generation
part = int(sys.argv[2])
os.makedirs(BASEDIR / "train" / str(part), exist_ok=True)
for i, feature_batch in tqdm(enumerate(process_iterate(tokenizer, val=False, part=part))):
with open(BASEDIR / f"train/{str(part)}/{part}_{i}.pkl", "wb") as f:
pickle.dump(feature_batch, f)

View File

@@ -0,0 +1,435 @@
# https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/cleanup_scripts/create_pretraining_data.py
# NOTE: This is a direct copy of the original script
# NOTE: With python 3.7.12, pip install tensorflow=1.15.5
"""Create masked LM/next sentence masked_lm TF examples for BERT."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' # NOTE: This is a workaround for protobuf issue
import collections
import random
import tokenization
import tensorflow as tf
flags = tf.flags
FLAGS = flags.FLAGS
flags.DEFINE_string("input_file", None,
"Input raw text file (or comma-separated list of files).")
flags.DEFINE_string(
"output_file", None,
"Output TF example file (or comma-separated list of files).")
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
flags.DEFINE_integer("max_predictions_per_seq", 20,
"Maximum number of masked LM predictions per sequence.")
flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
flags.DEFINE_integer(
"dupe_factor", 10,
"Number of times to duplicate the input data (with different masks).")
flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
flags.DEFINE_float(
"short_seq_prob", 0.1,
"Probability of creating sequences which are shorter than the "
"maximum length.")
class TrainingInstance(object):
"""A single training instance (sentence pair)."""
def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
is_random_next):
self.tokens = tokens
self.segment_ids = segment_ids
self.is_random_next = is_random_next
self.masked_lm_positions = masked_lm_positions
self.masked_lm_labels = masked_lm_labels
def __str__(self):
s = ""
s += "tokens: %s\n" % (" ".join(
[tokenization.printable_text(x) for x in self.tokens]))
s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
s += "is_random_next: %s\n" % self.is_random_next
s += "masked_lm_positions: %s\n" % (" ".join(
[str(x) for x in self.masked_lm_positions]))
s += "masked_lm_labels: %s\n" % (" ".join(
[tokenization.printable_text(x) for x in self.masked_lm_labels]))
s += "\n"
return s
def __repr__(self):
return self.__str__()
def write_instance_to_example_files(instances, tokenizer, max_seq_length,
max_predictions_per_seq, output_files):
"""Create TF example files from `TrainingInstance`s."""
writers = []
for output_file in output_files:
writers.append(tf.python_io.TFRecordWriter(output_file))
writer_index = 0
total_written = 0
for (inst_index, instance) in enumerate(instances):
input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
input_mask = [1] * len(input_ids)
segment_ids = list(instance.segment_ids)
assert len(input_ids) <= max_seq_length
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
masked_lm_positions = list(instance.masked_lm_positions)
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
masked_lm_weights = [1.0] * len(masked_lm_ids)
while len(masked_lm_positions) < max_predictions_per_seq:
masked_lm_positions.append(0)
masked_lm_ids.append(0)
masked_lm_weights.append(0.0)
next_sentence_label = 1 if instance.is_random_next else 0
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(input_ids)
features["input_mask"] = create_int_feature(input_mask)
features["segment_ids"] = create_int_feature(segment_ids)
features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
features["next_sentence_labels"] = create_int_feature([next_sentence_label])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writers[writer_index].write(tf_example.SerializeToString())
writer_index = (writer_index + 1) % len(writers)
total_written += 1
if inst_index < 20:
tf.logging.info("*** Example ***")
tf.logging.info("tokens: %s" % " ".join(
[tokenization.printable_text(x) for x in instance.tokens]))
for feature_name in features.keys():
feature = features[feature_name]
values = []
if feature.int64_list.value:
values = feature.int64_list.value
elif feature.float_list.value:
values = feature.float_list.value
tf.logging.info(
"%s: %s" % (feature_name, " ".join([str(x) for x in values])))
for writer in writers:
writer.close()
tf.logging.info("Wrote %d total instances", total_written)
def create_int_feature(values):
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return feature
def create_float_feature(values):
feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return feature
def create_training_instances(input_files, tokenizer, max_seq_length,
dupe_factor, short_seq_prob, masked_lm_prob,
max_predictions_per_seq, rng):
"""Create `TrainingInstance`s from raw text."""
all_documents = [[]]
# Input file format:
# (1) One sentence per line. These should ideally be actual sentences, not
# entire paragraphs or arbitrary spans of text. (Because we use the
# sentence boundaries for the "next sentence prediction" task).
# (2) Blank lines between documents. Document boundaries are needed so
# that the "next sentence prediction" task doesn't span between documents.
for input_file in input_files:
with tf.gfile.GFile(input_file, "r") as reader:
while True:
line = tokenization.convert_to_unicode(reader.readline())
if not line:
break
line = line.strip()
# Empty lines are used as document delimiters
if not line:
all_documents.append([])
tokens = tokenizer.tokenize(line)
if tokens:
all_documents[-1].append(tokens)
# Remove empty documents
all_documents = [x for x in all_documents if x]
rng.shuffle(all_documents)
vocab_words = list(tokenizer.vocab.keys())
instances = []
for _ in range(dupe_factor):
for document_index in range(len(all_documents)):
instances.extend(
create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
rng.shuffle(instances)
return instances
def create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
"""Creates `TrainingInstance`s for a single document."""
document = all_documents[document_index]
# Account for [CLS], [SEP], [SEP]
max_num_tokens = max_seq_length - 3
# We *usually* want to fill up the entire sequence since we are padding
# to `max_seq_length` anyways, so short sequences are generally wasted
# computation. However, we *sometimes*
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
# sequences to minimize the mismatch between pre-training and fine-tuning.
# The `target_seq_length` is just a rough target however, whereas
# `max_seq_length` is a hard limit.
target_seq_length = max_num_tokens
if rng.random() < short_seq_prob:
target_seq_length = rng.randint(2, max_num_tokens)
# We DON'T just concatenate all of the tokens from a document into a long
# sequence and choose an arbitrary split point because this would make the
# next sentence prediction task too easy. Instead, we split the input into
# segments "A" and "B" based on the actual "sentences" provided by the user
# input.
instances = []
current_chunk = []
current_length = 0
i = 0
while i < len(document):
segment = document[i]
current_chunk.append(segment)
current_length += len(segment)
if i == len(document) - 1 or current_length >= target_seq_length:
if current_chunk:
# `a_end` is how many segments from `current_chunk` go into the `A`
# (first) sentence.
a_end = 1
if len(current_chunk) >= 2:
a_end = rng.randint(1, len(current_chunk) - 1)
tokens_a = []
for j in range(a_end):
tokens_a.extend(current_chunk[j])
tokens_b = []
# Random next
is_random_next = False
if len(current_chunk) == 1 or rng.random() < 0.5:
is_random_next = True
target_b_length = target_seq_length - len(tokens_a)
# This should rarely go for more than one iteration for large
# corpora. However, just to be careful, we try to make sure that
# the random document is not the same as the document
# we're processing.
for _ in range(10):
random_document_index = rng.randint(0, len(all_documents) - 1)
if random_document_index != document_index:
break
random_document = all_documents[random_document_index]
random_start = rng.randint(0, len(random_document) - 1)
for j in range(random_start, len(random_document)):
tokens_b.extend(random_document[j])
if len(tokens_b) >= target_b_length:
break
# We didn't actually use these segments so we "put them back" so
# they don't go to waste.
num_unused_segments = len(current_chunk) - a_end
i -= num_unused_segments
# Actual next
else:
is_random_next = False
for j in range(a_end, len(current_chunk)):
tokens_b.extend(current_chunk[j])
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
assert len(tokens_a) >= 1
assert len(tokens_b) >= 1
tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
(tokens, masked_lm_positions,
masked_lm_labels) = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
instance = TrainingInstance(
tokens=tokens,
segment_ids=segment_ids,
is_random_next=is_random_next,
masked_lm_positions=masked_lm_positions,
masked_lm_labels=masked_lm_labels)
instances.append(instance)
current_chunk = []
current_length = 0
i += 1
return instances
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"])
def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng):
"""Creates the predictions for the masked LM objective."""
cand_indexes = []
for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
cand_indexes.append(i)
rng.shuffle(cand_indexes)
output_tokens = list(tokens)
num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
masked_lms = []
covered_indexes = set()
for index in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
if index in covered_indexes:
continue
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if rng.random() < 0.8:
masked_token = "[MASK]"
else:
# 10% of the time, keep original
if rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
output_tokens[index] = masked_token
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_positions = []
masked_lm_labels = []
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)
return (output_tokens, masked_lm_positions, masked_lm_labels)
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
"""Truncates a pair of sequences to a maximum sequence length."""
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_num_tokens:
break
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
assert len(trunc_tokens) >= 1
# We want to sometimes truncate from the front and sometimes from the
# back to add more randomness and avoid biases.
if rng.random() < 0.5:
del trunc_tokens[0]
else:
trunc_tokens.pop()
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
input_files = []
for input_pattern in FLAGS.input_file.split(","):
input_files.extend(tf.gfile.Glob(input_pattern))
tf.logging.info("*** Reading from input files ***")
for input_file in input_files:
tf.logging.info(" %s", input_file)
rng = random.Random(FLAGS.random_seed)
instances = create_training_instances(
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
rng)
output_files = FLAGS.output_file.split(",")
tf.logging.info("*** Writing to output files ***")
for output_file in output_files:
tf.logging.info(" %s", output_file)
write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
FLAGS.max_predictions_per_seq, output_files)
if __name__ == "__main__":
flags.mark_flag_as_required("input_file")
flags.mark_flag_as_required("output_file")
flags.mark_flag_as_required("vocab_file")
tf.app.run()

View File

@@ -0,0 +1,82 @@
# USAGE:
# 1. Download raw text data with `wikipedia_download.py`
# 2. Install python==3.7.12 and tensorflow==1.15.5
# Run `create_pretraining_data.py` to create TFRecords on specific part (This will take some time)
# Command: python3 create_pretraining_data.py --input_file=/path/to/part-00XXX-of-00500 --vocab_file=/path/to/vocab.txt \
# --output_file=/path/to/output.tfrecord --max_seq_length=512 --max_predictions_per_seq=76
#
# 2.1 For eval: --input_file=/path/to/eval.txt and
# Command: python3 pick_eval_samples.py --input_tfrecord=/path/to/eval.tfrecord --output_tfrecord=/path/to/output_eval.tfrecord
# 3. Run `wikipedia.py` to preprocess the data with tinygrad (Use python > 3.7)
# Command: BASEDIR=/path/to/basedir python3 wikipedia.py pre-train XXX (NOTE: part number needs to match part of step 2)
# This will output to /path/to/basedir/train/XXX
#
# 3.1 For eval:
# Command: BASEDIR=/path/to/basedir python3 wikipedia.py pre-eval
# This will output to /path/to/basedir/eval
# 4. Run this script to verify the correctness of the preprocessing script for specific part
# Command: python3 external_test_preprocessing_part.py --preprocessed_part_dir=/path/to/basedir/part --tf_records=/path/to/output.tfrecord
import os, argparse, pickle
from tqdm import tqdm
# This is a workaround for protobuf issue
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
def _parse_function(proto, max_seq_length, max_predictions_per_seq):
feature_description = {
'input_ids': tf.io.FixedLenFeature([max_seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([max_seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([max_seq_length], tf.int64),
'masked_lm_positions': tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
'masked_lm_ids': tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
'masked_lm_weights': tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32),
'next_sentence_labels': tf.io.FixedLenFeature([1], tf.int64),
}
return tf.io.parse_single_example(proto, feature_description)
def load_dataset(file_path, max_seq_length=512, max_predictions_per_seq=76):
dataset = tf.data.TFRecordDataset(file_path)
parse_function = lambda proto: _parse_function(proto, max_seq_length, max_predictions_per_seq) # noqa: E731
return dataset.map(parse_function)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Verify the correctness of the preprocessing script for specific part",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--preprocessed_part_dir", type=str, default=None,
help="Path to dir with preprocessed samples from `wikipedia.py`")
parser.add_argument("--tf_records", type=str, default=None,
help="Path to TFRecords file from `create_pretraining_data.py` (Reference implementation)")
parser.add_argument("--max_seq_length", type=int, default=512, help="Max sequence length. For MLPerf keep it as 512")
parser.add_argument("--max_predictions_per_seq", type=int, default=76, help="Max predictions per sequence. For MLPerf keep it as 76")
parser.add_argument("--is_eval", type=bool, default=False, help="Whether to run eval or train preprocessing")
args = parser.parse_args()
assert os.path.isdir(args.preprocessed_part_dir), f"The specified directory {args.preprocessed_part_dir} does not exist."
assert os.path.isfile(args.tf_records), f"The specified TFRecords file {args.tf_records} does not exist."
preprocessed_samples = []
for file_name in sorted(os.listdir(args.preprocessed_part_dir), key=lambda x: int(x.split("_")[1].split(".")[0]) if not args.is_eval else int(x.split(".")[0])): # 0_3.pkl -> 3 # noqa: E501
with open(os.path.join(args.preprocessed_part_dir, file_name), 'rb') as f:
samples = pickle.load(f)
preprocessed_samples.extend(samples)
dataset = load_dataset(args.tf_records, args.max_seq_length, args.max_predictions_per_seq)
tf_record_count = sum(1 for _ in dataset)
assert tf_record_count == len(preprocessed_samples), f"Samples in reference: {tf_record_count} != Preprocessed samples: {len(preprocessed_samples)}"
print(f"Total samples in the part: {tf_record_count}")
for i, (reference_example, preprocessed_sample) in tqdm(enumerate(zip(dataset, preprocessed_samples)), desc="Checking samples", total=len(preprocessed_samples)): # noqa: E501
feature_keys = ["input_ids", "input_mask", "segment_ids", "masked_lm_positions", "masked_lm_ids", "masked_lm_weights", "next_sentence_labels"]
for key in feature_keys:
reference_example_feature = reference_example[key].numpy()
assert (reference_example_feature == preprocessed_sample[key]).all(), \
f"{key} are not equal at index {i}\nReference: {reference_example_feature}\nPreprocessed: {preprocessed_sample[key]}"

View File

@@ -0,0 +1,127 @@
# https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/cleanup_scripts/pick_eval_samples.py
# NOTE: This is a direct copy of the original script
"""Script for picking certain number of sampels.
"""
import argparse
import time
import logging
import collections
import tensorflow as tf
parser = argparse.ArgumentParser(
description="Eval sample picker for BERT.")
parser.add_argument(
'--input_tfrecord',
type=str,
default='',
help='Input tfrecord path')
parser.add_argument(
'--output_tfrecord',
type=str,
default='',
help='Output tfrecord path')
parser.add_argument(
'--num_examples_to_pick',
type=int,
default=10000,
help='Number of examples to pick')
parser.add_argument(
'--max_seq_length',
type=int,
default=512,
help='The maximum number of tokens within a sequence.')
parser.add_argument(
'--max_predictions_per_seq',
type=int,
default=76,
help='The maximum number of predictions within a sequence.')
args = parser.parse_args()
max_seq_length = args.max_seq_length
max_predictions_per_seq = args.max_predictions_per_seq
logging.basicConfig(level=logging.INFO)
def decode_record(record):
"""Decodes a record to a TensorFlow example."""
name_to_features = {
"input_ids":
tf.FixedLenFeature([max_seq_length], tf.int64),
"input_mask":
tf.FixedLenFeature([max_seq_length], tf.int64),
"segment_ids":
tf.FixedLenFeature([max_seq_length], tf.int64),
"masked_lm_positions":
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
"masked_lm_ids":
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
"masked_lm_weights":
tf.FixedLenFeature([max_predictions_per_seq], tf.float32),
"next_sentence_labels":
tf.FixedLenFeature([1], tf.int64),
}
example = tf.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.to_int32(t)
example[name] = t
return example
def create_int_feature(values):
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return feature
def create_float_feature(values):
feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return feature
if __name__ == '__main__':
tic = time.time()
tf.enable_eager_execution()
d = tf.data.TFRecordDataset(args.input_tfrecord)
num_examples = 0
records = []
for record in d:
records.append(record)
num_examples += 1
writer = tf.python_io.TFRecordWriter(args.output_tfrecord)
i = 0
pick_ratio = num_examples / args.num_examples_to_pick
num_examples_picked = 0
for i in range(args.num_examples_to_pick):
example = decode_record(records[int(i * pick_ratio)])
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(
example["input_ids"].numpy())
features["input_mask"] = create_int_feature(
example["input_mask"].numpy())
features["segment_ids"] = create_int_feature(
example["segment_ids"].numpy())
features["masked_lm_positions"] = create_int_feature(
example["masked_lm_positions"].numpy())
features["masked_lm_ids"] = create_int_feature(
example["masked_lm_ids"].numpy())
features["masked_lm_weights"] = create_float_feature(
example["masked_lm_weights"].numpy())
features["next_sentence_labels"] = create_int_feature(
example["next_sentence_labels"].numpy())
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
num_examples_picked += 1
writer.close()
toc = time.time()
logging.info("Picked %d examples out of %d samples in %.2f sec",
num_examples_picked, num_examples, toc - tic)

View File

@@ -0,0 +1,415 @@
# https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/cleanup_scripts/tokenization.py
# NOTE: This is a direct copy of the original script
"""Tokenization classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import unicodedata
from absl import flags
import six
import tensorflow.compat.v1 as tf
FLAGS = flags.FLAGS
flags.DEFINE_bool(
"preserve_unused_tokens", False,
"If True, Wordpiece tokenization will not be applied to words in the vocab."
)
_UNUSED_TOKEN_RE = re.compile("^\\[unused\\d+\\]$")
def preserve_token(token, vocab):
"""Returns True if the token should forgo tokenization and be preserved."""
if not FLAGS.preserve_unused_tokens:
return False
if token not in vocab:
return False
return bool(_UNUSED_TOKEN_RE.search(token))
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if not init_checkpoint:
return
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
if m is None:
return
model_name = m.group(1)
lower_models = [
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
]
cased_models = [
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
"multi_cased_L-12_H-768_A-12"
]
is_bad_config = False
if model_name in lower_models and not do_lower_case:
is_bad_config = True
actual_flag = "False"
case_name = "lowercased"
opposite_flag = "True"
if model_name in cased_models and do_lower_case:
is_bad_config = True
actual_flag = "True"
case_name = "cased"
opposite_flag = "False"
if is_bad_config:
raise ValueError(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check." % (actual_flag, init_checkpoint,
model_name, case_name, opposite_flag))
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode): # noqa: F821
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode): # noqa: F821
return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
with tf.gfile.GFile(vocab_file, "r") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
if token not in vocab:
vocab[token] = len(vocab)
return vocab
def convert_by_vocab(vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab."""
output = []
for item in items:
output.append(vocab[item])
return output
def convert_tokens_to_ids(vocab, tokens):
return convert_by_vocab(vocab, tokens)
def convert_ids_to_tokens(inv_vocab, ids):
return convert_by_vocab(inv_vocab, ids)
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(
do_lower_case=do_lower_case, vocab=self.vocab)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
if preserve_token(token, self.vocab):
split_tokens.append(token)
continue
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True, vocab=tuple()):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
vocab: A container of tokens to not mutate during tokenization.
"""
self.do_lower_case = do_lower_case
self.vocab = vocab
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if preserve_token(token, self.vocab):
split_tokens.append(token)
continue
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text = convert_to_unicode(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically control characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat in ("Cc", "Cf"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False