diff --git a/test/models/test_onnx.py b/test/models/test_onnx.py index 34ed658e43..6b2285bb85 100644 --- a/test/models/test_onnx.py +++ b/test/models/test_onnx.py @@ -58,6 +58,18 @@ class TestOnnxModel(unittest.TestCase): print(cls, _LABELS[cls]) assert "car" in _LABELS[cls] or _LABELS[cls] == "convertible" + def test_pad_list_value(self): + from tinygrad.nn.onnx import onnx_ops + from tinygrad import Tensor + Pad = onnx_ops['Pad'] + x = Tensor([1, 2, 3]) + out = Pad(x, pads=[0, 1], value=[-float('inf')]) + assert out.shape == (4,) + assert out.numpy()[-1] == -float('inf') + out2 = Pad(x, pads=[1, 0], constant_value=[5.0]) + assert out2.shape == (4,) + assert out2.numpy()[0] == 5.0 + @unittest.skipUnless(Device.DEFAULT == "METAL", "only run on METAL") class TestHuggingFaceOnnxModels(unittest.TestCase): @classmethod diff --git a/tinygrad/nn/onnx.py b/tinygrad/nn/onnx.py index 4845d27ea5..dd4d373d52 100644 --- a/tinygrad/nn/onnx.py +++ b/tinygrad/nn/onnx.py @@ -710,7 +710,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None, mode:Literal["constant", "reflect", "edge", "wrap"]="constant", value=0): - value = constant_value or value + value = _resolve_const(constant_value or value) axes = axes or list(range(x.ndim)) real_pads = [0] * (x.ndim*2) for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)]