From e8777cb8db8b2bebd3341bb698d1e77d79916efd Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 25 Nov 2024 07:43:50 -0500 Subject: [PATCH] assert view on uops without shape [pr] (#7898) * assert view on uops without shape [pr] * lint --- tinygrad/engine/schedule.py | 3 ++- tinygrad/ops.py | 13 +++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 49aab9eefa..268cb90203 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -118,7 +118,8 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]: swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles] assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}" new_shape, new_input_shape = swizzle_shapes[0] - ret = root.replace(src=tuple(x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src)) + new_src = tuple(x if not x.has_st else x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src) + ret = root.replace(src=new_src) return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape)) def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 95322415f2..6ff58105a2 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -355,11 +355,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @property def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self - def view(self, st:ShapeTracker) -> UOp: - if self.st is None: return self - assert self.op is not Ops.STORE, "VIEW of STORE is invalid, STORE is always base" - if st.contiguous and self.base.st == st: return self.base - return UOp(Ops.VIEW, self.dtype, (self,), st) + def view(self, new_st:ShapeTracker) -> UOp: + assert self.op is not Ops.STORE, "STORE must stay base" + assert self.st is not None, f"must have shape {self}" + if new_st.contiguous and self.base.st == new_st: return self.base + return UOp(Ops.VIEW, self.dtype, (self,), new_st) def reshape(self, arg:Tuple[sint, ...]) -> UOp: return self.view(unwrap(self.st).reshape(arg)) # *** uop Buffer stuff *** @@ -1202,7 +1202,8 @@ merge_views = PatternMatcher([(UPat(Ops.VIEW, name="s0").view(name="s1"), lambda # push VIEW to loads view_left = merge_views+PatternMatcher([ # VIEW before elementwise ops - (UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))), + (UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"), + lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))), # early merge VIEW buffer ops (UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))), ])