const_like can return a valid [pr] (#8005)

* const_like can return a valid [pr]

* fixup
This commit is contained in:
qazal
2024-12-03 05:42:12 -05:00
committed by GitHub
parent 4e91533419
commit ba1183314a

View File

@@ -319,7 +319,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return ret
def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs)
def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def const_like(self, b:ConstLike): return UOp.const(self.dtype, b)
def const_like(self, b:ConstLike): return UOp.const(self.dtype, b) if self.st is None else UOp.const_with_shape(self.dtype, b, self.shape)
def broadcast(self, count:int):
assert self.dtype.count == 1
if count == 1: return self
@@ -369,8 +369,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
def view(self, new_st:ShapeTracker) -> UOp:
assert self.st is not None and self.base.st is not None, f"must have shape {self}"
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
return UOp.const_with_shape(self.dtype, 0, new_st.shape)
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)): return self.const_like(0)
if new_st.contiguous and self.base.st.shape == new_st.shape: return self.base
return UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))