diff --git a/test/extra/test_utils.py b/test/extra/test_utils.py index 5836a5e327..57c97515ae 100644 --- a/test/extra/test_utils.py +++ b/test/extra/test_utils.py @@ -1,11 +1,12 @@ #!/usr/bin/env python import io import unittest +from tinygrad.helpers import getenv from extra.utils import fetch, fake_torch_load_zipped from PIL import Image -class TestUtils(unittest.TestCase): - @unittest.skip("hangs sometimes") +@unittest.skipIf(getenv("CI", "") != "", "no internet tests in CI") +class TestFetch(unittest.TestCase): def test_fetch_bad_http(self): self.assertRaises(AssertionError, fetch, 'http://httpstat.us/500') self.assertRaises(AssertionError, fetch, 'http://httpstat.us/404') @@ -19,6 +20,7 @@ class TestUtils(unittest.TestCase): pimg = Image.open(io.BytesIO(img)) assert pimg.size == (705, 1024) +class TestUtils(unittest.TestCase): def test_fake_torch_load_zipped(self): import torch import numpy as np @@ -54,6 +56,5 @@ class TestUtils(unittest.TestCase): assert a.dtype == b.dtype assert np.array_equal(a, b) - if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/test/test_dtype.py b/test/test_dtype.py index 2eb5b618f4..dc4766b71a 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -4,9 +4,9 @@ from tinygrad.helpers import getenv from tinygrad.lazy import Device from tinygrad.tensor import Tensor, dtypes -# for GPU, cl_khr_fp16 isn't supported +# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!) # for LLVM, it segfaults because it can't link to the casting function -@unittest.skipIf(getenv("CI", "") != "" and Device.DEFAULT in ["GPU", "LLVM"], "float16 broken in some CI backends") +@unittest.skipIf(getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"], "float16 broken in some CI backends") class TestDtype(unittest.TestCase): def test_half_to_np(self): a = Tensor([1,2,3,4], dtype=dtypes.float16) diff --git a/test/test_ops.py b/test/test_ops.py index 782681f655..a75558c25a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -128,7 +128,7 @@ class TestOps(unittest.TestCase): def test_sqrt(self): helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, a=0) def test_relu(self): - helper_test_op([(45,65)], lambda x: x.relu(), Tensor.relu) + helper_test_op([(64,64)], lambda x: x.relu(), Tensor.relu) def test_leakyrelu(self): helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu) def test_abs(self): @@ -328,11 +328,21 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv2d(x,w).relu(), lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5) + def test_simple_conv2d_m4(self): + helper_test_op([(1,16,18,18), (16,16,3,3)], + lambda x,w: torch.nn.functional.conv2d(x,w).relu(), + lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5) + def test_simple_conv2d_1x1(self): helper_test_op([(1,4,9,9), (4,4,1,1)], lambda x,w: torch.nn.functional.conv2d(x,w).relu(), lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5) + def test_simple_conv2d_1x1_m4(self): + helper_test_op([(1,16,32,32), (16,16,1,1)], + lambda x,w: torch.nn.functional.conv2d(x,w).relu(), + lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5) + def test_nested_conv2d(self): helper_test_op([(1,32,9,9), (32,32,3,3), (32,32,3,3)], lambda x,w1,w2: torch.nn.functional.conv2d(torch.nn.functional.conv2d(x,w1).relu(), w2).relu(), diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index 5129fd1a99..06d4847276 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -3,7 +3,7 @@ import unittest import numpy as np from tinygrad.helpers import prod, all_same from tinygrad.shape.shapetracker import ShapeTracker, View, ZeroView, merge_views, get_contraction -from tinygrad.codegen.gpu import to_image_idx +from tinygrad.codegen.cstyle import to_image_idx def shapetracker_getitem(st, val): locals = {"idx": val, "valid": 1} diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index 268c44d347..32ea338b61 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -24,6 +24,19 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(Variable("a", 3, 8)<3, 0, 0, "0") self.helper_test_variable(Variable("a", 3, 8)<2, 0, 0, "0") + def test_ge_divides(self): + expr = (Variable("idx", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512 + self.helper_test_variable(expr, 0, 1, "(((idx*4)+FLOAT4_INDEX)<512)") + self.helper_test_variable(expr//4, 0, 1, "(idx<128)") + + def test_ge_divides_and(self): + expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512, + (Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512]) + self.helper_test_variable(expr//4, 0, 1, "((idx1<128) and (idx2<128))") + expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512, + (Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7)) < 512]) + self.helper_test_variable(expr//4, 0, 1, "((((FLOAT8_INDEX//4)+idx2)<128) and (idx1<128))") + def test_div_becomes_num(self): assert isinstance(Variable("a", 2, 3)//2, NumNode) diff --git a/tinygrad/codegen/ast.py b/tinygrad/codegen/ast.py deleted file mode 100644 index 601798f007..0000000000 --- a/tinygrad/codegen/ast.py +++ /dev/null @@ -1,167 +0,0 @@ -import itertools -from enum import Enum, auto -from typing import List, Tuple -from tinygrad.helpers import prod, dedup, all_same, colored, DType -from tinygrad.ops import LazyOp, MovementOps, get_buffers, ReduceOps, get_lazyops, map_buffers, ASTRunner, get_lazyop_info, FlopCounter -from tinygrad.shape.shapetracker import ShapeTracker, View, strides_for_shape - -def get_first_reduce(shapes): - for i in range(len(shapes[0])): - if not all_same([x[i] for x in shapes]): return i - return len(shapes[0]) # off the end - -# this will be removed soon anyway -class Types(Enum): FLOAT = auto(); FLOAT4 = auto() # noqa: E702 -class Token: - def __init__(self, tok:str, typ:Types, ptr:bool=False): - assert isinstance(tok, str) - self.tok, self.typ, self.ptr = tok, typ, ptr - self.axis: List[Tuple[int, int, bool]] = [] - def array(self, length, stride, reduce): self.axis.append((length, stride, reduce)) - def size(self): return prod([x[0] for x in self.axis]) - def offsets(self): return [sum(t) for t in itertools.product(*[[y*x[1] for y in range(x[0])] for x in self.axis[::-1]])] if len(self.axis) else [0] - def can_float4(self): return any(a[0:2] == (4,1) for a in self.axis) - # TODO: this is sort of a hack, it gets the accumulator indices - def acc_offsets(self): - if len(self.axis) == 0: return [0] - acc_strides = [x*(1-self.axis[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in self.axis[::-1])))] - return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(self.axis[::-1])])] - def decltype(self, dtype:DType): return (dtype.name if self.typ == Types.FLOAT else f'{dtype.name}4') + ('*' if self.ptr else str()) - def __repr__(self): return f"<{self.typ}{'*' if self.ptr else str()} {self.tok}{f'[{self.axis}]' if len(self.axis) else str()}>" - -# ast kernel can contain one ReduceOp with arbitrary Binary/Unary ops -class ASTKernel: - def __init__(self, ast:LazyOp, output_buffer=None): - # NOTE: if there's a RESHAPE, we skip it. the output shape is set from the reduce op or a latebuf - if ast.op == MovementOps.RESHAPE: ast = ast.src[0] - - self.bufs = [output_buffer] + dedup(get_buffers(ast)) - self.ast = ast - - # key for lookup in cache (can change, str might not be right) - # bufs are needed because kernels like f(x) = x + x and f(x, y) = x + y have the same str(ast), but are different kernels. - # mapping the buffers to integers is required because a-b != b-a (and how would you tell a and b apart?) - self.key = f"ASTKernelKey ast={str(map_buffers({x:i for i,x in enumerate(self.bufs)}, ast))} bufs={self.bufs}" - - def process(self) -> None: - if hasattr(self, "sts"): return # already processed - - # fetch lazyop info - self.info: FlopCounter = get_lazyop_info(self.ast) - - reduceops = [x for x in get_lazyops(self.ast) if x.op in ReduceOps] - assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast" - self.reduceop = reduceops[0] if reduceops else None - self.earlybufs = dedup(get_buffers(self.reduceop)) if self.reduceop else [] - - self.buftokens = [Token(f"data{i}", Types.FLOAT, ptr=True) for i in range(len(self.bufs))] - self.group_for_reduce: List[int] = [] - - # check valid AST kernel - assert all_same([x.shape for x in self.earlybufs]), "all earlybufs must have the same shape" - assert all_same([x.shape for x in self.bufs[1:] if x not in self.earlybufs]), "all latebufs must have the same shape" - assert all_same([len(x.shape) for x in self.bufs[1:]]), "all bufs must have the same shape size" - - # get full shape buf index (earlybufs if there are any, otherwise output) - self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if len(self.earlybufs) > 0 else 0 - - # process - self.sts: List[ShapeTracker] = [x.st.copy() for x in self.bufs] # create new shapetrackers inside this kernel - for st in self.sts: st.simplify() - - # make the output buffer shape correct in here - self.sts[0].reshape(self.info.shape) - - # move all reduce axes to the end - reduce = list(enumerate(zip(self.full_shape, self.sts[0].shape))) - permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n]) - self.reshape_and_permute(None, permute) - - # simplify - self.simplify_ones() - self.simplify_merge_adjacent() - - def printbufs(self, prefix="", print_shapetrackers=False): - print(f"first_reduce: {self.first_reduce} shape_len: {self.shape_len} group_for_reduce: {self.group_for_reduce}") - if print_shapetrackers: - for st in self.sts: print(st) - for i in range(len(self.sts)): - print(prefix, self.bufs[i].dtype if self.bufs[i] is not None else None, self.buftokens[i], f"early:{'T' if i < len(self.bufs) and self.bufs[i] in self.earlybufs else 'F'}", self.sts[i].shape, self.sts[i].views[-1].strides, len(self.sts[i].views), self.bufs[i].realized if self.bufs[i] is not None else "FAKE") - - def codegen(self) -> ASTRunner: raise NotImplementedError("need a codegen") - - @property - def shape_len(self) -> int: return len(self.sts[0].shape) - - @property - def full_shape(self) -> Tuple[int, ...]: return self.sts[self.full_buf_index].shape - - @property - def upcast_in_mid_reduce_axes(self) -> List[int]: return [j for j in range(self.first_reduce, self.first_reduce+len(self.group_for_reduce)) if self.full_shape[j] == self.sts[0].shape[j]] - - def colorshape(self, pad=50) -> str: - axis = [(f"{rs:4d}", (("green" if i in self.upcast_in_mid_reduce_axes else "cyan") if i < self.first_reduce + len(self.group_for_reduce) else "red") if i >= self.first_reduce else "blue") for i, rs in enumerate(self.full_shape)] - axis += [(f"{s:4d}", 'magenta' if reduce else 'yellow') for s, _, reduce in self.buftokens[self.full_buf_index].axis[::-1]] - return ' '.join([colored(*x) for x in axis])+(" "*(pad-len(' '.join([x[0] for x in axis])))) - - def simplify_ones(self): - # remove places where the shape is all ones - # TODO: this should be factored in to multi shape stride - all_ones = [all(st.shape[i]==1 for st in self.sts) for i in range(self.shape_len)] - # keep at least 1 one - if all(all_ones): all_ones[-1] = False - self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None) - # find first mismatch, don't reduce this - self.first_reduce = get_first_reduce([x.shape for x in self.sts]) - - def simplify_merge_adjacent(self): - shapes, strides = [x.shape for x in self.sts], [x.views[-1].strides for x in self.sts] - - # merge dimensions if we can, multi get_shape_strides - # TODO: does this always preserve the reduce dimension, NO - # TODO: move this into shapetracker, with tests! - rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))] - for i in range(1, len(shapes[0])): - can_merge = [] - for j in range(len(shapes)): - # TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case - can_merge.append((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*strides[j][i]) or (strides[j][i] == 0 and rets[j][-1][1] == 0)) - # more can merge than this - mergeable = all(can_merge) and i != self.first_reduce - for j in range(len(shapes)): - if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i]) - else: rets[j].append((shapes[j][i], strides[j][i])) - - for i,x in enumerate(rets): self.sts[i].reshape(tuple(y[0] for y in x)) - self.first_reduce = get_first_reduce([x.shape for x in self.sts]) - - # this should be aware of the three parts to the shape - # * the input/output dimensions - # * the reduce dimensions - # * the size outputted by each kernel - def reshape_and_permute(self, new_shape_fxn, axis): - for st in self.sts: - if new_shape_fxn is not None: st.reshape(tuple(new_shape_fxn(st.shape))) - if axis is not None: st.permute(tuple(axis)) - - # axis : the axis to pull from - # amount : the amount to take - # top : if you want to pull that amount from the top - # insert_before : place to insert the new stuff - def shift_to(self, axis, amount, top=False, insert_before=None): - if insert_before is None: insert_before = self.shape_len - move_axis = axis if top else axis+1 - if move_axis < insert_before: insert_before += 1 - self.reshape_and_permute( - lambda x: list(x[0:axis]) + (([amount, x[axis]//amount] if top else [x[axis]//amount, amount]) if x[axis] > 1 else [1,1]) + list(x[axis+1:]), - [i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis]) - - # drops the final dimension - def upcast(self): - upcasted = [x.shape[-1] for x in self.sts if x.shape[-1] != 1] - assert len(upcasted) >= 1 and all_same(upcasted), f"can't upcast mismatch {upcasted}" - for st,buftoken in zip(self.sts, self.buftokens): - # add last axis to the buftoken (if it's not a 1) - if st.shape[-1] == upcasted[0]: buftoken.array(st.shape[-1], st.views[-1].strides[-1], len(upcasted) != len(self.sts)) - # remove the last axis (unless it's the only dimension, then make it a 1) - st.views[-1] = View(st.shape[0:-1], st.views[-1].strides[0:-1], st.views[-1].offset) if len(st.shape) > 1 else View((1,), (0,), st.views[-1].offset) diff --git a/tinygrad/codegen/cstyle.py b/tinygrad/codegen/cstyle.py new file mode 100644 index 0000000000..4fa2d35d77 --- /dev/null +++ b/tinygrad/codegen/cstyle.py @@ -0,0 +1,210 @@ +from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set +import math, collections +from tinygrad.codegen.linearizer import Linearizer, UOps +from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps +from tinygrad.helpers import getenv, all_same, partition, ImageDType, DEBUG, dtypes +from tinygrad.runtime.lib import RawConst +from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, Node, SumNode, MulNode + +# div is different in cl than python +render_cl = render_python.copy() +render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops, ctx)}/{self.b})" +render_cl[AndNode] = lambda self,ops,ctx: f"({'&&'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})" + +NATIVE_EXPLOG = getenv("NATIVE_EXPLOG", 0) # this is needed as a switch for the tests to pass + +class CStyleLanguage(NamedTuple): + kernel_prefix: str = "" + buffer_prefix: str = "" + buffer_suffix: str = "" + smem_prefix: str = "" + barrier: str = "" + gid: List[str] = [] + lid: List[str] = [] + extra_args: List[str] = [] + float4: Optional[str] = None + half_prekernel: Optional[str] = None + uses_vload: bool = False + +def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=False) -> Tuple[Node, Node]: + idy = (idxy//(4*base_shape[1])) + if validhacks and valid.min == 0: + idx = (idxy//4) + (idy*-base_shape[1]) + # find the ones in idx that didn't factorize and remove them (TODO: this is not universal) + if isinstance(idx, SumNode): + unfactored, idx_nodes = partition(idx.nodes, lambda x: isinstance(x, MulNode) and x.b == -base_shape[1]) + assert len(unfactored) <= 1 + idx = Variable.sum(idx_nodes) + unfactored = (Variable.sum(unfactored) // base_shape[1]) + idy += unfactored + # ugh really...handtuned garbage + if idx.min >= (base_shape[1]*3)//4: + idx -= base_shape[1] + idy += 1 + else: + idx = (idxy//4)%base_shape[1] + if DEBUG >= 5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy) + return idx, idy + +class CStyleCodegen(Linearizer): + lang: ClassVar[CStyleLanguage] = CStyleLanguage() + supports_constant_folding: bool = True + supports_float4: bool = True + + # for renaming + kernel_cnt: Final[DefaultDict[str, int]] = collections.defaultdict(int) + kernel_name_cache: Final[Dict[str, str]] = {} + + code_for_op: Final[Dict[Op, Callable]] = { + UnaryOps.EXP: lambda x: f"native_exp({x})" if NATIVE_EXPLOG else f"exp({x})", + UnaryOps.LOG: lambda x: f"native_log({x})" if NATIVE_EXPLOG else f"log({x})", + BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})", + BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})", + BinaryOps.POW: lambda a,b: f"pow({a},{b})", BinaryOps.MAX: lambda a,b: f"max({a},{b})", + BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({b}*{c})+{a})" + } + + def group_float4(self, grp:List[str]) -> str: + if all(g.endswith(e) for g,e in zip(grp, [".x", ".y", ".z", ".w"])) and all_same([g.split(".")[0] for g in grp]): return grp[0].split(".")[0] + else: return f"{self.lang.float4}({','.join(g for g in grp)})" + + def codegen(self): + self.process() + + # sometimes, there's more dimensions than len(self.lang.gid). + # compact all the dimensions into the first + # NOTE: this might make multiview shapetrackers + # NOTE: you ABSOLUTELY must do this before upcasting. the strides on the upcast are wrong if you don't + # TODO: this exposes bugs in the optimizers assuming the strides are on a single view + """ + if len(self.lang.gid) and self.first_reduce > len(self.lang.gid): + num_to_merge = (self.first_reduce - len(self.lang.gid))+1 + self.reshape_and_permute(lambda x: (prod(x[0:num_to_merge]),)+x[num_to_merge:], None) + if DEBUG >= 4: print("reshaped to", self.full_shape, "due to too many global dimensions") + """ + + self.hand_coded_optimizations() + self.linearize() + + prekernel: Set[str] = set() + kernel = [] + global_size = [] + local_size = [] + pend_close = None + + depth = 0 + def kk(s): kernel.append(" "*depth+s) + + for uop,newvar,args in self.uops: + if uop == UOps.LOOP: + root = None + for i,var in enumerate(args[0]): + if isinstance(var, NumNode): + if args[1] == "global" and self.lang.gid: global_size.append(1) + if args[1] == "local" and self.lang.lid: local_size.append(1) + # one number, not an index + kk("{") + else: + if args[1] == "global" and self.lang.gid: + if len(args[0]) >= 4 and len(args[0])-i > 2: + # sometimes, there's more dimensions. compact all the dimensions into the last CL dimension + # TODO: these compactions should be searchable (they sort of are with reshapes and permutes) + if i == 0: + kk(f"{{ int {var.expr} = {self.lang.gid[-1]}; /* {var.max+1} */") + root = var.expr + global_size.append(var.max+1) + else: + kk(f"{{ int {var.expr} = {root} % {var.max+1}; {root} /= {var.max+1};") + global_size[-1] *= var.max+1 + else: + kk(f"{{ int {var.expr} = {self.lang.gid[len(args[0])-1-i]}; /* {var.max+1} */") + global_size.append(var.max+1) + elif args[1] == "local" and self.lang.lid: + assert len(args[0]) <= len(self.lang.lid) + kk(f"{{ int {var.expr} = {self.lang.lid[len(args[0])-1-i]}; /* {var.max+1} */") + local_size.append(var.max+1) + else: + kk(f"for (int {var.expr} = {var.min}; {var.expr} <= {var.max}; ++{var.expr}) {{") + depth += 1 + if uop == UOps.ENDLOOP: + if args[1] == "local" and len(self.lang.lid): + # TODO: this is a bit of a hack. the local loop isn't real on the GPU + kk(self.lang.barrier) + kk(f"if ({Variable.sum(args[0]).render(render_cl)} == 0) {{") + pend_close = "}"*(len(args[0])+1) + f" /* {args[1]} */" + else: + if args[1] == "global" and pend_close: + depth -= 1 + kk(pend_close) + pend_close = None + depth -= 1 + kk("}"*len(args[0]) + f" /* {args[1]} */") + if uop == UOps.CONST: + if args[0] == -math.inf: + kk(f"float {newvar} = -INFINITY;") + else: + kk(f"float {newvar} = {args[0]}f;") + if uop == UOps.ALU: + if newvar is None: + kk(f"{args[2]} = {self.code_for_op[args[0]](*args[1])};") + else: + kk(f"float {newvar} = {self.code_for_op[args[0]](*args[1])};") + # TODO: refactor the next 14 lines + if uop == UOps.LOAD: + # TODO: merge with CONST? + if self.bufs[args[0]] is not None and isinstance(self.bufs[args[0]].realized, RawConst): + # nan? inf? + val = f"{self.bufs[args[0]].realized._buf}f" + else: + if self.lang.uses_vload and self.bufs[args[0]] is not None and self.bufs[args[0]].dtype == dtypes.float16: + val = f"vload_half({args[1].render(render_cl)}, {self.registers[args[0]].name})" + else: + val = f"{self.registers[args[0]].name}[{args[1].render(render_cl)}]" + # NOTE: if min and max are both 0, it should be a CONST in the Linearizer + if args[2].min == 1: kk(f"float {newvar} = {val};") + else: kk(f"float {newvar} = ({args[2].render(render_cl)}) ? ({val}) : 0.0f;") + if uop == UOps.LOAD4: + if self.bufs[args[0]] is not None and isinstance(self.bufs[args[0]].dtype, ImageDType): + prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n") + idx, idy = to_image_idx(self.bufs[args[0]].dtype.shape, args[1], args[2]) + val = f"read_imagef({self.registers[args[0]].name}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}))" + else: + val = f"(({self.lang.buffer_prefix if self.bufs[args[0]] is not None else self.lang.smem_prefix}float4*){self.registers[args[0]].name})[{(args[1]//4).render(render_cl)}]" + # NOTE: if min and max are both 0, it should be a CONST in the Linearizer + if args[2].min == 1: kk(f"float4 {newvar} = {val};") + else: kk(f"float4 {newvar} = ({args[2].render(render_cl)}) ? ({val}) : {self.group_float4(['0.0f']*4)};") + if uop == UOps.STORE: + assert args[2].min == 1, "store must be valid" + if self.lang.uses_vload and self.bufs[args[0]] is not None and self.bufs[args[0]].dtype == dtypes.float16: + kk(f"vstore_half({args[3]}, {args[1].render(render_cl)}, {self.registers[args[0]].name});") + else: + kk(f"{self.registers[args[0]].name}[{args[1].render(render_cl)}] = {args[3]};") + if uop == UOps.STORE4: + assert args[2].min == 1, "store must be valid" + if self.bufs[args[0]] is not None and isinstance(self.bufs[args[0]].dtype, ImageDType): + idx, idy = to_image_idx(self.bufs[args[0]].dtype.shape, args[1], args[2]) + kk(f"write_imagef({self.registers[args[0]].name}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {self.group_float4(args[3])});") + else: + kk(f"(({self.lang.buffer_prefix if self.bufs[args[0]] is not None else self.lang.smem_prefix}float4*){self.registers[args[0]].name})[{(args[1]//4).render(render_cl)}] = {self.group_float4(args[3])};") + if uop == UOps.DEFINE_LOCAL: + kk(self.lang.smem_prefix + f"float {args[0]}[{args[1]}];") + + buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else self.lang.buffer_prefix+x.dtype.name+"*"+self.lang.buffer_suffix) for i,x in enumerate(self.bufs) if x is not None and not isinstance(x.realized, RawConst)] + prg = ''.join([f"{self.lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] + + [', '.join([f'{"const" if i > 0 else ""} {t} data{i}' for i,t in buftypes] + self.lang.extra_args)] + + [") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"]) + + # if we have local_sizes, we have to correct the global_size + for i,s in enumerate(local_size): global_size[i] *= s + + # painfully name the function something unique + function_name = self.function_name + if prg in CStyleCodegen.kernel_name_cache: function_name = CStyleCodegen.kernel_name_cache[prg] + else: + CStyleCodegen.kernel_cnt[function_name] += 1 + if CStyleCodegen.kernel_cnt[function_name] > 1: function_name = f"{function_name}{'n'+str(CStyleCodegen.kernel_cnt[function_name]-1)}" + CStyleCodegen.kernel_name_cache[prg] = function_name + + return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), + global_size[::-1] if len(global_size) else [1], local_size[::-1] if len(local_size) else None, + op_estimate=self.info.flops, mem_estimate=self.mem_estimate) diff --git a/tinygrad/codegen/gpu.py b/tinygrad/codegen/gpu.py deleted file mode 100644 index a2f3f5109d..0000000000 --- a/tinygrad/codegen/gpu.py +++ /dev/null @@ -1,357 +0,0 @@ -import math, itertools -from collections import defaultdict -from typing import Optional, List, Tuple, Dict, Set, Final, NamedTuple, ClassVar, DefaultDict -from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, ASTRunner -from tinygrad.codegen.ast import ASTKernel, Token, Types -from tinygrad.shape.symbolic import Node, MulNode, DivNode, SumNode, AndNode, ModNode, Variable, render_python -from tinygrad.shape.shapetracker import ShapeTracker, View -from tinygrad.helpers import getenv, DEBUG, prod, partition, mnum, all_same, dedup, dtypes -from tinygrad.runtime.lib import RawConst - -# div is different in cl than python -render_cl = render_python.copy() -render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops, ctx)}/{self.b})" -render_cl[AndNode] = lambda self,ops,ctx: f"({'&&'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})" - -VALIDHACKS = getenv("VALIDHACKS", 0) # TODO: remove the need for this -NATIVE_EXPLOG = getenv("NATIVE_EXPLOG", 0) # this is needed as a switch for the tests to pass - -class GPULanguage(NamedTuple): - kernel_prefix: str = "" - buffer_prefix: str = "" - buffer_suffix: str = "" - smem_prefix: str = "" - barrier: str = "" - gid: List[str] = [] - lid: List[str] = [] - extra_args: List[str] = [] - float4: Optional[str] = None - half_prekernel: Optional[str] = None - -def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=False) -> Tuple[Node, Node]: - idy = (idxy//(4*base_shape[1])) - if validhacks and valid.min == 0: - idx = (idxy//4) + (idy*-base_shape[1]) - # find the ones in idx that didn't factorize and remove them (TODO: this is not universal) - if isinstance(idx, SumNode): - unfactored, idx_nodes = partition(idx.nodes, lambda x: isinstance(x, MulNode) and x.b == -base_shape[1]) - assert len(unfactored) <= 1 - idx = Variable.sum(idx_nodes) - unfactored = (Variable.sum(unfactored) // base_shape[1]) - idy += unfactored - # ugh really...handtuned garbage - if idx.min >= (base_shape[1]*3)//4: - idx -= base_shape[1] - idy += 1 - else: - idx = (idxy//4)%base_shape[1] - #print(base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy) - return idx, idy - -class GPUCodegen(ASTKernel): - lang: ClassVar[GPULanguage] = GPULanguage() - supports_constant_folding: bool = True - - # for renaming - kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int) - kernel_name_cache: Final[Dict[str, str]] = {} - - code_for_op: Final[Dict[Op, str]] = { - UnaryOps.NOOP: "(A)", UnaryOps.CAST: "(A)", - UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)", - UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)", - BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", - BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)", - BinaryOps.MAX: "max(A,B)", ReduceOps.SUM: "A+=B", ReduceOps.MAX: "A=max(A,B)" - } - start_for_op: Final[Dict[Op, str]] = {ReduceOps.SUM: "0.0f", ReduceOps.MAX: "-INFINITY"} - - def group_float4(self, grp:List[Token]) -> Token: - if all(g.tok.endswith(e) for g,e in zip(grp, [".x", ".y", ".z", ".w"])) and all_same([g.tok.split(".")[0] for g in grp]): return Token(grp[0].tok.split(".")[0], Types.FLOAT4) - else: return Token(f"{self.lang.float4}({','.join(g.tok for g in grp)})", Types.FLOAT4) - - def store(self, buf_index:int, value:List[Token]) -> None: - assert len(value) == self.buftokens[buf_index].size(), f"size mismatch {len(value)} != {self.buftokens[buf_index].size()}" - assert len(self.sts[buf_index].views) == 1, "store has more than one view" - - # all stores can merge, since they have one view and are valid - should_upcast = self.lang.float4 and self.buftokens[buf_index].can_float4() and (self.bufs[buf_index] is None or self.bufs[buf_index].dtype != dtypes.float16 or self.bufs[buf_index].dtype.name.startswith('image')) - - to_store = {o:v for o,v in zip(self.buftokens[buf_index].offsets(), value)} - did_store = set() - for o,v in to_store.items(): - if o in did_store: continue - idxy, valid = self.sts[buf_index].expr_idxs(o) - assert valid.min == 1, "store must always be valid" - if should_upcast: - for j in range(4): did_store.add(o+j) - v = self.group_float4([to_store[o+j] for j in range(4)]) - if self.bufs[buf_index] is not None and self.bufs[buf_index].dtype.name.startswith('image'): - assert v.typ == Types.FLOAT4, "Image requires upcasting to FLOAT4" - idx, idy = to_image_idx(self.bufs[buf_index].dtype.shape, idxy, valid) - self.kernel.append(f"write_imagef({self.buftokens[buf_index].tok}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {v.tok}); /* {self.bufs[buf_index].dtype.shape} */\n") - elif v.typ == Types.FLOAT4: - self.kernel.append(f"(({self.lang.buffer_prefix if self.bufs[buf_index] is not None else self.lang.smem_prefix}float4*){self.buftokens[buf_index].tok})[{(idxy//4).render(render_cl)}] = {v.tok};\n") - else: - self.kernel.append(f"{self.buftokens[buf_index].tok}[{(idxy//(4 if v.typ == Types.FLOAT4 else 1)).render(render_cl)}] = {v.tok};\n") - - def load(self, buf_index:int, idx_override:Optional[str]=None) -> List[Token]: - # constant folding - const = None - if self.bufs[buf_index] is not None and isinstance(self.bufs[buf_index].realized, RawConst): - val = self.bufs[buf_index].realized._buf - assert not math.isnan(val) - const = Token(f"({val}f)", Types.FLOAT) - - def check_no_mul(test, var): - if test == var: return True - if isinstance(test, SumNode): return any(check_no_mul(x, var) for x in test.nodes) # in a sum is okay - if isinstance(test, ModNode) and test.b%4 == 0: return check_no_mul(test.a, var) # removing a mod is okay - return False - - is_image = self.bufs[buf_index] is not None and self.bufs[buf_index].dtype.name.startswith('image') - should_upcast = self.lang.float4 and const is None and self.buftokens[buf_index].can_float4() and (self.bufs[buf_index] is None or self.bufs[buf_index].dtype != dtypes.float16 or self.bufs[buf_index].dtype.name.startswith('image')) - tokens = [] - test_idy = [] - for o in self.buftokens[buf_index].offsets(): - key = f"val{mnum(buf_index)}_{mnum(o)}" - if (buf_index, o) not in self.loaded_keys: - idxy, valid = self.sts[buf_index].expr_idxs(o) if idx_override is None else self.sts[buf_index].expr_node(idx_override, o) - if should_upcast: - float4_index = Variable("FLOAT4_INDEX", 0, 3) - idxy_test, valid_test = self.sts[buf_index].expr_idxs(float4_index+o) if idx_override is None else self.sts[buf_index].expr_node(idx_override, float4_index+o) - can_merge = check_no_mul(idxy_test, float4_index) - # NOTE: valid_test.render() can contain a FLOAT4_INDEX that can't affect the result: example <(((idx0<0,511>*4)+FLOAT4_INDEX<0,3>)<1024)> - can_merge = can_merge and "FLOAT4_INDEX" not in (idxy_test//4).render() and ("FLOAT4_INDEX" not in valid_test.render() or True) # float4_index must not be in after divide or in valid (TODO: don't check render) - if const is not None: - ldr = const - elif self.bufs[buf_index] is not None and is_image: - assert should_upcast and can_merge, f"Image requires upcasting to FLOAT4 {self.buftokens[buf_index]} should_upcast:{should_upcast} can_merge:{can_merge}" - idx, idy = to_image_idx(self.bufs[buf_index].dtype.shape, idxy, valid, VALIDHACKS) - ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)})) /* {self.bufs[buf_index].dtype.shape} */", Types.FLOAT4) - test_idy.append(idy.render(render_cl)) - elif should_upcast and can_merge: - ldr = Token(f"(({self.lang.buffer_prefix if self.bufs[buf_index] is not None else self.lang.smem_prefix}float4*){self.buftokens[buf_index].tok})[{(idxy//4).render(render_cl)}]", Types.FLOAT4) - else: - ldr = Token(f"{self.buftokens[buf_index].tok}[{idxy.render(render_cl)}]", Types.FLOAT) - invalid = self.group_float4([Token("0.0f", Types.FLOAT)]*4) if ldr.typ == Types.FLOAT4 else Token("0.0f", Types.FLOAT) - ldr = ldr if valid.min == 1 or (VALIDHACKS and is_image) else (Token(f"({valid.render(render_cl)} ? {ldr.tok} : {invalid.tok})", ldr.typ) if valid.max == 1 else invalid) - if const is not None: - self.loaded_keys[(buf_index,o)] = ldr - else: - # NOTE: we always do compute in float32 - self.kernel.append(f"{ldr.decltype(dtypes.float32)} {key} = {ldr.tok};\n") - if should_upcast and can_merge: - for j in range(4): - self.loaded_keys[(buf_index,o+j)] = Token(key+f'.{"xyzw"[j]}', Types.FLOAT) - else: - self.loaded_keys[(buf_index,o)] = Token(key, Types.FLOAT) - tokens.append(self.loaded_keys[(buf_index,o)]) - assert not VALIDHACKS or all_same(test_idy), f"idy changed! {test_idy}" - return tokens - - def ast_parse(self, x, acc:List[Token], do_reduce=False) -> List[Token]: - if not isinstance(x, LazyOp): return self.load(self.bufs.index(x), "mid" if x is None else None) # hack for local - if isinstance(x.op, ReduceOps) and not do_reduce: return acc - values: List[List[Token]] = ([acc] if isinstance(x.op, ReduceOps) else []) + [self.ast_parse(v, acc, do_reduce) for v in x.src] - code = GPUCodegen.code_for_op[x.op] # TODO: replace this with a function - if len(values) == 2: - assert len(values[0]) == len(values[1]) and values[0][0].typ == values[1][0].typ, f"values mismatch {values}" - return [Token(code.replace("A", a.tok).replace("B", b.tok), a.typ) for a,b in zip(values[0], values[1])] - else: - return [Token(code.replace("A", a.tok), a.typ) for a in values[0]] - - def required_optimizations(self, early_only=False): - for buf_index,buf in enumerate(self.bufs): - upcast_strides = [self.sts[buf_index].strides[i] for i in self.upcast_in_mid_reduce_axes] - if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.name.startswith('image') and not (self.buftokens[buf_index].can_float4() or (buf not in self.earlybufs and (1 in upcast_strides))): - axes = [i for i,x in enumerate(self.sts[buf_index].strides) if x == 1] - assert len(axes) == 1, f"wrong number of stride 1 axis : {axes} on buf_index {buf_index}, {self.sts[buf_index]}" - assert self.sts[buf_index].shape[axes[0]]%4 == 0, f"axis:{axes[0]} in buffer {buf_index} is not a multiple of 4, {self.sts[buf_index].shape}" - self.shift_to(axes[0], 4) - self.upcast() - assert self.buftokens[buf_index].can_float4() - - def hand_coded_optimizations(self): - # if there's images in the earlybufs, we have to make an axis the 4 loading one - self.required_optimizations(early_only=True) - - # simplify (sets first_reduce) - self.simplify_ones() - - # are we grouping? (requires local shape support) - if len(self.lang.lid) and not self.buftokens[0].can_float4() and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: - # TODO: use 1024 if it's allowed in a smarter way - for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]): - if all([st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts]): - self.shift_to(self.first_reduce, sz, top=True, insert_before=self.first_reduce) - self.group_for_reduce.append(sz) - break - - # are we upcasting in mid reduce? (only for images) - if self.bufs[0].dtype.name.startswith('image') and not self.buftokens[0].can_float4() and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: - axes = [i for i,x in enumerate(self.sts[0].strides) if x == 1] - assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}" - if self.sts[0].shape[axes[0]]%4 == 0: - self.shift_to(axes[0], 4, insert_before=self.first_reduce + len(self.group_for_reduce)) # insert at the end of the grouped axis - self.group_for_reduce.append(4) - - # now do everything required - self.required_optimizations() - - # simplify (sets first_reduce) - self.simplify_ones() - - # use more opencl indexing if the output buffer is an image and we have room - if self.bufs[0].dtype.name.startswith('image') and self.first_reduce+len(self.group_for_reduce) < 3: - base_shape = self.bufs[0].dtype.shape - if (base_shape[0]*base_shape[1]) % self.sts[0].shape[0] == 0 and self.sts[0].shape[0]//base_shape[0] != 0: - if DEBUG >= 4: print("split opencl", base_shape, self.sts[0].shape) - self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None) - self.simplify_ones() - - # no more opt if we are grouping - if self.group_for_reduce: return - - # **** below this line need to be optional and benchmarked **** - - # potentially do more upcasts of non reduce axes based on a heuristic - while prod(self.sts[0].shape[:self.first_reduce]) >= 1024: - xb_choices = [] - for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce - # if it mods, and some buffer has stride 0 on axis while having no stride 0 in the buftoken - if self.full_shape[axis]%upcast_amount == 0 and any(self.sts[buf_index].strides[axis] == 0 and not any(x[1] == 0 for x in self.buftokens[buf_index].axis) for buf_index in range(len(self.sts))): - xb_choices.append((sum(st.strides[axis]>0 for st in self.sts), sum(st.strides[axis] for st in self.sts), axis, upcast_amount)) - if len(xb_choices): - xb_choices = sorted(xb_choices) - if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}") - self.shift_to(xb_choices[0][2], amount=xb_choices[0][3]) - self.upcast() - self.simplify_ones() - else: - break - - # if last dim <= 5 and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS - if self.first_reduce < self.shape_len and self.full_shape[-1] <= 5 and (max([x.size() for i,x in enumerate(self.buftokens) if self.bufs[i] in self.earlybufs]) <= 4 or not any(r for _,_,r in self.buftokens[self.full_buf_index].axis)): - self.upcast() - - def get_accumulators(self, name="acc") -> List[Token]: - assert self.reduceop is not None, "no accumulators if you aren't reducing" - should_upcast = self.lang.float4 and self.buftokens[0].can_float4() - accumulators = [Token(f"{name}{i//4}.{'xyzw'[i%4]}" if should_upcast else f"{name}{i}", self.buftokens[0].typ) for i in self.buftokens[0].offsets()] - if should_upcast: - self.kernel += [f"float4 {tok} = {self.group_float4([Token(GPUCodegen.start_for_op[self.reduceop.op], Types.FLOAT)]*4).tok};\n" for tok in dedup([x.tok.split('.')[0] for x in accumulators])] - else: - self.kernel += [f"float {x.tok} = {GPUCodegen.start_for_op[self.reduceop.op]};\n" for x in accumulators] - return accumulators - - # STOP WASTING TIME WITH DOING THE RESHAPES AND PERMUTES BY HAND. KERNEL SEARCH IS THE ONLY WAY IT WILL EVER BE GOOD - # group_for_reduce will have to be better first - def codegen(self) -> ASTRunner: - self.process() - if DEBUG >= 4: self.printbufs("old:", DEBUG>=5) - - self.hand_coded_optimizations() - - # fancy colored shape printer - if DEBUG >= 3: print(self.colorshape(), end="") - - # add a local buffer for multistage reduce - if len(self.group_for_reduce): - self.bufs.append(None) - # TODO: the strides of this can be controlled - st = ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.buftokens[0].axis])) - buftoken = Token("temp", Types.FLOAT, ptr=True) - # manual upcast of the local - for _,_,r in self.buftokens[0].axis[::-1]: - buftoken.array(st.shape[-1], st.views[-1].strides[-1], r) - st.views[-1] = View(st.shape[0:-1], st.views[-1].strides[0:-1], st.views[-1].offset) - self.sts.append(st) - self.buftokens.append(buftoken) - - self.output_shape: Tuple[int, ...] = self.sts[0].shape[:self.first_reduce] + tuple(self.group_for_reduce) - assert self.full_shape[:len(self.output_shape)] == self.output_shape, f"output shape mismatch : {self.full_shape[:len(self.output_shape)]} != {self.output_shape}" - if DEBUG >= 4: - print("output shape", self.output_shape) - self.printbufs("new:", DEBUG>=5) - - self.loaded_keys: Dict[Tuple[int,int], Token] = {} - self.prekernel: Set[str] = set() - self.kernel: List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(buf.dtype.name.startswith("image") for buf in self.bufs if buf is not None) else [] - - if self.lang.half_prekernel and any(x.dtype == dtypes.float16 for x in self.bufs if x is not None): self.prekernel.add(self.lang.half_prekernel+"\n") - - if len(self.lang.gid) == 0: - self.kernel += [f"for (int idx{i} = 0; idx{i} < {self.output_shape[i]}; idx{i}++) {{\n" for i in range(0, len(self.output_shape))] - else: - # output_shape[-1] is get_global_id(0) - self.kernel += [f"int idx{len(self.output_shape)-1-i} = {self.lang.gid[i]}; /* {self.output_shape[-1-i]} */\n" for i in range(min(len(self.lang.gid), len(self.output_shape))) if self.output_shape[-1-i] != 1] - if len(self.output_shape) > len(self.lang.gid): - # sometimes, there's more dimensions. compact all the dimensions into the first one - # TODO: these compactions should be searchable (they sort of are with reshapes and permutes) - final_dimension = len(self.output_shape)-len(self.lang.gid) - for i in range(final_dimension-1, -1, -1): - self.kernel += [f"int idx{i} = idx{final_dimension} % {self.output_shape[i]};", f"idx{final_dimension} = idx{final_dimension} / {self.output_shape[i]};\n"] - self.output_shape = (prod(self.output_shape[0:final_dimension+1]), ) + self.output_shape[final_dimension+1:] - if DEBUG >= 4: print(f"replaced output shape with {self.output_shape}") - - # early ast - accumulators: List[Token] = [] - if self.reduceop is not None: - accumulators = self.get_accumulators() - self.kernel += [f"for (int idx{i} = 0; idx{i} < {self.full_shape[i]}; idx{i}++) {{\n" for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len)] - self.kernel += [f"{x.tok};\n" for x in self.ast_parse(self.reduceop, [accumulators[off] for off in self.buftokens[self.full_buf_index].acc_offsets()], do_reduce=True)] - self.kernel += ["}\n"] * (self.shape_len - (self.first_reduce + len(self.group_for_reduce))) - - # second stage reduce - if self.group_for_reduce: - self.kernel.append(self.lang.smem_prefix + f"float {self.buftokens[-1].tok}[{self.sts[-1].size()*self.buftokens[-1].size()}];\n") - self.store(-1, accumulators) # TODO: this is assuming the local size = global size. should use lidxs - self.kernel.append(self.lang.barrier+"\n") - - # this is used to identify the thread doing the reducing (lidx == 0) and is repeated from store - # must happen before the upcast - lidx, lvalid = self.sts[-1].expr_idxs() - assert lvalid.min == 1, "local buffer must always be valid" - - # if any group_for_reduce items aren't reduces, upcast them here - for j in self.upcast_in_mid_reduce_axes: - self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j]) - self.upcast() - #if DEBUG >= 4: print("upcast", self.colorshape()) # NOTE: colorshape is wrong here, you have to remove it from group_for_reduce before calling - - self.kernel.append(f"if ({lidx.render(render_cl)} == 0) {{\n") # lidx.max works here too - - # second stage reduce with a new set of accumulators. TODO: do we need acc_offsets here? - accumulators = self.get_accumulators("output") - self.kernel.append(f"for (int mid = 0; mid < {self.sts[-1].size()}; mid++) {{\n") - self.kernel += [f"{x.tok};\n" for x in self.ast_parse(LazyOp(self.reduceop.op, (None,), self.sts[0].shape), accumulators, do_reduce=True)] - self.kernel.append("}\n") - - # late ast - self.store(0, self.ast_parse(self.ast, accumulators)) - if self.group_for_reduce: self.kernel.append("}") - if len(self.lang.gid) == 0: self.kernel += ["}"] * len(self.output_shape) - self.kernel.append("\n}") - - # concat kernel into prg - buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else self.lang.buffer_prefix+self.buftokens[i].decltype(self.bufs[i].dtype)+self.lang.buffer_suffix) for i,x in enumerate(self.bufs) if x is not None and not isinstance(x.realized, RawConst)] - prg = ' '.join(list(self.prekernel) + [f"{self.lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] + - [', '.join([f'{t} data{i}' for i,t in buftypes] + self.lang.extra_args)] + - [") {\n"] + self.kernel) - - # kernel function definition - function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) for x in self.full_shape]) - - # painfully name the function - if prg in GPUCodegen.kernel_name_cache: function_name = GPUCodegen.kernel_name_cache[prg] - else: - GPUCodegen.kernel_cnt[function_name] += 1 - if GPUCodegen.kernel_cnt[function_name] > 1: function_name = f"{function_name}{'n'+str(GPUCodegen.kernel_cnt[function_name]-1)}" - GPUCodegen.kernel_name_cache[prg] = function_name - - return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), - list(self.output_shape[::-1]) if len(self.output_shape) > 0 else [1], - (self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None, - op_estimate=self.info.flops, - mem_estimate=sum(x.dtype.itemsize*(x.realized.size if x.realized is not None else prod(x.shape)) for x in self.bufs if x is not None)) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py new file mode 100644 index 0000000000..488593364a --- /dev/null +++ b/tinygrad/codegen/linearizer.py @@ -0,0 +1,411 @@ +from typing import List, Tuple, Any, Optional, cast, Dict, DefaultDict +import itertools, math +from collections import defaultdict +from enum import Enum, auto + +from tinygrad.helpers import dedup, colored, all_same, ImageDType, DEBUG, prod, dtypes, mnum +from tinygrad.ops import LazyOp, get_lazyops, get_buffers, FlopCounter, get_lazyop_info, map_buffers, UnaryOps +from tinygrad.lazy import LazyBuffer +from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps +from tinygrad.shape.shapetracker import ShapeTracker, View, strides_for_shape +from tinygrad.shape.symbolic import Variable, SumNode, ModNode + +class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); LOAD4 = auto(); STORE4 = auto() # noqa: E702 + +def get_first_reduce(shapes): + for i in range(len(shapes[0])): + if not all_same([x[i] for x in shapes]): return i + return len(shapes[0]) # off the end + +def check_no_mul(test, var): + if test == var: return True + if isinstance(test, SumNode): return any(check_no_mul(x, var) for x in test.nodes) # in a sum is okay + if isinstance(test, ModNode) and test.b%4 == 0: return check_no_mul(test.a, var) # removing a mod is okay + return False + +class Register: + def __init__(self, name:str): + self.name = name + self.axis: List[Tuple[int, int, bool]] = [] + def array(self, length, stride, reduce): self.axis.append((length, stride, reduce)) + def size(self): return prod([x[0] for x in self.axis]) + def offsets(self): return [sum(t) for t in itertools.product(*[[y*x[1] for y in range(x[0])] for x in self.axis[::-1]])] if len(self.axis) else [0] + def can_float4(self): return any(a[0:2] == (4,1) for a in self.axis) + # TODO: this is sort of a hack, it gets the accumulator indices + def acc_offsets(self): + if len(self.axis) == 0: return [0] + acc_strides = [x*(1-self.axis[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in self.axis[::-1])))] + return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(self.axis[::-1])])] + def __repr__(self): return f"<{self.name}{f'{self.axis}'}>" + +class Linearizer: + supports_float4: bool = False + + def __init__(self, ast:LazyOp, output_buffer:LazyBuffer): + # NOTE: if there's a RESHAPE, we skip it. the output shape is set from the reduce op or a latebuf + self.ast = ast.src[0] if ast.op == MovementOps.RESHAPE else ast + + # get the output buffers + self.bufs = [output_buffer] + dedup(get_buffers(ast)) + + # key for lookup in cache (can change, str might not be right) + # bufs are needed because kernels like f(x) = x + x and f(x, y) = x + y have the same str(ast), but are different kernels. + # mapping the buffers to integers is required because a-b != b-a (and how would you tell a and b apart?) + self.key = f"ASTKernelKey ast={str(map_buffers({x:i for i,x in enumerate(self.bufs)}, ast))} bufs={self.bufs}" + + def process(self) -> None: + if hasattr(self, "sts"): return # already processed + + # fetch lazyop info + self.info: FlopCounter = get_lazyop_info(self.ast) + self.mem_estimate: int = sum(x.dtype.itemsize*(x.realized.size if x.realized is not None else prod(x.shape)) for x in self.bufs if x is not None) + + # there's only allowed to be one reduceop + reduceops = [x for x in get_lazyops(self.ast) if x.op in ReduceOps] + assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast" + self.reduceop = reduceops[0] if reduceops else None + + # get earlybufs, before the one reduce op + self.earlybufs = dedup(get_buffers(self.reduceop)) if self.reduceop else [] + + # create new shapetrackers inside this kernel, we will permute them + self.sts: List[ShapeTracker] = [x.st.copy() for x in self.bufs] + for st in self.sts: st.simplify() + + # make the output buffer shape correct in here + self.sts[0].reshape(self.info.shape) + self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if len(self.earlybufs) > 0 else 0 + + # move all reduce axes to the end + reduce = list(enumerate(zip(self.full_shape, self.sts[0].shape))) + permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n]) + self.reshape_and_permute(None, permute) + + # group simplifies + self.simplify_ones() + self.simplify_merge_adjacent() + + # is this generic? + self.registers = [Register(f"data{i}") for i in range(len(self.bufs))] + self.group_for_reduce: List[int] = [] + + def linearize(self): + # uops + self.uops: List[Tuple[UOps, Optional[str], Any]] = [] + + # add a local buffer for multistage reduce + if len(self.group_for_reduce): + self.bufs.append(None) + # TODO: the strides of this can be controlled + st = ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.registers[0].axis])) + buftoken = Register("temp") + # manual upcast of the local + for _,_,is_reduce in self.registers[0].axis[::-1]: + buftoken.array(st.shape[-1], st.views[-1].strides[-1], is_reduce) + st.views[-1] = View(st.shape[0:-1], st.views[-1].strides[0:-1], st.views[-1].offset) + self.sts.append(st) + self.registers.append(buftoken) + self.uop(UOps.DEFINE_LOCAL, (self.registers[-1].name, self.sts[-1].size()*self.registers[-1].size())) + + # TODO: add upcasting to float4 here + def global_buf(i, idxs, store=None): + should_upcast = self.supports_float4 and self.registers[i].can_float4() and (self.bufs[i] is None or self.bufs[i].dtype != dtypes.float16 or isinstance(self.bufs[i].dtype, ImageDType)) + cache: Dict[int, str] = {} + store_offset: Dict[int, int] = {y:x for x,y in enumerate(self.registers[i].offsets())} # NOTE: for stores, these should be unique + def op(offset): + if offset in cache: return cache[offset] + will_merge = False + if should_upcast and offset%4 == 0: + float4_index = Variable("FLOAT4_INDEX", 0, 3) + idxy_test, valid_test = self.sts[i].expr_idxs(float4_index+offset, idxs) + if DEBUG >= 4: print(f"attempting to fuse buf {i} :", check_no_mul(idxy_test, float4_index), idxy_test//4, valid_test//4) + # float4_index must not be in after divide or in valid. NOTE: this forces it to always be aligned too, maybe not required? + will_merge = check_no_mul(idxy_test, float4_index) and "FLOAT4_INDEX" not in (idxy_test//4).render() and "FLOAT4_INDEX" not in (valid_test//4).render() + if store is not None: + if offset in store_offset: + if will_merge: + offsets = [] + for j in range(0, 4): + offsets.append(store[store_offset[offset+j]]) + del store_offset[offset+j] + self.uop(UOps.STORE4, (i, *self.sts[i].expr_idxs(offset, idxs), offsets)) + else: + self.uop(UOps.STORE, (i, *self.sts[i].expr_idxs(offset, idxs), store[store_offset[offset]])) + del store_offset[offset] + else: + reg = self.uop(UOps.LOAD4 if will_merge else UOps.LOAD, (i, *self.sts[i].expr_idxs(offset, idxs)), self.registers[i].name+"_"+mnum(offset)) + if will_merge: + for j in range(0, 4): cache[offset+j] = reg+"."+"xyzw"[j] + else: + cache[offset] = reg + return cache[offset] + return [op(o) for o in self.registers[i].offsets()] + + # parse AST + loaded_buffers = {} + acc = [] + + # ssa + _ssa:DefaultDict[str,int] = defaultdict(int) + def ssa(name): + _ssa[name] += 1 + return f"{name}{_ssa[name]-1}" + + # global loop + global_idxs = [Variable(f"gidx{i}", 0, self.full_shape[i]-1 if i < self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))] + self.uop(UOps.LOOP, (global_idxs, "global")) + + # local loop + if self.group_for_reduce: + # NOTE: this is assuming the global size = the local size in these dims. in general, this doesn't have to be true + local_idxs = [Variable(f"lidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))] + self.uop(UOps.LOOP, (local_idxs, "local")) + gl_idxs = [x*(y.max+1)+y for x,y in zip(global_idxs, local_idxs)] + else: + # without local idxs, it's just the global idxs + gl_idxs = global_idxs + + # reduce op + if self.reduceop is not None: + # define accumulator + acc = [self.uop(UOps.CONST, ({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)],), ssa('acc')) for _ in self.registers[0].offsets()] + + # reduce loop + reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len)] + self.uop(UOps.LOOP, (reduce_idxs, "reduce")) + + # load earlybufs + loaded_buffers.update({b:global_buf(i, gl_idxs+reduce_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs and i != 0}) + + # run early AST (with reduce) + self.ast_parse(self.reduceop, [acc[off] for off in self.registers[self.full_buf_index].acc_offsets()], loaded_buffers, ssa, do_reduce=True) + + # end the reduce loop + self.uop(UOps.ENDLOOP, (reduce_idxs, "reduce")) + + # end the local loop, do the local reduce + if self.group_for_reduce: + global_buf(-1, local_idxs, acc) # store accumulators + self.uop(UOps.ENDLOOP, (local_idxs, "local")) # this is a barrier on GPUs + + # if any group_for_reduce items aren't reduces, upcast them here + for j in self.upcast_in_mid_reduce_axes: + self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j]) + self.upcast() + self.group_for_reduce.pop() + + # NOTE: this structure is the same as the reduce op above + + # define late accumulator + acc = [self.uop(UOps.CONST, ({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)],), ssa('lacc')) for _ in self.registers[-1].offsets()] + + # late reduce loop + end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))] + self.uop(UOps.LOOP, (end_local_idxs, "late_reduce")) + + # load localbufs + loaded_buffers["LOCAL_BUFFER"] = global_buf(-1, end_local_idxs) + + # there's no AST here (and there's no shape for the reduce LazyOp) + self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.registers[-1].acc_offsets()], loaded_buffers, ssa, do_reduce=True) + + # end the late reduce loop + self.uop(UOps.ENDLOOP, (end_local_idxs, "late_reduce")) + + # load latebufs + loaded_buffers.update({b:global_buf(i, global_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b is not None}) + + # run late AST + val = self.ast_parse(self.ast, acc, loaded_buffers, ssa) + + # store + global_buf(0, global_idxs, val) + + # end the global loop + self.uop(UOps.ENDLOOP, (global_idxs, "global")) + + # kernel function definition + self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) for x in self.full_shape]) + + # print + if DEBUG >= 3: + self.printbufs() + for x in self.uops: + print(x) + + def uop(self, uop:UOps, arg:Any, name:Optional[str]=None): + self.uops.append((uop, name, arg)) + return name + + def ast_parse(self, x, acc, loaded_buffers, ssa, do_reduce=False) -> List[str]: + if not isinstance(x, LazyOp): return loaded_buffers[x] + if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, loaded_buffers, ssa) # cast isn't an ALU op + if x.op in ReduceOps and not do_reduce: return acc + # MULACC fusion. TODO: this is copied from Interpreted + if x.op == ReduceOps.SUM and isinstance(x.src[0], LazyOp) and x.src[0].op == BinaryOps.MUL: + x = LazyOp(FusedOps.MULACC, x.src[0].src, x.arg) + values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src] + if isinstance(x.op, (ReduceOps, FusedOps)): + return [self.uop(UOps.ALU, ({ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op], val, val[0]), None) for val in zip(acc, *values)] + else: + return [self.uop(UOps.ALU, (x.op, val), ssa('alu')) for val in zip(*values)] + + @property + def first_reduce(self) -> int: return get_first_reduce([x.shape for i,x in enumerate(self.sts) if self.bufs[i] is not None]) + + @property + def full_shape(self) -> Tuple[int, ...]: return self.sts[self.full_buf_index].shape + + @property + def shape_len(self) -> int: return len(self.sts[0].shape) + + @property + def upcast_in_mid_reduce_axes(self) -> List[int]: return [j for j in range(self.first_reduce, self.first_reduce+len(self.group_for_reduce)) if self.full_shape[j] == self.sts[0].shape[j]] + + def colorshape(self, pad=50) -> str: + axis = [(f"{rs:4d}", (("green" if i in self.upcast_in_mid_reduce_axes else "cyan") if i < self.first_reduce + len(self.group_for_reduce) else "red") if i >= self.first_reduce else "blue") for i, rs in enumerate(self.full_shape)] + axis += [(f"{s:4d}", 'magenta' if reduce else 'yellow') for s, _, reduce in self.registers[self.full_buf_index].axis[::-1]] + return ' '.join([colored(*x) for x in axis])+(" "*(pad-len(' '.join([x[0] for x in axis])))) + + def printbufs(self, prefix=""): + for i in range(len(self.sts)): + print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i] is not None else 'FAKE':47s} {str(self.registers[i]):38s}", self.sts[i].views) + print(self.colorshape()) + + # ******************** base simplifiers ******************** + + def simplify_merge_adjacent(self): + if self.shape_len == 0: return + shapes, strides = [x.shape for x in self.sts], [x.views[-1].strides for x in self.sts] + + # merge dimensions if we can, multi get_shape_strides + # TODO: does this always preserve the reduce dimension, NO + # TODO: move this into shapetracker, with tests! + rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))] + for i in range(1, len(shapes[0])): + can_merge = [] + for j in range(len(shapes)): + # TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case + can_merge.append((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*strides[j][i]) or (strides[j][i] == 0 and rets[j][-1][1] == 0)) + # more can merge than this + mergeable = all(can_merge) and i != self.first_reduce + for j in range(len(shapes)): + if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i]) + else: rets[j].append((shapes[j][i], strides[j][i])) + + # do the reshapes + for i,x in enumerate(rets): self.sts[i].reshape(tuple(y[0] for y in x)) + + def simplify_ones(self): + # remove places where the shape is all ones + # TODO: this should be factored in to multi shape stride + all_ones = [all(st.shape[i]==1 for st in self.sts) for i in range(self.shape_len)] + # keep at least 1 one + if all(all_ones): all_ones[-1] = False + self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None) + + # apply reshape and permute to all shapetrackers + def reshape_and_permute(self, new_shape_fxn, axis): + for st in self.sts: + if new_shape_fxn is not None: st.reshape(tuple(new_shape_fxn(st.shape))) + if axis is not None: st.permute(tuple(axis)) + + # ******************** complex simplifiers ******************** + + # axis : the axis to pull from + # amount : the amount to take + # top : if you want to pull that amount from the top + # insert_before : place to insert the new stuff + def shift_to(self, axis, amount, top=False, insert_before=None): + if insert_before is None: insert_before = self.shape_len + move_axis = axis if top else axis+1 + if move_axis < insert_before: insert_before += 1 + self.reshape_and_permute( + lambda x: list(x[0:axis]) + (([amount, x[axis]//amount] if top else [x[axis]//amount, amount]) if x[axis] > 1 else [1,1]) + list(x[axis+1:]), + [i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis]) + + # drops the final dimension + def upcast(self): + upcasted = [x.shape[-1] for x in self.sts if x.shape[-1] != 1] + assert len(upcasted) >= 1 and all_same(upcasted), f"can't upcast mismatch {upcasted}" + for st,buftoken in zip(self.sts, self.registers): + # add last axis to the buftoken (if it's not a 1) + if st.shape[-1] == upcasted[0]: buftoken.array(st.shape[-1], st.views[-1].strides[-1], len(upcasted) != len(self.sts)) + # remove the last axis (unless it's the only dimension, then make it a 1) + st.views[-1] = View(st.shape[0:-1], st.views[-1].strides[0:-1], st.views[-1].offset) if len(st.shape) > 1 else View((1,), (0,), st.views[-1].offset) + + # ******************** GPU simplifiers ******************** + + def required_optimizations(self, early_only=False): + for buf_index,buf in enumerate(self.bufs): + upcast_strides = [self.sts[buf_index].strides[i] for i in self.upcast_in_mid_reduce_axes] + if (not early_only or buf in self.earlybufs) and isinstance(self.bufs[buf_index].dtype, ImageDType) and not (self.registers[buf_index].can_float4() or (buf not in self.earlybufs and (1 in upcast_strides))): + axes = [i for i,x in enumerate(self.sts[buf_index].strides) if x == 1] + assert len(axes) == 1, f"wrong number of stride 1 axis : {axes} on buf_index {buf_index}, {self.sts[buf_index]}" + assert self.sts[buf_index].shape[axes[0]]%4 == 0, f"axis:{axes[0]} in buffer {buf_index} is not a multiple of 4, {self.sts[buf_index].shape}" + self.shift_to(axes[0], 4) + self.upcast() + assert self.registers[buf_index].can_float4() + + def hand_coded_optimizations(self): + # if there's images in the earlybufs, we have to make an axis the 4 loading one + self.required_optimizations(early_only=True) + + # simplify (sets first_reduce) + self.simplify_ones() + + # are we grouping? (requires local shape support) + if not self.registers[0].can_float4() and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: + # TODO: use 1024 if it's allowed in a smarter way + for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]): + if all([st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts]): + self.shift_to(self.first_reduce, sz, top=True, insert_before=self.first_reduce) + self.group_for_reduce.append(sz) + break + + # are we upcasting in mid reduce? (only for images) + if self.bufs[0].dtype.name.startswith('image') and not self.registers[0].can_float4() and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: + axes = [i for i,x in enumerate(self.sts[0].strides) if x == 1] + assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}" + if self.sts[0].shape[axes[0]]%4 == 0: + self.shift_to(axes[0], 4, insert_before=self.first_reduce + len(self.group_for_reduce)) # insert at the end of the grouped axis + self.group_for_reduce.append(4) + + # now do everything required + self.required_optimizations() + + # simplify (sets first_reduce) + self.simplify_ones() + + # use more opencl indexing if the output buffer is an image and we have room + if self.bufs[0].dtype.name.startswith('image') and self.first_reduce+len(self.group_for_reduce) < 3: + base_shape = self.bufs[0].dtype.shape + if (base_shape[0]*base_shape[1]) % self.sts[0].shape[0] == 0 and self.sts[0].shape[0]//base_shape[0] != 0: + if DEBUG >= 4: print("split opencl", base_shape, self.sts[0].shape) + self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None) + self.simplify_ones() + + # no more opt if we are grouping + if self.group_for_reduce: return + + # **** below this line need to be optional and benchmarked **** + + # potentially do more upcasts of non reduce axes based on a heuristic + while prod(self.sts[0].shape[:self.first_reduce]) >= 1024: + xb_choices = [] + for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce + # if it mods, and some buffer has stride 0 on axis while having no stride 0 in the buftoken + if self.full_shape[axis]%upcast_amount == 0 and any(self.sts[buf_index].strides[axis] == 0 and not any(x[1] == 0 for x in self.registers[buf_index].axis) for buf_index in range(len(self.sts))): + xb_choices.append((sum(st.strides[axis]>0 for st in self.sts), sum(st.strides[axis] for st in self.sts), axis, upcast_amount)) + if len(xb_choices): + xb_choices = sorted(xb_choices) + if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}") + self.shift_to(xb_choices[0][2], amount=xb_choices[0][3]) + self.upcast() + self.simplify_ones() + else: + break + + # if last dim <= 5 and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS + if self.first_reduce < self.shape_len and self.full_shape[-1] <= 5 and (max([x.size() for i,x in enumerate(self.registers) if self.bufs[i] in self.earlybufs]) <= 4 or not any(r for _,_,r in self.registers[self.full_buf_index].axis)): + self.upcast() diff --git a/tinygrad/codegen/llvm.py b/tinygrad/codegen/llvm.py deleted file mode 100644 index 62d1c6bd4f..0000000000 --- a/tinygrad/codegen/llvm.py +++ /dev/null @@ -1,215 +0,0 @@ -import functools, math -from typing import ClassVar, List -from llvmlite import ir # type: ignore -from tinygrad.codegen.ast import ASTKernel -from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, ASTRunner -from tinygrad.helpers import DEBUG, prod, dtypes - -from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, GeNode, LtNode, SumNode, AndNode -def int_const(x): return ir.Constant(ir.IntType(64), x) -render_llvm = { - Variable: lambda self,ops,ctx: self.expr, - NumNode: lambda self,ops,ctx: int_const(self.b), - MulNode: lambda self,ops,ctx: ctx.mul(self.a.render(ops,ctx), int_const(self.b)), - DivNode: lambda self,ops,ctx: ctx.sdiv(self.a.render(ops,ctx), int_const(self.b)), - ModNode: lambda self,ops,ctx: ctx.srem(self.a.render(ops,ctx), int_const(self.b)), - GeNode: lambda self,ops,ctx: ctx.icmp_signed(">=", self.a.render(ops,ctx), int_const(self.b)), - LtNode: lambda self,ops,ctx: ctx.icmp_signed("<", self.a.render(ops,ctx), int_const(self.b)), - SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.add(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)), - AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.and_(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)) -} - -class LLVMCodegen(ASTKernel): - op_lookup: ClassVar = { - UnaryOps.NOOP: lambda builder,x: x, - UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x], fastmath=('fast',)), - UnaryOps.LOG: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log', [ir.FloatType()]), [x], fastmath=('fast',)), - BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)), - BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)), - BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y, flags=('fast',)), - BinaryOps.DIV: lambda builder,x,y: builder.fdiv(x,y, flags=('fast',)), - BinaryOps.POW: lambda builder,x,y: builder.call(builder._block.module.declare_intrinsic('llvm.pow', [ir.FloatType()]), [x,y], fastmath=('fast',)), - BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y, flags=('fast',)), ir.FloatType()), - BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=('fast',)), x, y, flags=('fast',)) - } - start_for_op: ClassVar = { - ReduceOps.SUM: ir.Constant(ir.FloatType(), 0), - ReduceOps.MAX: ir.Constant(ir.FloatType(), -math.inf) - } - - def codegen(self): - self.process() - if DEBUG >= 3: self.printbufs("old:", DEBUG>=4) - - # this stuff can't be hand coded - kernel_output_axis: List[int] = [] - """ - CACHE_DIM = 32 - if len(k.shapes[0]) == 2: - # cache tiling, makes permute fast - k.reshape_and_permute( - lambda shape: (shape[0]//CACHE_DIM, CACHE_DIM, shape[1]//CACHE_DIM, CACHE_DIM), - (0,2,1,3)) - elif len(k.shapes[0]) == 3: - if k.reduceop: - if k.strides[1][-1] == 1 and k.strides[2][-1] == 1: - DY, DX = 8, 8 - elif k.strides[1][-1] in [1,0] and k.strides[1][-2] in [1,0]: - DY, DX = 4, 16 - else: - DY, DX = 16, 4 - # matmul: YyXxK -> YXKyx - k.reshape_and_permute( - lambda shape: (shape[0]//DY, DY, shape[1]//DX, DX, shape[2]), - (0,2,4,1,3)) - kernel_output_axis = [-2, -1] - else: - CACHE_L2_DIM = 256 - k.reshape_and_permute( - lambda shape: (shape[0], shape[1]//CACHE_L2_DIM, CACHE_L2_DIM, shape[2]), - (1,0,2,3)) - kernel_output_axis = [-1] - elif len(k.shapes[0]) == 7: - # conv: split chans and X - DY, DX = 4, 16 - k.reshape_and_permute( - lambda shape: (shape[0], shape[1]//DY, DY, shape[2], shape[3]//DX, DX, shape[4], shape[5], shape[6]), - (0,1,3,4,6,7,8,2,5)) - kernel_output_axis = [-2, -1] - """ - - # the 4x4 need to go all the way at the end, even after reduce - output_shape = self.sts[0].shape - full_shape_options = [x.shape for x in self.sts if x.shape != output_shape] - full_shape = output_shape if len(full_shape_options) == 0 else full_shape_options[0] - - full_shape = full_shape if not kernel_output_axis else full_shape[:-len(kernel_output_axis)] - kernel_output_dim = prod([self.sts[0].shape[a] for a in kernel_output_axis]) - kernel_output_type = ir.FloatType() if kernel_output_dim == 1 else ir.VectorType(ir.FloatType(), kernel_output_dim) - - def get_idxs(builder, idx, buf_index): - idx_offsets = [0] - """ - for axis in kernel_output_axis: - new_idx_offsets = [] - for s in range(k.shapes[buf_index][axis]): - for i in idx_offsets: - new_idx_offsets.append(i + s * k.strides[buf_index][axis]) - idx_offsets = new_idx_offsets - """ - return [builder.add(idx, int_const(i)) for i in idx_offsets] - - # *** llvm specific below this line *** - - # create llvm function - module = ir.Module(name=__file__) - func_dtypes = [{dtypes.float16:ir.HalfType(), dtypes.float32:ir.FloatType()}[buf.dtype] for buf in self.bufs] - func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name='exec') - - # force llvmlite to allow us to add function attribute then add the attribute - func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"'])) - func.attributes.add('"no-nans-fp-math"="true"') - - # construct the structure of the loops - loop_entry, loop_exit = [ir.IRBuilder(func.append_basic_block(name="entry"))], [] - for i,_ in enumerate(full_shape): loop_entry.append(ir.IRBuilder(func.append_basic_block(name=f"loop_{i}"))) - for i,_ in enumerate(full_shape): loop_exit.append(ir.IRBuilder(func.append_basic_block(name=f"loopexit_{len(full_shape)-1-i}"))) - loop_exit.append(ir.IRBuilder(func.append_basic_block(name="exit"))) - loop_exit = loop_exit[::-1] - - # add the buffer indexing - idx_level = [[int_const(st.offset)] for st in self.sts] - for i in range(len(full_shape)): - for j in range(len(self.bufs)): - # stride - si = loop_entry[i+1].phi(ir.IntType(64), name=f"idx_{j}_{i}") - si.add_incoming(idx_level[j][-1], loop_entry[i]._block) - si_ps = loop_exit[i+1].add(si, int_const(self.sts[j].views[-1].strides[i])) - si.add_incoming(si_ps, loop_exit[i+1]._block) - idx_level[j].append(si) - - # the ast parser - def ast_parse(builder, x, level, reduce_result=None): - if not isinstance(x, LazyOp): - m = kernel_output_type(ir.Undefined) - buf_index = self.bufs.index(x) - for i, idx in enumerate(get_idxs(builder, idx_level[buf_index][level], buf_index)): - # first view is already implictly handled - idx, valid = self.sts[buf_index]._expr_idx(Variable(idx, 0, prod(self.sts[buf_index].shape))) - idx = idx.render(render_llvm, builder) - if valid.min == 0: - valid = valid.render(render_llvm, builder) - # this always does the load, so we have it load *0 if the arg won't be used - # TODO: would control flow be faster? - aug_idx = builder.select(valid, idx, int_const(0)) - element = builder.select(valid, builder.load(builder.gep(func.args[buf_index], [aug_idx], inbounds=True)), ir.Constant(func_dtypes[buf_index], 0)) - else: - element = builder.load(builder.gep(func.args[buf_index], [idx], inbounds=True)) - # upcast - if func_dtypes[buf_index] != ir.FloatType(): element = builder.fpext(element, ir.FloatType()) - m = element if kernel_output_dim == 1 else builder.insert_element(m, element, int_const(i)) - return m - if isinstance(x.op, ReduceOps): - if reduce_result is None: - raise RuntimeError("no reduce") - return reduce_result - values = [ast_parse(builder, v, level, reduce_result) for v in x.src] - - m = kernel_output_type(ir.Undefined) - if kernel_output_dim == 1: - return LLVMCodegen.op_lookup[x.op](builder, *values) - else: - # TODO: this only has to be done for certain ops - for i in range(kernel_output_dim): - value = [builder.extract_element(v, int_const(i)) for v in values] - element = LLVMCodegen.op_lookup[x.op](builder, *value) - m = builder.insert_element(m, element, int_const(i)) - return m - - # add the ast + final store - store_loop = output_shape.index(1) if 1 in output_shape else -1 - - # do the early ast - reduce_result = None - if self.reduceop: - reduce_input = ast_parse(loop_exit[-1], self.reduceop.src[0], -1) - phis = [LLVMCodegen.start_for_op[self.reduceop.op]] # type: ignore - if kernel_output_dim > 1: phis = [kernel_output_type(phis * kernel_output_dim)] - for i in range(store_loop+1, len(loop_entry)): - val = loop_entry[i].phi(kernel_output_type, f"reduce_phi_{i}") - val.add_incoming(phis[-1], loop_entry[i-1]._block) - phis.append(val) - - if self.reduceop.op == ReduceOps.SUM: - reduce_result = loop_exit[-1].fadd(reduce_input, val, flags=('fast',)) - elif self.reduceop.op == ReduceOps.MAX: - reduce_result = loop_exit[-1].select(loop_exit[-1].fcmp_unordered(">", val, reduce_input, flags=('fast',)), val, reduce_input, flags=('fast',)) - - for i,phi in enumerate(phis[1:]): - phi.add_incoming(reduce_result, loop_exit[store_loop+1+i]._block) - - # do the late ast - result = ast_parse(loop_exit[store_loop], self.ast, store_loop, reduce_result=reduce_result) - - # store result - builder = loop_exit[store_loop] - for i, idx in enumerate(get_idxs(builder, idx_level[0][store_loop], 0)): - element = result if kernel_output_dim == 1 else builder.extract_element(result, int_const(i)) - if func_dtypes[0] != ir.FloatType(): element = builder.fptrunc(element, func_dtypes[0]) - builder.store(element, builder.gep(func.args[0], [idx], inbounds=True)) - - # add the looping - for i,s in enumerate(full_shape): - loop_entry[i].branch(loop_entry[i+1]._block) - idx = loop_entry[i+1].phi(ir.IntType(64), name=f"loopvar_{i}") - idx.add_incoming(int_const(0), loop_entry[i]._block) - idx_p1 = loop_exit[i+1].add(idx, int_const(1)) - idx.add_incoming(idx_p1, loop_exit[i+1]._block) - loop_exit[i+1].cbranch(loop_exit[i+1].icmp_unsigned("==", idx_p1, int_const(s)), loop_exit[i]._block, loop_entry[i+1]._block) - - loop_entry[-1].branch(loop_exit[-1]._block) - loop_exit[0].ret_void() - - # TODO: mem_estimate is copied from GPU - return ASTRunner('exec', str(module), op_estimate=self.info.flops, - mem_estimate=sum(x.dtype.itemsize*(x.realized.size if x.realized is not None else prod(x.shape)) for x in self.bufs if x is not None)) diff --git a/tinygrad/codegen/llvmir.py b/tinygrad/codegen/llvmir.py new file mode 100644 index 0000000000..512b96036d --- /dev/null +++ b/tinygrad/codegen/llvmir.py @@ -0,0 +1,103 @@ +from typing import Final, Dict, Callable, Any, List, Optional +import functools +from llvmlite import ir # type: ignore +from tinygrad.codegen.linearizer import Linearizer, UOps +from tinygrad.helpers import dtypes +from tinygrad.ops import Op, ASTRunner, UnaryOps, BinaryOps, FusedOps + +from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, GeNode, LtNode, SumNode, AndNode +def int_const(x): return ir.Constant(ir.IntType(64), x) +render_llvm = { + NumNode: lambda self,ops,ctx: int_const(self.b), + MulNode: lambda self,ops,ctx: ctx.mul(self.a.render(ops,ctx), int_const(self.b)), + DivNode: lambda self,ops,ctx: ctx.sdiv(self.a.render(ops,ctx), int_const(self.b)), + ModNode: lambda self,ops,ctx: ctx.srem(self.a.render(ops,ctx), int_const(self.b)), + GeNode: lambda self,ops,ctx: ctx.icmp_signed(">=", self.a.render(ops,ctx), int_const(self.b)), + LtNode: lambda self,ops,ctx: ctx.icmp_signed("<", self.a.render(ops,ctx), int_const(self.b)), + SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.add(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)), + AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.and_(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)) +} + +class LLVMIRCodegen(Linearizer): + code_for_op: Final[Dict[Op, Callable]] = { + UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x], fastmath=('fast',)), + UnaryOps.LOG: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log', [ir.FloatType()]), [x], fastmath=('fast',)), + BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)), + BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)), + BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y, flags=('fast',)), + BinaryOps.DIV: lambda builder,x,y: builder.fdiv(x,y, flags=('fast',)), + BinaryOps.POW: lambda builder,x,y: builder.call(builder._block.module.declare_intrinsic('llvm.pow', [ir.FloatType()]), [x,y], fastmath=('fast',)), + BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y, flags=('fast',)), ir.FloatType()), + BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=('fast',)), x, y, flags=('fast',)), + FusedOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(y,z, flags=('fast',)), x, flags=('fast',)), + } + def codegen(self): + self.process() + # no optimize, this doesn't support local + self.linearize() + + # create llvm function + module = ir.Module(name=__file__) + func_dtypes = [{dtypes.float16:ir.HalfType(), dtypes.float32:ir.FloatType()}[buf.dtype] for buf in self.bufs] + func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name='exec') + + # force llvmlite to allow us to add function attribute then add the attribute + func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"'])) + func.attributes.add('"no-nans-fp-math"="true"') + + bb = [ir.IRBuilder(func.append_basic_block("entry"))] + loop_blocks = [] + reduce_phis: List = [] + # TODO: newvar probably shouldn't be optional + lvars: Dict[Optional[str], Any] = {} # this Any is an llvm type + render_llvm[Variable] = lambda self,ops,ctx: lvars[self.expr] + + for uop,newvar,args in self.uops: + if uop == UOps.CONST: + lvars[newvar] = ir.Constant(ir.FloatType(), args[0]) + reduce_phis.append(newvar) + if uop == UOps.LOOP: + for var in args[0]: + if isinstance(var, NumNode): continue + bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{var.expr}"))) + bb[-2].branch(bb[-1]._block) + + phis = [] + for rp in reduce_phis: + incoming = lvars[rp] + lvars[rp] = bb[-1].phi(ir.FloatType()) + lvars[rp].add_incoming(incoming, bb[-2]._block) + phis.append((rp, lvars[rp])) + loop_blocks.append((bb[-1], phis)) + + lvars[var.expr] = bb[-1].phi(ir.IntType(64), name=var.expr) + lvars[var.expr].add_incoming(int_const(var.min), bb[-2]._block) + if uop == UOps.ENDLOOP: + for var in args[0][::-1]: + if isinstance(var, NumNode): continue + block, phis = loop_blocks.pop() + idx_p1 = bb[-1].add(lvars[var.expr], int_const(1)) + lvars[var.expr].add_incoming(idx_p1, bb[-1]._block) + for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block) + bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{var.expr}"))) + bb[-2].cbranch(bb[-2].icmp_unsigned("==", idx_p1, int_const(var.max+1)), bb[-1]._block, block._block) + if uop == UOps.LOAD: + idx, valid = args[1].render(render_llvm, bb[-1]), args[2].render(render_llvm, bb[-1]) + if args[2].min == 0: + aug_idx = bb[-1].select(valid, idx, int_const(0)) + val= bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[args[0]], [aug_idx], inbounds=True)), ir.Constant(func_dtypes[args[0]], 0)) + else: + val = bb[-1].load(bb[-1].gep(func.args[args[0]], [idx], inbounds=True)) + if func_dtypes[args[0]] != ir.FloatType(): val = bb[-1].fpext(val, ir.FloatType()) + lvars[newvar] = val + if uop == UOps.STORE: + assert args[2].min == 1, "store must be valid" + idx = args[1].render(render_llvm, bb[-1]) + element = lvars[args[3]] + if func_dtypes[0] != ir.FloatType(): element = bb[-1].fptrunc(element, func_dtypes[0]) + bb[-1].store(element, bb[-1].gep(func.args[args[0]], [idx], inbounds=True)) + if uop == UOps.ALU: + lvars[newvar if newvar is not None else args[2]] = self.code_for_op[args[0]](bb[-1], *[lvars[x] for x in args[1]]) + + bb[-1].ret_void() + return ASTRunner('exec', str(module), op_estimate=self.info.flops, mem_estimate=self.mem_estimate) diff --git a/tinygrad/nn/image.py b/tinygrad/nn/image.py index 34aa35108c..698e48bc2e 100644 --- a/tinygrad/nn/image.py +++ b/tinygrad/nn/image.py @@ -3,7 +3,7 @@ from tinygrad.helpers import prod, IMAGE, ImageDType, getenv, dtypes from tinygrad.lazy import get_single_root FLOAT16 = getenv("FLOAT16", 0) -base_image_type = (100, 2, "image_half", np.float16) if FLOAT16 else (100, 4, "image_float", np.float32) +base_image_type = (100, 2, "imageh", np.float16) if FLOAT16 else (100, 4, "imagef", np.float32) def image_dot(self, w): # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 4b0da1554c..d7e485bd3b 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -124,16 +124,6 @@ class Compiled: # all movementops do nothing in a Compiled buffer! if ast.op in MovementOps and not isinstance(ast.src[0], LazyOp) and ast.src[0].realized is not None: return ast.src[0].realized - k = self.codegen(ast, output) - - # this is the default now - if getenv("ENABLE_METHOD_CACHE", 1): - if k.key not in self.method_cache: self.method_cache[k.key] = k.codegen().build(self.runtime) - elif DEBUG >= 4: print(f"method cache hit : {k.key}") - prg = self.method_cache[k.key] - else: - prg = k.codegen().build(self.runtime) - # check if we can reuse the output buffer # if it's aliased, don't use it # NOTE: this is pretty wrong actually, who knows where else this buffer is used? @@ -149,5 +139,16 @@ class Compiled: if output.realized is None: output.realized = self.buffer(prod(output.shape), output.dtype) + # compilation time + k = self.codegen(ast, output) + + # this is the default now + if getenv("ENABLE_METHOD_CACHE", 1): + if k.key not in self.method_cache: self.method_cache[k.key] = k.codegen().build(self.runtime) + elif DEBUG >= 4: print(f"method cache hit : {k.key}") + prg = self.method_cache[k.key] + else: + prg = k.codegen().build(self.runtime) + prg.exec(k.bufs) return output.realized diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index 4a324eed2d..702ba56517 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -1,7 +1,7 @@ import os, time, ctypes, hashlib, subprocess, platform from tinygrad.ops import Compiled from tinygrad.runtime.lib import RawMallocBuffer -from tinygrad.codegen.gpu import GPUCodegen, GPULanguage +from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage class ClangProgram: def __init__(self, name:str, prg:str): @@ -20,7 +20,8 @@ class ClangProgram: self.fxn(*[x._buf for x in args]) if wait: return time.monotonic()-st -class ClangCodegen(GPUCodegen): - lang = GPULanguage(buffer_suffix=" restrict") +class ClangCodegen(CStyleCodegen): + lang = CStyleLanguage(buffer_suffix=" restrict") + supports_float4: bool = False ClangBuffer = Compiled(RawMallocBuffer, ClangCodegen, ClangProgram) diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 5763909362..1bb5f0ea60 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -6,7 +6,7 @@ from pycuda.compiler import compile as cuda_compile # type: ignore from tinygrad.helpers import DEBUG from tinygrad.ops import Compiled from tinygrad.runtime.lib import RawBufferCopyInOut -from tinygrad.codegen.gpu import GPUCodegen, GPULanguage +from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage class RawCUDABuffer(RawBufferCopyInOut): def __init__(self, size, dtype): super().__init__(size, dtype, cuda.mem_alloc(size * dtype.itemsize)) @@ -38,8 +38,8 @@ class CUDAProgram: end.synchronize() return start.time_till(end)*1e-3 -class CUDACodegen(GPUCodegen): - lang = GPULanguage( +class CUDACodegen(CStyleCodegen): + lang = CStyleLanguage( kernel_prefix = "__global__", smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4", half_prekernel = "#include ", gid = [f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' for i in range(3)], diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 86998eb307..f3e3c2cc6d 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -6,7 +6,7 @@ from typing import Optional, List from tinygrad.helpers import DEBUG, getenv, prod, ImageDType from tinygrad.ops import Compiled from tinygrad.runtime.lib import RawBufferCopyInOut -from tinygrad.codegen.gpu import GPUCodegen, GPULanguage +from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage OSX = platform.system() == "Darwin" OSX_TIMING_RATIO = (125/3) if OSX else 1.0 # see test/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something @@ -68,11 +68,11 @@ class CLProgram: return ((e.profile.end - e.profile.start) * OSX_TIMING_RATIO) * 1e-9 return None -class CLCodegen(GPUCodegen): - lang = GPULanguage( +class CLCodegen(CStyleCodegen): + lang = CStyleLanguage( kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ", half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable", barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)", - gid = [f'get_global_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)]) + gid = [f'get_global_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True) GPUBuffer = Compiled(CLBuffer, CLCodegen, CLProgram) diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index a3e739301c..3208f8624a 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -3,7 +3,7 @@ from typing import ClassVar from tinygrad.ops import Compiled from tinygrad.helpers import getenv, DEBUG from ctypes import CFUNCTYPE -from tinygrad.codegen.llvm import LLVMCodegen +from tinygrad.codegen.llvmir import LLVMIRCodegen from tinygrad.runtime.lib import RawMallocBuffer import llvmlite.binding as llvm # type: ignore @@ -62,4 +62,4 @@ class LLVMProgram: cfunc(*[x._buf for x in bufs]) if wait: return time.monotonic()-st -LLVMBuffer = Compiled(RawMallocBuffer, LLVMCodegen, LLVMProgram) +LLVMBuffer = Compiled(RawMallocBuffer, LLVMIRCodegen, LLVMProgram) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index aac962af54..350bed338c 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -2,9 +2,8 @@ import os, subprocess, pathlib import Metal, Cocoa, libdispatch # type: ignore from typing import List, Any -from tinygrad.codegen.gpu import GPUCodegen, GPULanguage +from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage from tinygrad.helpers import prod, getenv, DEBUG, DType -#from tinygrad.ops import CompiledBuffer, Specialized from tinygrad.ops import Compiled from tinygrad.runtime.lib import RawBufferMapped @@ -74,8 +73,8 @@ class MetalProgram: else: METAL.mtl_buffers_in_flight.append(command_buffer) -class MetalCodegen(GPUCodegen): - lang = GPULanguage( +class MetalCodegen(CStyleCodegen): + lang = CStyleLanguage( kernel_prefix = "#include \nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ", barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);", float4 = "float4", gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)], diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index daa369ca8b..446a371e3b 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -42,7 +42,7 @@ class View: # generate an expression if you have a variable or expression for each index def expr_idxs(self, idxs, offset:Union[Node, int]=0): - return Variable.sum([Variable.num(self.offset)+offset] + [Variable(idx, 0, sh-1)*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0]) + return Variable.sum([Variable.num(self.offset)+offset] + [(idx if isinstance(idx, Variable) else Variable(idx, 0, sh-1))*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0]) class ZeroView: def __init__(self, old_shape:Tuple[int, ...], arg): diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index ef30186c3e..4a0a14e817 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -12,7 +12,7 @@ class Node: max: int def render(self, ops=None, ctx=None) -> str: if ops is None: ops = render_python - assert isinstance(self, NumNode) or self.min != self.max + assert isinstance(self, (Variable, NumNode)) or self.min != self.max return ops[type(self)](self, ops, ctx) @functools.cached_property def key(self) -> str: return self.render(ctx="DEBUG") @@ -38,6 +38,12 @@ class Node: assert b != 0 if b < 0: return (self//-b)*-1 if b == 1: return self + + # this is a hack to make div work with boolean nodes. TODO: make generic + if isinstance(self, GeNode): return (self.a//b) >= (self.b//b) + if isinstance(self, LtNode): return (self.a//b) < (self.b//b) + if isinstance(self, AndNode): return Variable.ands([x//b for x in self.nodes]) + if isinstance(self, ModNode) and self.b % b == 0: return (self.a//b) % (self.b//b) # put the div inside mod if isinstance(self, DivNode): return self.a//(self.b*b) # two divs is one div if isinstance(self, MulNode) and self.b % b == 0: return self.a*(self.b//b) @@ -181,7 +187,7 @@ def create_rednode(typ:Type[RedNode], nodes:List[Node]): return create_node(ret) render_python: Dict[Type, Callable] = { - Variable: lambda self,ops,ctx: f"{self.expr}<{self.min},{self.max}>" if ctx == "DEBUG" else f"{self.expr}", + Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}]" if ctx == "DEBUG" else f"{self.expr}", NumNode: lambda self,ops,ctx: f"{self.b}", MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{self.b})", DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",