From 7103b716c4d3099120c02f11870eccdff9737f2e Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 1 Nov 2023 15:20:01 -0700 Subject: [PATCH] merge kernel and optimizer (#2200) * merge kernel and optimizer * linearize is reentrant * move global/local size * clean up linearizer copy * remove unneeded lin copies * stop linearizing twice * oops, that should be None --- examples/handcode_resnet50_opt.py | 4 +- extra/optimization/extract_policynet.py | 2 +- extra/optimization/extract_sa_pairs.py | 2 +- extra/optimization/get_action_space.py | 7 + extra/optimization/pretrain_valuenet.py | 2 +- extra/optimization/test_beam_search.py | 12 +- test/test_linearizer.py | 2 +- test/test_search.py | 2 +- tinygrad/codegen/kernel.py | 452 +++++++++++++++++++++++- tinygrad/codegen/linearizer.py | 39 +- tinygrad/codegen/optimizer.py | 417 ---------------------- tinygrad/features/search.py | 13 +- tinygrad/helpers.py | 2 +- 13 files changed, 495 insertions(+), 461 deletions(-) delete mode 100644 tinygrad/codegen/optimizer.py diff --git a/examples/handcode_resnet50_opt.py b/examples/handcode_resnet50_opt.py index 73eef224d1..b2123517a8 100644 --- a/examples/handcode_resnet50_opt.py +++ b/examples/handcode_resnet50_opt.py @@ -59,9 +59,9 @@ if __name__ == "__main__": # benchmark the programs choices = [] for lin in lins: - tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10, should_copy=False) + tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10) gflops = sym_infer(lin.info.flops, {k:k.min for k in vars_from_ast(lin.ast)})*1e-9/tm - choices.append((tm, gflops, lin)) + choices.append((tm, gflops, lin.linearize())) # print all kernels if DEBUG >= 1: print(f" kernel {i:2d} {lin.display_name+' '*(37-ansilen(lin.display_name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS") diff --git a/extra/optimization/extract_policynet.py b/extra/optimization/extract_policynet.py index 774137d08f..149aacf2a1 100644 --- a/extra/optimization/extract_policynet.py +++ b/extra/optimization/extract_policynet.py @@ -17,7 +17,7 @@ from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.shape.symbolic import Variable inf, nan = float('inf'), float('nan') -from tinygrad.codegen.optimizer import Opt, OptOps +from tinygrad.codegen.kernel import Opt, OptOps INNER = 256 class PolicyNet: diff --git a/extra/optimization/extract_sa_pairs.py b/extra/optimization/extract_sa_pairs.py index 08d8846cf1..5a806e0a0e 100644 --- a/extra/optimization/extract_sa_pairs.py +++ b/extra/optimization/extract_sa_pairs.py @@ -10,7 +10,7 @@ from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.shape.symbolic import Variable inf, nan = float('inf'), float('nan') -from tinygrad.codegen.optimizer import Opt, OptOps +from tinygrad.codegen.kernel import Opt, OptOps # more stuff from tinygrad.codegen.linearizer import Linearizer diff --git a/extra/optimization/get_action_space.py b/extra/optimization/get_action_space.py index fb6f50c358..4939594545 100644 --- a/extra/optimization/get_action_space.py +++ b/extra/optimization/get_action_space.py @@ -1,3 +1,4 @@ +import random from tqdm import tqdm from extra.optimization.helpers import load_worlds, ast_str_to_lin from tinygrad.features.search import actions @@ -17,11 +18,17 @@ def test_rebuild(lin): if __name__ == "__main__": ast_strs = load_worlds(False, False, False) + random.shuffle(ast_strs) + ast_strs = ast_strs[:2000] for ast_str in tqdm(ast_strs): lin = ast_str_to_lin(ast_str) #if not lin.apply_tensor_cores(): lin.hand_coded_optimizations() test_rebuild(lin) + # confirm linearize can be called twice + uops1 = lin.linearize().uops + uops2 = lin.linearize().uops + assert tuple(uops1) == tuple(uops2), f"uops mismatch {lin.colored_shape()}" print(len(tactions), len(actions)) print(sorted(list(tactions))) diff --git a/extra/optimization/pretrain_valuenet.py b/extra/optimization/pretrain_valuenet.py index 102e22b0cd..dd850def98 100644 --- a/extra/optimization/pretrain_valuenet.py +++ b/extra/optimization/pretrain_valuenet.py @@ -14,7 +14,7 @@ from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.shape.symbolic import Variable inf, nan = float('inf'), float('nan') -from tinygrad.codegen.optimizer import Opt, OptOps +from tinygrad.codegen.kernel import Opt, OptOps from extra.optimization.helpers import lin_to_feats, MAX_DIMS diff --git a/extra/optimization/test_beam_search.py b/extra/optimization/test_beam_search.py index b8d66e51c8..298add9a95 100644 --- a/extra/optimization/test_beam_search.py +++ b/extra/optimization/test_beam_search.py @@ -1,9 +1,10 @@ import unittest import numpy as np -from tinygrad.helpers import BEAM +from tinygrad.helpers import BEAM, Timing from tinygrad.shape.symbolic import Variable from tinygrad.tensor import Tensor +from tinygrad.nn import Conv2d class TestBeamSearch(unittest.TestCase): def setUp(self): @@ -22,3 +23,12 @@ class TestBeamSearch(unittest.TestCase): a.assign(a+1) actual = a.numpy() np.testing.assert_allclose(actual, desired) + + def test_conv_beam(self): + c = Conv2d(3, 16, (3,3)) + x = Tensor.rand(1,3,32,32) + with Timing(): + c(x).realize() + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_linearizer.py b/test/test_linearizer.py index cb7bd5c336..bd37e2af75 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -2,7 +2,7 @@ import numpy as np import unittest, os from tinygrad.codegen.kernel import tensor_cores -from tinygrad.codegen.optimizer import Opt, OptOps +from tinygrad.codegen.kernel import Opt, OptOps from tinygrad.codegen.linearizer import Linearizer, UOps from tinygrad.ops import Compiled, Device from tinygrad.tensor import Tensor diff --git a/test/test_search.py b/test/test_search.py index 2e1ce67bf7..10b94b9551 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -12,5 +12,5 @@ class TestTimeLinearizer(unittest.TestCase): def test_reasonable_time(self): si = [si for si in Tensor([1,2,3,4]).add(1).lazydata.schedule() if si.ast.op not in LoadOps][0] rawbufs = [Device[Device.DEFAULT].buffer(si.out.st.size(), si.out.dtype)] + [Device[Device.DEFAULT].buffer(x.st.size(), x.dtype) for x in si.inputs] - tm = time_linearizer(Linearizer(si.ast), rawbufs, allow_test_size=False, cnt=10, should_copy=False) + tm = time_linearizer(Linearizer(si.ast), rawbufs, allow_test_size=False, cnt=10) assert tm > 0 and tm != float('inf') diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 281a32aa5c..dcfb79850f 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -1,13 +1,24 @@ from __future__ import annotations -from typing import NamedTuple, Optional, List, Tuple, cast, Dict -from copy import deepcopy -import itertools -from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, ReduceOps, MemBuffer, BufferOps, Device, Compiled -from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int, ansilen -from tinygrad.shape.shapetracker import ShapeTracker +import os, math, itertools +from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union +from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, Device, Compiled +from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int, ansilen, getenv, prod, DEBUG +from tinygrad.shape.shapetracker import ShapeTracker, get_contraction from tinygrad.shape.symbolic import sint -from tinygrad.shape.view import strides_for_shape +from tinygrad.shape.view import View, strides_for_shape from dataclasses import dataclass +from enum import Enum, auto + +class OptOps(Enum): + UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto() # noqa: E702 + def __lt__(self, x:OptOps): return self.value < x.value + +@dataclass(frozen=True, order=True) +class Opt: + op: OptOps + axis: Optional[int] = None + amt: Optional[int] = None + def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})" @dataclass(frozen=True) class TensorCore: @@ -22,7 +33,6 @@ class TensorCore: arch: Optional[str] = None def __str__(self): return f"tensor_core<{self.device}, {self.dims}, {self.dtype_in}, {self.dtype_out}>" -# TODO(TC): doesn't belong here!!! tensor_cores: Dict[str, List[TensorCore]] = { "METAL": [ TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), @@ -66,16 +76,25 @@ class Kernel: self.reduceop = reduceops[0] if reduceops else None # create new shapetrackers inside this kernel, we will permute them - self.bufs = [MemBuffer(0, self.info.dtype, ShapeTracker.from_shape(self.info.shape))] + dedup([x.arg for x in self.ast.get_lazyops() if x.op in BufferOps]) - self.sts: List[ShapeTracker] = [x.st for x in self.bufs] + self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = [MemBuffer(0, self.info.dtype, ShapeTracker.from_shape(self.info.shape))] + dedup([x.arg for x in self.ast.get_lazyops() if x.op in BufferOps]) - self.mem_estimate: int = sum(x.dtype.itemsize*x.st.size() for x in self.bufs) + # extract things from the buffers + self.mem_estimate: int = sum(x.dtype.itemsize*x.st.size() for x in cast(List[Union[MemBuffer, ConstBuffer]], self.bufs)) # get earlybufs, before the one reduce op self.earlybufs = [x.arg for x in self.reduceop.get_lazyops() if x.op in BufferOps] if self.reduceop else [] self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0 - # parameters + # create the (permuted) shapetrackers + self.sts: List[ShapeTracker] = [x.st for x in cast(List[Union[MemBuffer, ConstBuffer]], self.bufs)] + + # move all reduce axes to the end + reduce = list(enumerate(zip(self.full_shape, self.sts[0].shape))) + permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n]) + self.reshape_and_permute(None, permute) + + # parameters for optimization + self.applied_opts: List[Opt] = [] self.group_for_reduce: List[int] = [] self.upcasted: int = 0 self.local_dims: int = 0 @@ -83,18 +102,39 @@ class Kernel: self.tensor_core: Optional[TensorCore] = None self.dont_use_locals: bool = False - self.global_size: Optional[List[int]] = None - self.local_size: Optional[List[int]] = None + # group simplifies + self.simplify_ones() + self.simplify_merge_adjacent() + + # cache + self.applied_opts_cache: Optional[List[Opt]] = None def copy(self): - return deepcopy(self) + ret = type(self).__new__(type(self)) + + # base linearizer params + ret.opts, ret.ast = self.opts, self.ast + + # things downstream of the AST + # NOTE: we copy bufs for local buffers and sts for optimizations + ret.info, ret.reduceop, ret.bufs, ret.mem_estimate, ret.earlybufs, ret.full_buf_index, ret.sts = \ + self.info, self.reduceop, self.bufs[:], self.mem_estimate, self.earlybufs, self.full_buf_index, self.sts[:] + + # parameters for optimizations + ret.applied_opts, ret.group_for_reduce, ret.upcasted, ret.local_dims, ret.local_alias, ret.tensor_core, ret.dont_use_locals = \ + self.applied_opts[:], self.group_for_reduce[:], self.upcasted, self.local_dims, self.local_alias.copy(), self.tensor_core, self.dont_use_locals + + # uncached since linearize didn't run + ret.applied_opts_cache = None + + return ret @property def membufs(self) -> List[MemBuffer]: return [x for x in self.bufs if isinstance(x, MemBuffer)] def has_variable_shape(self) -> bool: for b in self.bufs: - if not all_int(b.st.views[-1].shape): return True + if not isinstance(b, LocalBuffer) and not all_int(b.st.views[-1].shape): return True return False def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()] @@ -170,3 +210,383 @@ class Kernel: print(prefix, f"{i:3d} {str(self.bufs[i]):47s}", st.views) print(self.colored_shape()) + # ******************** base simplifiers ******************** + + # apply reshape and permute to all shapetrackers + def reshape_and_permute(self, new_shape_fxn, axis): + new_sts = [] + for st in self.sts: + if new_shape_fxn is not None: st = st.reshape(tuple(new_shape_fxn(st.shape))) + if axis is not None: st = st.permute(tuple(axis)) + new_sts.append(st) + self.sts = new_sts + + # drops the final dimension + def upcast(self): + assert self.full_shape[-1] != 1, "can't upcast a dimension with size 1" + self.upcasted += 1 + + # axis : the axis to pull from + # amount : the amount to take + # top : if you want to pull that amount from the top + # insert_before : place to insert the new stuff + def shift_to(self, axis, amount, top=False, insert_before=None): + if insert_before is None: insert_before = self.shape_len + move_axis = axis if top else axis+1 + if move_axis < insert_before: insert_before += 1 + self.reshape_and_permute( + lambda x: list(x[0:axis]) + (([amount, x[axis]//amount] if top else [x[axis]//amount, amount]) if x[axis] > 1 else [1,1]) + list(x[axis+1:]), + [i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis]) + + # ******************** complex simplifiers ******************** + + def simplify_ones(self) -> bool: + # remove places where the shape is all ones + # TODO: this should be factored in to multi shape stride + if self.shape_len == 0: return False + all_ones = [s==1 for s in self.full_shape] + self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce]) + self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:]) + self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None) + return any(all_ones) + + def simplify_merge_adjacent(self): + if self.shape_len == 0: return + shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts] + + # if it's an image, insert fake strides such that this fusion doesn't happen across image axes + if isinstance(self.bufs[0].dtype, ImageDType): + base_shape = self.bufs[0].dtype.shape + if shape_idx_groups := get_contraction(self.output_shape, base_shape): + special_strides: Tuple[int, ...] = tuple() + for i,g in enumerate(shape_idx_groups): + shape_piece = tuple(self.output_shape[x] for x in g) + assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}" + special_strides += strides_for_shape(shape_piece) + # adding the fake image shape + shapes.append(self.output_shape) + strides.append(special_strides) + + # merge dimensions if we can, multi get_shape_strides + # TODO: does this always preserve the reduce dimension, NO + # TODO: move this into shapetracker, with tests! + rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))] + for i in range(1, len(shapes[0])): + can_merge = [] + for j in range(len(shapes)): + # TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case + can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0))) + # more can merge than this + mergeable = all(can_merge) and i != self.first_reduce + for j in range(len(shapes)): + if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i]) + else: rets[j].append((shapes[j][i], strides[j][i])) + + # do the reshapes + for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x])) + + # ******************** GPU simplifiers ******************** + def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]: + new_shape,dims = list(x), len(x) + for i in range(dims): + next_idx = (i + 1) % dims + while new_shape[i] > max_size[i]: + new_shape[i] = new_shape[i] // 2 + if (new_shape[next_idx] <= max_size[next_idx]): + new_shape[next_idx] = new_shape[next_idx] * 2 + else: + next_idx = (next_idx + 1) % dims + new_shape[next_idx] = new_shape[next_idx] * 2 + return tuple(new_shape) + + def limit_dims_to_max(self, global_max: List[int], local_max: List[int]): + # Check the global allocation limit, current the global_size will be flipped during codegen + # and then padded right with 1s if its length < 3 which makes this part a bit awkward to write + global_dims = self.first_reduce-self.local_dims + if global_dims > 0: + if global_max: + tmp = global_max[:global_dims] + (local_max[:self.local_dims] if local_max else []) + if max(global_max) < max(self.full_shape[:global_dims]): self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None) + assert max(global_max) >= max(self.full_shape[:global_dims]), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}" + for i in range(global_dims-1): + if self.full_shape[i] > global_max[i]: + order = list(range(len(self.full_shape))) + order[i], order[global_dims-1] = order[global_dims-1], order[i] + self.reshape_and_permute(None, order) + if DEBUG >= 3: print("permuted global dim", order, "due to allocation exceeds global limit") + + def alias_buffer(self, i, pattern): + assert len(pattern) == len(self.sts[i].shape), f"must include a pattern for each shape {pattern} {self.sts[i].shape}" + + bst = 1 + real_strides = self.sts[i].real_strides() + shp, stride = [(s if p != 0 else 1) for s,p in zip(self.sts[i].shape, pattern)], [0]*len(pattern) + for priority in range(1, max(pattern)+1): # priority. 0 is non local and ignored + for j,p in enumerate(pattern): + if priority == p and real_strides[j] != 0: + stride[j] = bst + bst *= shp[j] + + self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),))) + self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size())) + if DEBUG >= 4: print("aliasing buffer", self.sts[i]) + self.local_alias[i] = cast(LocalBuffer, self.bufs[-1]) + + # ******************** high level optimizers ******************** + + def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None): + if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op == ReduceOps.SUM and self.opts.device in tensor_cores: + for tc in tensor_cores[self.opts.device]: + if not((tc.arch is None or tc.arch == os.uname().machine) and isinstance(self.reduceop.src[0], LazyOp)): continue + has_cast = tc.dtype_in != tc.dtype_out + + if has_cast and not(isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue + mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0] + + if not(isinstance(mul_op, LazyOp) and mul_op.op == BinaryOps.MUL): continue + if not(isinstance(mul_op.src[0], LazyOp) and mul_op.src[0].op == BufferOps.MEM and mul_op.src[0].arg.dtype == tc.dtype_in): continue + if not(isinstance(mul_op.src[1], LazyOp) and mul_op.src[1].op == BufferOps.MEM and mul_op.src[1].arg.dtype == tc.dtype_in): continue + buf0, buf1 = self.bufs.index(cast(MemBuffer, mul_op.src[0].arg)), self.bufs.index(cast(MemBuffer, mul_op.src[1].arg)) + buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides() + axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[0] == 0] + axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[1] == 0] + + if not(axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%tc.dims[2] == 0 and self.full_shape[self.first_reduce] >= tc.dims[2] and (self.shape_len-self.first_reduce) == 1): continue + + if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc) + + s0, s1 = axis_buf0[-1][0], axis_buf1[-1][0] # TODO: select axis in smart way + s0_exists, s1_exists = True, True + assert s0 != s1 and self.full_shape[s0]%tc.dims[0] == 0 and self.full_shape[s1]%tc.dims[1] == 0 + def fix(needed, ax): + nonlocal s0, s1, s0_exists, s1_exists + if not needed: return + if s0_exists and ax == s0: + if s1_exists and s0 < s1: s1 -= 1 + s0_exists = False + elif s1_exists and ax == s1: + if s0_exists and s1 < s0: s0 -= 1 + s1_exists = False + + # tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern + self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2])) + self.apply_opt(Opt(OptOps.UPCAST, s0 if tc.upcast_dim == 0 else s1, (tc.dims[0]*tc.dims[2])//prod([a[1] for a in tc.threads]))) + for (tc_dim, tc_amt) in tc.threads: + fix(self.apply_opt(Opt(OptOps.LASTLOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1) + + # assert tensor core and prevent extra_opts from altering the key shape structure + if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA + + if extra_opts is not None: + for opt in extra_opts: + self.apply_opt(opt) + else: + # hand-coded TC opts + if s1_exists: + s1_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s1]%upc == 0][0] + if s1_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s1, s1_div)), s1) + if s0_exists: + s0_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s0]%upc == 0][0] + if s0_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s0, s0_div)), s0) + if self.tensor_core and s0_exists: + for upc in [4,2]: + if self.full_shape[s0] % upc == 0: + self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc)) + break + + # alias buffer + alias_pattern = [0]*(self.global_dims+(self.local_dims-len(tc.threads))) + [2]*(len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2) + self.alias_buffer(buf0, alias_pattern) + self.alias_buffer(buf1, alias_pattern) + return True + return False + + def apply_opt(self, opt:Opt): + assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.LASTLOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals" + self.applied_opts.append(opt) + if opt.axis is not None: + axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP else 0)) + else: + axis = -1 + if opt.amt is not None: + amt = opt.amt if opt.amt != 0 else self.full_shape[axis] + assert self.full_shape[axis] % amt == 0, "no longer valid shift" + assert isinstance(amt, int) and amt != 1, "shift of amt 1 or Node is meaningless" + else: + amt = -1 + if opt.op == OptOps.LOCAL: # cyan + assert axis < self.first_reduce, "can't local a reduce" + assert not(self.tensor_core), "can't local with tensor cores" + self.shift_to(axis, amt, insert_before=self.first_reduce) + self.local_dims += 1 + elif opt.op == OptOps.LASTLOCAL: # cyan + assert axis < self.first_reduce, "can't local a reduce" + self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims) + self.local_dims += 1 + elif opt.op == OptOps.GROUP: # green + assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group" + assert not(self.tensor_core), "can't group with tensor cores" + self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce)) + self.group_for_reduce.append(amt) + elif opt.op == OptOps.GROUPTOP: # green + assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group" + assert not(self.tensor_core), "can't group with tensor cores" + self.shift_to(axis, amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce)) + self.group_for_reduce.append(amt) + elif opt.op == OptOps.UNROLL: # purple + assert axis < self.shape_len-self.upcasted, "can't upcasted already upcasted" + assert amt <= 32, "don't unroll more than 32" + self.shift_to(axis, amt, insert_before=None) + self.upcast() + elif opt.op == OptOps.UPCAST: # yellow + assert axis < self.first_reduce, "upcast is for non-reduce" + assert amt <= 8, "don't upcast more than 8" + self.shift_to(axis, amt, insert_before=None) + self.upcast() + elif opt.op == OptOps.UPCASTMID: # white + assert self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce" + axes = self.sts[0].unit_stride_axes() + assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}" + assert axes[0] == axis, "wrong axis" + assert amt == 4, "don't upcast mid anything but 4" + self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce)) + self.group_for_reduce.append(amt) + elif opt.op == OptOps.NOLOCALS: + assert self.local_dims == 0 and len(self.group_for_reduce) == 0, "can't have no locals with locals" + assert not self.dont_use_locals, "already not using locals" + self.dont_use_locals = True + return self.simplify_ones() + + def required_optimizations(self, early_only=False): + for buf_index,buf in enumerate(self.bufs): + unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0] + if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.__class__ is ImageDType: + assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}" + if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: + if unit_stride_axes_mul_4[0] < self.first_reduce: + self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4)) + else: + self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4)) + + def hand_coded_optimizations(self): + # if there's images in the earlybufs, we have to make an axis the 4 loading one + self.required_optimizations(early_only=True) + + # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat + MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4) + if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \ + self.reduceop and self.reduceop.op == ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \ + isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == BinaryOps.MUL and \ + self.reduceop.src[0].src[0].op == BufferOps.MEM and self.reduceop.src[0].src[1].op == BufferOps.MEM: + buf0 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[0]).arg) + buf1 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[1]).arg) + buf0_strides = self.sts[buf0].real_strides() + buf1_strides = self.sts[buf1].real_strides() + def has_expanded_axis(s, st): return any(x > 1 and y == 0 for x,y in zip(s,st)) + if buf0_strides[self.first_reduce] == 1 and not (has_expanded_axis(self.sts[buf0].shape, buf0_strides) and has_expanded_axis(self.sts[buf1].shape, buf1_strides)): + for global_idx in range(self.global_dims): + if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0: + if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}") + if MV_THREADS_PER_ROW > 1: + self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW)) + if MV_BLOCKSIZE > 1: + self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE)) + if MV_ROWS_PER_THREAD > 1: + self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD)) + return + + if self.opts.has_local and self.opts.has_shared and all(isinstance(s, int) for s in self.sts[0].shape[:self.first_reduce]): + # are we grouping? (requires local shape support) + if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: + # TODO: use 1024 if it's allowed in a smarter way + for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]): + if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts): + self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz)) + break + + # are we upcasting in mid reduce? (only for images) + if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: + axes = self.sts[0].unit_stride_axes() + assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}" + if self.sts[0].shape[axes[0]]%4 == 0: + self.apply_opt(Opt(OptOps.UPCASTMID, axes[0], 4)) + + # now do everything required + self.required_optimizations() + + # no more opt if we are grouping + if self.group_for_reduce: return + + # **** below this line need to be optional and benchmarked **** + + # TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx) + # to trigger the above bug, remove prod(self.full_shape[self.shape_len - self.upcasted:]) from the below + # expression and run test/test_ops.py with IMAGE=2 + # if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack) + # this can be made much smarter + to_upcast: List[int] = [] + # upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first) + for axis in range(self.first_reduce): + # we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent + # for now skip upcasting here if there is a symbolic axis + if isinstance(self.full_shape[axis], int) and self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and \ + prod(self.full_shape[self.shape_len - self.upcasted:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7: + if DEBUG >= 4: print(f"upcasting masked axis : {axis}") + to_upcast.append(axis) + for axis in to_upcast[::-1]: + self.apply_opt(Opt(OptOps.UPCAST, axis, 0)) + + # potentially do more upcasts of non reduce axes based on a heuristic + upcasted_axis = set() + while prod(self.sts[0].shape[:self.first_reduce]) >= 1024: + xb_choices = [] + for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce + # if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already + if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)): + xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount)) + if xb_choices: + xb_choices = sorted(xb_choices) + if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}") + self.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3])) + upcasted_axis.add(xb_choices[0][2]) + else: + break + + # if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS + if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64): + if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis + self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0)) + # if it's small, upcast a second reduce dimension too + if self.first_reduce < (self.shape_len-self.upcasted) and s <= 3 and (s2:=self.full_unupcasted_shape[-1]) <= 3 and isinstance(s2, int): + self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0)) + else: + for splits in [4]: + if self.full_unupcasted_shape[-1]%splits == 0: + self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, splits)) + break + + # if nothing at all is upcasted and it's easy to, do an upcast + # TODO: this is breaking the tests + for splits in [4]: + if self.upcasted == 0 and self.full_unupcasted_shape and self.full_unupcasted_shape[-1] % splits == 0: + self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, splits)) + + # **** local groups **** + + if self.opts.has_local: + if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduce: + self.apply_opt(Opt(OptOps.NOLOCALS)) + else: + # prioritize making expand axes local + local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))] + to_local: List[Tuple[int, int]] = [] + for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])): + local_size = prod(sz for _, sz in to_local) + local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None) + if local_sz is not None: to_local.append((axis, local_sz)) + deleted_shape = 0 + for axis, local_sz in sorted(to_local[:3]): + axis = axis - deleted_shape + will_delete_shape = local_sz == self.full_shape[axis] + self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz)) + if will_delete_shape: deleted_shape += 1 diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index ada5eb4176..8067bb2d2c 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -9,8 +9,7 @@ from tinygrad.ops import LazyOp, UnaryOps, ConstBuffer, MemBuffer, BufferOps from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import Variable, NumNode, VariableOrNum, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, sym_rename -from tinygrad.codegen.optimizer import OptimizedKernel -from tinygrad.codegen.kernel import LocalBuffer +from tinygrad.codegen.kernel import LocalBuffer, Kernel from tinygrad.lazy import vars_from_ast from tinygrad.features.image import to_image_idx @@ -45,7 +44,7 @@ def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0): local_idxs = local_idxs[0:maxdim-1] + nli[::-1] return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)] -class Linearizer(OptimizedKernel): +class Linearizer(Kernel): def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype=dtypes.int32): render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx)) return self.uop(UOps.ALU, dtype, (a, render_b), op) @@ -62,7 +61,8 @@ class Linearizer(OptimizedKernel): AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) } def global_load(self, i:int, idxs:Sequence[Node], acc=None) -> List[UOp]: - const = self.bufs[i].val if isinstance(self.bufs[i], ConstBuffer) else acc + buf = self.bufs[i] + const = buf.val if isinstance(buf, ConstBuffer) else acc def rename_var(v: VariableOrNum, expr: str): return v if isinstance(v, NumNode) else Variable(expr, v.min, v.max) @@ -84,10 +84,10 @@ class Linearizer(OptimizedKernel): e_idxs, e_valids = g_idx.expand(expand_vars), g_valid.expand(expand_vars) ret = [] - invalid_value = 0 if dtypes.is_int(self.bufs[i].dtype) else 0.0 + invalid_value = 0 if dtypes.is_int(buf.dtype) else 0.0 for idx, valid, rep_idx in zip(e_idxs, e_valids, Node.iter_idxs(expand_vars)): this_const, idx, valid = (invalid_value, Variable.num(0), Variable.num(1)) if valid.max == 0 else (const, idx, valid) - key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else (self.bufs[i].idx if isinstance(self.bufs[i], MemBuffer) else self.bufs[i].name)}{idx.render()}{valid.render()}" + key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" if key not in self.load_cache: if acc is not None: assert valid.min == 1 @@ -100,8 +100,8 @@ class Linearizer(OptimizedKernel): else: buf_uop = self.buf_uops[i] assert buf_uop is not None, f"buffer {i} wasn't UOped" - if isinstance(self.bufs[i].dtype, ImageDType): - idx, valid = to_image_idx(self.bufs[i].dtype.shape, idx, valid) + if isinstance(buf.dtype, ImageDType): + idx, valid = to_image_idx(buf.dtype.shape, idx, valid) rendered_idx = self.uop(UOps.CAST, dtypes._int2, (idx[0].render(self.render_ops, self), idx[1].render(self.render_ops, self))) else: rendered_idx = idx.render(self.render_ops, self) @@ -115,6 +115,7 @@ class Linearizer(OptimizedKernel): return ret def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> None: + buf = self.bufs[i] buf_uop = self.buf_uops[i] assert buf_uop is not None, f"buffer {i} wasn't UOped" @@ -140,8 +141,8 @@ class Linearizer(OptimizedKernel): for idx, var in store_offset.items(): idx, valid = self.sts[i].expr_idxs(idx) - if isinstance(self.bufs[i].dtype, ImageDType): - idx, valid = to_image_idx(self.bufs[i].dtype.shape, idx, valid) + if isinstance(buf.dtype, ImageDType): + idx, valid = to_image_idx(buf.dtype.shape, idx, valid) rendered_idx = self.uop(UOps.CAST, dtypes._int2, tuple(x.render(self.render_ops, self) for x in idx)) else: rendered_idx = idx.render(self.render_ops, self) @@ -149,6 +150,12 @@ class Linearizer(OptimizedKernel): kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int) def linearize(self): + # no new opts and we already ran? skip relinearizing + if self.applied_opts == self.applied_opts_cache: return self + + # save backups + sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduce[:], self.upcasted + # global uop cache self.saved_exprs: Dict[Tuple, UOp] = dict() @@ -207,6 +214,9 @@ class Linearizer(OptimizedKernel): loop_uop = self.loop_uops[x.expr] if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, (loop_uop,)) + # set global/local size + self.global_size: Optional[List[int]] = None + self.local_size: Optional[List[int]] = None if self.dont_use_locals: self.global_size = [x.max+1 for x in loop_global_idxs][::-1] self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) @@ -338,10 +348,10 @@ class Linearizer(OptimizedKernel): render_loop(end_local_idxs) # load localbufs - loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) + loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) # there's no AST here (and there's no shape for the reduce LazyOp) - self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True) # type: ignore + self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True) # type: ignore # end the late reduce loop end_loop(end_local_idxs) @@ -372,6 +382,11 @@ class Linearizer(OptimizedKernel): if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}") self.uops = nu + # restore backups + self.sts, self.group_for_reduce, self.upcasted = sts_backup, gfr_backup, upc_backup + + # set cache and return + self.applied_opts_cache = self.applied_opts[:] return self def uop(self, uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None, cachable=True) -> UOp: diff --git a/tinygrad/codegen/optimizer.py b/tinygrad/codegen/optimizer.py deleted file mode 100644 index 84ce4e6f2a..0000000000 --- a/tinygrad/codegen/optimizer.py +++ /dev/null @@ -1,417 +0,0 @@ -from __future__ import annotations -from typing import Tuple, List, cast, Optional -from dataclasses import dataclass -import itertools, math, os -from tinygrad.helpers import DEBUG, prod, getenv, ImageDType -from tinygrad.ops import ReduceOps, BinaryOps, UnaryOps, LazyOp, BufferOps -from tinygrad.codegen.kernel import Kernel, LocalBuffer, LinearizerOptions, tensor_cores -from tinygrad.shape.shapetracker import ShapeTracker, get_contraction -from tinygrad.shape.view import View, strides_for_shape -from enum import Enum, auto - -class OptOps(Enum): - UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto() # noqa: E702 - def __lt__(self, x:OptOps): return self.value < x.value - -@dataclass(frozen=True, order=True) -class Opt: - op: OptOps - axis: Optional[int] = None - amt: Optional[int] = None - def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})" - -class OptimizedKernel(Kernel): - def __init__(self, ast:LazyOp, opts:Optional[LinearizerOptions]=None): - super().__init__(ast, opts) - - # move all reduce axes to the end - reduce = list(enumerate(zip(self.full_shape, self.sts[0].shape))) - permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n]) - self.reshape_and_permute(None, permute) - - # group simplifies - self.simplify_ones() - self.simplify_merge_adjacent() - - self.applied_opts: List[Opt] = [] - - # ******************** base simplifiers ******************** - - # apply reshape and permute to all shapetrackers - def reshape_and_permute(self, new_shape_fxn, axis): - new_sts = [] - for st in self.sts: - if new_shape_fxn is not None: st = st.reshape(tuple(new_shape_fxn(st.shape))) - if axis is not None: st = st.permute(tuple(axis)) - new_sts.append(st) - self.sts = new_sts - - # drops the final dimension - def upcast(self): - assert self.full_shape[-1] != 1, "can't upcast a dimension with size 1" - self.upcasted += 1 - - # axis : the axis to pull from - # amount : the amount to take - # top : if you want to pull that amount from the top - # insert_before : place to insert the new stuff - def shift_to(self, axis, amount, top=False, insert_before=None): - if insert_before is None: insert_before = self.shape_len - move_axis = axis if top else axis+1 - if move_axis < insert_before: insert_before += 1 - self.reshape_and_permute( - lambda x: list(x[0:axis]) + (([amount, x[axis]//amount] if top else [x[axis]//amount, amount]) if x[axis] > 1 else [1,1]) + list(x[axis+1:]), - [i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis]) - - # ******************** complex simplifiers ******************** - - def simplify_ones(self) -> bool: - # remove places where the shape is all ones - # TODO: this should be factored in to multi shape stride - if self.shape_len == 0: return False - all_ones = [s==1 for s in self.full_shape] - self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce]) - self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:]) - self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None) - return any(all_ones) - - def simplify_merge_adjacent(self): - if self.shape_len == 0: return - shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts] - - # if it's an image, insert fake strides such that this fusion doesn't happen across image axes - if self.bufs[0].dtype.name.startswith('image'): - base_shape = self.bufs[0].dtype.shape - if shape_idx_groups := get_contraction(self.output_shape, base_shape): - special_strides: Tuple[int, ...] = tuple() - for i,g in enumerate(shape_idx_groups): - shape_piece = tuple(self.output_shape[x] for x in g) - assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}" - special_strides += strides_for_shape(shape_piece) - # adding the fake image shape - shapes.append(self.output_shape) - strides.append(special_strides) - - # merge dimensions if we can, multi get_shape_strides - # TODO: does this always preserve the reduce dimension, NO - # TODO: move this into shapetracker, with tests! - rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))] - for i in range(1, len(shapes[0])): - can_merge = [] - for j in range(len(shapes)): - # TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case - can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0))) - # more can merge than this - mergeable = all(can_merge) and i != self.first_reduce - for j in range(len(shapes)): - if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i]) - else: rets[j].append((shapes[j][i], strides[j][i])) - - # do the reshapes - for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x])) - - # ******************** GPU simplifiers ******************** - def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]: - new_shape,dims = list(x), len(x) - for i in range(dims): - next_idx = (i + 1) % dims - while new_shape[i] > max_size[i]: - new_shape[i] = new_shape[i] // 2 - if (new_shape[next_idx] <= max_size[next_idx]): - new_shape[next_idx] = new_shape[next_idx] * 2 - else: - next_idx = (next_idx + 1) % dims - new_shape[next_idx] = new_shape[next_idx] * 2 - return tuple(new_shape) - - def limit_dims_to_max(self, global_max: List[int], local_max: List[int]): - # Check the global allocation limit, current the global_size will be flipped during codegen - # and then padded right with 1s if its length < 3 which makes this part a bit awkward to write - global_dims = self.first_reduce-self.local_dims - if global_dims > 0: - if global_max: - tmp = global_max[:global_dims] + (local_max[:self.local_dims] if local_max else []) - if max(global_max) < max(self.full_shape[:global_dims]): self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None) - assert max(global_max) >= max(self.full_shape[:global_dims]), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}" - for i in range(global_dims-1): - if self.full_shape[i] > global_max[i]: - order = list(range(len(self.full_shape))) - order[i], order[global_dims-1] = order[global_dims-1], order[i] - self.reshape_and_permute(None, order) - if DEBUG >= 3: print("permuted global dim", order, "due to allocation exceeds global limit") - - def alias_buffer(self, i, pattern): - assert len(pattern) == len(self.sts[i].shape), f"must include a pattern for each shape {pattern} {self.sts[i].shape}" - - bst = 1 - real_strides = self.sts[i].real_strides() - shp, stride = [(s if p != 0 else 1) for s,p in zip(self.sts[i].shape, pattern)], [0]*len(pattern) - for priority in range(1, max(pattern)+1): # priority. 0 is non local and ignored - for j,p in enumerate(pattern): - if priority == p and real_strides[j] != 0: - stride[j] = bst - bst *= shp[j] - - self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),))) - self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size())) - if DEBUG >= 4: print("aliasing buffer", self.sts[i]) - self.local_alias[i] = self.bufs[-1] - - # ******************** high level optimizers ******************** - - def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None): - if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op == ReduceOps.SUM and self.opts.device in tensor_cores: - for tc in tensor_cores[self.opts.device]: - if not((tc.arch is None or tc.arch == os.uname().machine) and isinstance(self.reduceop.src[0], LazyOp)): continue - has_cast = tc.dtype_in != tc.dtype_out - - if has_cast and not(isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue - mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0] - - if not(isinstance(mul_op, LazyOp) and mul_op.op == BinaryOps.MUL): continue - if not(isinstance(mul_op.src[0], LazyOp) and mul_op.src[0].op == BufferOps.MEM and mul_op.src[0].arg.dtype == tc.dtype_in): continue - if not(isinstance(mul_op.src[1], LazyOp) and mul_op.src[1].op == BufferOps.MEM and mul_op.src[1].arg.dtype == tc.dtype_in): continue - buf0, buf1 = self.bufs.index(cast(LazyOp, mul_op.src[0].arg)), self.bufs.index(cast(LazyOp, mul_op.src[1].arg)) - buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides() - axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[0] == 0] - axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[1] == 0] - - if not(axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%tc.dims[2] == 0 and self.full_shape[self.first_reduce] >= tc.dims[2] and (self.shape_len-self.first_reduce) == 1): continue - - if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc) - - s0, s1 = axis_buf0[-1][0], axis_buf1[-1][0] # TODO: select axis in smart way - s0_exists, s1_exists = True, True - assert s0 != s1 and self.full_shape[s0]%tc.dims[0] == 0 and self.full_shape[s1]%tc.dims[1] == 0 - def fix(needed, ax): - nonlocal s0, s1, s0_exists, s1_exists - if not needed: return - if s0_exists and ax == s0: - if s1_exists and s0 < s1: s1 -= 1 - s0_exists = False - elif s1_exists and ax == s1: - if s0_exists and s1 < s0: s0 -= 1 - s1_exists = False - - # tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern - self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2])) - self.apply_opt(Opt(OptOps.UPCAST, s0 if tc.upcast_dim == 0 else s1, (tc.dims[0]*tc.dims[2])//prod([a[1] for a in tc.threads]))) - for (tc_dim, tc_amt) in tc.threads: - fix(self.apply_opt(Opt(OptOps.LASTLOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1) - - # assert tensor core and prevent extra_opts from altering the key shape structure - if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA - - if extra_opts is not None: - for opt in extra_opts: - self.apply_opt(opt) - else: - # hand-coded TC opts - if s1_exists: - s1_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s1]%upc == 0][0] - if s1_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s1, s1_div)), s1) - if s0_exists: - s0_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s0]%upc == 0][0] - if s0_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s0, s0_div)), s0) - if self.tensor_core and s0_exists: - for upc in [4,2]: - if self.full_shape[s0] % upc == 0: - self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc)) - break - - # alias buffer - alias_pattern = [0]*(self.global_dims+(self.local_dims-len(tc.threads))) + [2]*(len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2) - self.alias_buffer(buf0, alias_pattern) - self.alias_buffer(buf1, alias_pattern) - return True - return False - - def apply_opt(self, opt:Opt): - assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.LASTLOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals" - self.applied_opts.append(opt) - if opt.axis is not None: - axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP else 0)) - else: - axis = -1 - if opt.amt is not None: - amt = opt.amt if opt.amt != 0 else self.full_shape[axis] - assert self.full_shape[axis] % amt == 0, "no longer valid shift" - assert isinstance(amt, int) and amt != 1, "shift of amt 1 or Node is meaningless" - else: - amt = -1 - if opt.op == OptOps.LOCAL: # cyan - assert axis < self.first_reduce, "can't local a reduce" - assert not(self.tensor_core), "can't local with tensor cores" - self.shift_to(axis, amt, insert_before=self.first_reduce) - self.local_dims += 1 - elif opt.op == OptOps.LASTLOCAL: # cyan - assert axis < self.first_reduce, "can't local a reduce" - self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims) - self.local_dims += 1 - elif opt.op == OptOps.GROUP: # green - assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group" - assert not(self.tensor_core), "can't group with tensor cores" - self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce)) - self.group_for_reduce.append(amt) - elif opt.op == OptOps.GROUPTOP: # green - assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group" - assert not(self.tensor_core), "can't group with tensor cores" - self.shift_to(axis, amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce)) - self.group_for_reduce.append(amt) - elif opt.op == OptOps.UNROLL: # purple - assert axis < self.shape_len-self.upcasted, "can't upcasted already upcasted" - assert amt <= 32, "don't unroll more than 32" - self.shift_to(axis, amt, insert_before=None) - self.upcast() - elif opt.op == OptOps.UPCAST: # yellow - assert axis < self.first_reduce, "upcast is for non-reduce" - assert amt <= 8, "don't upcast more than 8" - self.shift_to(axis, amt, insert_before=None) - self.upcast() - elif opt.op == OptOps.UPCASTMID: # white - assert self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce" - axes = self.sts[0].unit_stride_axes() - assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}" - assert axes[0] == axis, "wrong axis" - assert amt == 4, "don't upcast mid anything but 4" - self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce)) - self.group_for_reduce.append(amt) - elif opt.op == OptOps.NOLOCALS: - assert self.local_dims == 0 and len(self.group_for_reduce) == 0, "can't have no locals with locals" - assert not self.dont_use_locals, "already not using locals" - self.dont_use_locals = True - return self.simplify_ones() - - def required_optimizations(self, early_only=False): - for buf_index,buf in enumerate(self.bufs): - unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0] - if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.__class__ is ImageDType: - assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}" - if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: - if unit_stride_axes_mul_4[0] < self.first_reduce: - self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4)) - else: - self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4)) - - def hand_coded_optimizations(self): - # if there's images in the earlybufs, we have to make an axis the 4 loading one - self.required_optimizations(early_only=True) - - # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat - MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4) - if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \ - self.reduceop and self.reduceop.op == ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \ - isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == BinaryOps.MUL and \ - self.reduceop.src[0].src[0].op == BufferOps.MEM and self.reduceop.src[0].src[1].op == BufferOps.MEM: - buf0 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[0]).arg) - buf1 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[1]).arg) - buf0_strides = self.sts[buf0].real_strides() - buf1_strides = self.sts[buf1].real_strides() - def has_expanded_axis(s, st): return any(x > 1 and y == 0 for x,y in zip(s,st)) - if buf0_strides[self.first_reduce] == 1 and not (has_expanded_axis(self.sts[buf0].shape, buf0_strides) and has_expanded_axis(self.sts[buf1].shape, buf1_strides)): - for global_idx in range(self.global_dims): - if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0: - if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}") - if MV_THREADS_PER_ROW > 1: - self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW)) - if MV_BLOCKSIZE > 1: - self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE)) - if MV_ROWS_PER_THREAD > 1: - self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD)) - return - - if self.opts.has_local and self.opts.has_shared and all(isinstance(s, int) for s in self.sts[0].shape[:self.first_reduce]): - # are we grouping? (requires local shape support) - if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: - # TODO: use 1024 if it's allowed in a smarter way - for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]): - if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts): - self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz)) - break - - # are we upcasting in mid reduce? (only for images) - if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: - axes = self.sts[0].unit_stride_axes() - assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}" - if self.sts[0].shape[axes[0]]%4 == 0: - self.apply_opt(Opt(OptOps.UPCASTMID, axes[0], 4)) - - # now do everything required - self.required_optimizations() - - # no more opt if we are grouping - if self.group_for_reduce: return - - # **** below this line need to be optional and benchmarked **** - - # TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx) - # to trigger the above bug, remove prod(self.full_shape[self.shape_len - self.upcasted:]) from the below - # expression and run test/test_ops.py with IMAGE=2 - # if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack) - # this can be made much smarter - to_upcast: List[int] = [] - # upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first) - for axis in range(self.first_reduce): - # we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent - # for now skip upcasting here if there is a symbolic axis - if isinstance(self.full_shape[axis], int) and self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and \ - prod(self.full_shape[self.shape_len - self.upcasted:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7: - if DEBUG >= 4: print(f"upcasting masked axis : {axis}") - to_upcast.append(axis) - for axis in to_upcast[::-1]: - self.apply_opt(Opt(OptOps.UPCAST, axis, 0)) - - # potentially do more upcasts of non reduce axes based on a heuristic - upcasted_axis = set() - while prod(self.sts[0].shape[:self.first_reduce]) >= 1024: - xb_choices = [] - for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce - # if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already - if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)): - xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount)) - if xb_choices: - xb_choices = sorted(xb_choices) - if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}") - self.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3])) - upcasted_axis.add(xb_choices[0][2]) - else: - break - - # if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS - if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64): - if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis - self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0)) - # if it's small, upcast a second reduce dimension too - if self.first_reduce < (self.shape_len-self.upcasted) and s <= 3 and (s2:=self.full_unupcasted_shape[-1]) <= 3 and isinstance(s2, int): - self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0)) - else: - for splits in [4]: - if self.full_unupcasted_shape[-1]%splits == 0: - self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, splits)) - break - - # if nothing at all is upcasted and it's easy to, do an upcast - # TODO: this is breaking the tests - for splits in [4]: - if self.upcasted == 0 and self.full_unupcasted_shape and self.full_unupcasted_shape[-1] % splits == 0: - self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, splits)) - - # **** local groups **** - - if self.opts.has_local: - if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduce: - self.apply_opt(Opt(OptOps.NOLOCALS)) - else: - # prioritize making expand axes local - local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))] - to_local: List[Tuple[int, int]] = [] - for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])): - local_size = prod(sz for _, sz in to_local) - local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None) - if local_sz is not None: to_local.append((axis, local_sz)) - deleted_shape = 0 - for axis, local_sz in sorted(to_local[:3]): - axis = axis - deleted_shape - will_delete_shape = local_sz == self.full_shape[axis] - self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz)) - if will_delete_shape: deleted_shape += 1 diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index 59d9551661..72767a7fdb 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -7,7 +7,7 @@ from tinygrad.runtime.lib import RawBuffer from collections import defaultdict from tinygrad.tensor import Tensor -from tinygrad.codegen.optimizer import Opt, OptOps +from tinygrad.codegen.kernel import Opt, OptOps actions = flatten([[Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,7]] for axis in range(6)]) actions += flatten([[Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4]] for axis in range(4)]) actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29]] for axis in range(5)]) @@ -20,10 +20,9 @@ actions += [ ] # returns time in seconds -def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, should_copy=True, disable_cache=False, clear_l2=False) -> float: +def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size} - if should_copy and not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val) - if should_copy: lin = lin.copy() # TODO: remove the need for this + if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val) var_vals = {k:k.min for k in vars_from_ast(lin.ast)} try: lin.linearize() @@ -75,7 +74,7 @@ def bufs_from_lin(lin:Linearizer) -> List[RawBuffer]: # get dictionary of all possible actions def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Linearizer]: - acted_lins = {0:lin.copy()} if include_0 else {} + acted_lins = {0:lin} if include_0 else {} for i,a in enumerate(actions): if a.axis is not None and a.axis >= lin.shape_len: continue if a.axis is not None and lin.full_shape[a.axis] == a.amt and Opt(a.op, a.axis, 0) in actions: continue @@ -104,7 +103,7 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea # NOTE: real uops use a weird compare method that's only valid inside a linearizer def tuplize_uops(uops): return tuple([(x.uop, x.dtype, tuple(x.num for x in x.vin), x.arg) for x in uops]) - seen_uops = {tuplize_uops(lin.copy().linearize().uops): tuple(lin.applied_opts)} + seen_uops = {tuplize_uops(lin.linearize().uops): tuple(lin.applied_opts)} while 1: acted_lins = lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) @@ -112,7 +111,7 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea # dedup with uops (TODO: double linearize not needed) acted_lins_dedup = [] for lin in acted_lins: - tuops = tuplize_uops(lin.copy().linearize().uops) + tuops = tuplize_uops(lin.linearize().uops) if tuops in seen_uops: #print(seen_uops[tuops], lin.applied_opts) continue diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 0af36ff3e7..5938ec919a 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -169,7 +169,7 @@ def cache_compiled(func): CACHEDB = getenv("CACHEDB", "/tmp/tinygrad_cache") CACHELEVEL = getenv("CACHELEVEL", 2) -VERSION = 4 +VERSION = 5 _db_connection = None def db_connection(): global _db_connection