diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 61e6d718cf..4322c24ead 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -122,6 +122,47 @@ class TestLinearizer(unittest.TestCase): y_tiny = x.var(axis=2, correction=0) np.testing.assert_allclose(y_tiny.numpy(), wanna_output, atol=1e-4, rtol=1e-4) + # *** buildup to fused indexing + @unittest.skipIf(CI, "very slow because of recomputing") + def test_arange_expanded(self): + # Tensor.arange(16384) expanded such that output shape is (4, 16384, 256, 1) + # basically it's pushing the expand through this reduce: + tiny = Tensor.arange(16384).reshape(16384, 1).expand(4, 16384, 256).reshape(4, 16384, 256, 1) + real_arange = np.broadcast_to(np.arange(16384).reshape(16384, 1), (4, 16384, 256)).reshape(4, 16384, 256, 1) + # NOTE: this is stupidly recomputing because it's not fused, but it proves a point. + arange_input_st = ShapeTracker(views=(View(shape=(16385, 32767), strides=(0, 0), offset=0, mask=((0, 16385), (16383, 32767)), contiguous=False), \ + View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False))) + arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384)) + arange_axis = (3,) + arange = LazyOp(ReduceOps.SUM, (LazyOp(BufferOps.CONST, (), ConstBuffer(1, dtypes.int, arange_input_st)), ), arange_axis) + output_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape)) + out = arange-LazyOp.const(1, dtypes.int, output_shape) + store = LazyOp(BufferOps.STORE, (out, ), MemBuffer(0, dtypes.int, st=ShapeTracker.from_shape(output_shape))) + helper_linearizer_ast((store, ), [], wanna_output=[real_arange]) + with Context(DEBUG=0, NOOPT=0): np.testing.assert_equal(tiny.numpy(), real_arange) + + @unittest.skipIf(CI and Device.DEFAULT in {"PTX", "AMD", "NV"}, "very slow") + def test_indexing_multireduce(self): + arange_input_st = ShapeTracker(views=(View(shape=(16385, 32767), strides=(0, 0), offset=0, mask=((0, 16385), (16383, 32767)), contiguous=False), \ + View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False))) + # TODO: do this arange broadcast in the scheduler + arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384)) + arange_axis = (3,) + arange = LazyOp(ReduceOps.SUM, (LazyOp(BufferOps.CONST, (), ConstBuffer(1, dtypes.int, arange_input_st)), ), arange_axis) + arange_out_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape)) + arange = arange-LazyOp.const(1, dtypes.int, arange_out_shape) + # p2: the indexing + dataset = Tensor.rand(16384, 256).realize() + data1 = MemBuffer(1, dataset.dtype, ShapeTracker.from_shape(dataset.shape).reshape((1, 16384, 256, 1)).expand(arange_out_shape)) + idxs = Tensor([0,3,5,6]).realize() + data2 = MemBuffer(2, dtypes.int, ShapeTracker.from_shape((4,)+(1,)*(len(arange_out_shape)-1)).expand(arange_out_shape)) + reduce_input = LazyOp(BufferOps.LOAD, (), data1)*LazyOp(UnaryOps.CAST, (arange.eq(LazyOp(BufferOps.LOAD, (), data2)),), dataset.dtype) + out = LazyOp(ReduceOps.SUM, (reduce_input, ), (1,)) + output_shape = tuple(1 if i in out.arg else s for i,s in enumerate(arange_out_shape)) + store = LazyOp(BufferOps.STORE, (out, ), MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape(output_shape))) + real_index = dataset.numpy()[idxs.numpy()].reshape(4, 1, 256, 1) + helper_linearizer_ast((store, ), [dataset, idxs], wanna_output=[real_index]) + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_end_local(self): diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 60d6ac7625..a32caf8554 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -90,7 +90,12 @@ class LazyOp: def __add__(self, x:LazyOp): return LazyOp(BinaryOps.ADD, (self, x)) def __sub__(self, x:LazyOp): return LazyOp(BinaryOps.ADD, (self, -x)) def __mul__(self, x:LazyOp): return LazyOp(BinaryOps.MUL, (self, x)) + def ne(self, x:LazyOp): return LazyOp(BinaryOps.CMPNE, (self, x)) + def eq(self, x:LazyOp): return -self.ne(x) def __neg__(self): return LazyOp(UnaryOps.NEG, (self,)) + @staticmethod + def const(val, dtype:DType, shape:Tuple[sint, ...]): + return LazyOp(BufferOps.CONST, (), ConstBuffer(val, dtype, ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape))) # **************** independent FlopCounter ****************