From 975c318dbc30fa6a8253dbddb950ef535de5c3ad Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 19 Feb 2025 08:17:27 -0500 Subject: [PATCH] bert use int32 for input ids (#9173) original data was int32 for these. float might have caused precision issues --- examples/mlperf/dataloader.py | 12 ++++++------ examples/mlperf/helpers.py | 12 ++++++------ extra/models/bert.py | 3 ++- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/examples/mlperf/dataloader.py b/examples/mlperf/dataloader.py index e49edcf771..2bde3fe5f1 100644 --- a/examples/mlperf/dataloader.py +++ b/examples/mlperf/dataloader.py @@ -170,13 +170,13 @@ def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None, pad_fir def process_batch_bert(data: List[dict]) -> dict[str, Tensor]: return { - "input_ids": Tensor(np.concatenate([s["input_ids"] for s in data], axis=0), dtype=dtypes.float32, device="CLANG"), - "input_mask": Tensor(np.concatenate([s["input_mask"] for s in data], axis=0), dtype=dtypes.default_float, device="CLANG"), - "segment_ids": Tensor(np.concatenate([s["segment_ids"] for s in data], axis=0), dtype=dtypes.float32, device="CLANG"), - "masked_lm_positions": Tensor(np.concatenate([s["masked_lm_positions"] for s in data], axis=0), dtype=dtypes.float32, device="CLANG"), - "masked_lm_ids": Tensor(np.concatenate([s["masked_lm_ids"] for s in data], axis=0), dtype=dtypes.float32, device="CLANG"), + "input_ids": Tensor(np.concatenate([s["input_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"), + "input_mask": Tensor(np.concatenate([s["input_mask"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"), + "segment_ids": Tensor(np.concatenate([s["segment_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"), + "masked_lm_positions": Tensor(np.concatenate([s["masked_lm_positions"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"), + "masked_lm_ids": Tensor(np.concatenate([s["masked_lm_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"), "masked_lm_weights": Tensor(np.concatenate([s["masked_lm_weights"] for s in data], axis=0), dtype=dtypes.float32, device="CLANG"), - "next_sentence_labels": Tensor(np.concatenate([s["next_sentence_labels"] for s in data], axis=0), dtype=dtypes.float32, device="CLANG"), + "next_sentence_labels": Tensor(np.concatenate([s["next_sentence_labels"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"), } def load_file(file: str): diff --git a/examples/mlperf/helpers.py b/examples/mlperf/helpers.py index 19e82ef956..f3aa16be12 100644 --- a/examples/mlperf/helpers.py +++ b/examples/mlperf/helpers.py @@ -224,11 +224,11 @@ def get_mlperf_bert_model(): def get_fake_data_bert(BS:int): return { - "input_ids": Tensor.empty((BS, 512), dtype=dtypes.float32, device="CLANG"), - "input_mask": Tensor.empty((BS, 512), dtype=dtypes.default_float, device="CLANG"), - "segment_ids": Tensor.empty((BS, 512), dtype=dtypes.float32, device="CLANG"), - "masked_lm_positions": Tensor.empty((BS, 76), dtype=dtypes.float32, device="CLANG"), - "masked_lm_ids": Tensor.empty((BS, 76), dtype=dtypes.float32, device="CLANG"), + "input_ids": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CLANG"), + "input_mask": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CLANG"), + "segment_ids": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CLANG"), + "masked_lm_positions": Tensor.empty((BS, 76), dtype=dtypes.int32, device="CLANG"), + "masked_lm_ids": Tensor.empty((BS, 76), dtype=dtypes.int32, device="CLANG"), "masked_lm_weights": Tensor.empty((BS, 76), dtype=dtypes.float32, device="CLANG"), - "next_sentence_labels": Tensor.empty((BS, 1), dtype=dtypes.float32, device="CLANG"), + "next_sentence_labels": Tensor.empty((BS, 1), dtype=dtypes.int32, device="CLANG"), } diff --git a/extra/models/bert.py b/extra/models/bert.py index 01a136c8e4..a9d2b07f70 100644 --- a/extra/models/bert.py +++ b/extra/models/bert.py @@ -71,7 +71,8 @@ class BertForPretraining: seq_relationship_correct = (seq_relationship_predictions == next_sentence_labels) next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels) - return masked_lm_correct.sum() / valid.sum(), seq_relationship_correct.mean(), masked_lm_loss, next_sentence_loss + # TODO: is it okay that next_sentence_loss is half here? + return masked_lm_correct.sum() / valid.sum(), seq_relationship_correct.mean(), masked_lm_loss, next_sentence_loss.float() def load_from_pretrained(self, tf_weight_path:str=Path(__file__).parent.parent / "datasets" / "wiki"): os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Mute tf flag info