Files
tinygrad/tinygrad/codegen/linearizer.py
chenyu 7d049fc20c move getting 0 and min value of a dtype to dtype.py (#5328)
cleanup getting base case for reduce ops
[run_process_replay]
2024-07-08 10:51:56 -04:00

526 lines
32 KiB
Python

from __future__ import annotations
from typing import List, Tuple, Optional, Type, cast, DefaultDict, Dict, Union, Final, Iterator, Sequence, Callable
import itertools, functools
from collections import defaultdict
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
from tinygrad.helpers import colored, DEBUG, dedup, diskcache_put, prod, getenv, to_function_name, flatten
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node, sint
from tinygrad.codegen.kernel import LocalBuffer, Kernel
from tinygrad.renderer import Program
from tinygrad.codegen.uops import UOps, UOp, UOpGraph
def get_grouped_dims(prefix:str, off:int, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse_dims:bool=False):
""" Maps all global/local dims onto global/local sizes and returns the idxs, loop_idxs and sizes.
* If there are fewer dims than size, size will be padded with 1s to the length of max_sizes.
* If there are more dims than size, dims will be collapsed onto size starting from left-most (i.e. onto x, then y, then z).
* If the dim is too large for the size, the dim will be split between adjacent size axes space permitting, otherwise assert
Keyword arguments:
prefix -- the prefix to use for the size Variable names.
off -- the starting index for the size Variable names.
dims -- the global or local dims of the full shape.
max_sizes -- the maximum values for each size in (x, y, z) order.
reverse_dims -- reverse the order of the dims as they are mapped into size, i.e. if True, the right dim will go to the left size (.x).
"""
# check the edge cases on max_sizes
if max_sizes is None: max_sizes = tuple([0xFFFFFFFFFFFFFFFF] * len(dims))
assert len(max_sizes) > 0 or len(dims) == 0, f"{prefix} dims should be empty because no size axes available"
if len(max_sizes) == 0: return [], [], None
# initialize the map of dims to size with a single dim in each size axis
# TODO: support sint properly
size_dims:List[List[Tuple[int, sint, sint]]] = [[(dim_idx, dim, dim if isinstance(dim, int) else dim.max+1)] for dim_idx, dim in enumerate(dims)]
# reverse the order of the dims to size map, if desired (currently for globals where smallest stride is on the right)
# TODO: remove reverse_dims, the mapping of dims to size for globals should be cosearched with memory layouts for optimal peformance
if reverse_dims: size_dims = size_dims[::-1]
# ensure that the initial dims initially fit the valid size axes
for size_idx in range(min(len(max_sizes), len(size_dims))):
# if the initial dim is too large, split the dim to separate size axes, if possible
dim_idx, dim, dim_max = size_dims[size_idx][0]
if dim_max <= (max_sz:=max_sizes[size_idx]): continue
assert isinstance(dim, int), "variable shape too large for size"
for factor in range(2, int(dim**0.5)+1):
if dim % factor == 0 and dim // factor <= max_sz:
size_dims = size_dims[:size_idx] + [[(dim_idx, dim//factor, dim//factor)], [(dim_idx, factor, factor)]] + size_dims[size_idx+1:]
break
assert size_dims[size_idx][0][2] <= max_sz, f"dim at {size_idx} too large and non-factorable: {dim} > {max_sz}"
# compress the extra dims, collapsing them onto the left-most valid size axis
cur_size_idx = 0
while len(size_dims) > len(max_sizes):
if prod([dim_max for (_, _, dim_max) in size_dims[cur_size_idx]])*size_dims[cur_size_idx+1][0][2] <= max_sizes[cur_size_idx]:
size_dims = size_dims[:cur_size_idx] + [size_dims[cur_size_idx] + size_dims[cur_size_idx+1]] + size_dims[cur_size_idx+2:]
elif cur_size_idx < len(max_sizes)-1: cur_size_idx += 1
else: raise AssertionError(f"cannot fit dims in size: {dims=} {max_sizes=}")
# construct the final dim idx variables from the the portions of the size variables
sizes, idxs = [prod([dim for (_, dim, _) in size_dim]) for size_dim in size_dims], [NumNode(0)] * len(dims)
size_vars = loop_idxs = [Variable(f"{prefix}{len(sizes)-1-(i+off) if reverse_dims else i+off}", 0, s-1) for i,s in enumerate(sizes)]
for size_idx, size_var in enumerate(size_vars):
for dim_idx, dim, _ in size_dims[size_idx]:
idxs[dim_idx] += (size_var % dim) * (idxs[dim_idx].max+1)
size_var //= dim
# pad the final sizes array to the proper length if necessary
return idxs, [x for x in loop_idxs if not isinstance(x, NumNode)], sizes + [1]*(len(max_sizes)-len(sizes))
def expand_idx(node:Node) -> Union[Variable, NumNode]: return next((v for v in node.vars() if v.expr.startswith("_uidx")), NumNode(0))
def expand_idxs(nodes:Sequence[Node]) -> Tuple[Union[Variable, NumNode], ...]:
eidxs = [expand_idx(node) for node in nodes]
return tuple([v if v not in eidxs[:j] else NumNode(0) for j, v in enumerate(eidxs)]) # take only first occurrence of expand variable
def iter_idxs(idxs:Tuple[Union[Variable, NumNode], ...]) -> Iterator[Tuple[int,...]]:
yield from (x[::-1] for x in itertools.product(*[list(range(v.min, v.max + 1)) for v in idxs[::-1]]))
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1]))
# TODO: bring back the valid removal logic (correct!)
if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy, valid)
return (idx, idy), valid
# expand a Node into List[Node] that enumerates the underlying Variables from min to max
# expand increments earlier variables faster than later variables (as specified in the argument)
@functools.lru_cache(maxsize=None)
def expand_node(node:Node, idxs:Optional[Tuple[Union[Variable, NumNode], ...]]=None) -> List[Node]:
if idxs is None: idxs = (expand_idx(node),)
return [node.substitute({k:v for k,v in zip(idxs, (NumNode(x) for x in rep)) if isinstance(k, Variable)}) for rep in iter_idxs(idxs)]
def variable_to_uop(x, ctx=None) -> UOp:
if isinstance(x, int): return UOp.const(dtypes.int, x)
return x.render(render_ops, ctx)
render_ops: Dict[Type, Callable[..., UOp]] = {
NumNode: lambda self, ops, ctx: UOp.const(dtypes.int, self.b),
Variable: lambda self, ops, ctx: ctx[self.expr] if self.expr in ctx else UOp(UOps.DEFINE_VAR, dtypes.int, (), self),
MulNode: lambda self, ops, ctx: self.a.render(ops, ctx)*variable_to_uop(self.b, ctx),
DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx),
ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+variable_to_uop(b, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*variable_to_uop(b, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
class Linearizer(Kernel):
def get_reduce_acc(self, reduceop:LazyOp):
if reduceop.op is ReduceOps.SUM: return dtypes.as_const(0, reduceop.dtype)
if reduceop.op is ReduceOps.MAX: return dtypes.min(reduceop.dtype)
# NOTE: once images are loaded, we uop them as their base float
def get_base_dtype(self, dt:DType) -> DType: return dt.base if isinstance(dt, ImageDType) else dt
def global_load(self, i:int, idxs:List[Node], acc:Optional[LazyOp]=None, barrier:Tuple[UOp, ...]=(), loop_ctx:Tuple[UOp, ...]=()) -> List[UOp]:
buf = self.bufs[i]
localtype = self.get_base_dtype(buf.dtype if acc is None else acc.dtype)
const = buf.val if isinstance(buf, ConstBuffer) else None
expand_vars = expand_idxs(idxs)
dim, amt = None, 1
# float 4 grouping
if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expand_node(idxs[upcast_dim[0]])) in [4,2]:
dim, amt = upcast_dim[0], len(float4_expand)
g_idx, g_valid = self.sts[i].expr_idxs(idxs[:dim] + [float4_expand[0]] + idxs[dim+1:])
# do not use float4 if idx is not aligned
if g_idx != (g_idx//amt*amt): dim, amt = None, 1
if dim is None:
g_idx, g_valid = self.sts[i].expr_idxs(idxs)
# todo: multioutput test with different output valids to add if acc is None: g_valid = NumNode(1)
if amt > 1: localtype = localtype.vec(amt)
e_idxs, e_valids = expand_node(g_idx, expand_vars), expand_node(g_valid, expand_vars) # pylint: disable=possibly-used-before-assignment
ret = []
invalid_value = 0
acc_count = 0
for idx, valid, rep_idx in zip(e_idxs, e_valids, iter_idxs(expand_vars)):
this_const, idx = (invalid_value, NumNode(0)) if valid.max == 0 else (const, idx)
valid_uop = UOp.const(dtypes.bool, valid.b) if valid.min == valid.max else valid.render(render_ops, self.loop_uops)
key = f"{'' if acc is None else self.reduceops.index(acc)}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
if key not in self.load_cache:
if acc is not None:
self.load_cache[key] = UOp(UOps.DEFINE_ACC, localtype, (UOp.const(localtype.scalar(), self.get_reduce_acc(acc)), *loop_ctx), (i, acc_count))
acc_count += 1
elif this_const is not None:
self.load_cache[key] = UOp.const(localtype, this_const)
if valid.min == 0 and valid.max == 1:
self.load_cache[key] = UOp.alu(TernaryOps.WHERE, valid_uop, self.load_cache[key], UOp.const(localtype, invalid_value))
elif isinstance(buf.dtype, ImageDType):
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
valid_tuple = (valid_uop, UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple()
self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4), (buf_uop, rendered_idx) + valid_tuple + barrier)
if localtype == localtype.scalar():
idx_small = idx%4
res = idx_small.render(render_ops, self.loop_uops)
out = UOp(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
for ix in range(idx_small.max, idx_small.min, -1):
rvv = UOp(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
sel = UOp.alu(BinaryOps.CMPLT, res, UOp.const(dtypes.int, ix))
out = UOp.alu(TernaryOps.WHERE, sel, rvv, out)
self.load_cache[key] = out
else:
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
rendered_idx = idx.render(render_ops, self.loop_uops)
valid_tuple = (valid_uop, UOp.const(localtype, invalid_value)) if valid.min == 0 else tuple()
self.load_cache[key] = UOp(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + barrier)
ret.append(UOp(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
return ret
def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]:
buf = self.bufs[i]
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
expand_vars = expand_idxs(idxs)
_idxs = zip(*[expand_node(idx, expand_vars) for idx in idxs]) if idxs else [tuple()] # transpose
store_offset = dict(zip(_idxs, store))
# float4 grouping
if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expand_node(idxs[upcast_dim[0]])) in [2,4]:
grouped_store_offset = defaultdict(list)
for k in store_offset:
_idx = k[:upcast_dim[0]] + (float4_expand[0],) + k[upcast_dim[0]+1:]
grouped_store_offset[_idx].append(store_offset[k])
store_offset_new = {}
for k,grouped in grouped_store_offset.items():
amt = len(grouped)
idx, valid = self.sts[i].expr_idxs(k)
assert idx == ((idx//amt)*amt), "float4 stores are always aligned"
store_offset_new[k] = UOp(UOps.VECTORIZE, buf.dtype.vec(amt), tuple(grouped))
store_offset = store_offset_new
stores = []
for _idx, var in store_offset.items():
idx, valid = self.sts[i].expr_idxs(_idx)
if isinstance(buf.dtype, ImageDType):
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
else:
rendered_idx = idx.render(render_ops, self.loop_uops)
if self.late_gate is not None: valid *= self.late_gate
# TODO: let UPat check this once it's fast
if valid.min == 1: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var)))
elif valid.max == 1: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(render_ops, self.loop_uops))))
return stores
# render loop
def render_loop(self, xx:List[Variable], depth:int, reduce:bool) -> Tuple[UOp, ...]:
new_loops = {x.expr:UOp(UOps.RANGE, dtypes.int32, (
UOp.const(dtypes.int, x.min) if isinstance(x.min, int) else cast(Node, x.min).render(render_ops, self.loop_uops),
UOp.const(dtypes.int, x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(render_ops, self.loop_uops)), arg=(depth,i,reduce)) for i,x in enumerate(xx) if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
self.loop_uops.update(new_loops)
return tuple(new_loops.values())
def index_local_aliases(self, global_idxs, local_idxs, reduce_idxs, upcast_idxs, full_upcast_idxs):
def calc_tc_idxs(local_sizes: List[int], aliases: List[List[int]]):
replace_idxs, thread_idxs, thread_idx = [], [], Variable("_uidx_tc", 0, prod(local_sizes)-1)
for s in local_sizes:
thread_idxs.append(thread_idx % s)
thread_idx //= s
for alias in aliases:
full_var, full_var_sz = NumNode(0), 1
if alias[0] != 0:
for i in alias:
next_var = local_idxs[i-1] if i > 0 else thread_idxs[-i-1]
full_var += next_var * full_var_sz
full_var_sz *= next_var.max+1
replace_idxs.append(full_var)
return replace_idxs
# compute local aliases
alias_buf_idxs: DefaultDict[LazyOp, List[Tuple[int, int, List]]] = defaultdict(list)
for op, local_alias in self.local_alias.items():
for i in local_alias:
localbuf_idx = self.bufs.index(local_alias[i])
buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())]
if (tc:=self.tensor_core):
min_alias_idx = min(local_alias.keys())
replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx])
for n in range(len(tc.threads)):
buf_idxs[self.global_dims+n] = replace_input_idxs[n] # replace locals
for n in range(tc.num_upcasts()):
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(tc.threads)+n] # replace upcasts
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}")
alias_buf_idxs[op].append((i, localbuf_idx, buf_idxs))
# modify idxs if necessary for TC
if (tc:=self.tensor_core):
replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2])
for n in range(len(tc.threads)):
local_idxs[n] = replace_acc_idxs[n] # replace locals
for n in range(len(replace_acc_idxs)-len(tc.threads)):
upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+upcast_idxs}")
return alias_buf_idxs
def render_reduceop(self, reduceop:LazyOp, accs:Dict[LazyOp, List[UOp]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]],
global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, reduce_idxs, fake_reduce_idxs,
alias_buf_idxs:List[Tuple[int, int, List]]) -> Tuple[List[NumNode|Variable], List[NumNode|Variable]]:
# reduce loop
loop_ctx = self.render_loop(reduce_idxs, (i:=self.reduceops.index(reduceop))*2+2, True)
# define accumulator - modify idxs if necessary for TC
out_buf = -len(self.reduceops)+i if self.group_for_reduces else 0
accs[reduceop] = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
# store local aliases
locals_to_store = [(localbuf_idx, buf_idxs, self.global_load(i, buf_idxs)) for i, localbuf_idx, buf_idxs in alias_buf_idxs]
if (tc:=self.tensor_core):
# run tensor cores AST
wmma_sz = [prod(l) for l in tc.thread_local_sizes]
def upcast_strides(buf:int):
strides, next_ = [], 1
for (sz, stride, _) in self.upcasted_axis(buf)[tc.num_upcasts():]:
strides.append((0 if stride == 0 else next_, sz))
next_ *= 1 if stride == 0 else sz
return strides
upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device
# vectorize initial accs
wmmas = [UOp(UOps.VECTORIZE, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(accs[reduceop][x:x+wmma_sz[2]]))
for x in range(0, len(accs[reduceop]), wmma_sz[2])]
for it in [x[::-1] for x in itertools.product(*list([range(sz) for _,sz in upcasts[0]][::-1]))]:
offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(it, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
ops = (UOp(UOps.VECTORIZE, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
UOp(UOps.VECTORIZE, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
wmmas[(wmma_idx:=offs[2]//wmma_sz[2])])
# TODO: don't need to DEFINE_ACC, pass to WMMA in op3, or PHI accs that are not valid
wmmas[wmma_idx] = UOp(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), dev))
# phi the last wmmas back to accs
accs[reduceop] = [UOp(UOps.PHI, tc.dtype_out, (acc, UOp(UOps.GEP, tc.dtype_out, (wmmas[z//wmma_sz[2]],), z%wmma_sz[2])))
for z, acc in enumerate(accs[reduceop])]
else:
assert not locals_to_store, "storing locals isn't supported here"
# load earlybufs
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[reduceop][i]) if i in self.local_alias else i,
global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs})
def gate_acc(r, idxs): return [
UOp.alu(TernaryOps.WHERE, valid.render(render_ops, self.loop_uops), acc, UOp.const(r.dtype, 0)) if valid.min == 0 and valid.max == 1 else acc
for valid, acc in zip(expand_node(self.sts[self.full_buf_index].expr_idxs(idxs)[1], expand_idxs(idxs)), accs[r])]
local_accs = {r: gate_acc(r,global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for r in accs}
# run early AST (with reduce)
self.ast_parse(reduceop, local_accs, self.acc_offsets(self.full_buf_index), loaded_buffers, reduce_acc=accs[reduceop])
# end the reduce loop
self.load_cache.clear()
# end the local loop, do the local reduce
if self.group_for_reduces:
fake_global_idxs = [x*0 for x in global_idxs]
stores = self.global_store(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop]) # store accumulators
barrier = UOp(UOps.BARRIER, None, tuple(stores))
if self.opts.has_local:
fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
self.late_gate = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1)
# create new late reduce local loops and replace local_idxs that have been used
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+self.group_for_reduces)] # noqa: E501
local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:]
# if any group_for_reduce items aren't reduces, upcast them here
for j in self.upcast_in_mid_reduce_axes:
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
self.upcast()
self.group_for_reduces -= 1
local_idxs = local_idxs[:-1]
end_local_idxs = end_local_idxs[:-1]
# regenerate upcast_idxs
upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
# NOTE: this structure is the same as the reduce op above
# late reduce loop
loop_ctx = self.render_loop(end_local_idxs, i*2+3, True)
# define late accumulator
accs[reduceop] = self.global_load(0, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
# load localbufs
loaded_buffers[self.bufs[out_buf]] = self.global_load(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=(barrier,))
# there's no AST here (and there's no shape for the reduce LazyOp)
self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[out_buf]),)),\
accs, self.acc_offsets(-1), loaded_buffers, reduce_acc=accs[reduceop])
# end the late reduce loop
self.load_cache.clear()
if reduceop is not self.reduceops[-1]:
for j in self.upcast_in_mid_reduce_axes:
self.upcasted -= 1
self.group_for_reduces += 1
assert self.buf_uops[out_buf] is not None, "Local reduce buf must have been uoped at this point"
fake_local_idxs = local_idxs[:self.local_dims] + [x*0 for x in local_idxs[self.local_dims:]]
stores = self.global_store(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop])
barrier = UOp(UOps.BARRIER, None, tuple(stores))
accs[reduceop] = self.global_load(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, barrier=(barrier,))
return local_idxs[:self.local_dims] + [NumNode(0) for _ in range(self.group_for_reduces)], upcast_idxs
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
def linearize(self) -> Linearizer:
# no new opts and we already ran? skip relinearizing
if self.applied_opts == self.applied_opts_cache: return self
# late alias the tensor core buffers
if (tc:=self.tensor_core) and self.tensor_core_opts is not None:
alias_pattern = [0]*(self.global_dims) + [2]*(len(tc.threads)) + [0]*(self.local_dims-len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2) # noqa: E501
for op, tc_bufs in self.bufs_for_tensor_core.items():
for tc_buf in tc_bufs: self.alias_buffer(op, tc_buf, alias_pattern)
# save backups
sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduces, self.upcasted
# uops
self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs)
self.loop_uops: Dict[str, UOp] = {}
self.late_gate = None
# add global buffers
for i,buf in enumerate(self.bufs):
if isinstance(buf, MemBuffer):
self.buf_uops[i] = UOp(UOps.DEFINE_GLOBAL,
buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
(buf.idx, any(buf.idx == x.idx for x in self.outbufs)))
# define local buffers
for aliases in self.local_alias.values():
for lb in aliases.values(): self.buf_uops[self.bufs.index(lb)] = UOp(UOps.DEFINE_LOCAL, PtrDType(lb.dtype),
(), (lb.name, self.sts[self.bufs.index(lb)].size))
# add a local buffer for multistage reduce. # TODO: use local alias
if self.group_for_reduces:
for i in range(len(self.reduceops)):
# TODO: the strides of this can be controlled
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
temp_dtype = self.get_base_dtype(cast(LazyOp, self.reduceop).dtype)
self.bufs.append(LocalBuffer(name:=f"temp{i if len(self.reduceops) > 1 else ''}", buf_size:=self.sts[-1].size, temp_dtype))
self.buf_uops.append(UOp(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), (name, buf_size)))
# kernel name (before late upcast)
self.name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \
(f"{len(self.outbufs)}_" if len(self.outbufs) > 1 else "_") + \
colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
# name the function something unique
Linearizer.kernel_cnt[(function_name := to_function_name(self.name))] += 1
suffix = f"{'n'+str(Linearizer.kernel_cnt[function_name]-1)}" if Linearizer.kernel_cnt[function_name] > 1 else ""
self.name = self.name+colored(suffix, 'BLACK')
# define indexes
gl_dims = self.full_shape[:self.first_reduce+self.group_for_reduces]
global_idxs, loop_global_idxs, self.global_size = get_grouped_dims("idx" if self.dont_use_locals else "gidx", 0, gl_dims[:self.global_dims],
self.opts.global_max, self.opts.has_local)
local_idxs, loop_local_idxs, self.local_size = get_grouped_dims("lidx", self.global_dims, gl_dims[self.global_dims:],
self.opts.local_max if self.opts.has_local else (), False)
upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])]
# render global and local as specials or a loop
if self.opts.has_local:
self.loop_uops.update({x.expr:UOp(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
if not self.dont_use_locals:
self.loop_uops.update({x.expr:UOp(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
else:
self.global_size, self.local_size = None, None
self.render_loop(loop_global_idxs+loop_local_idxs, 1, False)
# define idxs for aliased buffers TODO: this doesn't belong in Kernel, but it can't exist in Block either (because of multireduce tensor cores)
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
alias_buf_idxs = self.index_local_aliases(global_idxs,local_idxs,reduce_idxs,upcast_idxs,full_upcast_idxs)
# parse AST
self.load_cache: Dict[str, UOp] = {}
loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]] = {}
accs: Dict[LazyOp, List[UOp]] = {}
# render reduceops by depth
for reduceop in self.reduceops:
self.render_block((reduceop, ), global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs)
stores = self.render_block(self.ast, global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs)
# only the final stores are needed to define the full UOps graph
self.uops:UOpGraph = UOpGraph(flatten(stores))
# maybe graph the uops
if DEBUG >= 5: self.uops.print()
if getenv("GRAPHUOPS"): self.uops.graph()
# restore backups
self.sts, self.group_for_reduces, self.upcasted = sts_backup, gfr_backup, upc_backup
# set cache and return
self.applied_opts_cache = self.applied_opts[:]
return self
def render_block(self, outputs:Tuple[LazyOp, ...], global_idxs, local_idxs, upcast_idxs, full_upcast_idxs,
alias_buf_idxs:DefaultDict[LazyOp,List[Tuple[int,int,List[NumNode|Variable]]]],
loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], accs:Dict[LazyOp,List[UOp]]) -> List[List[UOp]]:
reduceops = dedup(x for x in outputs if x.op in ReduceOps)
assert len(reduceops) <= 1, "max one reduceop per block"
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
fake_reduce_idxs = [x*0 for x in reduce_idxs]
if len(reduceops) != 0:
# TODO: delete render_reduceop and move the logic for group_for_reduces to Block
nlidx, nuidx = self.render_reduceop((r:=reduceops[0]),accs,loaded_buffers,\
global_idxs,local_idxs,upcast_idxs,full_upcast_idxs,reduce_idxs,fake_reduce_idxs,alias_buf_idxs[r])
# all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have
# been rewritten with fake end_local_idxs.
if r is self.reduceops[-1]: local_idxs[:], upcast_idxs[:] = nlidx, nuidx
return [accs[r]]
# load latebufs
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \
for i,b in enumerate(self.bufs) if b not in self.earlybufs and b.__class__ is not LocalBuffer})
# run late AST (without the store)
store_vals = {op.arg.idx:self.ast_parse(op.src[0], accs, None, loaded_buffers) for op in self.ast}
return [self.global_store(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val) for i, val in store_vals.items()]
def ast_parse(self, x:LazyOp, accs:Dict[LazyOp, List[UOp]], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], reduce_acc:Optional[List[UOp]]=None, cache=None) -> List[UOp]: # noqa: E501
if cache is None: cache = {}
if x in cache: return cache[x]
if x.op in BufferOps: return loaded_buffers[x.arg]
if x.op in [UnaryOps.CAST, UnaryOps.BITCAST]:
return [UOp(UOps.BITCAST if x.op is UnaryOps.BITCAST else UOps.CAST,
self.get_base_dtype(x.arg), (u,)) for u in self.ast_parse(x.src[0], accs, offs, loaded_buffers)]
if x.op in ReduceOps and reduce_acc is None:
return [accs[x][i] for i in offs] if offs else accs[x]
values = [self.ast_parse(v, accs, offs, loaded_buffers, cache=cache) for v in x.src]
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}
if x.op in ops:
assert reduce_acc is not None
ret: List[UOp] = []
acc, input_acc = reduce_acc, reduce_acc[:]
for val, off in zip(zip(*values), cast(List[int], offs)):
acc[off] = UOp.alu(ops[cast(ReduceOps, x.op)], *(val+(acc[off], )))
ret.append(acc[off])
for off in range(len(acc)):
if input_acc[off] != acc[off]:
acc[off] = UOp(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]))
else: ret = [UOp.alu(x.op, *src) for src in zip(*values)]
cache[x] = ret
return ret
def to_program(self) -> Program:
self.linearize()
info = get_lazyop_info(self.ast[0])
src = self.opts.render(name:=to_function_name(self.name), self.uops)
if getenv("RUN_PROCESS_REPLAY"): diskcache_put("process_replay", id(self), (self.ast, self.opts, self.applied_opts, name, src))
ops, mem = self.uops.flops_mem()
run_count = prod((self.global_size or []) + (self.local_size or []))
# NOTE: we use min here to ignore the indexing FLOPS
return Program(self.name, src, self.opts.device, self.global_size, self.local_size,
self.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))