assert op is not store in view (#7679)

* assert op is not store in view

* update view spec

* hotfix: nit

---------

Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
This commit is contained in:
ignaciosica
2024-11-14 11:17:18 -03:00
committed by GitHub
parent 43040c0e24
commit 1419d8e58a

View File

@@ -344,7 +344,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@property
def base(self): return self.src[0] if self.op is Ops.VIEW and len(self.src) != 0 else self
def view(self, st:ShapeTracker): return self if self.st is None or self.st == st else UOp(Ops.VIEW, self.dtype, (self,), st)
def view(self, st:ShapeTracker):
assert self.op is not Ops.STORE, "VIEW of STORE is invalid, STORE is always base"
return self if self.st is None or self.st == st else UOp(Ops.VIEW, self.dtype, (self,), st)
def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
# *** uop Buffer stuff ***
@@ -761,7 +763,7 @@ spec = PatternMatcher([
# TODO: confirm the args of both of these are shapetrackers
(UPat(Ops.VIEW, src=()), lambda: True),
(UPat(Ops.VIEW, src=(UPat(),)), lambda: True),
(UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype == src.dtype),
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
(UPat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),