new (post) group for reduce (#11837)

* new (post) group for reduce

* fixes

* leave if

* fix locals

* size

* no vectorized buf

* image fixes

* don't track that

* fix ptx

* name buffer with reduce range

* remove unused in lowerer

* yay DEFINE_REG refactor
This commit is contained in:
George Hotz
2025-08-25 18:03:00 -07:00
committed by GitHub
parent ac3449b0c8
commit 215818379b
9 changed files with 56 additions and 56 deletions

View File

@@ -18,6 +18,7 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in
from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
from tinygrad.codegen.opt import pm_get_optimization, pm_do_optimize
from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen
@dataclass
class RewriteStep:
@@ -70,6 +71,9 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
# expand
ret.append(RewriteStep(sym+expander, name="expander"))
# add locals
ret.append(RewriteStep(pm_add_buffers+rangeify_codegen, name="add local buffers"))
# ** devectorizer (full_graph_rewrite) **
# remove reduce
ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce"))

View File

@@ -1,7 +1,7 @@
import math
import math, functools, operator
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType
from tinygrad.helpers import all_int, partition, flatten, prod, dedup
from tinygrad.dtype import dtypes
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.shape.view import get_contraction
from tinygrad.renderer import Renderer
@@ -83,7 +83,7 @@ def add_gpudims(ctx:Renderer, s:UOp):
if r.op is not Ops.RANGE: continue
try:
ii = (global_dims+local_dims).index(r.arg[0]%1000)
if r.arg[0] < 2000 and r.arg[1] == AxisType.GROUP_REDUCE: continue
if r.arg[1] == AxisType.REDUCE: continue
subs[r] = idxs[ii]
except ValueError: continue
return s.substitute(subs)
@@ -104,7 +104,24 @@ def fix_store_unroll(x:UOp):
if len(store_expand) == 0: return None
return UOp(Ops.CONTRACT, dtypes.void, (x.replace(src=x.src[:2]+tuple(store_range)),), tuple(flatten(x.arg for x in store_expand)), tag=1)
def fix_group_for_reduce(x:UOp):
reduce_gfr, reduce_r = partition(x.src[1:], lambda u: u.op is Ops.RANGE and u.arg[1] == AxisType.GROUP_REDUCE)
if len(reduce_gfr) == 0: return None
# NOTE: if there's other locals here, we need them in the buffer too
upstream_locals = [u for u in x.toposort() if u.op is Ops.RANGE and u.arg[1] == AxisType.LOCAL]
# do only the non grouped reduces early
ret = x.replace(src=(x.src[0],)+tuple(reduce_r))
reduce_loop = [x.replace(arg=(x.arg[0]+100, AxisType.REDUCE)) for x in reduce_gfr]
buf = ret.bufferize(*upstream_locals, *reduce_gfr, arg=(AddrSpace.LOCAL, reduce_gfr[0].arg[0])).index(*upstream_locals, *reduce_loop)
# gate with an if on the store + do the final reduce
buf = UOp(Ops.IF, dtype=buf.dtype, src=(functools.reduce(operator.and_, [x.eq(0) for x in reduce_gfr]), buf))
return buf.reduce(*reduce_loop, arg=x.arg)
pm_add_gpudims = PatternMatcher([
# add gpudims must be last
(UPat(Ops.SINK, name="s"), add_gpudims),
# rewrite UPCAST/UNROLL range to something to be expanded
(UPat(Ops.RANGE, name="r"),
@@ -113,4 +130,6 @@ pm_add_gpudims = PatternMatcher([
# fix REDUCEs with UNROLLs
(UPat(Ops.REDUCE, name="x"), fix_reduce_unroll),
(UPat(Ops.STORE, name="x"), fix_store_unroll),
# fix group for reduce
(UPat(Ops.REDUCE, name="x"), fix_group_for_reduce),
])

View File

@@ -232,17 +232,17 @@ def no_vectorized_alu(alu:UOp):
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
return UOp(Ops.VECTORIZE, alu.dtype, alus)
def no_vectorized_acc(acc:UOp, c:UOp):
if acc.dtype.count == 1: return None
assert c.arg == 0, "this only supports index 0"
new_acc = acc.replace(dtype=acc.dtype.base.scalar().ptr(acc.dtype.count, cast(PtrDType, acc.dtype).addrspace))
return UOp(Ops.PTRCAT, acc.dtype, tuple([new_acc.index(UOp.const(dtypes.int, i)) for i in range(acc.dtype.count)]))
def no_vectorized_buf(buf:UOp, idx:UOp):
# NOTE: this should work for define reg too
if (cnt:=buf.dtype.count) == 1: return None
new_buf = buf.replace(dtype=buf.dtype.base.scalar().ptr(cast(PtrDType, buf.dtype).size*cnt, cast(PtrDType, buf.dtype).addrspace))
return new_buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.int.vec(cnt), tuple(range(cnt))))
devectorize = PatternMatcher([
# no ALU on vectorized dtypes
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu),
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
(UPat(Ops.DEFINE_REG, name="acc").index(UPat.cvar("c")), no_vectorized_acc),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf").index(UPat.var("idx")), no_vectorized_buf),
])
pm_render = PatternMatcher([

View File

@@ -1,7 +1,7 @@
# this converts a lowerer program into a vectorized program
import functools, itertools, operator
from tinygrad.dtype import dtypes
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.helpers import AMX, dedup, flatten, all_same, prod
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp
@@ -50,9 +50,11 @@ def do_expand(root:UOp):
if root.op is Ops.IF or src.op is Ops.IF:
# for the first arg of IF, just pass them through ignoring UNROLLS
new_srcs.append(src)
elif (root.op is Ops.STORE and i >= 2) or (root.op is Ops.REDUCE and i >= 1):
elif (root.op is Ops.STORE and i >= 2) or (root.op in {Ops.REDUCE, Ops.BUFFERIZE} and i >= 1):
# for any range args of STORE/REDUCE, pass them through
new_srcs.append(src)
elif root.op is Ops.INDEX and i >= 1 and not isinstance(root.dtype, PtrDType):
new_srcs.append(src)
elif src.dtype.count > 1:
# put any input dtype > 1 grouped together
new_srcs.append(UOp(Ops.CAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz))
@@ -84,7 +86,7 @@ expander = PatternMatcher([
(UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
# do expansion
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX,
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE,
Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
(UPat(Ops.CONTRACT, name="con"), do_contract),
# BARRIERs aren't actually expanded

View File

@@ -1,8 +1,6 @@
# the job of the lowerer is to do indexing
import functools, operator
from typing import cast
from dataclasses import dataclass
from tinygrad.dtype import dtypes, AddrSpace, PtrDType
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite
# ***** indexing *****
@@ -50,15 +48,7 @@ def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
stored = subblock(ctx, real_new_idxs, x.src[1])
used_ranges = [x for x in used_idxs if x.op is Ops.RANGE]
ret = buf.index(idx, valid).store(stored, *used_ranges)
# insert BARRIER if we are ending a LOCAL, IF if we are ending a GROUP_REDUCE
if cast(PtrDType, buf.dtype).addrspace == AddrSpace.LOCAL and \
any(ctx.axis_types[x.arg[0]%1000] in {AxisType.GROUP_REDUCE, AxisType.LOCAL} for x in used_ranges):
ret = ret.barrier()
range_gates = [x.eq(0) for x in used_ranges if ctx.axis_types[x.arg[0]%1000] == AxisType.GROUP_REDUCE]
if len(range_gates): ret = UOp(Ops.IF, src=(functools.reduce(operator.and_, range_gates), ret))
return ret
return buf.index(idx, valid).store(stored, *used_ranges)
def fixup_wmma(ctx:IndexContext, x:UOp):
if x.tag is not None: return None

View File

@@ -10,7 +10,7 @@ from tinygrad.uop.spec import type_verify, ast_spec
from tinygrad.device import Device
from tinygrad.codegen.opt.tc import TensorCore
from tinygrad.renderer import Renderer
from tinygrad.dtype import ImageDType, AddrSpace
from tinygrad.dtype import ImageDType
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import strides_for_shape, get_contraction
@@ -464,8 +464,7 @@ class Kernel:
if op.op is Ops.REDUCE_AXIS:
reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
changed = tuple(i for i in range(self.shape_len) if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i]))
axes = tuple(i for i in self.axes_of(AxisType.REDUCE, AxisType.UNROLL) if i in changed)
grouped_axes = tuple(i for i in self.axes_of(AxisType.GROUP_REDUCE) if i in changed)
axes = tuple(i for i in self.axes_of(AxisType.REDUCE, AxisType.GROUP_REDUCE, AxisType.UNROLL) if i in changed)
if (tc := self.tensor_core) and self.use_tensor_cores == 1:
# get reduce/upcast axes for the tensor cores
tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))])
@@ -490,23 +489,6 @@ class Kernel:
return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_axes)) if (new_axes := tuple(i for i in axes if i not in tc_reduce_axes)) else tc_uop
ret = ret.replace(arg = (op.arg[0], axes))
if self.group_for_reduces and grouped_axes:
local_axes = tuple([i for i,t in enumerate(self.axis_types) if t in (AxisType.LOCAL, AxisType.UPCAST) or i in grouped_axes])
slocal, supcast, sgroup = sorted(self.axes_of(AxisType.LOCAL)), sorted(self.axes_of(AxisType.UPCAST)), sorted(grouped_axes)
# NOTE: start with UPCAST at the end so it has stride 1 and can merge
base_shape = tuple([self.full_shape[i] for i in slocal] + [self.full_shape[i] for i in sgroup] + [self.full_shape[i] for i in supcast])
permute_axes = tuple([local_axes.index(i) for i in slocal+sgroup+supcast])
local_shape = tuple([s if i in local_axes else 1 for i,s in enumerate(self.full_shape)])
local_src_shape = tuple([self.full_shape[i] if i in self.axes_of(AxisType.GLOBAL) else s for i,s in enumerate(local_shape)])
st = ShapeTracker.from_shape(base_shape).permute(permute_axes).reshape(local_shape).expand(local_src_shape)
local_size = st.real_size()
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, addrspace=AddrSpace.LOCAL), (), f"temp{self.reduceops.index(op)}")
local_load = local_buffer.view(st).load(local_buffer.view(st).store(ret))
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
if op is self.reduceops[-1]: return grouped_reduce
st = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else s for i,s in enumerate(local_shape)]))
return local_buffer.view(st).load(local_buffer.view(st).store(grouped_reduce))
return ret
self.finalized = True
fixed_ast = fixup_ast(self.ast)

View File

@@ -128,7 +128,8 @@ fix_kernel_ops = view_left_through_load+PatternMatcher([
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
# no ImageDType after index
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW, Ops.INDEX}, name="x"),
lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
(UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
])

View File

@@ -119,7 +119,7 @@ string_rewrite = PatternMatcher([
ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[0]], dtypes.int, ctx.types[dtypes.int]),
f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]),
(UPat(Ops.DEFINE_LOCAL, name="x"),
lambda ctx, x: [f".shared .align 16 .b8 {x.arg}[{x.dtype.size*x.dtype.itemsize}];", f"mov.u64 {ctx.r[x]}, {x.arg}[0];"]),
lambda ctx, x: [f".shared .align 16 .b8 local{x.arg}[{x.dtype.size*x.dtype.itemsize}];", f"mov.u64 {ctx.r[x]}, local{x.arg}[0];"]),
(UPat(Ops.IF, name="x"), lambda ctx, x: f"@!{ctx.r[x.src[0]]} bra IF_{ctx.r[x.src[0]][1:]}_{ctx.uops.index(x)};"),
(UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"),
(UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))),
@@ -215,7 +215,7 @@ class PTXRenderer(Renderer):
[ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.dtype.scalar().itemsize)]]
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.ENDRANGE: ("pred", "pred"), Ops.RANGE: ("ridx", None),
Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL:("local",self.types[dtypes.ulong]),
Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL: ("local",self.types[dtypes.ulong]),
Ops.DEFINE_GLOBAL: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None))
if prefix: r[u] = ssa(prefix, u, dtype)

View File

@@ -1,6 +1,6 @@
from typing import Any
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, PtrDType, AddrSpace
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, colored, RANGEIFY
from tinygrad.schedule.multi import multi_pm
@@ -332,17 +332,15 @@ pm_cleanups = double_reshape+pm_mops+PatternMatcher([
def bufferize_to_store(x:UOp):
rngs = x.src[1:]
shape = tuple([int(r.vmax+1) for r in rngs])
sdtype = x.dtype.ptr(size=prod(shape), addrspace=AddrSpace.GLOBAL if not isinstance(x.arg, AddrSpace) else x.arg)
assert prod(shape) > 0, f"no zero sized buffers {shape}"
size = prod(shape)
assert size > 0, f"no zero sized buffers {shape}"
sdtype = x.dtype.ptr(size=size, addrspace=AddrSpace.GLOBAL if not isinstance(x.arg, tuple) else x.arg[0])
if x.src[0].op is Ops.ASSIGN:
assign_target, assign_src = x.src[0].src
assert assign_target.op is Ops.INDEX
return assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=sdtype)
if sdtype.addrspace == AddrSpace.GLOBAL:
buf = UOp.new_buffer(x.arg, prod(shape), x.dtype)
else:
# TODO: how to dedup this
buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=UOp.unique().arg)
if sdtype.addrspace == AddrSpace.GLOBAL: buf = UOp.new_buffer(x.arg, size, x.dtype)
else: buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=x.arg[1])
return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).forced_reshape(shape, dtype=x.dtype)
pm_add_buffers = pm_mops+PatternMatcher([
@@ -393,11 +391,15 @@ rangeify_codegen = PatternMatcher([
# add loads to non ptr indexes
# TODO: this can be moved into codegen?
(UPat((Ops.DEFINE_GLOBAL, Ops.STORE), name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
lambda dg,idx: idx.replace(dtype=dg.dtype, arg=None).load() if not isinstance(idx.dtype, PtrDType) else None),
lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()),
# TODO: this can be moved into codegen
(UPat(Ops.STORE, name="store").f(Ops.INDEX, allow_any_len=True, name="idx").f(Ops.LOAD),
lambda store,idx: idx.replace(src=(store.as_buf(),)+idx.src[1:]).load(store if idx.dtype.addrspace != AddrSpace.LOCAL else store.barrier())),
# TODO: hack for group for reduce
(UPat(Ops.IF, src=(UPat.var("gate"), UPat(Ops.LOAD, src=(UPat.var("src"), UPat.var("barrier"))),)),
lambda src, barrier, gate: src.load(UOp(Ops.IF, src=(gate, barrier)))),
])
def split_store(x:UOp):