mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 15:28:10 -05:00
assert view on uops without shape [pr] (#7898)
* assert view on uops without shape [pr] * lint
This commit is contained in:
@@ -118,7 +118,8 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
|
||||
swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
|
||||
assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
|
||||
new_shape, new_input_shape = swizzle_shapes[0]
|
||||
ret = root.replace(src=tuple(x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src))
|
||||
new_src = tuple(x if not x.has_st else x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src)
|
||||
ret = root.replace(src=new_src)
|
||||
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
|
||||
|
||||
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
|
||||
@@ -355,11 +355,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
@property
|
||||
def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
|
||||
def view(self, st:ShapeTracker) -> UOp:
|
||||
if self.st is None: return self
|
||||
assert self.op is not Ops.STORE, "VIEW of STORE is invalid, STORE is always base"
|
||||
if st.contiguous and self.base.st == st: return self.base
|
||||
return UOp(Ops.VIEW, self.dtype, (self,), st)
|
||||
def view(self, new_st:ShapeTracker) -> UOp:
|
||||
assert self.op is not Ops.STORE, "STORE must stay base"
|
||||
assert self.st is not None, f"must have shape {self}"
|
||||
if new_st.contiguous and self.base.st == new_st: return self.base
|
||||
return UOp(Ops.VIEW, self.dtype, (self,), new_st)
|
||||
def reshape(self, arg:Tuple[sint, ...]) -> UOp: return self.view(unwrap(self.st).reshape(arg))
|
||||
|
||||
# *** uop Buffer stuff ***
|
||||
@@ -1202,7 +1202,8 @@ merge_views = PatternMatcher([(UPat(Ops.VIEW, name="s0").view(name="s1"), lambda
|
||||
# push VIEW to loads
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# VIEW before elementwise ops
|
||||
(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))),
|
||||
(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"),
|
||||
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
|
||||
# early merge VIEW buffer ops
|
||||
(UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user