mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix(onnx): unwrap list/tuple value in Pad op (#13500)
* fix(onnx): unwrap list/tuple value in Pad op * add regression test for Pad list value * remove trailing whitespace * use _resolve_const for Pad constant_value
This commit is contained in:
@@ -58,6 +58,18 @@ class TestOnnxModel(unittest.TestCase):
|
||||
print(cls, _LABELS[cls])
|
||||
assert "car" in _LABELS[cls] or _LABELS[cls] == "convertible"
|
||||
|
||||
def test_pad_list_value(self):
|
||||
from tinygrad.nn.onnx import onnx_ops
|
||||
from tinygrad import Tensor
|
||||
Pad = onnx_ops['Pad']
|
||||
x = Tensor([1, 2, 3])
|
||||
out = Pad(x, pads=[0, 1], value=[-float('inf')])
|
||||
assert out.shape == (4,)
|
||||
assert out.numpy()[-1] == -float('inf')
|
||||
out2 = Pad(x, pads=[1, 0], constant_value=[5.0])
|
||||
assert out2.shape == (4,)
|
||||
assert out2.numpy()[0] == 5.0
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "METAL", "only run on METAL")
|
||||
class TestHuggingFaceOnnxModels(unittest.TestCase):
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user