mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
multi reduce linearizer tests start (#4529)
* test_end_local * test_early_end_local * todos * mean+std * skip no locals
This commit is contained in:
@@ -97,6 +97,45 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert len(mutable_bufs) == len(stores) == 2
|
||||
assert [u.arg[0] for u in mutable_bufs] == [0, 1]
|
||||
|
||||
def test_end_local(self):
|
||||
if not (opts:=Device[Device.DEFAULT].renderer).has_local or not opts.has_shared: self.skipTest("device does not support locals")
|
||||
load = MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker.from_shape((32,)))
|
||||
store = MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker.from_shape((1,)))
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, arg=load),), arg=(0,)),), arg=store),
|
||||
|
||||
load_t = Tensor.full(load.st.shape, 1).contiguous().realize()
|
||||
k = helper_linearizer_ast(ast, [load_t], wanna_output=[load_t.numpy().sum()])[1]
|
||||
self.assertEqual(k.uops.uops[-1].uop, UOps.ENDIF)
|
||||
self.assertLess(k.uops.uops.index([x for x in k.uops.uops if x.uop is UOps.STORE][-1]), k.uops.uops.index(k.uops.uops[-1]))
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_early_end_local(self):
|
||||
shape, output_shape = (32,), (1,)
|
||||
load0 = MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker.from_shape(shape))
|
||||
load1 = MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker.from_shape(shape))
|
||||
store = MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape(output_shape))
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(
|
||||
LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, arg=load0),)),LazyOp(op=BufferOps.LOAD, arg=load1), )),)),), arg=store),
|
||||
|
||||
load0_t = Tensor.randn(shape).realize()
|
||||
load1_t = Tensor.randn(shape).realize()
|
||||
k = helper_linearizer_ast(ast, [load0_t, load1_t], wanna_output=[(load0_t.numpy().sum() + load1_t.numpy()).sum()])[1]
|
||||
self.assertEqual(len(endifs:=[x for x in k.uops.uops if x.uop is UOps.ENDIF]), len(ifs:=[x for x in k.uops.uops if x.uop is UOps.IF]))
|
||||
self.assertEqual(len(barriers:=[x for x in k.uops.uops if x.uop is UOps.BARRIER]), 3)
|
||||
self.assertEqual(k.uops.uops[k.uops.uops.index(endifs[0])-1].uop, UOps.STORE)
|
||||
self.assertEqual(k.uops.uops[k.uops.uops.index(endifs[0])+1], barriers[1])
|
||||
self.assertEqual(k.uops.uops[k.uops.uops.index(endifs[0])+2].uop, UOps.LOAD)
|
||||
self.assertLess(k.uops.uops.index(barriers[0]), k.uops.uops.index(ifs[0]))
|
||||
self.assertLess(k.uops.uops.index(ifs[0]), k.uops.uops.index(endifs[0]))
|
||||
self.assertLess(k.uops.uops.index(barriers[1]), k.uops.uops.index(ifs[1]))
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_mean_std_multireduce(self):
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619047619047618e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None), LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619047619047618e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None)), arg=None),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619628162145687e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501
|
||||
|
||||
x = Tensor.randn(15, 25, 35).realize()
|
||||
helper_linearizer_ast(ast, [x], wanna_output=[x.numpy().std()])
|
||||
|
||||
def test_load_dedup(self):
|
||||
# for different leaves in the AST, the same loads may occur.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user