bunch up ops for lines [pr] (#8030)

This commit is contained in:
George Hotz
2024-12-04 17:03:01 +08:00
committed by GitHub
parent e9c5b23ba1
commit 1e06aefde7

View File

@@ -93,43 +93,31 @@ class MathTrait(SimpleMathTrait):
# the order of these Ops controls the order of the toposort
class Ops(FastEnum):
# uops that aren't rendered
SINK = auto()
CONTIGUOUS = auto()
PRELOAD = auto()
SINK = auto(); CONTIGUOUS = auto(); PRELOAD = auto() # noqa: E702
# MetaOps
COPY = auto()
EMPTY = auto()
BUFFER_VIEW = auto()
COPY = auto(); EMPTY = auto(); BUFFER_VIEW = auto() # noqa: E702
# blocks in linearizer
BLOCK = auto(); BLOCKSTART = auto(); BLOCKFORK = auto(); BLOCKEND = auto() # noqa: E702
BLOCK = auto(); BLOCKSTART = auto(); BLOCKFORK = auto(); BLOCKEND = auto() # noqa: E702
EXPAND = auto()
CONTRACT = auto()
VIEW = auto()
DEFINE_GLOBAL = auto()
BUFFER = auto()
DEFINE_VAR = auto()
DEFINE_LOCAL = auto()
DEFINE_ACC = auto()
VALID = auto()
SPECIAL = auto()
NOOP = auto()
# misc ops
EXPAND = auto(); CONTRACT = auto() # noqa: E702
VIEW = auto(); DEFINE_GLOBAL = auto(); BUFFER = auto() # noqa: E702
DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
VALID = auto(); SPECIAL = auto(); NOOP = auto() # noqa: E702
# reduce
REDUCE_AXIS = auto()
# helper ops
GEP = auto()
VECTORIZE = auto()
GEP = auto(); VECTORIZE = auto() # noqa: E702
# UnaryOps
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
# load/store before math
LOAD = auto()
STORE = auto()
LOAD = auto(); STORE = auto() # noqa: E702
# early INDEX
INDEX = auto()
@@ -149,17 +137,10 @@ class Ops(FastEnum):
BIND = auto()
# control flow ops
BARRIER = auto()
RANGE = auto()
IF = auto()
# ops that are not graph nodes
ENDRANGE = auto()
ENDIF = auto()
BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto() # noqa: E702
# consts last!
VCONST = auto()
CONST = auto()
VCONST = auto(); CONST = auto() # noqa: E702
class GroupOp:
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
@@ -190,7 +171,6 @@ def can_pad(u:UOp, edges:Dict[UOp, UOp], visisted:Set[UOp]) -> bool:
visisted.add(u)
return all(can_pad(x.base, edges, visisted) for x in u.src)
# With True as the default, this matches the old symbolic behavior
def resolve(x, default:bool=True):
if not isinstance(x, UOp): return bool(x)