mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
delete forced_reshape [pr] (#14926)
This commit is contained in:
@@ -360,12 +360,12 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
|
||||
do_store = buf.broadcast(x.src[1].dtype.count).index(idx, dtype=sdtype).store(x.src[0]).end(*rngs)
|
||||
return buf.after(do_store.barrier())
|
||||
|
||||
# collapse any BUFFERIZE to single input BUFFERIZE. move the tag to a reshape
|
||||
# collapse any BUFFERIZE to single input BUFFERIZE
|
||||
def flatten_bufferize(x:UOp):
|
||||
if len(x.src) == 2: return None
|
||||
ret = x.replace(src=(x.src[0], get_single_element(apply_movement_op(Ops.RESHAPE, (prod(x.shape),), x.shape, x.src[1:]))))
|
||||
rngs = x.src[1:]
|
||||
ret = ret.forced_reshape(x.shape)
|
||||
ret = ret.reshape(x.shape)
|
||||
if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs):
|
||||
sym_shape = tuple([r.src[0] if r.op is not Ops.CONST else 1 for r in rngs])
|
||||
ret = ret.shrink(tuple([(0,x) for x in sym_shape]))
|
||||
|
||||
@@ -597,7 +597,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
return ret
|
||||
|
||||
# in these four, if the shape doesn't change we can return self
|
||||
def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=False)
|
||||
#def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=True)
|
||||
#def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, same_shape_noop=True)
|
||||
#def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True)
|
||||
@@ -1434,8 +1433,6 @@ pm_pyrender_extra = PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat(), UPat()), allow_any_len=True, name="x"), lambda ctx,x:
|
||||
f"{ctx[x.src[0]]}.index({ctx[x.src[1]]}, "+(f"{ctx[x.src[2]]}, " if len(x.src) > 2 else "")+
|
||||
(f"dtype={x.dtype})" if x.src[0].dtype != x.dtype else "ptr=True)") if x.src[0].dtype.base != x.dtype else None),
|
||||
# TODO: fix forced_reshape
|
||||
(UPat(Ops.RESHAPE, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.forced_reshape({render_marg(ctx,x)})" if x.src[0].shape == x.shape else None),
|
||||
(UPat(GroupOp.Movement, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({render_marg(ctx,x)})"),
|
||||
# NOTE: CMPNE doesn't work cause there's no __rne__
|
||||
# NOTE: only match CONSTs without UNIQUE (len(src)==1), unique_const needs explicit rendering
|
||||
|
||||
Reference in New Issue
Block a user