mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 07:18:40 -05:00
swizzle store [pr] (#7964)
* swizzle store [pr] * assign extra swizzle * now arg is optional * extra
This commit is contained in:
@@ -120,6 +120,8 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
|
||||
new_shape, new_input_shape = swizzle_shapes[0]
|
||||
new_src = tuple(x if not x.has_st else x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src)
|
||||
ret = root.replace(src=new_src)
|
||||
# update the ASSIGN offset to match the new shape
|
||||
if ret.op is Ops.ASSIGN and ret.arg is not None: ret = ret.replace(arg=ret.arg+ShapeTracker.from_shape(new_input_shape),)
|
||||
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
|
||||
|
||||
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
@@ -129,9 +131,9 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
|
||||
# push VIEW to stores
|
||||
view_right = merge_views+PatternMatcher([
|
||||
# ASSIGN can override st
|
||||
# ASSIGN with offset swizzles STORE
|
||||
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.ASSIGN, name="a"))),
|
||||
lambda a,b,st: None if a.arg is None else UOp.store(b, (a.arg+st.arg).to_uop(), a.replace(arg=None))),
|
||||
lambda a,b,st: None if a.arg is None else apply_swizzle(UOp.store(b, st, a.replace(arg=None)), a.arg)),
|
||||
# non contiguous VIEW on a reduce creates a new VIEW
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
|
||||
# push a VIEW down to STORE, through a reduce (ONLY reshapes)
|
||||
|
||||
@@ -360,7 +360,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@property
|
||||
def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
|
||||
def view(self, new_st:ShapeTracker) -> UOp:
|
||||
assert self.op is not Ops.STORE, "STORE must stay base"
|
||||
assert self.st is not None and self.base.st is not None, f"must have shape {self}"
|
||||
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
|
||||
return UOp.const_with_shape(self.dtype, 0, new_st.shape)
|
||||
|
||||
Reference in New Issue
Block a user