mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
new (post) group for reduce (#11837)
* new (post) group for reduce * fixes * leave if * fix locals * size * no vectorized buf * image fixes * don't track that * fix ptx * name buffer with reduce range * remove unused in lowerer * yay DEFINE_REG refactor
This commit is contained in:
@@ -18,6 +18,7 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in
|
||||
from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
||||
from tinygrad.codegen.opt import pm_get_optimization, pm_do_optimize
|
||||
from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen
|
||||
|
||||
@dataclass
|
||||
class RewriteStep:
|
||||
@@ -70,6 +71,9 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
|
||||
# expand
|
||||
ret.append(RewriteStep(sym+expander, name="expander"))
|
||||
|
||||
# add locals
|
||||
ret.append(RewriteStep(pm_add_buffers+rangeify_codegen, name="add local buffers"))
|
||||
|
||||
# ** devectorizer (full_graph_rewrite) **
|
||||
# remove reduce
|
||||
ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce"))
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import math
|
||||
import math, functools, operator
|
||||
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType
|
||||
from tinygrad.helpers import all_int, partition, flatten, prod, dedup
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
from tinygrad.shape.view import get_contraction
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
@@ -83,7 +83,7 @@ def add_gpudims(ctx:Renderer, s:UOp):
|
||||
if r.op is not Ops.RANGE: continue
|
||||
try:
|
||||
ii = (global_dims+local_dims).index(r.arg[0]%1000)
|
||||
if r.arg[0] < 2000 and r.arg[1] == AxisType.GROUP_REDUCE: continue
|
||||
if r.arg[1] == AxisType.REDUCE: continue
|
||||
subs[r] = idxs[ii]
|
||||
except ValueError: continue
|
||||
return s.substitute(subs)
|
||||
@@ -104,7 +104,24 @@ def fix_store_unroll(x:UOp):
|
||||
if len(store_expand) == 0: return None
|
||||
return UOp(Ops.CONTRACT, dtypes.void, (x.replace(src=x.src[:2]+tuple(store_range)),), tuple(flatten(x.arg for x in store_expand)), tag=1)
|
||||
|
||||
def fix_group_for_reduce(x:UOp):
|
||||
reduce_gfr, reduce_r = partition(x.src[1:], lambda u: u.op is Ops.RANGE and u.arg[1] == AxisType.GROUP_REDUCE)
|
||||
if len(reduce_gfr) == 0: return None
|
||||
|
||||
# NOTE: if there's other locals here, we need them in the buffer too
|
||||
upstream_locals = [u for u in x.toposort() if u.op is Ops.RANGE and u.arg[1] == AxisType.LOCAL]
|
||||
|
||||
# do only the non grouped reduces early
|
||||
ret = x.replace(src=(x.src[0],)+tuple(reduce_r))
|
||||
reduce_loop = [x.replace(arg=(x.arg[0]+100, AxisType.REDUCE)) for x in reduce_gfr]
|
||||
buf = ret.bufferize(*upstream_locals, *reduce_gfr, arg=(AddrSpace.LOCAL, reduce_gfr[0].arg[0])).index(*upstream_locals, *reduce_loop)
|
||||
|
||||
# gate with an if on the store + do the final reduce
|
||||
buf = UOp(Ops.IF, dtype=buf.dtype, src=(functools.reduce(operator.and_, [x.eq(0) for x in reduce_gfr]), buf))
|
||||
return buf.reduce(*reduce_loop, arg=x.arg)
|
||||
|
||||
pm_add_gpudims = PatternMatcher([
|
||||
# add gpudims must be last
|
||||
(UPat(Ops.SINK, name="s"), add_gpudims),
|
||||
# rewrite UPCAST/UNROLL range to something to be expanded
|
||||
(UPat(Ops.RANGE, name="r"),
|
||||
@@ -113,4 +130,6 @@ pm_add_gpudims = PatternMatcher([
|
||||
# fix REDUCEs with UNROLLs
|
||||
(UPat(Ops.REDUCE, name="x"), fix_reduce_unroll),
|
||||
(UPat(Ops.STORE, name="x"), fix_store_unroll),
|
||||
# fix group for reduce
|
||||
(UPat(Ops.REDUCE, name="x"), fix_group_for_reduce),
|
||||
])
|
||||
|
||||
@@ -232,17 +232,17 @@ def no_vectorized_alu(alu:UOp):
|
||||
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
|
||||
return UOp(Ops.VECTORIZE, alu.dtype, alus)
|
||||
|
||||
def no_vectorized_acc(acc:UOp, c:UOp):
|
||||
if acc.dtype.count == 1: return None
|
||||
assert c.arg == 0, "this only supports index 0"
|
||||
new_acc = acc.replace(dtype=acc.dtype.base.scalar().ptr(acc.dtype.count, cast(PtrDType, acc.dtype).addrspace))
|
||||
return UOp(Ops.PTRCAT, acc.dtype, tuple([new_acc.index(UOp.const(dtypes.int, i)) for i in range(acc.dtype.count)]))
|
||||
def no_vectorized_buf(buf:UOp, idx:UOp):
|
||||
# NOTE: this should work for define reg too
|
||||
if (cnt:=buf.dtype.count) == 1: return None
|
||||
new_buf = buf.replace(dtype=buf.dtype.base.scalar().ptr(cast(PtrDType, buf.dtype).size*cnt, cast(PtrDType, buf.dtype).addrspace))
|
||||
return new_buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.int.vec(cnt), tuple(range(cnt))))
|
||||
|
||||
devectorize = PatternMatcher([
|
||||
# no ALU on vectorized dtypes
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu),
|
||||
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
|
||||
(UPat(Ops.DEFINE_REG, name="acc").index(UPat.cvar("c")), no_vectorized_acc),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf").index(UPat.var("idx")), no_vectorized_buf),
|
||||
])
|
||||
|
||||
pm_render = PatternMatcher([
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# this converts a lowerer program into a vectorized program
|
||||
|
||||
import functools, itertools, operator
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.helpers import AMX, dedup, flatten, all_same, prod
|
||||
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp
|
||||
|
||||
@@ -50,9 +50,11 @@ def do_expand(root:UOp):
|
||||
if root.op is Ops.IF or src.op is Ops.IF:
|
||||
# for the first arg of IF, just pass them through ignoring UNROLLS
|
||||
new_srcs.append(src)
|
||||
elif (root.op is Ops.STORE and i >= 2) or (root.op is Ops.REDUCE and i >= 1):
|
||||
elif (root.op is Ops.STORE and i >= 2) or (root.op in {Ops.REDUCE, Ops.BUFFERIZE} and i >= 1):
|
||||
# for any range args of STORE/REDUCE, pass them through
|
||||
new_srcs.append(src)
|
||||
elif root.op is Ops.INDEX and i >= 1 and not isinstance(root.dtype, PtrDType):
|
||||
new_srcs.append(src)
|
||||
elif src.dtype.count > 1:
|
||||
# put any input dtype > 1 grouped together
|
||||
new_srcs.append(UOp(Ops.CAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz))
|
||||
@@ -84,7 +86,7 @@ expander = PatternMatcher([
|
||||
(UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
|
||||
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
||||
# do expansion
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX,
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE,
|
||||
Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
|
||||
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
||||
# BARRIERs aren't actually expanded
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# the job of the lowerer is to do indexing
|
||||
import functools, operator
|
||||
from typing import cast
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.dtype import dtypes, AddrSpace, PtrDType
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite
|
||||
|
||||
# ***** indexing *****
|
||||
@@ -50,15 +48,7 @@ def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
|
||||
|
||||
stored = subblock(ctx, real_new_idxs, x.src[1])
|
||||
used_ranges = [x for x in used_idxs if x.op is Ops.RANGE]
|
||||
ret = buf.index(idx, valid).store(stored, *used_ranges)
|
||||
|
||||
# insert BARRIER if we are ending a LOCAL, IF if we are ending a GROUP_REDUCE
|
||||
if cast(PtrDType, buf.dtype).addrspace == AddrSpace.LOCAL and \
|
||||
any(ctx.axis_types[x.arg[0]%1000] in {AxisType.GROUP_REDUCE, AxisType.LOCAL} for x in used_ranges):
|
||||
ret = ret.barrier()
|
||||
range_gates = [x.eq(0) for x in used_ranges if ctx.axis_types[x.arg[0]%1000] == AxisType.GROUP_REDUCE]
|
||||
if len(range_gates): ret = UOp(Ops.IF, src=(functools.reduce(operator.and_, range_gates), ret))
|
||||
return ret
|
||||
return buf.index(idx, valid).store(stored, *used_ranges)
|
||||
|
||||
def fixup_wmma(ctx:IndexContext, x:UOp):
|
||||
if x.tag is not None: return None
|
||||
|
||||
@@ -10,7 +10,7 @@ from tinygrad.uop.spec import type_verify, ast_spec
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.codegen.opt.tc import TensorCore
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.dtype import ImageDType, AddrSpace
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import strides_for_shape, get_contraction
|
||||
@@ -464,8 +464,7 @@ class Kernel:
|
||||
if op.op is Ops.REDUCE_AXIS:
|
||||
reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
|
||||
changed = tuple(i for i in range(self.shape_len) if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i]))
|
||||
axes = tuple(i for i in self.axes_of(AxisType.REDUCE, AxisType.UNROLL) if i in changed)
|
||||
grouped_axes = tuple(i for i in self.axes_of(AxisType.GROUP_REDUCE) if i in changed)
|
||||
axes = tuple(i for i in self.axes_of(AxisType.REDUCE, AxisType.GROUP_REDUCE, AxisType.UNROLL) if i in changed)
|
||||
if (tc := self.tensor_core) and self.use_tensor_cores == 1:
|
||||
# get reduce/upcast axes for the tensor cores
|
||||
tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))])
|
||||
@@ -490,23 +489,6 @@ class Kernel:
|
||||
return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_axes)) if (new_axes := tuple(i for i in axes if i not in tc_reduce_axes)) else tc_uop
|
||||
|
||||
ret = ret.replace(arg = (op.arg[0], axes))
|
||||
if self.group_for_reduces and grouped_axes:
|
||||
local_axes = tuple([i for i,t in enumerate(self.axis_types) if t in (AxisType.LOCAL, AxisType.UPCAST) or i in grouped_axes])
|
||||
slocal, supcast, sgroup = sorted(self.axes_of(AxisType.LOCAL)), sorted(self.axes_of(AxisType.UPCAST)), sorted(grouped_axes)
|
||||
# NOTE: start with UPCAST at the end so it has stride 1 and can merge
|
||||
base_shape = tuple([self.full_shape[i] for i in slocal] + [self.full_shape[i] for i in sgroup] + [self.full_shape[i] for i in supcast])
|
||||
permute_axes = tuple([local_axes.index(i) for i in slocal+sgroup+supcast])
|
||||
local_shape = tuple([s if i in local_axes else 1 for i,s in enumerate(self.full_shape)])
|
||||
local_src_shape = tuple([self.full_shape[i] if i in self.axes_of(AxisType.GLOBAL) else s for i,s in enumerate(local_shape)])
|
||||
st = ShapeTracker.from_shape(base_shape).permute(permute_axes).reshape(local_shape).expand(local_src_shape)
|
||||
local_size = st.real_size()
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, addrspace=AddrSpace.LOCAL), (), f"temp{self.reduceops.index(op)}")
|
||||
local_load = local_buffer.view(st).load(local_buffer.view(st).store(ret))
|
||||
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
|
||||
if op is self.reduceops[-1]: return grouped_reduce
|
||||
st = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else s for i,s in enumerate(local_shape)]))
|
||||
return local_buffer.view(st).load(local_buffer.view(st).store(grouped_reduce))
|
||||
|
||||
return ret
|
||||
self.finalized = True
|
||||
fixed_ast = fixup_ast(self.ast)
|
||||
|
||||
@@ -128,7 +128,8 @@ fix_kernel_ops = view_left_through_load+PatternMatcher([
|
||||
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
|
||||
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
|
||||
# no ImageDType after index
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW, Ops.INDEX}, name="x"),
|
||||
lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
||||
(UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
|
||||
])
|
||||
|
||||
@@ -119,7 +119,7 @@ string_rewrite = PatternMatcher([
|
||||
ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[0]], dtypes.int, ctx.types[dtypes.int]),
|
||||
f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"),
|
||||
lambda ctx, x: [f".shared .align 16 .b8 {x.arg}[{x.dtype.size*x.dtype.itemsize}];", f"mov.u64 {ctx.r[x]}, {x.arg}[0];"]),
|
||||
lambda ctx, x: [f".shared .align 16 .b8 local{x.arg}[{x.dtype.size*x.dtype.itemsize}];", f"mov.u64 {ctx.r[x]}, local{x.arg}[0];"]),
|
||||
(UPat(Ops.IF, name="x"), lambda ctx, x: f"@!{ctx.r[x.src[0]]} bra IF_{ctx.r[x.src[0]][1:]}_{ctx.uops.index(x)};"),
|
||||
(UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"),
|
||||
(UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))),
|
||||
@@ -215,7 +215,7 @@ class PTXRenderer(Renderer):
|
||||
[ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.dtype.scalar().itemsize)]]
|
||||
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
|
||||
prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.ENDRANGE: ("pred", "pred"), Ops.RANGE: ("ridx", None),
|
||||
Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL:("local",self.types[dtypes.ulong]),
|
||||
Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL: ("local",self.types[dtypes.ulong]),
|
||||
Ops.DEFINE_GLOBAL: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None))
|
||||
if prefix: r[u] = ssa(prefix, u, dtype)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, PtrDType, AddrSpace
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, colored, RANGEIFY
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
@@ -332,17 +332,15 @@ pm_cleanups = double_reshape+pm_mops+PatternMatcher([
|
||||
def bufferize_to_store(x:UOp):
|
||||
rngs = x.src[1:]
|
||||
shape = tuple([int(r.vmax+1) for r in rngs])
|
||||
sdtype = x.dtype.ptr(size=prod(shape), addrspace=AddrSpace.GLOBAL if not isinstance(x.arg, AddrSpace) else x.arg)
|
||||
assert prod(shape) > 0, f"no zero sized buffers {shape}"
|
||||
size = prod(shape)
|
||||
assert size > 0, f"no zero sized buffers {shape}"
|
||||
sdtype = x.dtype.ptr(size=size, addrspace=AddrSpace.GLOBAL if not isinstance(x.arg, tuple) else x.arg[0])
|
||||
if x.src[0].op is Ops.ASSIGN:
|
||||
assign_target, assign_src = x.src[0].src
|
||||
assert assign_target.op is Ops.INDEX
|
||||
return assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=sdtype)
|
||||
if sdtype.addrspace == AddrSpace.GLOBAL:
|
||||
buf = UOp.new_buffer(x.arg, prod(shape), x.dtype)
|
||||
else:
|
||||
# TODO: how to dedup this
|
||||
buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=UOp.unique().arg)
|
||||
if sdtype.addrspace == AddrSpace.GLOBAL: buf = UOp.new_buffer(x.arg, size, x.dtype)
|
||||
else: buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=x.arg[1])
|
||||
return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).forced_reshape(shape, dtype=x.dtype)
|
||||
|
||||
pm_add_buffers = pm_mops+PatternMatcher([
|
||||
@@ -393,11 +391,15 @@ rangeify_codegen = PatternMatcher([
|
||||
# add loads to non ptr indexes
|
||||
# TODO: this can be moved into codegen?
|
||||
(UPat((Ops.DEFINE_GLOBAL, Ops.STORE), name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
|
||||
lambda dg,idx: idx.replace(dtype=dg.dtype, arg=None).load() if not isinstance(idx.dtype, PtrDType) else None),
|
||||
lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()),
|
||||
|
||||
# TODO: this can be moved into codegen
|
||||
(UPat(Ops.STORE, name="store").f(Ops.INDEX, allow_any_len=True, name="idx").f(Ops.LOAD),
|
||||
lambda store,idx: idx.replace(src=(store.as_buf(),)+idx.src[1:]).load(store if idx.dtype.addrspace != AddrSpace.LOCAL else store.barrier())),
|
||||
|
||||
# TODO: hack for group for reduce
|
||||
(UPat(Ops.IF, src=(UPat.var("gate"), UPat(Ops.LOAD, src=(UPat.var("src"), UPat.var("barrier"))),)),
|
||||
lambda src, barrier, gate: src.load(UOp(Ops.IF, src=(gate, barrier)))),
|
||||
])
|
||||
|
||||
def split_store(x:UOp):
|
||||
|
||||
Reference in New Issue
Block a user