mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
const_like can return a valid [pr] (#8005)
* const_like can return a valid [pr] * fixup
This commit is contained in:
@@ -319,7 +319,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return ret
|
||||
def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs)
|
||||
def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
||||
def const_like(self, b:ConstLike): return UOp.const(self.dtype, b)
|
||||
def const_like(self, b:ConstLike): return UOp.const(self.dtype, b) if self.st is None else UOp.const_with_shape(self.dtype, b, self.shape)
|
||||
def broadcast(self, count:int):
|
||||
assert self.dtype.count == 1
|
||||
if count == 1: return self
|
||||
@@ -369,8 +369,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
|
||||
def view(self, new_st:ShapeTracker) -> UOp:
|
||||
assert self.st is not None and self.base.st is not None, f"must have shape {self}"
|
||||
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
|
||||
return UOp.const_with_shape(self.dtype, 0, new_st.shape)
|
||||
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)): return self.const_like(0)
|
||||
if new_st.contiguous and self.base.st.shape == new_st.shape: return self.base
|
||||
return UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
|
||||
def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
|
||||
|
||||
Reference in New Issue
Block a user