mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
merge kernel and optimizer (#2200)
* merge kernel and optimizer * linearize is reentrant * move global/local size * clean up linearizer copy * remove unneeded lin copies * stop linearizing twice * oops, that should be None
This commit is contained in:
@@ -59,9 +59,9 @@ if __name__ == "__main__":
|
||||
# benchmark the programs
|
||||
choices = []
|
||||
for lin in lins:
|
||||
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10, should_copy=False)
|
||||
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
|
||||
gflops = sym_infer(lin.info.flops, {k:k.min for k in vars_from_ast(lin.ast)})*1e-9/tm
|
||||
choices.append((tm, gflops, lin))
|
||||
choices.append((tm, gflops, lin.linearize()))
|
||||
|
||||
# print all kernels
|
||||
if DEBUG >= 1: print(f" kernel {i:2d} {lin.display_name+' '*(37-ansilen(lin.display_name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS")
|
||||
|
||||
@@ -17,7 +17,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
inf, nan = float('inf'), float('nan')
|
||||
from tinygrad.codegen.optimizer import Opt, OptOps
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
|
||||
INNER = 256
|
||||
class PolicyNet:
|
||||
|
||||
@@ -10,7 +10,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
inf, nan = float('inf'), float('nan')
|
||||
from tinygrad.codegen.optimizer import Opt, OptOps
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
|
||||
# more stuff
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin
|
||||
from tinygrad.features.search import actions
|
||||
@@ -17,11 +18,17 @@ def test_rebuild(lin):
|
||||
|
||||
if __name__ == "__main__":
|
||||
ast_strs = load_worlds(False, False, False)
|
||||
random.shuffle(ast_strs)
|
||||
ast_strs = ast_strs[:2000]
|
||||
for ast_str in tqdm(ast_strs):
|
||||
lin = ast_str_to_lin(ast_str)
|
||||
#if not lin.apply_tensor_cores():
|
||||
lin.hand_coded_optimizations()
|
||||
test_rebuild(lin)
|
||||
# confirm linearize can be called twice
|
||||
uops1 = lin.linearize().uops
|
||||
uops2 = lin.linearize().uops
|
||||
assert tuple(uops1) == tuple(uops2), f"uops mismatch {lin.colored_shape()}"
|
||||
|
||||
print(len(tactions), len(actions))
|
||||
print(sorted(list(tactions)))
|
||||
|
||||
@@ -14,7 +14,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
inf, nan = float('inf'), float('nan')
|
||||
from tinygrad.codegen.optimizer import Opt, OptOps
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
|
||||
from extra.optimization.helpers import lin_to_feats, MAX_DIMS
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from tinygrad.helpers import BEAM
|
||||
from tinygrad.helpers import BEAM, Timing
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d
|
||||
|
||||
class TestBeamSearch(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -22,3 +23,12 @@ class TestBeamSearch(unittest.TestCase):
|
||||
a.assign(a+1)
|
||||
actual = a.numpy()
|
||||
np.testing.assert_allclose(actual, desired)
|
||||
|
||||
def test_conv_beam(self):
|
||||
c = Conv2d(3, 16, (3,3))
|
||||
x = Tensor.rand(1,3,32,32)
|
||||
with Timing():
|
||||
c(x).realize()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -2,7 +2,7 @@ import numpy as np
|
||||
import unittest, os
|
||||
|
||||
from tinygrad.codegen.kernel import tensor_cores
|
||||
from tinygrad.codegen.optimizer import Opt, OptOps
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps
|
||||
from tinygrad.ops import Compiled, Device
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
@@ -12,5 +12,5 @@ class TestTimeLinearizer(unittest.TestCase):
|
||||
def test_reasonable_time(self):
|
||||
si = [si for si in Tensor([1,2,3,4]).add(1).lazydata.schedule() if si.ast.op not in LoadOps][0]
|
||||
rawbufs = [Device[Device.DEFAULT].buffer(si.out.st.size(), si.out.dtype)] + [Device[Device.DEFAULT].buffer(x.st.size(), x.dtype) for x in si.inputs]
|
||||
tm = time_linearizer(Linearizer(si.ast), rawbufs, allow_test_size=False, cnt=10, should_copy=False)
|
||||
tm = time_linearizer(Linearizer(si.ast), rawbufs, allow_test_size=False, cnt=10)
|
||||
assert tm > 0 and tm != float('inf')
|
||||
|
||||
@@ -1,13 +1,24 @@
|
||||
from __future__ import annotations
|
||||
from typing import NamedTuple, Optional, List, Tuple, cast, Dict
|
||||
from copy import deepcopy
|
||||
import itertools
|
||||
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, ReduceOps, MemBuffer, BufferOps, Device, Compiled
|
||||
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int, ansilen
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
import os, math, itertools
|
||||
from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
|
||||
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, Device, Compiled
|
||||
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int, ansilen, getenv, prod, DEBUG
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
|
||||
class OptOps(Enum):
|
||||
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto() # noqa: E702
|
||||
def __lt__(self, x:OptOps): return self.value < x.value
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
class Opt:
|
||||
op: OptOps
|
||||
axis: Optional[int] = None
|
||||
amt: Optional[int] = None
|
||||
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorCore:
|
||||
@@ -22,7 +33,6 @@ class TensorCore:
|
||||
arch: Optional[str] = None
|
||||
def __str__(self): return f"tensor_core<{self.device}, {self.dims}, {self.dtype_in}, {self.dtype_out}>"
|
||||
|
||||
# TODO(TC): doesn't belong here!!!
|
||||
tensor_cores: Dict[str, List[TensorCore]] = {
|
||||
"METAL": [
|
||||
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"),
|
||||
@@ -66,16 +76,25 @@ class Kernel:
|
||||
self.reduceop = reduceops[0] if reduceops else None
|
||||
|
||||
# create new shapetrackers inside this kernel, we will permute them
|
||||
self.bufs = [MemBuffer(0, self.info.dtype, ShapeTracker.from_shape(self.info.shape))] + dedup([x.arg for x in self.ast.get_lazyops() if x.op in BufferOps])
|
||||
self.sts: List[ShapeTracker] = [x.st for x in self.bufs]
|
||||
self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = [MemBuffer(0, self.info.dtype, ShapeTracker.from_shape(self.info.shape))] + dedup([x.arg for x in self.ast.get_lazyops() if x.op in BufferOps])
|
||||
|
||||
self.mem_estimate: int = sum(x.dtype.itemsize*x.st.size() for x in self.bufs)
|
||||
# extract things from the buffers
|
||||
self.mem_estimate: int = sum(x.dtype.itemsize*x.st.size() for x in cast(List[Union[MemBuffer, ConstBuffer]], self.bufs))
|
||||
|
||||
# get earlybufs, before the one reduce op
|
||||
self.earlybufs = [x.arg for x in self.reduceop.get_lazyops() if x.op in BufferOps] if self.reduceop else []
|
||||
self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0
|
||||
|
||||
# parameters
|
||||
# create the (permuted) shapetrackers
|
||||
self.sts: List[ShapeTracker] = [x.st for x in cast(List[Union[MemBuffer, ConstBuffer]], self.bufs)]
|
||||
|
||||
# move all reduce axes to the end
|
||||
reduce = list(enumerate(zip(self.full_shape, self.sts[0].shape)))
|
||||
permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n])
|
||||
self.reshape_and_permute(None, permute)
|
||||
|
||||
# parameters for optimization
|
||||
self.applied_opts: List[Opt] = []
|
||||
self.group_for_reduce: List[int] = []
|
||||
self.upcasted: int = 0
|
||||
self.local_dims: int = 0
|
||||
@@ -83,18 +102,39 @@ class Kernel:
|
||||
self.tensor_core: Optional[TensorCore] = None
|
||||
self.dont_use_locals: bool = False
|
||||
|
||||
self.global_size: Optional[List[int]] = None
|
||||
self.local_size: Optional[List[int]] = None
|
||||
# group simplifies
|
||||
self.simplify_ones()
|
||||
self.simplify_merge_adjacent()
|
||||
|
||||
# cache
|
||||
self.applied_opts_cache: Optional[List[Opt]] = None
|
||||
|
||||
def copy(self):
|
||||
return deepcopy(self)
|
||||
ret = type(self).__new__(type(self))
|
||||
|
||||
# base linearizer params
|
||||
ret.opts, ret.ast = self.opts, self.ast
|
||||
|
||||
# things downstream of the AST
|
||||
# NOTE: we copy bufs for local buffers and sts for optimizations
|
||||
ret.info, ret.reduceop, ret.bufs, ret.mem_estimate, ret.earlybufs, ret.full_buf_index, ret.sts = \
|
||||
self.info, self.reduceop, self.bufs[:], self.mem_estimate, self.earlybufs, self.full_buf_index, self.sts[:]
|
||||
|
||||
# parameters for optimizations
|
||||
ret.applied_opts, ret.group_for_reduce, ret.upcasted, ret.local_dims, ret.local_alias, ret.tensor_core, ret.dont_use_locals = \
|
||||
self.applied_opts[:], self.group_for_reduce[:], self.upcasted, self.local_dims, self.local_alias.copy(), self.tensor_core, self.dont_use_locals
|
||||
|
||||
# uncached since linearize didn't run
|
||||
ret.applied_opts_cache = None
|
||||
|
||||
return ret
|
||||
|
||||
@property
|
||||
def membufs(self) -> List[MemBuffer]: return [x for x in self.bufs if isinstance(x, MemBuffer)]
|
||||
|
||||
def has_variable_shape(self) -> bool:
|
||||
for b in self.bufs:
|
||||
if not all_int(b.st.views[-1].shape): return True
|
||||
if not isinstance(b, LocalBuffer) and not all_int(b.st.views[-1].shape): return True
|
||||
return False
|
||||
|
||||
def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()]
|
||||
@@ -170,3 +210,383 @@ class Kernel:
|
||||
print(prefix, f"{i:3d} {str(self.bufs[i]):47s}", st.views)
|
||||
print(self.colored_shape())
|
||||
|
||||
# ******************** base simplifiers ********************
|
||||
|
||||
# apply reshape and permute to all shapetrackers
|
||||
def reshape_and_permute(self, new_shape_fxn, axis):
|
||||
new_sts = []
|
||||
for st in self.sts:
|
||||
if new_shape_fxn is not None: st = st.reshape(tuple(new_shape_fxn(st.shape)))
|
||||
if axis is not None: st = st.permute(tuple(axis))
|
||||
new_sts.append(st)
|
||||
self.sts = new_sts
|
||||
|
||||
# drops the final dimension
|
||||
def upcast(self):
|
||||
assert self.full_shape[-1] != 1, "can't upcast a dimension with size 1"
|
||||
self.upcasted += 1
|
||||
|
||||
# axis : the axis to pull from
|
||||
# amount : the amount to take
|
||||
# top : if you want to pull that amount from the top
|
||||
# insert_before : place to insert the new stuff
|
||||
def shift_to(self, axis, amount, top=False, insert_before=None):
|
||||
if insert_before is None: insert_before = self.shape_len
|
||||
move_axis = axis if top else axis+1
|
||||
if move_axis < insert_before: insert_before += 1
|
||||
self.reshape_and_permute(
|
||||
lambda x: list(x[0:axis]) + (([amount, x[axis]//amount] if top else [x[axis]//amount, amount]) if x[axis] > 1 else [1,1]) + list(x[axis+1:]),
|
||||
[i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis])
|
||||
|
||||
# ******************** complex simplifiers ********************
|
||||
|
||||
def simplify_ones(self) -> bool:
|
||||
# remove places where the shape is all ones
|
||||
# TODO: this should be factored in to multi shape stride
|
||||
if self.shape_len == 0: return False
|
||||
all_ones = [s==1 for s in self.full_shape]
|
||||
self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
|
||||
self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:])
|
||||
self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
|
||||
return any(all_ones)
|
||||
|
||||
def simplify_merge_adjacent(self):
|
||||
if self.shape_len == 0: return
|
||||
shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
|
||||
|
||||
# if it's an image, insert fake strides such that this fusion doesn't happen across image axes
|
||||
if isinstance(self.bufs[0].dtype, ImageDType):
|
||||
base_shape = self.bufs[0].dtype.shape
|
||||
if shape_idx_groups := get_contraction(self.output_shape, base_shape):
|
||||
special_strides: Tuple[int, ...] = tuple()
|
||||
for i,g in enumerate(shape_idx_groups):
|
||||
shape_piece = tuple(self.output_shape[x] for x in g)
|
||||
assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
|
||||
special_strides += strides_for_shape(shape_piece)
|
||||
# adding the fake image shape
|
||||
shapes.append(self.output_shape)
|
||||
strides.append(special_strides)
|
||||
|
||||
# merge dimensions if we can, multi get_shape_strides
|
||||
# TODO: does this always preserve the reduce dimension, NO
|
||||
# TODO: move this into shapetracker, with tests!
|
||||
rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))]
|
||||
for i in range(1, len(shapes[0])):
|
||||
can_merge = []
|
||||
for j in range(len(shapes)):
|
||||
# TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
|
||||
can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0)))
|
||||
# more can merge than this
|
||||
mergeable = all(can_merge) and i != self.first_reduce
|
||||
for j in range(len(shapes)):
|
||||
if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i])
|
||||
else: rets[j].append((shapes[j][i], strides[j][i]))
|
||||
|
||||
# do the reshapes
|
||||
for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
|
||||
|
||||
# ******************** GPU simplifiers ********************
|
||||
def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]:
|
||||
new_shape,dims = list(x), len(x)
|
||||
for i in range(dims):
|
||||
next_idx = (i + 1) % dims
|
||||
while new_shape[i] > max_size[i]:
|
||||
new_shape[i] = new_shape[i] // 2
|
||||
if (new_shape[next_idx] <= max_size[next_idx]):
|
||||
new_shape[next_idx] = new_shape[next_idx] * 2
|
||||
else:
|
||||
next_idx = (next_idx + 1) % dims
|
||||
new_shape[next_idx] = new_shape[next_idx] * 2
|
||||
return tuple(new_shape)
|
||||
|
||||
def limit_dims_to_max(self, global_max: List[int], local_max: List[int]):
|
||||
# Check the global allocation limit, current the global_size will be flipped during codegen
|
||||
# and then padded right with 1s if its length < 3 which makes this part a bit awkward to write
|
||||
global_dims = self.first_reduce-self.local_dims
|
||||
if global_dims > 0:
|
||||
if global_max:
|
||||
tmp = global_max[:global_dims] + (local_max[:self.local_dims] if local_max else [])
|
||||
if max(global_max) < max(self.full_shape[:global_dims]): self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None)
|
||||
assert max(global_max) >= max(self.full_shape[:global_dims]), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}"
|
||||
for i in range(global_dims-1):
|
||||
if self.full_shape[i] > global_max[i]:
|
||||
order = list(range(len(self.full_shape)))
|
||||
order[i], order[global_dims-1] = order[global_dims-1], order[i]
|
||||
self.reshape_and_permute(None, order)
|
||||
if DEBUG >= 3: print("permuted global dim", order, "due to allocation exceeds global limit")
|
||||
|
||||
def alias_buffer(self, i, pattern):
|
||||
assert len(pattern) == len(self.sts[i].shape), f"must include a pattern for each shape {pattern} {self.sts[i].shape}"
|
||||
|
||||
bst = 1
|
||||
real_strides = self.sts[i].real_strides()
|
||||
shp, stride = [(s if p != 0 else 1) for s,p in zip(self.sts[i].shape, pattern)], [0]*len(pattern)
|
||||
for priority in range(1, max(pattern)+1): # priority. 0 is non local and ignored
|
||||
for j,p in enumerate(pattern):
|
||||
if priority == p and real_strides[j] != 0:
|
||||
stride[j] = bst
|
||||
bst *= shp[j]
|
||||
|
||||
self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
|
||||
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size()))
|
||||
if DEBUG >= 4: print("aliasing buffer", self.sts[i])
|
||||
self.local_alias[i] = cast(LocalBuffer, self.bufs[-1])
|
||||
|
||||
# ******************** high level optimizers ********************
|
||||
|
||||
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None):
|
||||
if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op == ReduceOps.SUM and self.opts.device in tensor_cores:
|
||||
for tc in tensor_cores[self.opts.device]:
|
||||
if not((tc.arch is None or tc.arch == os.uname().machine) and isinstance(self.reduceop.src[0], LazyOp)): continue
|
||||
has_cast = tc.dtype_in != tc.dtype_out
|
||||
|
||||
if has_cast and not(isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue
|
||||
mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
|
||||
|
||||
if not(isinstance(mul_op, LazyOp) and mul_op.op == BinaryOps.MUL): continue
|
||||
if not(isinstance(mul_op.src[0], LazyOp) and mul_op.src[0].op == BufferOps.MEM and mul_op.src[0].arg.dtype == tc.dtype_in): continue
|
||||
if not(isinstance(mul_op.src[1], LazyOp) and mul_op.src[1].op == BufferOps.MEM and mul_op.src[1].arg.dtype == tc.dtype_in): continue
|
||||
buf0, buf1 = self.bufs.index(cast(MemBuffer, mul_op.src[0].arg)), self.bufs.index(cast(MemBuffer, mul_op.src[1].arg))
|
||||
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
|
||||
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[0] == 0]
|
||||
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[1] == 0]
|
||||
|
||||
if not(axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%tc.dims[2] == 0 and self.full_shape[self.first_reduce] >= tc.dims[2] and (self.shape_len-self.first_reduce) == 1): continue
|
||||
|
||||
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
|
||||
|
||||
s0, s1 = axis_buf0[-1][0], axis_buf1[-1][0] # TODO: select axis in smart way
|
||||
s0_exists, s1_exists = True, True
|
||||
assert s0 != s1 and self.full_shape[s0]%tc.dims[0] == 0 and self.full_shape[s1]%tc.dims[1] == 0
|
||||
def fix(needed, ax):
|
||||
nonlocal s0, s1, s0_exists, s1_exists
|
||||
if not needed: return
|
||||
if s0_exists and ax == s0:
|
||||
if s1_exists and s0 < s1: s1 -= 1
|
||||
s0_exists = False
|
||||
elif s1_exists and ax == s1:
|
||||
if s0_exists and s1 < s0: s0 -= 1
|
||||
s1_exists = False
|
||||
|
||||
# tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
|
||||
self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]))
|
||||
self.apply_opt(Opt(OptOps.UPCAST, s0 if tc.upcast_dim == 0 else s1, (tc.dims[0]*tc.dims[2])//prod([a[1] for a in tc.threads])))
|
||||
for (tc_dim, tc_amt) in tc.threads:
|
||||
fix(self.apply_opt(Opt(OptOps.LASTLOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1)
|
||||
|
||||
# assert tensor core and prevent extra_opts from altering the key shape structure
|
||||
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
|
||||
|
||||
if extra_opts is not None:
|
||||
for opt in extra_opts:
|
||||
self.apply_opt(opt)
|
||||
else:
|
||||
# hand-coded TC opts
|
||||
if s1_exists:
|
||||
s1_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s1]%upc == 0][0]
|
||||
if s1_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s1, s1_div)), s1)
|
||||
if s0_exists:
|
||||
s0_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s0]%upc == 0][0]
|
||||
if s0_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s0, s0_div)), s0)
|
||||
if self.tensor_core and s0_exists:
|
||||
for upc in [4,2]:
|
||||
if self.full_shape[s0] % upc == 0:
|
||||
self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc))
|
||||
break
|
||||
|
||||
# alias buffer
|
||||
alias_pattern = [0]*(self.global_dims+(self.local_dims-len(tc.threads))) + [2]*(len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2)
|
||||
self.alias_buffer(buf0, alias_pattern)
|
||||
self.alias_buffer(buf1, alias_pattern)
|
||||
return True
|
||||
return False
|
||||
|
||||
def apply_opt(self, opt:Opt):
|
||||
assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.LASTLOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals"
|
||||
self.applied_opts.append(opt)
|
||||
if opt.axis is not None:
|
||||
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP else 0))
|
||||
else:
|
||||
axis = -1
|
||||
if opt.amt is not None:
|
||||
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
|
||||
assert self.full_shape[axis] % amt == 0, "no longer valid shift"
|
||||
assert isinstance(amt, int) and amt != 1, "shift of amt 1 or Node is meaningless"
|
||||
else:
|
||||
amt = -1
|
||||
if opt.op == OptOps.LOCAL: # cyan
|
||||
assert axis < self.first_reduce, "can't local a reduce"
|
||||
assert not(self.tensor_core), "can't local with tensor cores"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce)
|
||||
self.local_dims += 1
|
||||
elif opt.op == OptOps.LASTLOCAL: # cyan
|
||||
assert axis < self.first_reduce, "can't local a reduce"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims)
|
||||
self.local_dims += 1
|
||||
elif opt.op == OptOps.GROUP: # green
|
||||
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
|
||||
assert not(self.tensor_core), "can't group with tensor cores"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(amt)
|
||||
elif opt.op == OptOps.GROUPTOP: # green
|
||||
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
|
||||
assert not(self.tensor_core), "can't group with tensor cores"
|
||||
self.shift_to(axis, amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(amt)
|
||||
elif opt.op == OptOps.UNROLL: # purple
|
||||
assert axis < self.shape_len-self.upcasted, "can't upcasted already upcasted"
|
||||
assert amt <= 32, "don't unroll more than 32"
|
||||
self.shift_to(axis, amt, insert_before=None)
|
||||
self.upcast()
|
||||
elif opt.op == OptOps.UPCAST: # yellow
|
||||
assert axis < self.first_reduce, "upcast is for non-reduce"
|
||||
assert amt <= 8, "don't upcast more than 8"
|
||||
self.shift_to(axis, amt, insert_before=None)
|
||||
self.upcast()
|
||||
elif opt.op == OptOps.UPCASTMID: # white
|
||||
assert self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce"
|
||||
axes = self.sts[0].unit_stride_axes()
|
||||
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
||||
assert axes[0] == axis, "wrong axis"
|
||||
assert amt == 4, "don't upcast mid anything but 4"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(amt)
|
||||
elif opt.op == OptOps.NOLOCALS:
|
||||
assert self.local_dims == 0 and len(self.group_for_reduce) == 0, "can't have no locals with locals"
|
||||
assert not self.dont_use_locals, "already not using locals"
|
||||
self.dont_use_locals = True
|
||||
return self.simplify_ones()
|
||||
|
||||
def required_optimizations(self, early_only=False):
|
||||
for buf_index,buf in enumerate(self.bufs):
|
||||
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
|
||||
if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.__class__ is ImageDType:
|
||||
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
|
||||
if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
|
||||
if unit_stride_axes_mul_4[0] < self.first_reduce:
|
||||
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
||||
else:
|
||||
self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
|
||||
|
||||
def hand_coded_optimizations(self):
|
||||
# if there's images in the earlybufs, we have to make an axis the 4 loading one
|
||||
self.required_optimizations(early_only=True)
|
||||
|
||||
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
|
||||
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
|
||||
if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
|
||||
self.reduceop and self.reduceop.op == ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \
|
||||
isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == BinaryOps.MUL and \
|
||||
self.reduceop.src[0].src[0].op == BufferOps.MEM and self.reduceop.src[0].src[1].op == BufferOps.MEM:
|
||||
buf0 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[0]).arg)
|
||||
buf1 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[1]).arg)
|
||||
buf0_strides = self.sts[buf0].real_strides()
|
||||
buf1_strides = self.sts[buf1].real_strides()
|
||||
def has_expanded_axis(s, st): return any(x > 1 and y == 0 for x,y in zip(s,st))
|
||||
if buf0_strides[self.first_reduce] == 1 and not (has_expanded_axis(self.sts[buf0].shape, buf0_strides) and has_expanded_axis(self.sts[buf1].shape, buf1_strides)):
|
||||
for global_idx in range(self.global_dims):
|
||||
if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
|
||||
if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}")
|
||||
if MV_THREADS_PER_ROW > 1:
|
||||
self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
|
||||
if MV_BLOCKSIZE > 1:
|
||||
self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
|
||||
if MV_ROWS_PER_THREAD > 1:
|
||||
self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
|
||||
return
|
||||
|
||||
if self.opts.has_local and self.opts.has_shared and all(isinstance(s, int) for s in self.sts[0].shape[:self.first_reduce]):
|
||||
# are we grouping? (requires local shape support)
|
||||
if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
|
||||
# TODO: use 1024 if it's allowed in a smarter way
|
||||
for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
|
||||
if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
|
||||
self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
|
||||
break
|
||||
|
||||
# are we upcasting in mid reduce? (only for images)
|
||||
if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1:
|
||||
axes = self.sts[0].unit_stride_axes()
|
||||
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
||||
if self.sts[0].shape[axes[0]]%4 == 0:
|
||||
self.apply_opt(Opt(OptOps.UPCASTMID, axes[0], 4))
|
||||
|
||||
# now do everything required
|
||||
self.required_optimizations()
|
||||
|
||||
# no more opt if we are grouping
|
||||
if self.group_for_reduce: return
|
||||
|
||||
# **** below this line need to be optional and benchmarked ****
|
||||
|
||||
# TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx)
|
||||
# to trigger the above bug, remove prod(self.full_shape[self.shape_len - self.upcasted:]) from the below
|
||||
# expression and run test/test_ops.py with IMAGE=2
|
||||
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
|
||||
# this can be made much smarter
|
||||
to_upcast: List[int] = []
|
||||
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
|
||||
for axis in range(self.first_reduce):
|
||||
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
|
||||
# for now skip upcasting here if there is a symbolic axis
|
||||
if isinstance(self.full_shape[axis], int) and self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and \
|
||||
prod(self.full_shape[self.shape_len - self.upcasted:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7:
|
||||
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
|
||||
to_upcast.append(axis)
|
||||
for axis in to_upcast[::-1]:
|
||||
self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
|
||||
|
||||
# potentially do more upcasts of non reduce axes based on a heuristic
|
||||
upcasted_axis = set()
|
||||
while prod(self.sts[0].shape[:self.first_reduce]) >= 1024:
|
||||
xb_choices = []
|
||||
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
|
||||
# if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
||||
if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)):
|
||||
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount))
|
||||
if xb_choices:
|
||||
xb_choices = sorted(xb_choices)
|
||||
if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}")
|
||||
self.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3]))
|
||||
upcasted_axis.add(xb_choices[0][2])
|
||||
else:
|
||||
break
|
||||
|
||||
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
|
||||
if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64):
|
||||
if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
|
||||
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
|
||||
# if it's small, upcast a second reduce dimension too
|
||||
if self.first_reduce < (self.shape_len-self.upcasted) and s <= 3 and (s2:=self.full_unupcasted_shape[-1]) <= 3 and isinstance(s2, int):
|
||||
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
|
||||
else:
|
||||
for splits in [4]:
|
||||
if self.full_unupcasted_shape[-1]%splits == 0:
|
||||
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, splits))
|
||||
break
|
||||
|
||||
# if nothing at all is upcasted and it's easy to, do an upcast
|
||||
# TODO: this is breaking the tests
|
||||
for splits in [4]:
|
||||
if self.upcasted == 0 and self.full_unupcasted_shape and self.full_unupcasted_shape[-1] % splits == 0:
|
||||
self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, splits))
|
||||
|
||||
# **** local groups ****
|
||||
|
||||
if self.opts.has_local:
|
||||
if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduce:
|
||||
self.apply_opt(Opt(OptOps.NOLOCALS))
|
||||
else:
|
||||
# prioritize making expand axes local
|
||||
local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))]
|
||||
to_local: List[Tuple[int, int]] = []
|
||||
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
|
||||
local_size = prod(sz for _, sz in to_local)
|
||||
local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None)
|
||||
if local_sz is not None: to_local.append((axis, local_sz))
|
||||
deleted_shape = 0
|
||||
for axis, local_sz in sorted(to_local[:3]):
|
||||
axis = axis - deleted_shape
|
||||
will_delete_shape = local_sz == self.full_shape[axis]
|
||||
self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
|
||||
if will_delete_shape: deleted_shape += 1
|
||||
|
||||
@@ -9,8 +9,7 @@ from tinygrad.ops import LazyOp, UnaryOps, ConstBuffer, MemBuffer, BufferOps
|
||||
from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, VariableOrNum, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, sym_rename
|
||||
from tinygrad.codegen.optimizer import OptimizedKernel
|
||||
from tinygrad.codegen.kernel import LocalBuffer
|
||||
from tinygrad.codegen.kernel import LocalBuffer, Kernel
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
from tinygrad.features.image import to_image_idx
|
||||
|
||||
@@ -45,7 +44,7 @@ def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
|
||||
local_idxs = local_idxs[0:maxdim-1] + nli[::-1]
|
||||
return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)]
|
||||
|
||||
class Linearizer(OptimizedKernel):
|
||||
class Linearizer(Kernel):
|
||||
def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype=dtypes.int32):
|
||||
render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
|
||||
return self.uop(UOps.ALU, dtype, (a, render_b), op)
|
||||
@@ -62,7 +61,8 @@ class Linearizer(OptimizedKernel):
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
|
||||
def global_load(self, i:int, idxs:Sequence[Node], acc=None) -> List[UOp]:
|
||||
const = self.bufs[i].val if isinstance(self.bufs[i], ConstBuffer) else acc
|
||||
buf = self.bufs[i]
|
||||
const = buf.val if isinstance(buf, ConstBuffer) else acc
|
||||
|
||||
def rename_var(v: VariableOrNum, expr: str): return v if isinstance(v, NumNode) else Variable(expr, v.min, v.max)
|
||||
|
||||
@@ -84,10 +84,10 @@ class Linearizer(OptimizedKernel):
|
||||
e_idxs, e_valids = g_idx.expand(expand_vars), g_valid.expand(expand_vars)
|
||||
|
||||
ret = []
|
||||
invalid_value = 0 if dtypes.is_int(self.bufs[i].dtype) else 0.0
|
||||
invalid_value = 0 if dtypes.is_int(buf.dtype) else 0.0
|
||||
for idx, valid, rep_idx in zip(e_idxs, e_valids, Node.iter_idxs(expand_vars)):
|
||||
this_const, idx, valid = (invalid_value, Variable.num(0), Variable.num(1)) if valid.max == 0 else (const, idx, valid)
|
||||
key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else (self.bufs[i].idx if isinstance(self.bufs[i], MemBuffer) else self.bufs[i].name)}{idx.render()}{valid.render()}"
|
||||
key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}"
|
||||
if key not in self.load_cache:
|
||||
if acc is not None:
|
||||
assert valid.min == 1
|
||||
@@ -100,8 +100,8 @@ class Linearizer(OptimizedKernel):
|
||||
else:
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
if isinstance(self.bufs[i].dtype, ImageDType):
|
||||
idx, valid = to_image_idx(self.bufs[i].dtype.shape, idx, valid)
|
||||
if isinstance(buf.dtype, ImageDType):
|
||||
idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
||||
rendered_idx = self.uop(UOps.CAST, dtypes._int2, (idx[0].render(self.render_ops, self), idx[1].render(self.render_ops, self)))
|
||||
else:
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
@@ -115,6 +115,7 @@ class Linearizer(OptimizedKernel):
|
||||
return ret
|
||||
|
||||
def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> None:
|
||||
buf = self.bufs[i]
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
|
||||
@@ -140,8 +141,8 @@ class Linearizer(OptimizedKernel):
|
||||
|
||||
for idx, var in store_offset.items():
|
||||
idx, valid = self.sts[i].expr_idxs(idx)
|
||||
if isinstance(self.bufs[i].dtype, ImageDType):
|
||||
idx, valid = to_image_idx(self.bufs[i].dtype.shape, idx, valid)
|
||||
if isinstance(buf.dtype, ImageDType):
|
||||
idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
||||
rendered_idx = self.uop(UOps.CAST, dtypes._int2, tuple(x.render(self.render_ops, self) for x in idx))
|
||||
else:
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
@@ -149,6 +150,12 @@ class Linearizer(OptimizedKernel):
|
||||
|
||||
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
|
||||
def linearize(self):
|
||||
# no new opts and we already ran? skip relinearizing
|
||||
if self.applied_opts == self.applied_opts_cache: return self
|
||||
|
||||
# save backups
|
||||
sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduce[:], self.upcasted
|
||||
|
||||
# global uop cache
|
||||
self.saved_exprs: Dict[Tuple, UOp] = dict()
|
||||
|
||||
@@ -207,6 +214,9 @@ class Linearizer(OptimizedKernel):
|
||||
loop_uop = self.loop_uops[x.expr]
|
||||
if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, (loop_uop,))
|
||||
|
||||
# set global/local size
|
||||
self.global_size: Optional[List[int]] = None
|
||||
self.local_size: Optional[List[int]] = None
|
||||
if self.dont_use_locals:
|
||||
self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)})
|
||||
@@ -338,10 +348,10 @@ class Linearizer(OptimizedKernel):
|
||||
render_loop(end_local_idxs)
|
||||
|
||||
# load localbufs
|
||||
loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs)
|
||||
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs)
|
||||
|
||||
# there's no AST here (and there's no shape for the reduce LazyOp)
|
||||
self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True) # type: ignore
|
||||
self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True) # type: ignore
|
||||
|
||||
# end the late reduce loop
|
||||
end_loop(end_local_idxs)
|
||||
@@ -372,6 +382,11 @@ class Linearizer(OptimizedKernel):
|
||||
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
|
||||
self.uops = nu
|
||||
|
||||
# restore backups
|
||||
self.sts, self.group_for_reduce, self.upcasted = sts_backup, gfr_backup, upc_backup
|
||||
|
||||
# set cache and return
|
||||
self.applied_opts_cache = self.applied_opts[:]
|
||||
return self
|
||||
|
||||
def uop(self, uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None, cachable=True) -> UOp:
|
||||
|
||||
@@ -1,417 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from typing import Tuple, List, cast, Optional
|
||||
from dataclasses import dataclass
|
||||
import itertools, math, os
|
||||
from tinygrad.helpers import DEBUG, prod, getenv, ImageDType
|
||||
from tinygrad.ops import ReduceOps, BinaryOps, UnaryOps, LazyOp, BufferOps
|
||||
from tinygrad.codegen.kernel import Kernel, LocalBuffer, LinearizerOptions, tensor_cores
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from enum import Enum, auto
|
||||
|
||||
class OptOps(Enum):
|
||||
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto() # noqa: E702
|
||||
def __lt__(self, x:OptOps): return self.value < x.value
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
class Opt:
|
||||
op: OptOps
|
||||
axis: Optional[int] = None
|
||||
amt: Optional[int] = None
|
||||
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})"
|
||||
|
||||
class OptimizedKernel(Kernel):
|
||||
def __init__(self, ast:LazyOp, opts:Optional[LinearizerOptions]=None):
|
||||
super().__init__(ast, opts)
|
||||
|
||||
# move all reduce axes to the end
|
||||
reduce = list(enumerate(zip(self.full_shape, self.sts[0].shape)))
|
||||
permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n])
|
||||
self.reshape_and_permute(None, permute)
|
||||
|
||||
# group simplifies
|
||||
self.simplify_ones()
|
||||
self.simplify_merge_adjacent()
|
||||
|
||||
self.applied_opts: List[Opt] = []
|
||||
|
||||
# ******************** base simplifiers ********************
|
||||
|
||||
# apply reshape and permute to all shapetrackers
|
||||
def reshape_and_permute(self, new_shape_fxn, axis):
|
||||
new_sts = []
|
||||
for st in self.sts:
|
||||
if new_shape_fxn is not None: st = st.reshape(tuple(new_shape_fxn(st.shape)))
|
||||
if axis is not None: st = st.permute(tuple(axis))
|
||||
new_sts.append(st)
|
||||
self.sts = new_sts
|
||||
|
||||
# drops the final dimension
|
||||
def upcast(self):
|
||||
assert self.full_shape[-1] != 1, "can't upcast a dimension with size 1"
|
||||
self.upcasted += 1
|
||||
|
||||
# axis : the axis to pull from
|
||||
# amount : the amount to take
|
||||
# top : if you want to pull that amount from the top
|
||||
# insert_before : place to insert the new stuff
|
||||
def shift_to(self, axis, amount, top=False, insert_before=None):
|
||||
if insert_before is None: insert_before = self.shape_len
|
||||
move_axis = axis if top else axis+1
|
||||
if move_axis < insert_before: insert_before += 1
|
||||
self.reshape_and_permute(
|
||||
lambda x: list(x[0:axis]) + (([amount, x[axis]//amount] if top else [x[axis]//amount, amount]) if x[axis] > 1 else [1,1]) + list(x[axis+1:]),
|
||||
[i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis])
|
||||
|
||||
# ******************** complex simplifiers ********************
|
||||
|
||||
def simplify_ones(self) -> bool:
|
||||
# remove places where the shape is all ones
|
||||
# TODO: this should be factored in to multi shape stride
|
||||
if self.shape_len == 0: return False
|
||||
all_ones = [s==1 for s in self.full_shape]
|
||||
self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
|
||||
self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:])
|
||||
self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
|
||||
return any(all_ones)
|
||||
|
||||
def simplify_merge_adjacent(self):
|
||||
if self.shape_len == 0: return
|
||||
shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
|
||||
|
||||
# if it's an image, insert fake strides such that this fusion doesn't happen across image axes
|
||||
if self.bufs[0].dtype.name.startswith('image'):
|
||||
base_shape = self.bufs[0].dtype.shape
|
||||
if shape_idx_groups := get_contraction(self.output_shape, base_shape):
|
||||
special_strides: Tuple[int, ...] = tuple()
|
||||
for i,g in enumerate(shape_idx_groups):
|
||||
shape_piece = tuple(self.output_shape[x] for x in g)
|
||||
assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
|
||||
special_strides += strides_for_shape(shape_piece)
|
||||
# adding the fake image shape
|
||||
shapes.append(self.output_shape)
|
||||
strides.append(special_strides)
|
||||
|
||||
# merge dimensions if we can, multi get_shape_strides
|
||||
# TODO: does this always preserve the reduce dimension, NO
|
||||
# TODO: move this into shapetracker, with tests!
|
||||
rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))]
|
||||
for i in range(1, len(shapes[0])):
|
||||
can_merge = []
|
||||
for j in range(len(shapes)):
|
||||
# TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
|
||||
can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0)))
|
||||
# more can merge than this
|
||||
mergeable = all(can_merge) and i != self.first_reduce
|
||||
for j in range(len(shapes)):
|
||||
if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i])
|
||||
else: rets[j].append((shapes[j][i], strides[j][i]))
|
||||
|
||||
# do the reshapes
|
||||
for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
|
||||
|
||||
# ******************** GPU simplifiers ********************
|
||||
def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]:
|
||||
new_shape,dims = list(x), len(x)
|
||||
for i in range(dims):
|
||||
next_idx = (i + 1) % dims
|
||||
while new_shape[i] > max_size[i]:
|
||||
new_shape[i] = new_shape[i] // 2
|
||||
if (new_shape[next_idx] <= max_size[next_idx]):
|
||||
new_shape[next_idx] = new_shape[next_idx] * 2
|
||||
else:
|
||||
next_idx = (next_idx + 1) % dims
|
||||
new_shape[next_idx] = new_shape[next_idx] * 2
|
||||
return tuple(new_shape)
|
||||
|
||||
def limit_dims_to_max(self, global_max: List[int], local_max: List[int]):
|
||||
# Check the global allocation limit, current the global_size will be flipped during codegen
|
||||
# and then padded right with 1s if its length < 3 which makes this part a bit awkward to write
|
||||
global_dims = self.first_reduce-self.local_dims
|
||||
if global_dims > 0:
|
||||
if global_max:
|
||||
tmp = global_max[:global_dims] + (local_max[:self.local_dims] if local_max else [])
|
||||
if max(global_max) < max(self.full_shape[:global_dims]): self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None)
|
||||
assert max(global_max) >= max(self.full_shape[:global_dims]), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}"
|
||||
for i in range(global_dims-1):
|
||||
if self.full_shape[i] > global_max[i]:
|
||||
order = list(range(len(self.full_shape)))
|
||||
order[i], order[global_dims-1] = order[global_dims-1], order[i]
|
||||
self.reshape_and_permute(None, order)
|
||||
if DEBUG >= 3: print("permuted global dim", order, "due to allocation exceeds global limit")
|
||||
|
||||
def alias_buffer(self, i, pattern):
|
||||
assert len(pattern) == len(self.sts[i].shape), f"must include a pattern for each shape {pattern} {self.sts[i].shape}"
|
||||
|
||||
bst = 1
|
||||
real_strides = self.sts[i].real_strides()
|
||||
shp, stride = [(s if p != 0 else 1) for s,p in zip(self.sts[i].shape, pattern)], [0]*len(pattern)
|
||||
for priority in range(1, max(pattern)+1): # priority. 0 is non local and ignored
|
||||
for j,p in enumerate(pattern):
|
||||
if priority == p and real_strides[j] != 0:
|
||||
stride[j] = bst
|
||||
bst *= shp[j]
|
||||
|
||||
self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
|
||||
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size()))
|
||||
if DEBUG >= 4: print("aliasing buffer", self.sts[i])
|
||||
self.local_alias[i] = self.bufs[-1]
|
||||
|
||||
# ******************** high level optimizers ********************
|
||||
|
||||
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None):
|
||||
if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op == ReduceOps.SUM and self.opts.device in tensor_cores:
|
||||
for tc in tensor_cores[self.opts.device]:
|
||||
if not((tc.arch is None or tc.arch == os.uname().machine) and isinstance(self.reduceop.src[0], LazyOp)): continue
|
||||
has_cast = tc.dtype_in != tc.dtype_out
|
||||
|
||||
if has_cast and not(isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue
|
||||
mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
|
||||
|
||||
if not(isinstance(mul_op, LazyOp) and mul_op.op == BinaryOps.MUL): continue
|
||||
if not(isinstance(mul_op.src[0], LazyOp) and mul_op.src[0].op == BufferOps.MEM and mul_op.src[0].arg.dtype == tc.dtype_in): continue
|
||||
if not(isinstance(mul_op.src[1], LazyOp) and mul_op.src[1].op == BufferOps.MEM and mul_op.src[1].arg.dtype == tc.dtype_in): continue
|
||||
buf0, buf1 = self.bufs.index(cast(LazyOp, mul_op.src[0].arg)), self.bufs.index(cast(LazyOp, mul_op.src[1].arg))
|
||||
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
|
||||
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[0] == 0]
|
||||
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[1] == 0]
|
||||
|
||||
if not(axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%tc.dims[2] == 0 and self.full_shape[self.first_reduce] >= tc.dims[2] and (self.shape_len-self.first_reduce) == 1): continue
|
||||
|
||||
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
|
||||
|
||||
s0, s1 = axis_buf0[-1][0], axis_buf1[-1][0] # TODO: select axis in smart way
|
||||
s0_exists, s1_exists = True, True
|
||||
assert s0 != s1 and self.full_shape[s0]%tc.dims[0] == 0 and self.full_shape[s1]%tc.dims[1] == 0
|
||||
def fix(needed, ax):
|
||||
nonlocal s0, s1, s0_exists, s1_exists
|
||||
if not needed: return
|
||||
if s0_exists and ax == s0:
|
||||
if s1_exists and s0 < s1: s1 -= 1
|
||||
s0_exists = False
|
||||
elif s1_exists and ax == s1:
|
||||
if s0_exists and s1 < s0: s0 -= 1
|
||||
s1_exists = False
|
||||
|
||||
# tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
|
||||
self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]))
|
||||
self.apply_opt(Opt(OptOps.UPCAST, s0 if tc.upcast_dim == 0 else s1, (tc.dims[0]*tc.dims[2])//prod([a[1] for a in tc.threads])))
|
||||
for (tc_dim, tc_amt) in tc.threads:
|
||||
fix(self.apply_opt(Opt(OptOps.LASTLOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1)
|
||||
|
||||
# assert tensor core and prevent extra_opts from altering the key shape structure
|
||||
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
|
||||
|
||||
if extra_opts is not None:
|
||||
for opt in extra_opts:
|
||||
self.apply_opt(opt)
|
||||
else:
|
||||
# hand-coded TC opts
|
||||
if s1_exists:
|
||||
s1_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s1]%upc == 0][0]
|
||||
if s1_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s1, s1_div)), s1)
|
||||
if s0_exists:
|
||||
s0_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s0]%upc == 0][0]
|
||||
if s0_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s0, s0_div)), s0)
|
||||
if self.tensor_core and s0_exists:
|
||||
for upc in [4,2]:
|
||||
if self.full_shape[s0] % upc == 0:
|
||||
self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc))
|
||||
break
|
||||
|
||||
# alias buffer
|
||||
alias_pattern = [0]*(self.global_dims+(self.local_dims-len(tc.threads))) + [2]*(len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2)
|
||||
self.alias_buffer(buf0, alias_pattern)
|
||||
self.alias_buffer(buf1, alias_pattern)
|
||||
return True
|
||||
return False
|
||||
|
||||
def apply_opt(self, opt:Opt):
|
||||
assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.LASTLOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals"
|
||||
self.applied_opts.append(opt)
|
||||
if opt.axis is not None:
|
||||
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP else 0))
|
||||
else:
|
||||
axis = -1
|
||||
if opt.amt is not None:
|
||||
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
|
||||
assert self.full_shape[axis] % amt == 0, "no longer valid shift"
|
||||
assert isinstance(amt, int) and amt != 1, "shift of amt 1 or Node is meaningless"
|
||||
else:
|
||||
amt = -1
|
||||
if opt.op == OptOps.LOCAL: # cyan
|
||||
assert axis < self.first_reduce, "can't local a reduce"
|
||||
assert not(self.tensor_core), "can't local with tensor cores"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce)
|
||||
self.local_dims += 1
|
||||
elif opt.op == OptOps.LASTLOCAL: # cyan
|
||||
assert axis < self.first_reduce, "can't local a reduce"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims)
|
||||
self.local_dims += 1
|
||||
elif opt.op == OptOps.GROUP: # green
|
||||
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
|
||||
assert not(self.tensor_core), "can't group with tensor cores"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(amt)
|
||||
elif opt.op == OptOps.GROUPTOP: # green
|
||||
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
|
||||
assert not(self.tensor_core), "can't group with tensor cores"
|
||||
self.shift_to(axis, amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(amt)
|
||||
elif opt.op == OptOps.UNROLL: # purple
|
||||
assert axis < self.shape_len-self.upcasted, "can't upcasted already upcasted"
|
||||
assert amt <= 32, "don't unroll more than 32"
|
||||
self.shift_to(axis, amt, insert_before=None)
|
||||
self.upcast()
|
||||
elif opt.op == OptOps.UPCAST: # yellow
|
||||
assert axis < self.first_reduce, "upcast is for non-reduce"
|
||||
assert amt <= 8, "don't upcast more than 8"
|
||||
self.shift_to(axis, amt, insert_before=None)
|
||||
self.upcast()
|
||||
elif opt.op == OptOps.UPCASTMID: # white
|
||||
assert self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce"
|
||||
axes = self.sts[0].unit_stride_axes()
|
||||
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
||||
assert axes[0] == axis, "wrong axis"
|
||||
assert amt == 4, "don't upcast mid anything but 4"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(amt)
|
||||
elif opt.op == OptOps.NOLOCALS:
|
||||
assert self.local_dims == 0 and len(self.group_for_reduce) == 0, "can't have no locals with locals"
|
||||
assert not self.dont_use_locals, "already not using locals"
|
||||
self.dont_use_locals = True
|
||||
return self.simplify_ones()
|
||||
|
||||
def required_optimizations(self, early_only=False):
|
||||
for buf_index,buf in enumerate(self.bufs):
|
||||
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
|
||||
if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.__class__ is ImageDType:
|
||||
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
|
||||
if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
|
||||
if unit_stride_axes_mul_4[0] < self.first_reduce:
|
||||
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
||||
else:
|
||||
self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
|
||||
|
||||
def hand_coded_optimizations(self):
|
||||
# if there's images in the earlybufs, we have to make an axis the 4 loading one
|
||||
self.required_optimizations(early_only=True)
|
||||
|
||||
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
|
||||
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
|
||||
if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
|
||||
self.reduceop and self.reduceop.op == ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \
|
||||
isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == BinaryOps.MUL and \
|
||||
self.reduceop.src[0].src[0].op == BufferOps.MEM and self.reduceop.src[0].src[1].op == BufferOps.MEM:
|
||||
buf0 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[0]).arg)
|
||||
buf1 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[1]).arg)
|
||||
buf0_strides = self.sts[buf0].real_strides()
|
||||
buf1_strides = self.sts[buf1].real_strides()
|
||||
def has_expanded_axis(s, st): return any(x > 1 and y == 0 for x,y in zip(s,st))
|
||||
if buf0_strides[self.first_reduce] == 1 and not (has_expanded_axis(self.sts[buf0].shape, buf0_strides) and has_expanded_axis(self.sts[buf1].shape, buf1_strides)):
|
||||
for global_idx in range(self.global_dims):
|
||||
if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
|
||||
if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}")
|
||||
if MV_THREADS_PER_ROW > 1:
|
||||
self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
|
||||
if MV_BLOCKSIZE > 1:
|
||||
self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
|
||||
if MV_ROWS_PER_THREAD > 1:
|
||||
self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
|
||||
return
|
||||
|
||||
if self.opts.has_local and self.opts.has_shared and all(isinstance(s, int) for s in self.sts[0].shape[:self.first_reduce]):
|
||||
# are we grouping? (requires local shape support)
|
||||
if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
|
||||
# TODO: use 1024 if it's allowed in a smarter way
|
||||
for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
|
||||
if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
|
||||
self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
|
||||
break
|
||||
|
||||
# are we upcasting in mid reduce? (only for images)
|
||||
if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1:
|
||||
axes = self.sts[0].unit_stride_axes()
|
||||
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
||||
if self.sts[0].shape[axes[0]]%4 == 0:
|
||||
self.apply_opt(Opt(OptOps.UPCASTMID, axes[0], 4))
|
||||
|
||||
# now do everything required
|
||||
self.required_optimizations()
|
||||
|
||||
# no more opt if we are grouping
|
||||
if self.group_for_reduce: return
|
||||
|
||||
# **** below this line need to be optional and benchmarked ****
|
||||
|
||||
# TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx)
|
||||
# to trigger the above bug, remove prod(self.full_shape[self.shape_len - self.upcasted:]) from the below
|
||||
# expression and run test/test_ops.py with IMAGE=2
|
||||
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
|
||||
# this can be made much smarter
|
||||
to_upcast: List[int] = []
|
||||
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
|
||||
for axis in range(self.first_reduce):
|
||||
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
|
||||
# for now skip upcasting here if there is a symbolic axis
|
||||
if isinstance(self.full_shape[axis], int) and self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and \
|
||||
prod(self.full_shape[self.shape_len - self.upcasted:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7:
|
||||
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
|
||||
to_upcast.append(axis)
|
||||
for axis in to_upcast[::-1]:
|
||||
self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
|
||||
|
||||
# potentially do more upcasts of non reduce axes based on a heuristic
|
||||
upcasted_axis = set()
|
||||
while prod(self.sts[0].shape[:self.first_reduce]) >= 1024:
|
||||
xb_choices = []
|
||||
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
|
||||
# if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
||||
if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)):
|
||||
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount))
|
||||
if xb_choices:
|
||||
xb_choices = sorted(xb_choices)
|
||||
if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}")
|
||||
self.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3]))
|
||||
upcasted_axis.add(xb_choices[0][2])
|
||||
else:
|
||||
break
|
||||
|
||||
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
|
||||
if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64):
|
||||
if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
|
||||
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
|
||||
# if it's small, upcast a second reduce dimension too
|
||||
if self.first_reduce < (self.shape_len-self.upcasted) and s <= 3 and (s2:=self.full_unupcasted_shape[-1]) <= 3 and isinstance(s2, int):
|
||||
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
|
||||
else:
|
||||
for splits in [4]:
|
||||
if self.full_unupcasted_shape[-1]%splits == 0:
|
||||
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, splits))
|
||||
break
|
||||
|
||||
# if nothing at all is upcasted and it's easy to, do an upcast
|
||||
# TODO: this is breaking the tests
|
||||
for splits in [4]:
|
||||
if self.upcasted == 0 and self.full_unupcasted_shape and self.full_unupcasted_shape[-1] % splits == 0:
|
||||
self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, splits))
|
||||
|
||||
# **** local groups ****
|
||||
|
||||
if self.opts.has_local:
|
||||
if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduce:
|
||||
self.apply_opt(Opt(OptOps.NOLOCALS))
|
||||
else:
|
||||
# prioritize making expand axes local
|
||||
local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))]
|
||||
to_local: List[Tuple[int, int]] = []
|
||||
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
|
||||
local_size = prod(sz for _, sz in to_local)
|
||||
local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None)
|
||||
if local_sz is not None: to_local.append((axis, local_sz))
|
||||
deleted_shape = 0
|
||||
for axis, local_sz in sorted(to_local[:3]):
|
||||
axis = axis - deleted_shape
|
||||
will_delete_shape = local_sz == self.full_shape[axis]
|
||||
self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
|
||||
if will_delete_shape: deleted_shape += 1
|
||||
@@ -7,7 +7,7 @@ from tinygrad.runtime.lib import RawBuffer
|
||||
from collections import defaultdict
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
from tinygrad.codegen.optimizer import Opt, OptOps
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
actions = flatten([[Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,7]] for axis in range(6)])
|
||||
actions += flatten([[Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4]] for axis in range(4)])
|
||||
actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29]] for axis in range(5)])
|
||||
@@ -20,10 +20,9 @@ actions += [
|
||||
]
|
||||
|
||||
# returns time in seconds
|
||||
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, should_copy=True, disable_cache=False, clear_l2=False) -> float:
|
||||
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float:
|
||||
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size}
|
||||
if should_copy and not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
||||
if should_copy: lin = lin.copy() # TODO: remove the need for this
|
||||
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
||||
var_vals = {k:k.min for k in vars_from_ast(lin.ast)}
|
||||
try:
|
||||
lin.linearize()
|
||||
@@ -75,7 +74,7 @@ def bufs_from_lin(lin:Linearizer) -> List[RawBuffer]:
|
||||
|
||||
# get dictionary of all possible actions
|
||||
def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Linearizer]:
|
||||
acted_lins = {0:lin.copy()} if include_0 else {}
|
||||
acted_lins = {0:lin} if include_0 else {}
|
||||
for i,a in enumerate(actions):
|
||||
if a.axis is not None and a.axis >= lin.shape_len: continue
|
||||
if a.axis is not None and lin.full_shape[a.axis] == a.amt and Opt(a.op, a.axis, 0) in actions: continue
|
||||
@@ -104,7 +103,7 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
|
||||
|
||||
# NOTE: real uops use a weird compare method that's only valid inside a linearizer
|
||||
def tuplize_uops(uops): return tuple([(x.uop, x.dtype, tuple(x.num for x in x.vin), x.arg) for x in uops])
|
||||
seen_uops = {tuplize_uops(lin.copy().linearize().uops): tuple(lin.applied_opts)}
|
||||
seen_uops = {tuplize_uops(lin.linearize().uops): tuple(lin.applied_opts)}
|
||||
|
||||
while 1:
|
||||
acted_lins = lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam])
|
||||
@@ -112,7 +111,7 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
|
||||
# dedup with uops (TODO: double linearize not needed)
|
||||
acted_lins_dedup = []
|
||||
for lin in acted_lins:
|
||||
tuops = tuplize_uops(lin.copy().linearize().uops)
|
||||
tuops = tuplize_uops(lin.linearize().uops)
|
||||
if tuops in seen_uops:
|
||||
#print(seen_uops[tuops], lin.applied_opts)
|
||||
continue
|
||||
|
||||
@@ -169,7 +169,7 @@ def cache_compiled(func):
|
||||
CACHEDB = getenv("CACHEDB", "/tmp/tinygrad_cache")
|
||||
CACHELEVEL = getenv("CACHELEVEL", 2)
|
||||
|
||||
VERSION = 4
|
||||
VERSION = 5
|
||||
_db_connection = None
|
||||
def db_connection():
|
||||
global _db_connection
|
||||
|
||||
Reference in New Issue
Block a user