mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
don't merge movement ops
This commit is contained in:
@@ -26,7 +26,7 @@ OPT = int(os.getenv("OPT", "1"))
|
||||
NOCONV = int(os.getenv("NOCONV", "0"))
|
||||
|
||||
# TODO: movement ops that only change shape are really nops. treat them as such
|
||||
MERGE_MOVEMENT_OPS, REMOVE_MOVEMENT_NOPS, MERGE_UNARY_OPS = OPT>=1, OPT>=1, OPT>=1
|
||||
REMOVE_MOVEMENT_NOPS, MERGE_UNARY_OPS = OPT>=1, OPT>=1
|
||||
MERGE_ELEMENTWISE_OPS, MERGE_ONE_CONV_INTO_ELEMENTWISE, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_RESHAPE_OPS = OPT>=2, OPT>=2, OPT>=2, OPT>=2
|
||||
SHUFFLE_MOVEMENT_OPS = OPT>=3
|
||||
SHUFFLE_SLICE_OPS = OPT>=4 # NOTE: 0/0 is NaN if you slice, so this can change the output
|
||||
@@ -71,7 +71,8 @@ def log_op(optype : OpType, op : List[Op], ret : DeviceBuffer, inp : List[Device
|
||||
|
||||
top_colors = {LoadOps: '#FFFF80', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", ProcessingOps: "#ff8080"}
|
||||
|
||||
dashed = optype == LoadOps and getattr(ret, "_backing", None) is not None
|
||||
dashed = (optype == LoadOps and getattr(ret, "_backing", None) is not None) \
|
||||
or (getattr(ret, "st", None) is not None and not ret.st.contiguous)
|
||||
|
||||
for x in inp:
|
||||
if len(op) <= 2: sop = '.'.join([str(y).split(".")[1] for y in op][::-1])
|
||||
@@ -81,11 +82,7 @@ def log_op(optype : OpType, op : List[Op], ret : DeviceBuffer, inp : List[Device
|
||||
if 'label' not in G.nodes[nm(x)]: G.nodes[nm(x)]['label'] = str(x.shape)
|
||||
if nm(ret) not in G.nodes: G.add_node(nm(ret))
|
||||
|
||||
if getattr(ret, "st", None) is not None and not ret.st.contiguous:
|
||||
#G.nodes[nm(ret)]['label'] = str(ret.shape)+"\n"+str(tuple(x[0] if x[1]!=0 else 0 for x in ret.st.views[-1].shape_strides))
|
||||
G.nodes[nm(ret)]['label'] = str(tuple(x[0] if x[1]!=0 else 0 for x in ret.st.views[-1].shape_strides))
|
||||
dashed = True
|
||||
elif optype == ReduceOps: G.nodes[nm(ret)]['label'] = str(inp[0].shape)+"\n"+str(ret.shape)
|
||||
if optype == ReduceOps: G.nodes[nm(ret)]['label'] = str(inp[0].shape)+"\n"+str(ret.shape)
|
||||
else: G.nodes[nm(ret)]['label'] = str(ret.shape)
|
||||
G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('80' if dashed else '')) if optype in top_colors else "#ffffff"
|
||||
G.nodes[nm(ret)]['style'] = 'filled, dashed' if dashed else 'filled'
|
||||
@@ -235,6 +232,18 @@ class LazyBuffer:
|
||||
# TODO: look into why that copy is needed
|
||||
arg = tuple(copy(arg))
|
||||
|
||||
# instant nops
|
||||
if op in [MovementOps.RESHAPE, MovementOps.EXPAND] and arg == x.shape: return x
|
||||
if op == MovementOps.PERMUTE and arg == tuple(range(len(x.shape))): return x
|
||||
if op == MovementOps.SLICE and arg == tuple((0,i) for i in x.shape): return x
|
||||
if op == MovementOps.FLIP and tuple(i == 1 or not f for i,f in zip(arg, x.shape)): return x
|
||||
|
||||
# two reshapes in a row is one reshape
|
||||
if op == MovementOps.RESHAPE and not x.realized and x.op.op == MovementOps.RESHAPE: return x.op.src[0].movement_op(op, arg)
|
||||
|
||||
# two permutes in a row is one permute
|
||||
if op == MovementOps.PERMUTE and not x.realized and x.op.op == MovementOps.PERMUTE: return x.op.src[0].movement_op(op, tuple(arg[i] for i in x.op.arg))
|
||||
|
||||
# TODO: SHUFFLE_SLICE_OPS is okay if it's a shrink
|
||||
if (SHUFFLE_MOVEMENT_OPS or (SHUFFLE_RESHAPE_OPS and op == MovementOps.RESHAPE)) and x.optype == BinaryOps and x.realized is None and (SHUFFLE_SLICE_OPS or op != MovementOps.SLICE):
|
||||
# if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead
|
||||
@@ -244,13 +253,14 @@ class LazyBuffer:
|
||||
return elementwise_op(y.op, *[replace_with_movement_op(z) for z in y.src])
|
||||
return replace_with_movement_op(x.op)
|
||||
|
||||
# if a MovementOp is applied to a MovementOp, merge them and use one buffer
|
||||
ret = LazyBuffer(x.device, ShapeTracker(x.st).movement_op(op, arg), MovementOps,
|
||||
LazyOp(op, (x.op if MERGE_MOVEMENT_OPS and x.optype == MovementOps and x.realized is None else x,), arg))
|
||||
# create the buffer
|
||||
ret = LazyBuffer(x.device, ShapeTracker(x.st).movement_op(op, arg), MovementOps, LazyOp(op, (x,), arg))
|
||||
|
||||
# NOTE: if ret is in the cache, it can already be realized
|
||||
if REMOVE_MOVEMENT_NOPS and ret.realized is None and x.realized is None and ret.st.contiguous:
|
||||
root = get_lazybuffers(ret.op)[0]
|
||||
# MovementOps aren't stacked any more, they each have one parent, find the root
|
||||
root = x
|
||||
while root.optype == MovementOps: root = root.op.src[0]
|
||||
if root.st.contiguous and root != x and prod(ret.st.shape) == prod(root.shape):
|
||||
return root.movement_op(MovementOps.RESHAPE, ret.st.shape) if ret.st.shape != root.shape else root
|
||||
|
||||
|
||||
@@ -44,14 +44,12 @@ class View:
|
||||
|
||||
class ZeroView:
|
||||
def __init__(self, old_shape, arg):
|
||||
expr = ['valid']
|
||||
self.shape = []
|
||||
acc = 1
|
||||
expr, acc = ['valid'], 1
|
||||
for s,(x,y) in list(zip(old_shape, arg))[::-1]:
|
||||
self.shape = [y-x] + self.shape
|
||||
base = divmodidx(acc, self.shape[0], len(self.shape) != len(old_shape)) + f"+{x}"
|
||||
if x < 0: expr.append(f"(({base}) >= 0)")
|
||||
if y > s: expr.append(f"(({base}) < {s})")
|
||||
expr += ([f"(({base}) >= 0)"] if x < 0 else []) + ([f"(({base}) < {s})"] if y > s else [])
|
||||
acc *= self.shape[0]
|
||||
self.expr = 'valid=' + ' && '.join(expr)
|
||||
|
||||
@@ -66,8 +64,7 @@ def strides_for_shape(shape):
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def view_from_shape(shape:Tuple[int, ...]):
|
||||
if len(shape) == 0: shape = (1,)
|
||||
assert all([isinstance(x, int) for x in shape])
|
||||
assert all([isinstance(x, int) for x in shape]) and len(shape) != 0
|
||||
return View(tuple(shape), strides_for_shape(shape))
|
||||
|
||||
class ShapeTracker:
|
||||
@@ -103,7 +100,6 @@ class ShapeTracker:
|
||||
def reshape(self, *new_shape):
|
||||
assert all([isinstance(x, int) for x in new_shape])
|
||||
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}"
|
||||
if self.shape == new_shape: return
|
||||
|
||||
# check if this is adding or removing 1s (only)
|
||||
if tuple([x for x in self.shape if x != 1]) == tuple([x for x in new_shape if x != 1]):
|
||||
|
||||
Reference in New Issue
Block a user