mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
minor onnx dropout cleanup (#10891)
we should consider removing numpy random and test it similar to test_randomness, unless how seed works is part of spec?
This commit is contained in:
@@ -613,9 +613,13 @@ def get_onnx_ops():
|
||||
|
||||
# Reimplemented here because you need legacy RNG for passing ONNX tests.
|
||||
def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None):
|
||||
if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's.
|
||||
mask = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device)
|
||||
return data * mask * (1/(1.0 - ratio)), mask
|
||||
if not training_mode: return data, data.full_like(True, dtype=dtypes.bool)
|
||||
if seed is not None:
|
||||
rand = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)), requires_grad=False, dtype=data.dtype, device=data.device)
|
||||
else:
|
||||
rand = data.rand_like(requires_grad=False)
|
||||
mask = rand >= ratio
|
||||
return data * mask / (1.0 - ratio), mask
|
||||
# 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx
|
||||
def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test)
|
||||
Dropout = {6:Dropout_6, 7:Dropout_7}
|
||||
|
||||
Reference in New Issue
Block a user