minor onnx dropout cleanup (#10891)

we should consider removing numpy random and test it similar to test_randomness, unless how seed works is part of spec?
This commit is contained in:
chenyu
2025-06-20 10:18:34 -04:00
committed by GitHub
parent e94ac6e20c
commit 3f29c7edda

View File

@@ -613,9 +613,13 @@ def get_onnx_ops():
# Reimplemented here because you need legacy RNG for passing ONNX tests.
def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None):
if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's.
mask = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device)
return data * mask * (1/(1.0 - ratio)), mask
if not training_mode: return data, data.full_like(True, dtype=dtypes.bool)
if seed is not None:
rand = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)), requires_grad=False, dtype=data.dtype, device=data.device)
else:
rand = data.rand_like(requires_grad=False)
mask = rand >= ratio
return data * mask / (1.0 - ratio), mask
# 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx
def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test)
Dropout = {6:Dropout_6, 7:Dropout_7}