match lazy view in uop try 2 (#7905)

* match lazy view in uop

* reswizzle

* p2

* assert count

* empty

* smaller diff
This commit is contained in:
qazal
2024-11-26 07:31:50 -05:00
committed by GitHub
parent ea57c52b99
commit cab461c2b5

View File

@@ -357,9 +357,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
def view(self, new_st:ShapeTracker) -> UOp:
assert self.op is not Ops.STORE, "STORE must stay base"
assert self.st is not None, f"must have shape {self}"
if new_st.contiguous and self.base.st == new_st: return self.base
return UOp(Ops.VIEW, self.dtype, (self,), new_st)
assert self.st is not None and self.base.st is not None, f"must have shape {self}"
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
return UOp.const_with_shape(self.dtype, 0, new_st.shape)
if new_st.contiguous and self.base.st.shape == new_st.shape: return self.base
return UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
def reshape(self, arg:Tuple[sint, ...]) -> UOp: return self.view(unwrap(self.st).reshape(arg))
# *** uop Buffer stuff ***
@@ -1203,7 +1205,7 @@ merge_views = PatternMatcher([(UPat(Ops.VIEW, name="s0").view(name="s1"), lambda
view_left = merge_views+PatternMatcher([
# VIEW before elementwise ops
(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"),
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
lambda e,v: e.replace(src=tuple(s if not s.has_st else s.view(v.st) if s is s.base else s.base.view(s.st+v.st) for s in e.src))),
# early merge VIEW buffer ops
(UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
])