scan op work

This commit is contained in:
George Hotz
2025-11-17 18:09:17 -08:00
parent e4fead8a86
commit efc7d5f1b6
3 changed files with 40 additions and 3 deletions

View File

@@ -71,13 +71,29 @@ class TestOuterScan(unittest.TestCase):
ref.realize()
return vec, mats, ref
def test_uop_fold_matmul(self):
vec, mats, ref = self._test_scan()
# 3 matmuls with FOLD
i = UOp.range(3, -100, AxisType.OUTER)
out = Tensor.empty(1, 10)
phi = Tensor(i.eq(0).where(vec.uop, out.uop))
comp = phi @ mats[i]
store = out.uop.store(comp.uop).end(i)
out = Tensor(out.uop.after(store))
out.realize()
# TODO: testing allclose
assert Tensor.allclose(ref[2], out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
def test_uop_scan_matmul(self):
vec, mats, ref = self._test_scan()
# 3 matmuls with SCAN
i = UOp.range(3, -100, AxisType.OUTER)
out = Tensor.empty(3, 1, 10)
comp = Tensor(i.eq(0).where(vec.uop, out[(i-1).maximum(0)].uop)) @ mats[i]
phi = Tensor(i.eq(0).where(vec.uop, out[(i-1).maximum(0)].uop))
comp = phi @ mats[i]
store = out[i].uop.store(comp.uop).end(i)
out = Tensor(out.uop.after(store))
out.realize()
@@ -85,6 +101,16 @@ class TestOuterScan(unittest.TestCase):
# TODO: testing allclose
assert Tensor.allclose(ref, out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
def test_fold_matmul(self):
vec, mats, ref = self._test_scan()
# 3 matmuls with SCAN
i = UOp.range(3, -100, AxisType.OUTER)
phi = vec._apply_uop(UOp.phi)
comp = phi @ mats[i]
scan = comp._apply_uop(UOp.fold, phi, extra_args=(i,))
scan.realize()
class TestOuterworld(unittest.TestCase):
def test_range_plus_1(self):
t = Tensor.arange(100).reshape(10,10).realize()

View File

@@ -92,6 +92,8 @@ class Ops(FastEnum):
# reduce
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto()
PHI = auto(); SCAN = auto(); FOLD = auto()
# errors/placeholders
REWRITE_ERROR = auto(); SENTINEL = auto()

View File

@@ -25,7 +25,7 @@ axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL:
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5, AxisType.OUTER: -2}
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1}
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.FOLD: 2}
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
@@ -219,9 +219,14 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
# passthrough ops
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END | Ops.PHI | Ops.FOLD:
return self.src[0]._shape
# scan adds dims to the front
case Ops.SCAN:
if self.src[0]._shape is None: return None
return tuple(x.vmax+1 for x in self.src[2:]) + self.src[0]._shape
# ops with custom handling
case Ops.KERNEL: return self.arg.ast._shape
@@ -443,6 +448,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid)
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
def fold(self, *src:UOp, **kwargs): return UOp(Ops.FOLD, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
def scan(self, *src:UOp, **kwargs): return UOp(Ops.SCAN, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
def phi(self, *src:UOp, **kwargs): return UOp(Ops.PHI, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
def is_contiguous(self):
# TODO: this is is_realized
if self.op is Ops.RESHAPE: return self.src[0].is_contiguous()