diff --git a/test/test_pickle.py b/test/test_pickle.py index 94c0a30303..aed5de8cd6 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -104,7 +104,7 @@ class TestPickle(unittest.TestCase): assert ref_value == vt2.tolist() def test_pickle_numpy(self): - t = Tensor(np.array([1,2,3,4.])) + t = Tensor(np.array([1,2,3,4.]), dtype=dtypes.float32) st = pickle.dumps(t) t2:Tensor = pickle.loads(st) np.testing.assert_equal(t.numpy(), t2.numpy()) diff --git a/test/test_schedule.py b/test/test_schedule.py index 99c8f1d999..d69e9dbaa6 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1778,7 +1778,7 @@ def run_tensor_ast(r:Tensor): sink = graph_rewrite(sink, remove_movement_ops+sym+view_right) si = ScheduleItem(sink, tuple(x.buffer for x in bufs), (), ()) run_schedule([si]) - return output.realized.as_buffer().cast(output.dtype.fmt).tolist() + return output.realized.as_buffer().cast(output.dtype.fmt, r.shape).tolist() class TestSwizzle(unittest.TestCase): def test_swizzle_simple(self): @@ -1831,33 +1831,12 @@ class TestSwizzle(unittest.TestCase): self.assertEqual(new_load_st.views[0].strides, (0, 9, 3, 0, 1, 0, 27)) def test_permute_rewrite(self): - sink = UOp(Ops.STORE, dtypes.void, arg=None, src=( - x1:=UOp(Ops.BUFFER, dtypes.float, arg=(1, ('METAL', 16384, dtypes.float)), src=()), - x2:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.CONTIGUOUS, dtypes.float, arg=None, src=( - x1, - UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 32, 1, 1024), offset=0, mask=None, contiguous=False),)), src=( - UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 8)), src=( - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=( - x11:=UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float, arg=(2, ('METAL', 16384, dtypes.float)), src=()), - x2,)),)), - UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=( - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float, arg=(8, ('METAL', 256, dtypes.float)), src=()), - UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)), - UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=( - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float, arg=(10, ('METAL', 16, dtypes.float)), src=()), - UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), - UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=( - x11,)),)),)),)),)) - ret = swizzle_rewrite(sink) - self.assertEqual(swizzle_cnt(ret), 0) + x = Tensor.randn(4, 4, 16).realize() + y = Tensor.randn(4, 1, 16).realize() + z = Tensor.randn(4, 4, 1).realize() + t = (x*y).sum(axis=(0, 2)).reshape(1, 4, 1).permute(0, 2, 1)+z + t_np = (x.numpy()*y.numpy()).sum(axis=(0, 2)).reshape(1, 4, 1).transpose(0, 2, 1)+z.numpy() + np.testing.assert_allclose(run_tensor_ast(t), t_np, atol=1e-6, rtol=1e-3) @unittest.expectedFailure def test_fuse_conv2_relu_bw(self):