From efc7d5f1b6ecdc3ce30e6619de9bfaf4b7e25360 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 17 Nov 2025 18:09:17 -0800 Subject: [PATCH] scan op work --- test/test_outerworld.py | 28 +++++++++++++++++++++++++++- tinygrad/uop/__init__.py | 2 ++ tinygrad/uop/ops.py | 13 +++++++++++-- 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/test/test_outerworld.py b/test/test_outerworld.py index 2d40cb05fb..c1170acde1 100644 --- a/test/test_outerworld.py +++ b/test/test_outerworld.py @@ -71,13 +71,29 @@ class TestOuterScan(unittest.TestCase): ref.realize() return vec, mats, ref + def test_uop_fold_matmul(self): + vec, mats, ref = self._test_scan() + + # 3 matmuls with FOLD + i = UOp.range(3, -100, AxisType.OUTER) + out = Tensor.empty(1, 10) + phi = Tensor(i.eq(0).where(vec.uop, out.uop)) + comp = phi @ mats[i] + store = out.uop.store(comp.uop).end(i) + out = Tensor(out.uop.after(store)) + out.realize() + + # TODO: testing allclose + assert Tensor.allclose(ref[2], out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}" + def test_uop_scan_matmul(self): vec, mats, ref = self._test_scan() # 3 matmuls with SCAN i = UOp.range(3, -100, AxisType.OUTER) out = Tensor.empty(3, 1, 10) - comp = Tensor(i.eq(0).where(vec.uop, out[(i-1).maximum(0)].uop)) @ mats[i] + phi = Tensor(i.eq(0).where(vec.uop, out[(i-1).maximum(0)].uop)) + comp = phi @ mats[i] store = out[i].uop.store(comp.uop).end(i) out = Tensor(out.uop.after(store)) out.realize() @@ -85,6 +101,16 @@ class TestOuterScan(unittest.TestCase): # TODO: testing allclose assert Tensor.allclose(ref, out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}" + def test_fold_matmul(self): + vec, mats, ref = self._test_scan() + + # 3 matmuls with SCAN + i = UOp.range(3, -100, AxisType.OUTER) + phi = vec._apply_uop(UOp.phi) + comp = phi @ mats[i] + scan = comp._apply_uop(UOp.fold, phi, extra_args=(i,)) + scan.realize() + class TestOuterworld(unittest.TestCase): def test_range_plus_1(self): t = Tensor.arange(100).reshape(10,10).realize() diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index 322cd2323f..5066a3fa30 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -92,6 +92,8 @@ class Ops(FastEnum): # reduce REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() + PHI = auto(); SCAN = auto(); FOLD = auto() + # errors/placeholders REWRITE_ERROR = auto(); SENTINEL = auto() diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 8c86029e81..c0f43c40c2 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -25,7 +25,7 @@ axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3, AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5, AxisType.OUTER: -2} -range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1} +range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.FOLD: 2} # https://en.wikipedia.org/wiki/Identity_element def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt) @@ -219,9 +219,14 @@ class UOp(OpMixin, metaclass=UOpMetaClass): case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,) # passthrough ops - case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END: + case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END | Ops.PHI | Ops.FOLD: return self.src[0]._shape + # scan adds dims to the front + case Ops.SCAN: + if self.src[0]._shape is None: return None + return tuple(x.vmax+1 for x in self.src[2:]) + self.src[0]._shape + # ops with custom handling case Ops.KERNEL: return self.arg.ast._shape @@ -443,6 +448,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid) def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs) + def fold(self, *src:UOp, **kwargs): return UOp(Ops.FOLD, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs) + def scan(self, *src:UOp, **kwargs): return UOp(Ops.SCAN, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs) + def phi(self, *src:UOp, **kwargs): return UOp(Ops.PHI, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs) + def is_contiguous(self): # TODO: this is is_realized if self.op is Ops.RESHAPE: return self.src[0].is_contiguous()