mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
scan op work
This commit is contained in:
@@ -71,13 +71,29 @@ class TestOuterScan(unittest.TestCase):
|
||||
ref.realize()
|
||||
return vec, mats, ref
|
||||
|
||||
def test_uop_fold_matmul(self):
|
||||
vec, mats, ref = self._test_scan()
|
||||
|
||||
# 3 matmuls with FOLD
|
||||
i = UOp.range(3, -100, AxisType.OUTER)
|
||||
out = Tensor.empty(1, 10)
|
||||
phi = Tensor(i.eq(0).where(vec.uop, out.uop))
|
||||
comp = phi @ mats[i]
|
||||
store = out.uop.store(comp.uop).end(i)
|
||||
out = Tensor(out.uop.after(store))
|
||||
out.realize()
|
||||
|
||||
# TODO: testing allclose
|
||||
assert Tensor.allclose(ref[2], out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
|
||||
|
||||
def test_uop_scan_matmul(self):
|
||||
vec, mats, ref = self._test_scan()
|
||||
|
||||
# 3 matmuls with SCAN
|
||||
i = UOp.range(3, -100, AxisType.OUTER)
|
||||
out = Tensor.empty(3, 1, 10)
|
||||
comp = Tensor(i.eq(0).where(vec.uop, out[(i-1).maximum(0)].uop)) @ mats[i]
|
||||
phi = Tensor(i.eq(0).where(vec.uop, out[(i-1).maximum(0)].uop))
|
||||
comp = phi @ mats[i]
|
||||
store = out[i].uop.store(comp.uop).end(i)
|
||||
out = Tensor(out.uop.after(store))
|
||||
out.realize()
|
||||
@@ -85,6 +101,16 @@ class TestOuterScan(unittest.TestCase):
|
||||
# TODO: testing allclose
|
||||
assert Tensor.allclose(ref, out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
|
||||
|
||||
def test_fold_matmul(self):
|
||||
vec, mats, ref = self._test_scan()
|
||||
|
||||
# 3 matmuls with SCAN
|
||||
i = UOp.range(3, -100, AxisType.OUTER)
|
||||
phi = vec._apply_uop(UOp.phi)
|
||||
comp = phi @ mats[i]
|
||||
scan = comp._apply_uop(UOp.fold, phi, extra_args=(i,))
|
||||
scan.realize()
|
||||
|
||||
class TestOuterworld(unittest.TestCase):
|
||||
def test_range_plus_1(self):
|
||||
t = Tensor.arange(100).reshape(10,10).realize()
|
||||
|
||||
@@ -92,6 +92,8 @@ class Ops(FastEnum):
|
||||
# reduce
|
||||
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto()
|
||||
|
||||
PHI = auto(); SCAN = auto(); FOLD = auto()
|
||||
|
||||
# errors/placeholders
|
||||
REWRITE_ERROR = auto(); SENTINEL = auto()
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL:
|
||||
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
|
||||
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5, AxisType.OUTER: -2}
|
||||
|
||||
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1}
|
||||
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.FOLD: 2}
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
||||
@@ -219,9 +219,14 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
|
||||
|
||||
# passthrough ops
|
||||
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
|
||||
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END | Ops.PHI | Ops.FOLD:
|
||||
return self.src[0]._shape
|
||||
|
||||
# scan adds dims to the front
|
||||
case Ops.SCAN:
|
||||
if self.src[0]._shape is None: return None
|
||||
return tuple(x.vmax+1 for x in self.src[2:]) + self.src[0]._shape
|
||||
|
||||
# ops with custom handling
|
||||
case Ops.KERNEL: return self.arg.ast._shape
|
||||
|
||||
@@ -443,6 +448,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid)
|
||||
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
||||
|
||||
def fold(self, *src:UOp, **kwargs): return UOp(Ops.FOLD, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
||||
def scan(self, *src:UOp, **kwargs): return UOp(Ops.SCAN, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
||||
def phi(self, *src:UOp, **kwargs): return UOp(Ops.PHI, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
||||
|
||||
def is_contiguous(self):
|
||||
# TODO: this is is_realized
|
||||
if self.op is Ops.RESHAPE: return self.src[0].is_contiguous()
|
||||
|
||||
Reference in New Issue
Block a user