better test for reduce swizzle + don't use double dtype [pr] (#8586)

* better test_permute_rewrite

* use float32
This commit is contained in:
qazal
2025-01-13 05:02:21 -05:00
committed by GitHub
parent df59b072db
commit 7562cc0399
2 changed files with 8 additions and 29 deletions

View File

@@ -1778,7 +1778,7 @@ def run_tensor_ast(r:Tensor):
sink = graph_rewrite(sink, remove_movement_ops+sym+view_right)
si = ScheduleItem(sink, tuple(x.buffer for x in bufs), (), ())
run_schedule([si])
return output.realized.as_buffer().cast(output.dtype.fmt).tolist()
return output.realized.as_buffer().cast(output.dtype.fmt, r.shape).tolist()
class TestSwizzle(unittest.TestCase):
def test_swizzle_simple(self):
@@ -1831,33 +1831,12 @@ class TestSwizzle(unittest.TestCase):
self.assertEqual(new_load_st.views[0].strides, (0, 9, 3, 0, 1, 0, 27))
def test_permute_rewrite(self):
sink = UOp(Ops.STORE, dtypes.void, arg=None, src=(
x1:=UOp(Ops.BUFFER, dtypes.float, arg=(1, ('METAL', 16384, dtypes.float)), src=()),
x2:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CONTIGUOUS, dtypes.float, arg=None, src=(
x1,
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 32, 1, 1024), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 8)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
x11:=UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(2, ('METAL', 16384, dtypes.float)), src=()),
x2,)),)),
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(8, ('METAL', 256, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(10, ('METAL', 16, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
x11,)),)),)),)),))
ret = swizzle_rewrite(sink)
self.assertEqual(swizzle_cnt(ret), 0)
x = Tensor.randn(4, 4, 16).realize()
y = Tensor.randn(4, 1, 16).realize()
z = Tensor.randn(4, 4, 1).realize()
t = (x*y).sum(axis=(0, 2)).reshape(1, 4, 1).permute(0, 2, 1)+z
t_np = (x.numpy()*y.numpy()).sum(axis=(0, 2)).reshape(1, 4, 1).transpose(0, 2, 1)+z.numpy()
np.testing.assert_allclose(run_tensor_ast(t), t_np, atol=1e-6, rtol=1e-3)
@unittest.expectedFailure
def test_fuse_conv2_relu_bw(self):