mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
don't use empty in bert fake data (#13661)
somehow jit does not count empty as input
This commit is contained in:
@@ -223,13 +223,13 @@ def get_mlperf_bert_model():
|
||||
|
||||
def get_fake_data_bert(BS:int):
|
||||
return {
|
||||
"input_ids": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CPU"),
|
||||
"input_mask": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CPU"),
|
||||
"segment_ids": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CPU"),
|
||||
"masked_lm_positions": Tensor.empty((BS, 76), dtype=dtypes.int32, device="CPU"),
|
||||
"masked_lm_ids": Tensor.empty((BS, 76), dtype=dtypes.int32, device="CPU"),
|
||||
"masked_lm_weights": Tensor.empty((BS, 76), dtype=dtypes.float32, device="CPU"),
|
||||
"next_sentence_labels": Tensor.empty((BS, 1), dtype=dtypes.int32, device="CPU"),
|
||||
"input_ids": Tensor.zeros((BS, 512), dtype=dtypes.int32, device="CPU").contiguous(),
|
||||
"input_mask": Tensor.zeros((BS, 512), dtype=dtypes.int32, device="CPU").contiguous(),
|
||||
"segment_ids": Tensor.zeros((BS, 512), dtype=dtypes.int32, device="CPU").contiguous(),
|
||||
"masked_lm_positions": Tensor.zeros((BS, 76), dtype=dtypes.int32, device="CPU").contiguous(),
|
||||
"masked_lm_ids": Tensor.zeros((BS, 76), dtype=dtypes.int32, device="CPU").contiguous(),
|
||||
"masked_lm_weights": Tensor.zeros((BS, 76), dtype=dtypes.float32, device="CPU").contiguous(),
|
||||
"next_sentence_labels": Tensor.zeros((BS, 1), dtype=dtypes.int32, device="CPU").contiguous(),
|
||||
}
|
||||
|
||||
def find_matches(match_quality_matrix:np.ndarray, high_threshold:float=0.5, low_threshold:float=0.4, allow_low_quality_matches:bool=False) -> np.ndarray:
|
||||
|
||||
Reference in New Issue
Block a user