remove Ops.REDUCE (#7541)

This commit is contained in:
George Hotz
2024-11-05 09:41:28 +08:00
committed by GitHub
parent ab14fc1f5b
commit 075bdb81b3
3 changed files with 5 additions and 7 deletions

View File

@@ -352,8 +352,8 @@ def do_expand(root:UOp):
new_srcs.append(src.src[0].gep(tuple(lst)))
else:
# non-EXPAND input
if (root.op is Ops.IF) or (root.op is Ops.REDUCE and i != 0):
# for the first arg of IF and the RANGE args of REDUCE, just pass them through ignoring EXPANDS
if root.op is Ops.IF:
# for the first arg of IF, just pass them through ignoring EXPANDS
new_srcs.append(src)
elif src.dtype.count > 1:
# put any input dtype > 1 grouped together
@@ -404,7 +404,7 @@ expander = PatternMatcher([
lambda outer, inner: UOp(Ops.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
# do expansion
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
Ops.VECTORIZE, Ops.REDUCE, Ops.IF), name="root", custom_early_reject=set([Ops.EXPAND])), do_expand),
Ops.VECTORIZE, Ops.IF), name="root", custom_early_reject=set([Ops.EXPAND])), do_expand),
(UPat(Ops.CONTRACT, name="con"), do_contract),
# vectorize DEFINE_ACC
(UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype)),

View File

@@ -120,7 +120,6 @@ class Ops(FastEnum):
NOOP = auto()
# reduce
REDUCE = auto()
REDUCE_AXIS = auto()
# ReduceOps
@@ -346,7 +345,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
return UOp(Ops.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx)
def reduce(self, op:Ops, *rng:UOp): return UOp(Ops.REDUCE, self.dtype, (self,) + rng, op)
def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in GroupOp.Reduce else op, axis))
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))

View File

@@ -9,9 +9,9 @@ from tinygrad.ops import TrackedRewriteContext, UOp, Ops, lines, GroupOp
from tinygrad.codegen.kernel import Kernel
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#c0ffc0", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0",
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE: "#C4A484",
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#C4A484",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.REDUCE_AXIS: "#f58488", **{x:"#ffffc0" for x in GroupOp.ALU}}
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", **{x:"#ffffc0" for x in GroupOp.ALU}}
# ** API spec