simplify priority

This commit is contained in:
George Hotz
2025-11-06 07:57:59 -08:00
parent b9b68bf437
commit 6809ff8fe1
2 changed files with 29 additions and 27 deletions

View File

@@ -20,16 +20,8 @@ def linearize(u:UOp) -> list[UOp]:
# this will cause ranges to be placed late and ends to be placed early
run_count = prod([int(r.vmax)+1 for r in u.ranges])
# put loads in the beginning of the block and prevent priority inversion. hack for BARRIER grouping too
priority = [0] + [priorities[x][1] for x in consumers[u]]
if u.op is Ops.LOAD: priority.append(-1000)
if u.op is Ops.BARRIER: priority.append(-1500)
# ranges are scheduled as late as possible so anything that can be outside is
# if u.op is Ops.RANGE: priority = [2000]
if u.op is Ops.END: priority = [-1000]
# move defines and consts to the top
if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}: priority.append(-2000)
priorities[u] = (run_count, min(priority))
# simple priority
priorities[u] = (run_count, 0)
# number the uops in "ideal" order
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: (priorities[x],)+x.tuplize))}

View File

@@ -9,6 +9,22 @@ class FastEnum(IntEnum):
# the order of these Ops controls the order of the toposort
class Ops(FastEnum):
# ** 1 -- defines/consts **
# TODO: unify these ops into the levels of the memory hierarchy. depends on ASSIGN is STORE
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_REG = auto() # noqa: E702
# this is for symbolic shapes
DEFINE_VAR = auto(); BIND = auto() # noqa: E702
# consts. VCONST is a vectorized const
VCONST = auto(); CONST = auto() # noqa: E702
# this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly
SPECIAL = auto()
# ** 2 -- non op uops **
# uops that aren't rendered
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto(); REWRITE_ERROR = auto() # noqa: E702
SENTINEL = auto()
@@ -32,33 +48,28 @@ class Ops(FastEnum):
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
MULTI = auto() # MULTI is really a movement op
# TODO: unify these ops into the levels of the memory hierarchy. depends on ASSIGN is STORE
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_REG = auto() # noqa: E702
# this is for symbolic shapes
DEFINE_VAR = auto(); BIND = auto() # noqa: E702
# this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly
SPECIAL = auto()
# reduce
# reduce (movement)
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() # noqa: E702
# optimization helper ops
UNROLL = auto(); CONTRACT = auto(); GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
# UnaryOps
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIPROCAL = auto(); NEG = auto(); TRUNC = auto() # noqa: E702
# ** 3 -- load/store **
# INDEX is a BinaryOp similar to ADD, but it operates on pointers
INDEX = auto()
# load/store before math
LOAD = auto(); STORE = auto() # noqa: E702
ASSIGN = auto() # TODO: ASSIGN is STORE, remove ASSIGN
# ** 4 -- math **
# tensor core math op, not elementwise
WMMA = auto()
# INDEX is a BinaryOp similar to ADD, but it operates on pointers
INDEX = auto()
# UnaryOps
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIPROCAL = auto(); NEG = auto(); TRUNC = auto() # noqa: E702
# BinaryOps
ADD = auto(); MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); MAX = auto(); MOD = auto() # noqa: E702
@@ -69,12 +80,11 @@ class Ops(FastEnum):
# TernaryOps
WHERE = auto(); MULACC = auto() # noqa: E702
# ** 5 -- control flow / other **
# control flow ops
BARRIER = auto(); RANGE = auto(); IF = auto(); END = auto(); ENDIF = auto() # noqa: E702
# consts. VCONST is a vectorized const
VCONST = auto(); CONST = auto() # noqa: E702
# CUSTOM/CUSTOMI are used to output strings into codegen. the I makes the string inline
CUSTOM = auto(); CUSTOMI = auto() # noqa: E702