move ast to the kernel (#5362)

* move ast to the kernel

* locals aren't image

* comment
This commit is contained in:
George Hotz
2024-07-10 16:22:26 -07:00
committed by GitHub
parent 245d83a392
commit 7a014d5435
2 changed files with 101 additions and 120 deletions

View File

@@ -1,11 +1,13 @@
from __future__ import annotations
import itertools
from typing import Optional, List, Tuple, cast, Dict, Union
import itertools, functools
from dataclasses import replace
from collections import defaultdict
from typing import Optional, List, Tuple, cast, Dict, Union, Final, DefaultDict
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, UNSAFE_PAD_OPS, verify_lazyop
from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore
from tinygrad.dtype import dtypes, ImageDType, DType
from tinygrad.helpers import all_same, colored, ansilen, dedup, flatten, getenv, prod, DEBUG, round_up, all_int, get_contraction
from tinygrad.dtype import dtypes, ImageDType
from tinygrad.helpers import all_same, colored, ansilen, dedup, flatten, getenv, prod, DEBUG, round_up, all_int, get_contraction, to_function_name
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import sint
from tinygrad.shape.view import strides_for_shape
@@ -46,14 +48,6 @@ class TensorCoreOptions:
elif removed_axis == axes[tc_dim]: axes_exist[tc_dim] = False
self.axes, self.axes_exist = tuple(axes), tuple(axes_exist)
@dataclass(frozen=True)
class LocalBuffer:
name: str
size: int
dtype: DType = dtypes.float32
realized: None = None
def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"
class Kernel:
def __init__(self, *ast:LazyOp, opts:Optional[Renderer]=None):
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
@@ -69,14 +63,14 @@ class Kernel:
self.outbufs, self.vars = [x.arg for x in self.ast], flatten([x.vars() for x in self.ast])
loadops = [BufferOps.LOAD, BufferOps.CONST]
self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = self.outbufs + dedup([x.arg for x in self.lazyops if x.op in loadops])
self.bufs: List[Union[MemBuffer, ConstBuffer]] = self.outbufs + dedup([x.arg for x in self.lazyops if x.op in loadops])
# get earlybufs, before any reduceops
self.earlybufs = [x.arg for reduceop in self.reduceops for x in reduceop.lazyops if x.op in BufferOps]
self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0
# create new shapetrackers inside this kernel, we will permute them
self.sts: List[ShapeTracker] = [x.st for x in cast(List[Union[MemBuffer, ConstBuffer]], self.bufs)]
self.sts: List[ShapeTracker] = [x.st for x in self.bufs]
# move all reduce axes to the end
reduce = list(enumerate(zip(self.full_shape, self.output_shape)))
@@ -109,7 +103,7 @@ class Kernel:
# things downstream of the AST
ret.reduceops, ret.outbufs, ret.vars, ret.bufs, ret.earlybufs, ret.full_buf_index = \
self.reduceops, self.outbufs, self.vars, [x for x in self.bufs if not isinstance(x, LocalBuffer)], self.earlybufs, self.full_buf_index
self.reduceops, self.outbufs, self.vars, self.bufs, self.earlybufs, self.full_buf_index
ret.sts = self.sts[:len(ret.bufs)] # NOTE: must redo the local buffers with TC in beam
# parameters for optimizations
@@ -620,3 +614,85 @@ class Kernel:
will_delete_shape = local_sz == self.full_shape[axis]
self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
if will_delete_shape: deleted_shape += 1
# **** kernel outputs ****
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
@functools.cached_property
def name(self) -> str:
# kernel name (before late upcast)
name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \
(f"{len(self.outbufs)}_" if len(self.outbufs) > 1 else "_") + \
colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
# name the function something unique
Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1
suffix = f"{'n'+str(Kernel.kernel_cnt[function_name]-1)}" if Kernel.kernel_cnt[function_name] > 1 else ""
return name+colored(suffix, 'BLACK')
def get_optimized_ast(self) -> Tuple[LazyOp, ...]:
# set the shapetrackers to the optimized ones, fixup reduceop
# transformed to the final LazyOp
@functools.lru_cache(None)
def fixup_ast(op:LazyOp, apply_to_st=None) -> LazyOp:
if op.op in BufferOps:
idx = self.bufs.index(op.arg)
arg = replace(op.arg, st=self.sts[idx] if apply_to_st is None else apply_to_st(self.sts[idx]))
elif op.op in ReduceOps:
arg = tuple(i for i in range(self.first_reduce+self.group_for_reduces, self.shape_len) if self.full_shape[i] != self.sts[0].shape[i])
if op in self.bufs_for_tensor_core and (tc := self.tensor_core):
rsrc = op.src[0]
if rsrc.op is UnaryOps.CAST: rsrc = rsrc.src[0]
assert rsrc.op is BinaryOps.MUL
def fix_st(warp_dims, tcd_dims, tcd_expand, pattern_1, pattern_2, st1):
wd = self.global_dims
tcd = self.shape_len-self.upcasted
assert st1.shape[wd:wd+len(warp_dims)] == warp_dims, "warp dims wrong"
assert st1.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, "tcd dims wrong"
new_shape = st1.shape[:tcd] + tcd_expand + st1.shape[tcd+len(tcd_dims):] # expand the tcd
permaxis = list(range(wd))
for x,y in pattern_1: permaxis.append(y + (wd if x == 0 else tcd))
permaxis += list(range(wd+len(warp_dims), tcd))
for x,y in pattern_2: permaxis.append(y + (wd if x == 0 else tcd))
permaxis += list(range(tcd+len(tcd_expand), self.shape_len+len(tcd_expand)-len(tcd_dims)))
return st1.reshape(new_shape).simplify().permute(tuple(permaxis)).reshape(st1.shape)
if self.opts.device == "AMD":
reduce_axes = [self.shape_len-self.upcasted]
upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted, self.shape_len-self.upcasted+1)
fix_st1 = functools.partial(fix_st, (8,2,2), (16,8), (16,2,4), ((1,2), (0,2), (1,1), (0,1)), ((1,0), (0,0)))
fix_st2 = None
elif self.opts.device == "METAL":
reduce_axes = [self.shape_len-self.upcasted]
upcast_axis = (self.shape_len-self.upcasted+1, self.shape_len-self.upcasted+1, self.shape_len-self.upcasted+1)
fix_st1 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((1,1), (0,1), (1,0), (0,3)), ((0,0), (0,2), (1,3), (1,2)))
fix_st2 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((0,0), (1,1), (1,2), (0,2), (1,0)), ((0,1), (0,3), (1,3)))
elif self.opts.device in {"CUDA", "NV"}:
reduce_axes = [self.shape_len-self.upcasted, self.shape_len-self.upcasted+1]
upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted+2, self.shape_len-self.upcasted+2)
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
fix_st1 = functools.partial(fix_st, (2,2,2,2,2), (8,2,4), (2,2,2,2,2,2),
((1,1), (1,0), (0,2), (0,3), (0,4)), ((1,3), (1,4), (1,2), (0,0), (0,1), (1,5)))
fix_st2 = functools.partial(fix_st, (2,2,2,2,2), (8,2,4), (2,2,2,2,2,2),
((1,1), (1,0), (1,5), (0,0), (0,1)), ((0,4), (0,2), (1,4), (0,3), (1,3), (1,2)))
else:
raise RuntimeError("unsupported device for tensor cores")
assert apply_to_st is None, "double tensor core? not supported"
wmma_sz = [prod(l) for l in tc.thread_local_sizes]
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), self.opts.device, upcast_axis, tuple(reduce_axes))
ret = LazyOp(ReduceOps.WMMA, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), wmma_arg)
new_reduce_axes = tuple(i for i in arg if i not in reduce_axes)
return LazyOp(op.op, (ret,), new_reduce_axes) if len(new_reduce_axes) else ret
if self.group_for_reduces:
start = LazyOp(op.op, tuple(fixup_ast(x) for x in op.src), arg)
sts = ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])) # noqa: E501
local_buffer = MemBuffer(-1, start.dtype, sts)
local_store = LazyOp(BufferOps.STORE, (start,), local_buffer)
local_load = LazyOp(BufferOps.LOAD, (local_store,), local_buffer)
return LazyOp(op.op, (local_load,), tuple(range(self.first_reduce, self.first_reduce+self.group_for_reduces)))
else:
arg = op.arg
return LazyOp(op.op, tuple(fixup_ast(x) for x in op.src), arg)
return tuple(fixup_ast(x) for x in self.ast)

View File

@@ -1,15 +1,13 @@
from __future__ import annotations
from typing import List, Tuple, cast, Optional, Any, Dict, Final, DefaultDict
from typing import List, Tuple, cast, Optional, Any, Dict
import functools
from dataclasses import replace
from collections import defaultdict
from tinygrad.codegen.kernel import LocalBuffer, Kernel
from tinygrad.codegen.kernel import Kernel
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, MemBuffer, BinaryOps, get_lazyop_info
from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, get_lazyop_info
from tinygrad.codegen.uops import UOp, UOpGraph, UOps
from tinygrad.renderer import Program
from tinygrad.helpers import to_function_name, colored, DEBUG, getenv, prod
from tinygrad.helpers import to_function_name, DEBUG, getenv, prod
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker
def variable_to_uop(x, ctx=None) -> UOp:
@@ -59,10 +57,8 @@ class Lowerer(Kernel):
if x.op is BufferOps.CONST:
dtype = x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype
return UOp.alu(TernaryOps.WHERE, valid, UOp.const(dtype, x.arg.val), UOp.const(dtype, 0))
if isinstance(self.bufs[x.arg.idx], LocalBuffer):
# TODO: this should come from somewhere else
lb = self.bufs[x.arg.idx]
buf = UOp(UOps.DEFINE_LOCAL, PtrDType(lb.dtype), (), (lb.name, lb.size))
if x.arg.idx == -1:
buf = UOp(UOps.DEFINE_LOCAL, PtrDType(x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype), (), ("temp", x.arg.st.size))
else:
buf = UOp(UOps.DEFINE_GLOBAL, x.arg.dtype if isinstance(x.arg.dtype, ImageDType) else PtrDType(x.arg.dtype), (),
(x.arg.idx, any(x.arg.idx == y.idx for y in self.outbufs)))
@@ -85,106 +81,16 @@ class Lowerer(Kernel):
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=(upcast_axis[1],)),
UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=upcast_axis[2])
src = (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg)
return UOp(UOps.REDUCE, dtype, src, x.op)
return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op)
return UOp.alu(x.op, *in_uops)
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
def linearize(self) -> Lowerer:
sts_backup, bufs_backup = self.sts, self.bufs
self.uop_cache: Dict[LazyOp, UOp] = {}
# kernel name (before late upcast)
self.name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \
(f"{len(self.outbufs)}_" if len(self.outbufs) > 1 else "_") + \
colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
if DEBUG >= 4: print(self.name)
# name the function something unique
Lowerer.kernel_cnt[(function_name := to_function_name(self.name))] += 1
suffix = f"{'n'+str(Lowerer.kernel_cnt[function_name]-1)}" if Lowerer.kernel_cnt[function_name] > 1 else ""
self.name = self.name+colored(suffix, 'BLACK')
self.idxs = []
# add a local buffer for multistage reduce.
if self.group_for_reduces:
for i in range(len(self.reduceops)):
# TODO: the strides of this can be controlled
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
temp_dtype = cast(LazyOp, self.reduceop).dtype
self.bufs.append(LocalBuffer(f"temp{i if len(self.reduceops) > 1 else ''}", self.sts[-1].size,
temp_dtype.base if isinstance(temp_dtype, ImageDType) else temp_dtype))
# set the shapetrackers to the optimized ones, fixup reduceop
# transformed to the final LazyOp
@functools.lru_cache(None)
def fixup_ast(op:LazyOp, apply_to_st=None) -> LazyOp:
if op.op in BufferOps:
idx = self.bufs.index(op.arg)
arg = replace(op.arg, st=self.sts[idx] if apply_to_st is None else apply_to_st(self.sts[idx]))
elif op.op in ReduceOps:
arg = tuple(i for i in range(self.first_reduce+self.group_for_reduces, self.shape_len) if self.full_shape[i] != self.sts[0].shape[i])
if op in self.bufs_for_tensor_core and (tc := self.tensor_core):
rsrc = op.src[0]
if rsrc.op is UnaryOps.CAST: rsrc = rsrc.src[0]
assert rsrc.op is BinaryOps.MUL
def fix_st(warp_dims, tcd_dims, tcd_expand, pattern_1, pattern_2, st1):
wd = self.global_dims
tcd = self.shape_len-self.upcasted
assert st1.shape[wd:wd+len(warp_dims)] == warp_dims, "warp dims wrong"
assert st1.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, "tcd dims wrong"
new_shape = st1.shape[:tcd] + tcd_expand + st1.shape[tcd+len(tcd_dims):] # expand the tcd
permaxis = list(range(wd))
for x,y in pattern_1: permaxis.append(y + (wd if x == 0 else tcd))
permaxis += list(range(wd+len(warp_dims), tcd))
for x,y in pattern_2: permaxis.append(y + (wd if x == 0 else tcd))
permaxis += list(range(tcd+len(tcd_expand), self.shape_len+len(tcd_expand)-len(tcd_dims)))
return st1.reshape(new_shape).simplify().permute(tuple(permaxis)).reshape(st1.shape)
if self.opts.device == "AMD":
reduce_axes = [self.shape_len-self.upcasted]
upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted, self.shape_len-self.upcasted+1)
fix_st1 = functools.partial(fix_st, (8,2,2), (16,8), (16,2,4), ((1,2), (0,2), (1,1), (0,1)), ((1,0), (0,0)))
fix_st2 = None
elif self.opts.device == "METAL":
reduce_axes = [self.shape_len-self.upcasted]
upcast_axis = (self.shape_len-self.upcasted+1, self.shape_len-self.upcasted+1, self.shape_len-self.upcasted+1)
fix_st1 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((1,1), (0,1), (1,0), (0,3)), ((0,0), (0,2), (1,3), (1,2)))
fix_st2 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((0,0), (1,1), (1,2), (0,2), (1,0)), ((0,1), (0,3), (1,3)))
elif self.opts.device in {"CUDA", "NV"}:
reduce_axes = [self.shape_len-self.upcasted, self.shape_len-self.upcasted+1]
upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted+2, self.shape_len-self.upcasted+2)
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
fix_st1 = functools.partial(fix_st, (2,2,2,2,2), (8,2,4), (2,2,2,2,2,2),
((1,1), (1,0), (0,2), (0,3), (0,4)), ((1,3), (1,4), (1,2), (0,0), (0,1), (1,5)))
fix_st2 = functools.partial(fix_st, (2,2,2,2,2), (8,2,4), (2,2,2,2,2,2),
((1,1), (1,0), (1,5), (0,0), (0,1)), ((0,4), (0,2), (1,4), (0,3), (1,3), (1,2)))
else:
raise RuntimeError("unsupported device for tensor cores")
assert apply_to_st is None, "double tensor core? not supported"
wmma_sz = [prod(l) for l in tc.thread_local_sizes]
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), self.opts.device, upcast_axis, tuple(reduce_axes))
ret = LazyOp(ReduceOps.WMMA, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), wmma_arg)
new_reduce_axes = tuple(i for i in arg if i not in reduce_axes)
return LazyOp(op.op, (ret,), new_reduce_axes) if len(new_reduce_axes) else ret
if self.group_for_reduces:
start = LazyOp(op.op, tuple(fixup_ast(x) for x in op.src), arg)
local_buffer = MemBuffer(-1, start.dtype, self.sts[-1])
local_store = LazyOp(BufferOps.STORE, (start,), local_buffer)
local_load = LazyOp(BufferOps.LOAD, (local_store,), local_buffer)
return LazyOp(op.op, (local_load,), tuple(range(self.first_reduce, self.first_reduce+self.group_for_reduces)))
else:
arg = op.arg
return LazyOp(op.op, tuple(fixup_ast(x) for x in op.src), arg)
modified_ast = tuple(fixup_ast(x) for x in self.ast)
modified_ast = self.get_optimized_ast()
if DEBUG >= 4:
from tinygrad.engine.graph import print_tree
for mast in modified_ast: print_tree(mast)
self.idxs = []
if self.opts.has_local:
# define indexes
global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0)
@@ -218,10 +124,9 @@ class Lowerer(Kernel):
for a in range(self.first_reduce, self.first_reduce+self.group_for_reduces):
self.ridxs[a] = UOp(UOps.RANGE, dtypes.int32, (UOp.const(dtypes.int32, 0), variable_to_uop(self.full_shape[a])), (1000+a, True))
self.uop_cache: Dict[LazyOp, UOp] = {}
self.uops:UOpGraph = UOpGraph([self.to_uop(x) for x in modified_ast], self.opts)
self.sts, self.bufs = sts_backup, bufs_backup
# maybe graph the uops
if DEBUG >= 5: self.uops.print()
if getenv("GRAPHUOPS"):