diff --git a/tinygrad/codegen/late/linearizer.py b/tinygrad/codegen/late/linearizer.py index 24ba9ff005..894a87b9a4 100644 --- a/tinygrad/codegen/late/linearizer.py +++ b/tinygrad/codegen/late/linearizer.py @@ -20,16 +20,8 @@ def linearize(u:UOp) -> list[UOp]: # this will cause ranges to be placed late and ends to be placed early run_count = prod([int(r.vmax)+1 for r in u.ranges]) - # put loads in the beginning of the block and prevent priority inversion. hack for BARRIER grouping too - priority = [0] + [priorities[x][1] for x in consumers[u]] - if u.op is Ops.LOAD: priority.append(-1000) - if u.op is Ops.BARRIER: priority.append(-1500) - # ranges are scheduled as late as possible so anything that can be outside is - # if u.op is Ops.RANGE: priority = [2000] - if u.op is Ops.END: priority = [-1000] - # move defines and consts to the top - if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}: priority.append(-2000) - priorities[u] = (run_count, min(priority)) + # simple priority + priorities[u] = (run_count, 0) # number the uops in "ideal" order nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: (priorities[x],)+x.tuplize))} diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index 4264fc2fac..39c15a208e 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -9,6 +9,22 @@ class FastEnum(IntEnum): # the order of these Ops controls the order of the toposort class Ops(FastEnum): + # ** 1 -- defines/consts ** + + # TODO: unify these ops into the levels of the memory hierarchy. depends on ASSIGN is STORE + DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_REG = auto() # noqa: E702 + + # this is for symbolic shapes + DEFINE_VAR = auto(); BIND = auto() # noqa: E702 + + # consts. VCONST is a vectorized const + VCONST = auto(); CONST = auto() # noqa: E702 + + # this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly + SPECIAL = auto() + + # ** 2 -- non op uops ** + # uops that aren't rendered NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto(); REWRITE_ERROR = auto() # noqa: E702 SENTINEL = auto() @@ -32,33 +48,28 @@ class Ops(FastEnum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702 MULTI = auto() # MULTI is really a movement op - # TODO: unify these ops into the levels of the memory hierarchy. depends on ASSIGN is STORE - DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_REG = auto() # noqa: E702 - - # this is for symbolic shapes - DEFINE_VAR = auto(); BIND = auto() # noqa: E702 - - # this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly - SPECIAL = auto() - - # reduce + # reduce (movement) REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() # noqa: E702 # optimization helper ops UNROLL = auto(); CONTRACT = auto(); GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702 - # UnaryOps - CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIPROCAL = auto(); NEG = auto(); TRUNC = auto() # noqa: E702 + # ** 3 -- load/store ** + + # INDEX is a BinaryOp similar to ADD, but it operates on pointers + INDEX = auto() # load/store before math LOAD = auto(); STORE = auto() # noqa: E702 ASSIGN = auto() # TODO: ASSIGN is STORE, remove ASSIGN + # ** 4 -- math ** + # tensor core math op, not elementwise WMMA = auto() - # INDEX is a BinaryOp similar to ADD, but it operates on pointers - INDEX = auto() + # UnaryOps + CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIPROCAL = auto(); NEG = auto(); TRUNC = auto() # noqa: E702 # BinaryOps ADD = auto(); MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); MAX = auto(); MOD = auto() # noqa: E702 @@ -69,12 +80,11 @@ class Ops(FastEnum): # TernaryOps WHERE = auto(); MULACC = auto() # noqa: E702 + # ** 5 -- control flow / other ** + # control flow ops BARRIER = auto(); RANGE = auto(); IF = auto(); END = auto(); ENDIF = auto() # noqa: E702 - # consts. VCONST is a vectorized const - VCONST = auto(); CONST = auto() # noqa: E702 - # CUSTOM/CUSTOMI are used to output strings into codegen. the I makes the string inline CUSTOM = auto(); CUSTOMI = auto() # noqa: E702