mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix onnx Pad constant_value=None (#14271)
also removed a dead branch in _resolve_pool_pads
This commit is contained in:
9
test/external/external_test_onnx_ops.py
vendored
9
test/external/external_test_onnx_ops.py
vendored
@@ -65,6 +65,15 @@ class TestMainOnnxOps(TestOnnxOps):
|
||||
outputs = ["y"]
|
||||
self.helper_test_single_op("Conv", inputs, attributes, outputs, atol=1e-4)
|
||||
|
||||
def test_pad_constant_value_zero(self):
|
||||
from tinygrad.nn.onnx import onnx_ops
|
||||
Pad = onnx_ops["Pad"]
|
||||
x = Tensor.arange(4).reshape(1, 1, 2, 2).float()
|
||||
pads = [0, 0, 1, 1, 0, 0, 1, 1]
|
||||
out = Pad(x, pads, constant_value=0, value=3)
|
||||
expected = x.pad((pads[3], pads[7], pads[2], pads[6], pads[1], pads[5], pads[0], pads[4]), value=0)
|
||||
self.assertEqual(out.tolist(), expected.tolist())
|
||||
|
||||
def test_gather(self):
|
||||
# test const negative indices
|
||||
inputs = {
|
||||
|
||||
@@ -507,7 +507,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
if auto_pad == "VALID": return [0]*(len(k_)*2)
|
||||
i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_))
|
||||
if auto_pad == "NOTSET": return _onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2)
|
||||
o_ = [((i - (1 if auto_pad in ("SAME_UPPER", "SAME_LOWER") else k)) // s + 1) for i,k,s in zip(i_, k_, s_)]
|
||||
o_ = [((i - 1) // s + 1) for i,s in zip(i_, s_)]
|
||||
return _onnx_pads_to_tiny_pads(_auto_pad([(o-1)*s+k-i for o,i,k,s in zip(o_, i_, k_, s_)], auto_pad))
|
||||
|
||||
def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype)
|
||||
@@ -709,7 +709,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
|
||||
def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None,
|
||||
mode:Literal["constant", "reflect", "edge", "wrap"]="constant", value=0):
|
||||
value = _resolve_const(constant_value or value)
|
||||
value = _resolve_const(value if constant_value is None else constant_value)
|
||||
axes = axes or list(range(x.ndim))
|
||||
real_pads = [0] * (x.ndim*2)
|
||||
for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)]
|
||||
|
||||
Reference in New Issue
Block a user