mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix buffer init and skip test_swizzle_failure_permute [pr] (#8732)
* fix buffer init and skip test_swizzle_failure_permute [pr] * replace preload with just load * add
This commit is contained in:
@@ -1940,19 +1940,19 @@ class TestSwizzle(unittest.TestCase):
|
||||
ret = swizzle_rewrite(sink)
|
||||
self.assertEqual(swizzle_cnt(ret), 0)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skip("this swizzle can't be decided after the ADD")
|
||||
def test_swizzle_failure_permute(self):
|
||||
sink = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(20, ('METAL', 65, dtypes.float)), src=()),
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(20, 65), src=(UOp(Ops.DEVICE, arg="METAL"),)),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 65), strides=(0, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
x6:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.PRELOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(8, ('METAL', 2925, dtypes.float)), src=()),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(8, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)),
|
||||
x10:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(65, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
x12:=UOp(Ops.VALID, dtypes.bool, arg=None, src=(
|
||||
@@ -1971,13 +1971,13 @@ class TestSwizzle(unittest.TestCase):
|
||||
UOp(Ops.CONST, dtypes.float, arg=-1.0, src=()),
|
||||
x15,)),
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.PRELOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(2, ('METAL', 2925, dtypes.float)), src=()),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(2, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)),
|
||||
x10,)),
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(1, 89), offset=44, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(4, ('METAL', 2925, dtypes.float)), src=()),
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(4, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(65, 45, 90), strides=(1, 0, 65), offset=0, mask=((0, 65), (0, 45), (0, 45)), contiguous=False), View(shape=(65, 4094), strides=(4050, 1), offset=0, mask=((0, 65), (0, 4050)), contiguous=False), View(shape=(1, 65, 46, 89), strides=(0, 4094, 89, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)),)),))
|
||||
ret = swizzle_rewrite(sink)
|
||||
self.assertEqual(swizzle_cnt(ret), 0)
|
||||
|
||||
Reference in New Issue
Block a user