mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
remove Ops.REDUCE (#7541)
This commit is contained in:
@@ -352,8 +352,8 @@ def do_expand(root:UOp):
|
||||
new_srcs.append(src.src[0].gep(tuple(lst)))
|
||||
else:
|
||||
# non-EXPAND input
|
||||
if (root.op is Ops.IF) or (root.op is Ops.REDUCE and i != 0):
|
||||
# for the first arg of IF and the RANGE args of REDUCE, just pass them through ignoring EXPANDS
|
||||
if root.op is Ops.IF:
|
||||
# for the first arg of IF, just pass them through ignoring EXPANDS
|
||||
new_srcs.append(src)
|
||||
elif src.dtype.count > 1:
|
||||
# put any input dtype > 1 grouped together
|
||||
@@ -404,7 +404,7 @@ expander = PatternMatcher([
|
||||
lambda outer, inner: UOp(Ops.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
||||
# do expansion
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
|
||||
Ops.VECTORIZE, Ops.REDUCE, Ops.IF), name="root", custom_early_reject=set([Ops.EXPAND])), do_expand),
|
||||
Ops.VECTORIZE, Ops.IF), name="root", custom_early_reject=set([Ops.EXPAND])), do_expand),
|
||||
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
||||
# vectorize DEFINE_ACC
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype)),
|
||||
|
||||
@@ -120,7 +120,6 @@ class Ops(FastEnum):
|
||||
NOOP = auto()
|
||||
|
||||
# reduce
|
||||
REDUCE = auto()
|
||||
REDUCE_AXIS = auto()
|
||||
|
||||
# ReduceOps
|
||||
@@ -346,7 +345,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
|
||||
return UOp(Ops.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
|
||||
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx)
|
||||
def reduce(self, op:Ops, *rng:UOp): return UOp(Ops.REDUCE, self.dtype, (self,) + rng, op)
|
||||
def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in GroupOp.Reduce else op, axis))
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
|
||||
|
||||
|
||||
@@ -9,9 +9,9 @@ from tinygrad.ops import TrackedRewriteContext, UOp, Ops, lines, GroupOp
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#c0ffc0", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0",
|
||||
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE: "#C4A484",
|
||||
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#C4A484",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.REDUCE_AXIS: "#f58488", **{x:"#ffffc0" for x in GroupOp.ALU}}
|
||||
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", **{x:"#ffffc0" for x in GroupOp.ALU}}
|
||||
|
||||
# ** API spec
|
||||
|
||||
|
||||
Reference in New Issue
Block a user