mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
some permutes are reshapes
This commit is contained in:
@@ -29,7 +29,7 @@ NOCONV = int(os.getenv("NOCONV", "0"))
|
||||
REMOVE_MOVEMENT_NOPS, MERGE_UNARY_OPS = OPT>=1, OPT>=1
|
||||
MERGE_ELEMENTWISE_OPS, MERGE_ONE_CONV_INTO_ELEMENTWISE, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_RESHAPE_OPS = OPT>=2, OPT>=2, OPT>=2, OPT>=2
|
||||
SHUFFLE_MOVEMENT_OPS = OPT>=3
|
||||
SHUFFLE_SLICE_OPS = OPT>=4 # NOTE: 0/0 is NaN if you slice, so this can change the output
|
||||
SHUFFLE_PAD_OPS = OPT>=4 # NOTE: 0/0 is NaN if you pad, so this can change the output
|
||||
|
||||
# **** enumerate supported devices ****
|
||||
|
||||
@@ -71,8 +71,7 @@ def log_op(optype : OpType, op : List[Op], ret : DeviceBuffer, inp : List[Device
|
||||
|
||||
top_colors = {LoadOps: '#FFFF80', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", ProcessingOps: "#ff8080"}
|
||||
|
||||
dashed = (optype == LoadOps and getattr(ret, "_backing", None) is not None) \
|
||||
or (getattr(ret, "st", None) is not None and not ret.st.contiguous)
|
||||
dashed = (optype == LoadOps and getattr(ret, "_backing", None) is not None) or (getattr(ret, "st", None) is not None and not ret.st.contiguous)
|
||||
|
||||
for x in inp:
|
||||
if len(op) <= 2: sop = '.'.join([str(y).split(".")[1] for y in op][::-1])
|
||||
@@ -241,8 +240,11 @@ class LazyBuffer:
|
||||
# two permutes in a row is one permute
|
||||
if op == MovementOps.PERMUTE and x.realized is None and x.op.op == MovementOps.PERMUTE: return x.op.src[0].movement_op(op, tuple(x.op.arg[i] for i in arg))
|
||||
|
||||
# some permutes are actually just reshapes
|
||||
if op == MovementOps.PERMUTE and ShapeTracker(x.shape).movement_op(op, arg).contiguous: return x.movement_op(MovementOps.RESHAPE, tuple(x.shape[i] for i in arg))
|
||||
|
||||
# TODO: SHUFFLE_SLICE_OPS is okay if it's a shrink
|
||||
if (SHUFFLE_MOVEMENT_OPS or (SHUFFLE_RESHAPE_OPS and op == MovementOps.RESHAPE)) and x.optype == BinaryOps and x.realized is None and (SHUFFLE_SLICE_OPS or op != MovementOps.SLICE):
|
||||
if (SHUFFLE_MOVEMENT_OPS or (SHUFFLE_RESHAPE_OPS and op == MovementOps.RESHAPE)) and x.optype == BinaryOps and x.realized is None and (SHUFFLE_PAD_OPS or op not in [MovementOps.SLICE, MovementOps.PAD]):
|
||||
# if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead
|
||||
def replace_with_movement_op(y:Union[LazyOp, LazyBuffer]) -> LazyBuffer:
|
||||
if isinstance(y, LazyBuffer): return y.movement_op(op, arg)
|
||||
|
||||
Reference in New Issue
Block a user