From d0735d6489abfeb5ea3258a884c78170658e0877 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sat, 30 Nov 2024 08:32:50 -0500 Subject: [PATCH] swizzle store [pr] (#7964) * swizzle store [pr] * assign extra swizzle * now arg is optional * extra --- tinygrad/engine/schedule.py | 6 ++++-- tinygrad/ops.py | 1 - 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 6e01f38e3f..89baf5a5dd 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -120,6 +120,8 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]: new_shape, new_input_shape = swizzle_shapes[0] new_src = tuple(x if not x.has_st else x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src) ret = root.replace(src=new_src) + # update the ASSIGN offset to match the new shape + if ret.op is Ops.ASSIGN and ret.arg is not None: ret = ret.replace(arg=ret.arg+ShapeTracker.from_shape(new_input_shape),) return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape)) def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: @@ -129,9 +131,9 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: # push VIEW to stores view_right = merge_views+PatternMatcher([ - # ASSIGN can override st + # ASSIGN with offset swizzles STORE (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.ASSIGN, name="a"))), - lambda a,b,st: None if a.arg is None else UOp.store(b, (a.arg+st.arg).to_uop(), a.replace(arg=None))), + lambda a,b,st: None if a.arg is None else apply_swizzle(UOp.store(b, st, a.replace(arg=None)), a.arg)), # non contiguous VIEW on a reduce creates a new VIEW (UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)), # push a VIEW down to STORE, through a reduce (ONLY reshapes) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index ef2c52d59e..757ac1128a 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -360,7 +360,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @property def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self def view(self, new_st:ShapeTracker) -> UOp: - assert self.op is not Ops.STORE, "STORE must stay base" assert self.st is not None and self.base.st is not None, f"must have shape {self}" if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)): return UOp.const_with_shape(self.dtype, 0, new_st.shape)