From ba1183314aed66ebf926d9714fab02e90c33b344 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 3 Dec 2024 05:42:12 -0500 Subject: [PATCH] const_like can return a valid [pr] (#8005) * const_like can return a valid [pr] * fixup --- tinygrad/ops.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index d21f7044bf..e8884e1cb4 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -319,7 +319,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return ret def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs) def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx)) - def const_like(self, b:ConstLike): return UOp.const(self.dtype, b) + def const_like(self, b:ConstLike): return UOp.const(self.dtype, b) if self.st is None else UOp.const_with_shape(self.dtype, b, self.shape) def broadcast(self, count:int): assert self.dtype.count == 1 if count == 1: return self @@ -369,8 +369,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self def view(self, new_st:ShapeTracker) -> UOp: assert self.st is not None and self.base.st is not None, f"must have shape {self}" - if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)): - return UOp.const_with_shape(self.dtype, 0, new_st.shape) + if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)): return self.const_like(0) if new_st.contiguous and self.base.st.shape == new_st.shape: return self.base return UOp(Ops.VIEW, self.dtype, (self.base,), new_st) def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))