assert view on uops without shape [pr] (#7898)

* assert view on uops without shape [pr]

* lint
This commit is contained in:
qazal
2024-11-25 07:43:50 -05:00
committed by GitHub
parent a49ca0c2ff
commit e8777cb8db
2 changed files with 9 additions and 7 deletions

View File

@@ -118,7 +118,8 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
new_shape, new_input_shape = swizzle_shapes[0]
ret = root.replace(src=tuple(x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src))
new_src = tuple(x if not x.has_st else x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src)
ret = root.replace(src=new_src)
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:

View File

@@ -355,11 +355,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@property
def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
def view(self, st:ShapeTracker) -> UOp:
if self.st is None: return self
assert self.op is not Ops.STORE, "VIEW of STORE is invalid, STORE is always base"
if st.contiguous and self.base.st == st: return self.base
return UOp(Ops.VIEW, self.dtype, (self,), st)
def view(self, new_st:ShapeTracker) -> UOp:
assert self.op is not Ops.STORE, "STORE must stay base"
assert self.st is not None, f"must have shape {self}"
if new_st.contiguous and self.base.st == new_st: return self.base
return UOp(Ops.VIEW, self.dtype, (self,), new_st)
def reshape(self, arg:Tuple[sint, ...]) -> UOp: return self.view(unwrap(self.st).reshape(arg))
# *** uop Buffer stuff ***
@@ -1202,7 +1202,8 @@ merge_views = PatternMatcher([(UPat(Ops.VIEW, name="s0").view(name="s1"), lambda
# push VIEW to loads
view_left = merge_views+PatternMatcher([
# VIEW before elementwise ops
(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))),
(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"),
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
# early merge VIEW buffer ops
(UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
])