diff --git a/test/test_schedule.py b/test/test_schedule.py index db7726f9bd..4187ded416 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1930,7 +1930,7 @@ class TestSwizzle(unittest.TestCase): base = ShapeTracker.from_shape((32, 16, 1)) start = UOp(Ops.LOAD, dtypes.char, (UOp.new_buffer(Device.DEFAULT, base.size, dtypes.char), base.to_uop())) r = start.expand((32, 16, 16)).r(Ops.ADD, (2,)) - add = r.reshape((16, 32, 1)) + UOp.const_with_shape(r.dtype, 0, (16, 32, 1)) + add = r.reshape((16, 32, 1)) + UOp.const(r.dtype, 0) self.assertEqual(add.st, ShapeTracker.from_shape((16, 32, 1))) to_store = add.permute((1, 0, 2)).contiguous() to_store = graph_rewrite(to_store, remove_movement_ops) @@ -1941,6 +1941,8 @@ class TestSwizzle(unittest.TestCase): self.assertEqual(swizzle_cnt(ret), 1) def store_val(si:ScheduleItem): return si.ast.src[0].src[2] +# TODO: we only need valid on ast consts if it's masked, can fold this early to UOp.const +zero_pm = UPat(Ops.WHERE, src=(UPat(Ops.VALID), UPat(Ops.CONST, arg=0), UPat.cvar())) class TestView(unittest.TestCase): def test_all_masked_out(self): # start with non CONST Ops @@ -1948,8 +1950,7 @@ class TestView(unittest.TestCase): # all masked out, degrades to const 0 b = a.pad(((0, 10), None))[10:] sched = check_schedule(b.contiguous(), 1) - # TODO: this VALID can clean up, where do we need st? - self.assertIs(store_val(sched[-1]), UOp.const_with_shape(b.dtype, 0, b.lazydata.st.shape)) + assert zero_pm.match(store_val(sched[-1]), {}) run_schedule(sched) np.testing.assert_equal(b.numpy(), 0) @@ -1960,7 +1961,7 @@ class TestView(unittest.TestCase): assert b.shape == (10, 10) sched = check_schedule(b.contiguous(), 1) self.assertEqual(sched[-1].ast.full_shape, (10, 10)) - self.assertIs(store_val(sched[-1]), UOp.const_with_shape(b.dtype, 0, b.lazydata.st.shape)) + assert zero_pm.match(store_val(sched[-1]), {}) run_schedule(sched) np.testing.assert_equal(b.numpy(), 0) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 0f7a0c0d5e..023276ee36 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -276,15 +276,18 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @property def has_st(self) -> bool: return self.op not in {Ops.DEFINE_LOCAL, Ops.DEFINE_GLOBAL, Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR} @functools.cached_property - def st(self) -> Optional[ShapeTracker]: + def st(self) -> ShapeTracker|None: + # these uops define ShapeTracker from the arg if self.op is Ops.VIEW: return self.arg if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg) - # buffer ops can have a non contiguous shapetracker - if self.op in GroupOp.Buffer and len(src_sts:=[unwrap(x.st) for x in self.src if x.op is Ops.VIEW]) != 0: return src_sts[0] + # otherwise we derive the st from sources if len(src_sts:=[x.st for x in self.src if x.st is not None]) == 0: return None assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}" - # all other ops have a contiguous shapetracker + # st_arg on buffer uops defines the ShapeTracker, it's allowed to be non contiguous + if self.op in GroupOp.Buffer: return self.st_arg + # all other uops have a contiguous ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker + # only reduceop is allowed to change shape return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg) if self.op in (Ops.REDUCE_AXIS, Ops.WMMA) else src_sts[0].shape) @functools.cached_property def full_shape(self) -> tuple[sint, ...]: @@ -292,7 +295,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @property def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape @property - def size(self) -> int: return self.arg[-1] if self.op is Ops.BUFFER else unwrap(self.st).size + def size(self) -> int: return self.arg[1] if self.op is Ops.BUFFER else unwrap(self.st).size # *** uop evaluation *** @@ -338,8 +341,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def detach(self): return UOp(Ops.DETACH, self.dtype, (self,)) def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx)) def const_like(self, b:ConstLike): - if self._device is not None: return UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b) - return UOp.const(self.dtype, b) if self.st is None else UOp.const_with_shape(self.dtype, b, self.shape) + # constants can optionally have a DEVICE source + return UOp.const(self.dtype, b) if self._device is None else UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b) def broadcast(self, count:int): assert self.dtype.count == 1 if count == 1: return self @@ -429,10 +432,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # *** from LazyBuffer *** - @staticmethod - def const_with_shape(dtype:DType, val:ConstLike, shape:tuple[sint,...]) -> UOp: - from tinygrad.shape.shapetracker import ShapeTracker - return UOp(Ops.VALID, dtypes.bool, (ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)).where(UOp.const(dtype, val), 0) @staticmethod def metaop(op:Ops, shape:tuple[sint, ...], dtype:DType, device:str, arg=None, src:tuple[UOp, ...]=()) -> UOp: from tinygrad.shape.shapetracker import ShapeTracker @@ -506,8 +505,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): buffer_num = itertools.count(0) @staticmethod - def new_buffer(device:str, size:int, dtype:DType) -> UOp: - return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device),), (next(UOp.buffer_num), size)) + def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device),), (next(UOp.buffer_num), size)) @property def device(self) -> str: return unwrap(self._device) @functools.cached_property