diff --git a/test/external/external_test_onnx_ops.py b/test/external/external_test_onnx_ops.py index 76a5042385..171aa90022 100644 --- a/test/external/external_test_onnx_ops.py +++ b/test/external/external_test_onnx_ops.py @@ -65,6 +65,15 @@ class TestMainOnnxOps(TestOnnxOps): outputs = ["y"] self.helper_test_single_op("Conv", inputs, attributes, outputs, atol=1e-4) + def test_pad_constant_value_zero(self): + from tinygrad.nn.onnx import onnx_ops + Pad = onnx_ops["Pad"] + x = Tensor.arange(4).reshape(1, 1, 2, 2).float() + pads = [0, 0, 1, 1, 0, 0, 1, 1] + out = Pad(x, pads, constant_value=0, value=3) + expected = x.pad((pads[3], pads[7], pads[2], pads[6], pads[1], pads[5], pads[0], pads[4]), value=0) + self.assertEqual(out.tolist(), expected.tolist()) + def test_gather(self): # test const negative indices inputs = { diff --git a/tinygrad/nn/onnx.py b/tinygrad/nn/onnx.py index 43eb9a9aa7..e59a1a270d 100644 --- a/tinygrad/nn/onnx.py +++ b/tinygrad/nn/onnx.py @@ -507,7 +507,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT if auto_pad == "VALID": return [0]*(len(k_)*2) i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_)) if auto_pad == "NOTSET": return _onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2) - o_ = [((i - (1 if auto_pad in ("SAME_UPPER", "SAME_LOWER") else k)) // s + 1) for i,k,s in zip(i_, k_, s_)] + o_ = [((i - 1) // s + 1) for i,s in zip(i_, s_)] return _onnx_pads_to_tiny_pads(_auto_pad([(o-1)*s+k-i for o,i,k,s in zip(o_, i_, k_, s_)], auto_pad)) def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype) @@ -709,7 +709,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None, mode:Literal["constant", "reflect", "edge", "wrap"]="constant", value=0): - value = _resolve_const(constant_value or value) + value = _resolve_const(value if constant_value is None else constant_value) axes = axes or list(range(x.ndim)) real_pads = [0] * (x.ndim*2) for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)]