don't use empty in bert fake data (#13661)

somehow jit does not count empty as input
This commit is contained in:
chenyu
2025-12-12 15:59:50 -05:00
committed by GitHub
parent 316da9f7ff
commit fcaed1e1dd

View File

@@ -223,13 +223,13 @@ def get_mlperf_bert_model():
def get_fake_data_bert(BS:int):
return {
"input_ids": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CPU"),
"input_mask": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CPU"),
"segment_ids": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CPU"),
"masked_lm_positions": Tensor.empty((BS, 76), dtype=dtypes.int32, device="CPU"),
"masked_lm_ids": Tensor.empty((BS, 76), dtype=dtypes.int32, device="CPU"),
"masked_lm_weights": Tensor.empty((BS, 76), dtype=dtypes.float32, device="CPU"),
"next_sentence_labels": Tensor.empty((BS, 1), dtype=dtypes.int32, device="CPU"),
"input_ids": Tensor.zeros((BS, 512), dtype=dtypes.int32, device="CPU").contiguous(),
"input_mask": Tensor.zeros((BS, 512), dtype=dtypes.int32, device="CPU").contiguous(),
"segment_ids": Tensor.zeros((BS, 512), dtype=dtypes.int32, device="CPU").contiguous(),
"masked_lm_positions": Tensor.zeros((BS, 76), dtype=dtypes.int32, device="CPU").contiguous(),
"masked_lm_ids": Tensor.zeros((BS, 76), dtype=dtypes.int32, device="CPU").contiguous(),
"masked_lm_weights": Tensor.zeros((BS, 76), dtype=dtypes.float32, device="CPU").contiguous(),
"next_sentence_labels": Tensor.zeros((BS, 1), dtype=dtypes.int32, device="CPU").contiguous(),
}
def find_matches(match_quality_matrix:np.ndarray, high_threshold:float=0.5, low_threshold:float=0.4, allow_low_quality_matches:bool=False) -> np.ndarray: