From 0c0d07d3307236baefd2a77c9cc5a39e95f0efea Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 20 Feb 2026 22:35:31 -0500 Subject: [PATCH] delete forced_reshape [pr] (#14926) --- tinygrad/schedule/rangeify.py | 4 ++-- tinygrad/uop/ops.py | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index fa1ee0249d..03fd1eba62 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -360,12 +360,12 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True): do_store = buf.broadcast(x.src[1].dtype.count).index(idx, dtype=sdtype).store(x.src[0]).end(*rngs) return buf.after(do_store.barrier()) -# collapse any BUFFERIZE to single input BUFFERIZE. move the tag to a reshape +# collapse any BUFFERIZE to single input BUFFERIZE def flatten_bufferize(x:UOp): if len(x.src) == 2: return None ret = x.replace(src=(x.src[0], get_single_element(apply_movement_op(Ops.RESHAPE, (prod(x.shape),), x.shape, x.src[1:])))) rngs = x.src[1:] - ret = ret.forced_reshape(x.shape) + ret = ret.reshape(x.shape) if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs): sym_shape = tuple([r.src[0] if r.op is not Ops.CONST else 1 for r in rngs]) ret = ret.shrink(tuple([(0,x) for x in sym_shape])) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index e42dde9577..387c12c2a9 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -597,7 +597,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return ret # in these four, if the shape doesn't change we can return self - def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=False) #def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=True) #def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, same_shape_noop=True) #def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True) @@ -1434,8 +1433,6 @@ pm_pyrender_extra = PatternMatcher([ (UPat(Ops.INDEX, src=(UPat(), UPat()), allow_any_len=True, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.index({ctx[x.src[1]]}, "+(f"{ctx[x.src[2]]}, " if len(x.src) > 2 else "")+ (f"dtype={x.dtype})" if x.src[0].dtype != x.dtype else "ptr=True)") if x.src[0].dtype.base != x.dtype else None), - # TODO: fix forced_reshape - (UPat(Ops.RESHAPE, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.forced_reshape({render_marg(ctx,x)})" if x.src[0].shape == x.shape else None), (UPat(GroupOp.Movement, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({render_marg(ctx,x)})"), # NOTE: CMPNE doesn't work cause there's no __rne__ # NOTE: only match CONSTs without UNIQUE (len(src)==1), unique_const needs explicit rendering