diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 1131fee81e..f0d1d2be28 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -1,11 +1,13 @@ from __future__ import annotations -import itertools -from typing import Optional, List, Tuple, cast, Dict, Union +import itertools, functools +from dataclasses import replace +from collections import defaultdict +from typing import Optional, List, Tuple, cast, Dict, Union, Final, DefaultDict from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, UNSAFE_PAD_OPS, verify_lazyop from tinygrad.device import Device from tinygrad.renderer import Renderer, TensorCore -from tinygrad.dtype import dtypes, ImageDType, DType -from tinygrad.helpers import all_same, colored, ansilen, dedup, flatten, getenv, prod, DEBUG, round_up, all_int, get_contraction +from tinygrad.dtype import dtypes, ImageDType +from tinygrad.helpers import all_same, colored, ansilen, dedup, flatten, getenv, prod, DEBUG, round_up, all_int, get_contraction, to_function_name from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import sint from tinygrad.shape.view import strides_for_shape @@ -46,14 +48,6 @@ class TensorCoreOptions: elif removed_axis == axes[tc_dim]: axes_exist[tc_dim] = False self.axes, self.axes_exist = tuple(axes), tuple(axes_exist) -@dataclass(frozen=True) -class LocalBuffer: - name: str - size: int - dtype: DType = dtypes.float32 - realized: None = None - def __str__(self): return f"localbuffer<{self.name}[{self.size}]>" - class Kernel: def __init__(self, *ast:LazyOp, opts:Optional[Renderer]=None): self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer @@ -69,14 +63,14 @@ class Kernel: self.outbufs, self.vars = [x.arg for x in self.ast], flatten([x.vars() for x in self.ast]) loadops = [BufferOps.LOAD, BufferOps.CONST] - self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = self.outbufs + dedup([x.arg for x in self.lazyops if x.op in loadops]) + self.bufs: List[Union[MemBuffer, ConstBuffer]] = self.outbufs + dedup([x.arg for x in self.lazyops if x.op in loadops]) # get earlybufs, before any reduceops self.earlybufs = [x.arg for reduceop in self.reduceops for x in reduceop.lazyops if x.op in BufferOps] self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0 # create new shapetrackers inside this kernel, we will permute them - self.sts: List[ShapeTracker] = [x.st for x in cast(List[Union[MemBuffer, ConstBuffer]], self.bufs)] + self.sts: List[ShapeTracker] = [x.st for x in self.bufs] # move all reduce axes to the end reduce = list(enumerate(zip(self.full_shape, self.output_shape))) @@ -109,7 +103,7 @@ class Kernel: # things downstream of the AST ret.reduceops, ret.outbufs, ret.vars, ret.bufs, ret.earlybufs, ret.full_buf_index = \ - self.reduceops, self.outbufs, self.vars, [x for x in self.bufs if not isinstance(x, LocalBuffer)], self.earlybufs, self.full_buf_index + self.reduceops, self.outbufs, self.vars, self.bufs, self.earlybufs, self.full_buf_index ret.sts = self.sts[:len(ret.bufs)] # NOTE: must redo the local buffers with TC in beam # parameters for optimizations @@ -620,3 +614,85 @@ class Kernel: will_delete_shape = local_sz == self.full_shape[axis] self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz)) if will_delete_shape: deleted_shape += 1 + + # **** kernel outputs **** + + kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int) + @functools.cached_property + def name(self) -> str: + # kernel name (before late upcast) + name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \ + (f"{len(self.outbufs)}_" if len(self.outbufs) > 1 else "_") + \ + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())]) + + # name the function something unique + Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1 + suffix = f"{'n'+str(Kernel.kernel_cnt[function_name]-1)}" if Kernel.kernel_cnt[function_name] > 1 else "" + return name+colored(suffix, 'BLACK') + + def get_optimized_ast(self) -> Tuple[LazyOp, ...]: + # set the shapetrackers to the optimized ones, fixup reduceop + # transformed to the final LazyOp + @functools.lru_cache(None) + def fixup_ast(op:LazyOp, apply_to_st=None) -> LazyOp: + if op.op in BufferOps: + idx = self.bufs.index(op.arg) + arg = replace(op.arg, st=self.sts[idx] if apply_to_st is None else apply_to_st(self.sts[idx])) + elif op.op in ReduceOps: + arg = tuple(i for i in range(self.first_reduce+self.group_for_reduces, self.shape_len) if self.full_shape[i] != self.sts[0].shape[i]) + if op in self.bufs_for_tensor_core and (tc := self.tensor_core): + rsrc = op.src[0] + if rsrc.op is UnaryOps.CAST: rsrc = rsrc.src[0] + assert rsrc.op is BinaryOps.MUL + + def fix_st(warp_dims, tcd_dims, tcd_expand, pattern_1, pattern_2, st1): + wd = self.global_dims + tcd = self.shape_len-self.upcasted + assert st1.shape[wd:wd+len(warp_dims)] == warp_dims, "warp dims wrong" + assert st1.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, "tcd dims wrong" + new_shape = st1.shape[:tcd] + tcd_expand + st1.shape[tcd+len(tcd_dims):] # expand the tcd + permaxis = list(range(wd)) + for x,y in pattern_1: permaxis.append(y + (wd if x == 0 else tcd)) + permaxis += list(range(wd+len(warp_dims), tcd)) + for x,y in pattern_2: permaxis.append(y + (wd if x == 0 else tcd)) + permaxis += list(range(tcd+len(tcd_expand), self.shape_len+len(tcd_expand)-len(tcd_dims))) + return st1.reshape(new_shape).simplify().permute(tuple(permaxis)).reshape(st1.shape) + + if self.opts.device == "AMD": + reduce_axes = [self.shape_len-self.upcasted] + upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted, self.shape_len-self.upcasted+1) + fix_st1 = functools.partial(fix_st, (8,2,2), (16,8), (16,2,4), ((1,2), (0,2), (1,1), (0,1)), ((1,0), (0,0))) + fix_st2 = None + elif self.opts.device == "METAL": + reduce_axes = [self.shape_len-self.upcasted] + upcast_axis = (self.shape_len-self.upcasted+1, self.shape_len-self.upcasted+1, self.shape_len-self.upcasted+1) + fix_st1 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((1,1), (0,1), (1,0), (0,3)), ((0,0), (0,2), (1,3), (1,2))) + fix_st2 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((0,0), (1,1), (1,2), (0,2), (1,0)), ((0,1), (0,3), (1,3))) + elif self.opts.device in {"CUDA", "NV"}: + reduce_axes = [self.shape_len-self.upcasted, self.shape_len-self.upcasted+1] + upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted+2, self.shape_len-self.upcasted+2) + # https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float + fix_st1 = functools.partial(fix_st, (2,2,2,2,2), (8,2,4), (2,2,2,2,2,2), + ((1,1), (1,0), (0,2), (0,3), (0,4)), ((1,3), (1,4), (1,2), (0,0), (0,1), (1,5))) + fix_st2 = functools.partial(fix_st, (2,2,2,2,2), (8,2,4), (2,2,2,2,2,2), + ((1,1), (1,0), (1,5), (0,0), (0,1)), ((0,4), (0,2), (1,4), (0,3), (1,3), (1,2))) + else: + raise RuntimeError("unsupported device for tensor cores") + + assert apply_to_st is None, "double tensor core? not supported" + wmma_sz = [prod(l) for l in tc.thread_local_sizes] + wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), self.opts.device, upcast_axis, tuple(reduce_axes)) + ret = LazyOp(ReduceOps.WMMA, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), wmma_arg) + new_reduce_axes = tuple(i for i in arg if i not in reduce_axes) + return LazyOp(op.op, (ret,), new_reduce_axes) if len(new_reduce_axes) else ret + if self.group_for_reduces: + start = LazyOp(op.op, tuple(fixup_ast(x) for x in op.src), arg) + sts = ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])) # noqa: E501 + local_buffer = MemBuffer(-1, start.dtype, sts) + local_store = LazyOp(BufferOps.STORE, (start,), local_buffer) + local_load = LazyOp(BufferOps.LOAD, (local_store,), local_buffer) + return LazyOp(op.op, (local_load,), tuple(range(self.first_reduce, self.first_reduce+self.group_for_reduces))) + else: + arg = op.arg + return LazyOp(op.op, tuple(fixup_ast(x) for x in op.src), arg) + return tuple(fixup_ast(x) for x in self.ast) diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index dfddf96dbd..9b5269c27f 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -1,15 +1,13 @@ from __future__ import annotations -from typing import List, Tuple, cast, Optional, Any, Dict, Final, DefaultDict +from typing import List, Tuple, cast, Optional, Any, Dict import functools -from dataclasses import replace -from collections import defaultdict -from tinygrad.codegen.kernel import LocalBuffer, Kernel +from tinygrad.codegen.kernel import Kernel from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType -from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, MemBuffer, BinaryOps, get_lazyop_info +from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, get_lazyop_info from tinygrad.codegen.uops import UOp, UOpGraph, UOps from tinygrad.renderer import Program -from tinygrad.helpers import to_function_name, colored, DEBUG, getenv, prod +from tinygrad.helpers import to_function_name, DEBUG, getenv, prod # TODO: this needs to be replaced, there shouldn't be variables in the shapetracker def variable_to_uop(x, ctx=None) -> UOp: @@ -59,10 +57,8 @@ class Lowerer(Kernel): if x.op is BufferOps.CONST: dtype = x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype return UOp.alu(TernaryOps.WHERE, valid, UOp.const(dtype, x.arg.val), UOp.const(dtype, 0)) - if isinstance(self.bufs[x.arg.idx], LocalBuffer): - # TODO: this should come from somewhere else - lb = self.bufs[x.arg.idx] - buf = UOp(UOps.DEFINE_LOCAL, PtrDType(lb.dtype), (), (lb.name, lb.size)) + if x.arg.idx == -1: + buf = UOp(UOps.DEFINE_LOCAL, PtrDType(x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype), (), ("temp", x.arg.st.size)) else: buf = UOp(UOps.DEFINE_GLOBAL, x.arg.dtype if isinstance(x.arg.dtype, ImageDType) else PtrDType(x.arg.dtype), (), (x.arg.idx, any(x.arg.idx == y.idx for y in self.outbufs))) @@ -85,106 +81,16 @@ class Lowerer(Kernel): UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=(upcast_axis[1],)), UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg) return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=upcast_axis[2]) - src = (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg) - return UOp(UOps.REDUCE, dtype, src, x.op) + return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op) return UOp.alu(x.op, *in_uops) - kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int) def linearize(self) -> Lowerer: - sts_backup, bufs_backup = self.sts, self.bufs - - self.uop_cache: Dict[LazyOp, UOp] = {} - - # kernel name (before late upcast) - self.name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \ - (f"{len(self.outbufs)}_" if len(self.outbufs) > 1 else "_") + \ - colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())]) - if DEBUG >= 4: print(self.name) - - # name the function something unique - Lowerer.kernel_cnt[(function_name := to_function_name(self.name))] += 1 - suffix = f"{'n'+str(Lowerer.kernel_cnt[function_name]-1)}" if Lowerer.kernel_cnt[function_name] > 1 else "" - self.name = self.name+colored(suffix, 'BLACK') - - self.idxs = [] - # add a local buffer for multistage reduce. - if self.group_for_reduces: - for i in range(len(self.reduceops)): - # TODO: the strides of this can be controlled - self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501 - temp_dtype = cast(LazyOp, self.reduceop).dtype - self.bufs.append(LocalBuffer(f"temp{i if len(self.reduceops) > 1 else ''}", self.sts[-1].size, - temp_dtype.base if isinstance(temp_dtype, ImageDType) else temp_dtype)) - - # set the shapetrackers to the optimized ones, fixup reduceop - # transformed to the final LazyOp - @functools.lru_cache(None) - def fixup_ast(op:LazyOp, apply_to_st=None) -> LazyOp: - if op.op in BufferOps: - idx = self.bufs.index(op.arg) - arg = replace(op.arg, st=self.sts[idx] if apply_to_st is None else apply_to_st(self.sts[idx])) - elif op.op in ReduceOps: - arg = tuple(i for i in range(self.first_reduce+self.group_for_reduces, self.shape_len) if self.full_shape[i] != self.sts[0].shape[i]) - if op in self.bufs_for_tensor_core and (tc := self.tensor_core): - rsrc = op.src[0] - if rsrc.op is UnaryOps.CAST: rsrc = rsrc.src[0] - assert rsrc.op is BinaryOps.MUL - - def fix_st(warp_dims, tcd_dims, tcd_expand, pattern_1, pattern_2, st1): - wd = self.global_dims - tcd = self.shape_len-self.upcasted - assert st1.shape[wd:wd+len(warp_dims)] == warp_dims, "warp dims wrong" - assert st1.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, "tcd dims wrong" - new_shape = st1.shape[:tcd] + tcd_expand + st1.shape[tcd+len(tcd_dims):] # expand the tcd - permaxis = list(range(wd)) - for x,y in pattern_1: permaxis.append(y + (wd if x == 0 else tcd)) - permaxis += list(range(wd+len(warp_dims), tcd)) - for x,y in pattern_2: permaxis.append(y + (wd if x == 0 else tcd)) - permaxis += list(range(tcd+len(tcd_expand), self.shape_len+len(tcd_expand)-len(tcd_dims))) - return st1.reshape(new_shape).simplify().permute(tuple(permaxis)).reshape(st1.shape) - - if self.opts.device == "AMD": - reduce_axes = [self.shape_len-self.upcasted] - upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted, self.shape_len-self.upcasted+1) - fix_st1 = functools.partial(fix_st, (8,2,2), (16,8), (16,2,4), ((1,2), (0,2), (1,1), (0,1)), ((1,0), (0,0))) - fix_st2 = None - elif self.opts.device == "METAL": - reduce_axes = [self.shape_len-self.upcasted] - upcast_axis = (self.shape_len-self.upcasted+1, self.shape_len-self.upcasted+1, self.shape_len-self.upcasted+1) - fix_st1 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((1,1), (0,1), (1,0), (0,3)), ((0,0), (0,2), (1,3), (1,2))) - fix_st2 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((0,0), (1,1), (1,2), (0,2), (1,0)), ((0,1), (0,3), (1,3))) - elif self.opts.device in {"CUDA", "NV"}: - reduce_axes = [self.shape_len-self.upcasted, self.shape_len-self.upcasted+1] - upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted+2, self.shape_len-self.upcasted+2) - # https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float - fix_st1 = functools.partial(fix_st, (2,2,2,2,2), (8,2,4), (2,2,2,2,2,2), - ((1,1), (1,0), (0,2), (0,3), (0,4)), ((1,3), (1,4), (1,2), (0,0), (0,1), (1,5))) - fix_st2 = functools.partial(fix_st, (2,2,2,2,2), (8,2,4), (2,2,2,2,2,2), - ((1,1), (1,0), (1,5), (0,0), (0,1)), ((0,4), (0,2), (1,4), (0,3), (1,3), (1,2))) - else: - raise RuntimeError("unsupported device for tensor cores") - - assert apply_to_st is None, "double tensor core? not supported" - wmma_sz = [prod(l) for l in tc.thread_local_sizes] - wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), self.opts.device, upcast_axis, tuple(reduce_axes)) - ret = LazyOp(ReduceOps.WMMA, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), wmma_arg) - new_reduce_axes = tuple(i for i in arg if i not in reduce_axes) - return LazyOp(op.op, (ret,), new_reduce_axes) if len(new_reduce_axes) else ret - if self.group_for_reduces: - start = LazyOp(op.op, tuple(fixup_ast(x) for x in op.src), arg) - local_buffer = MemBuffer(-1, start.dtype, self.sts[-1]) - local_store = LazyOp(BufferOps.STORE, (start,), local_buffer) - local_load = LazyOp(BufferOps.LOAD, (local_store,), local_buffer) - return LazyOp(op.op, (local_load,), tuple(range(self.first_reduce, self.first_reduce+self.group_for_reduces))) - else: - arg = op.arg - return LazyOp(op.op, tuple(fixup_ast(x) for x in op.src), arg) - modified_ast = tuple(fixup_ast(x) for x in self.ast) - + modified_ast = self.get_optimized_ast() if DEBUG >= 4: from tinygrad.engine.graph import print_tree for mast in modified_ast: print_tree(mast) + self.idxs = [] if self.opts.has_local: # define indexes global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0) @@ -218,10 +124,9 @@ class Lowerer(Kernel): for a in range(self.first_reduce, self.first_reduce+self.group_for_reduces): self.ridxs[a] = UOp(UOps.RANGE, dtypes.int32, (UOp.const(dtypes.int32, 0), variable_to_uop(self.full_shape[a])), (1000+a, True)) + self.uop_cache: Dict[LazyOp, UOp] = {} self.uops:UOpGraph = UOpGraph([self.to_uop(x) for x in modified_ast], self.opts) - self.sts, self.bufs = sts_backup, bufs_backup - # maybe graph the uops if DEBUG >= 5: self.uops.print() if getenv("GRAPHUOPS"):