From 41ae7cb508380fb186fe844c384345ca5bf14531 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20O=C5=BC=C3=B3g?= <58388001+SzymonOzog@users.noreply.github.com> Date: Tue, 15 Aug 2023 19:47:48 +0200 Subject: [PATCH] Triton Buffer --- tinygrad/runtime/ops_triton.py | 84 ++++++++++++++++++++++++++++++++-- 1 file changed, 81 insertions(+), 3 deletions(-) diff --git a/tinygrad/runtime/ops_triton.py b/tinygrad/runtime/ops_triton.py index 920c4eb191..46566c15e6 100644 --- a/tinygrad/runtime/ops_triton.py +++ b/tinygrad/runtime/ops_triton.py @@ -1,5 +1,6 @@ from __future__ import annotations import hashlib +import math from weakref import WeakValueDictionary from torch import float32 import numpy as np @@ -8,12 +9,16 @@ import pycuda.driver as cuda # type: ignore import triton # type: ignore # noqa: F401 import triton.language as tl # type: ignore # noqa: F401 +from triton.compiler import compile as triton_compile -from typing import Union, Tuple, Optional, Dict -from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, GlobalCounters +from typing import Any, Union, Tuple, Optional, Dict, List, Final, Callable +from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, GlobalCounters, Compiled from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.helpers import prod, DEBUG +from tinygrad.helpers import prod, DEBUG, dtypes, ImageDType from tinygrad.runtime.ops_gpu import CLBuffer +from tinygrad.runtime.lib import RawBuffer +from tinygrad.codegen.linearizer import LinearizerOptions, UOp, UOps, LocalBuffer +from tinygrad.shape.symbolic import NumNode class TritonProgram: @@ -35,3 +40,76 @@ class TritonDeviceAllocation(CLBuffer): self.dtype = float32 def data_ptr(self): return int(self._cl) + +def uops_to_triton(function_name:str, uops:List[UOp]): + kernel = [] + global_size: List[int] = [] + local_size: List[int] = [] + depth = 1 + bufs = [] + def kk(s): kernel.append(" "*depth+s) + + full_local_shape: Tuple[Any, ...] = () + acc_local_shape = 1 + + gid = [f"tl.program_id({i})" for i in range(3)] + code_for_op: Final[Dict[Op, Callable]] = { + UnaryOps.EXP2: lambda x: f"tl.math.exp2({x})", + UnaryOps.LOG2: lambda x: f"tl.math.log2({x})", # TODO: is fast_log2f ok? + UnaryOps.SIN: lambda x: f"tl.sin({x})", + UnaryOps.SQRT: lambda x: f"tl.sqrt({x})", + BinaryOps.ADD: lambda x,y: f"({x}+{y})", BinaryOps.SUB: lambda x,y: f"({x}-{y})", + BinaryOps.MUL: lambda x,y: f"({x}*{y})", BinaryOps.DIV: lambda x,y: f"({x}/{y})", + BinaryOps.MAX: lambda x,y: f"tl.maximum({x},{y})", # axis? + BinaryOps.CMPLT: lambda x,y: f"({x}<{y})", + ReduceOps.SUM: lambda x: f"tl.expand_dims(tl.sum({x}, axis={len(full_local_shape)-len(self.group_for_reduce)}), axis={len(full_local_shape)-len(self.group_for_reduce)})" if len(self.group_for_reduce) != len(full_local_shape) else f"tl.sum({x}, axis={len(full_local_shape)-len(self.group_for_reduce)})", + } + + for uop,newvar,vin,args in uops: + if uop == UOps.LOOP: + for i,var in enumerate(args[0]): + if isinstance(var, NumNode): continue # python doesnt have block scope + else: + if args[1] == "global": + global_size.append(var.max+1) + kk(f"{var.expr} = {gid[i]} # {var.max+1}") + elif args[1] == "local": + full_local_shape = tuple([var.max+1 for var in args[0]]) + assert var.min == 0, "local loop must start at 0" + kk(f"{var.expr} = tl.arange({0}, {var.max+1})[{', '.join([':' if i == j else 'None' for j in range(len(args[0]))])}]") + acc_local_shape *= var.max+1 + else: + kk(f"for {var.expr} in range({var.min}, {var.max+1}):") + depth += 1 + elif uop == UOps.ENDLOOP: + if args[1] not in ["global", "local"] and len(args[0]): + depth -= 1 + kk(f"# end {args[1]}") + elif uop == UOps.ALU: + assert newvar is not None + kk(f"{newvar.render()} = {code_for_op[args](*[x.render() for x in vin])}") + elif uop == UOps.LOAD: + assert newvar is not None + val = f"{args.name} + {args.idx.render()}" # defaults to render_python + triton_dtype = {dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64"}[newvar.dtype] + if args.valid.min == 1: kk(f"{newvar.render()} = tl.load({val}).to({triton_dtype})") + else: kk(f"{newvar.render()} = tl.where({args.valid.render()}, tl.load({val}, mask={args.valid.render()}), 0.0).to({triton_dtype})") + elif uop == UOps.STORE: + assert vin[0].dtype == dtypes.float, "unimplemented: float4 store" + assert not isinstance(args.memory_dtype, ImageDType), "unimplemented: image store" + assert args.valid.min == 1, "store must be valid" + kk(f"tl.store({args.name} + {args.idx.render()}, {vin[0].render()})") + elif uop == UOps.DEFINE_GLOBAL: + bufs.append(args) + elif uop == UOps.CAST: raise NotImplementedError("unimplemented: cast") + else: + raise NotImplementedError(f"unimplemented: {uop}") + + prg = "import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n" + # TODO @triton.jit \n + prg += "def fxn("+','.join(f"data{i}" for i in range(len(bufs)))+"):\n" + prg += '\n'.join(kernel) + if DEBUG >= 4: print(prg) + return prg, global_size, local_size + +TritonBuffer = Compiled(TritonDeviceAllocation, LinearizerOptions(), uops_to_triton, TritonProgram) \ No newline at end of file