delete forced_reshape [pr] (#14926)

This commit is contained in:
chenyu
2026-02-20 22:35:31 -05:00
committed by GitHub
parent 5b6fcd1cda
commit 0c0d07d330
2 changed files with 2 additions and 5 deletions

View File

@@ -360,12 +360,12 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
do_store = buf.broadcast(x.src[1].dtype.count).index(idx, dtype=sdtype).store(x.src[0]).end(*rngs)
return buf.after(do_store.barrier())
# collapse any BUFFERIZE to single input BUFFERIZE. move the tag to a reshape
# collapse any BUFFERIZE to single input BUFFERIZE
def flatten_bufferize(x:UOp):
if len(x.src) == 2: return None
ret = x.replace(src=(x.src[0], get_single_element(apply_movement_op(Ops.RESHAPE, (prod(x.shape),), x.shape, x.src[1:]))))
rngs = x.src[1:]
ret = ret.forced_reshape(x.shape)
ret = ret.reshape(x.shape)
if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs):
sym_shape = tuple([r.src[0] if r.op is not Ops.CONST else 1 for r in rngs])
ret = ret.shrink(tuple([(0,x) for x in sym_shape]))

View File

@@ -597,7 +597,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return ret
# in these four, if the shape doesn't change we can return self
def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=False)
#def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=True)
#def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, same_shape_noop=True)
#def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True)
@@ -1434,8 +1433,6 @@ pm_pyrender_extra = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat(), UPat()), allow_any_len=True, name="x"), lambda ctx,x:
f"{ctx[x.src[0]]}.index({ctx[x.src[1]]}, "+(f"{ctx[x.src[2]]}, " if len(x.src) > 2 else "")+
(f"dtype={x.dtype})" if x.src[0].dtype != x.dtype else "ptr=True)") if x.src[0].dtype.base != x.dtype else None),
# TODO: fix forced_reshape
(UPat(Ops.RESHAPE, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.forced_reshape({render_marg(ctx,x)})" if x.src[0].shape == x.shape else None),
(UPat(GroupOp.Movement, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({render_marg(ctx,x)})"),
# NOTE: CMPNE doesn't work cause there's no __rne__
# NOTE: only match CONSTs without UNIQUE (len(src)==1), unique_const needs explicit rendering