mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 14:43:57 -05:00
now DEFINE_GLOBAL uop.arg[1] is always the same as uop.dtype, we can remove the one in arg and just use uop.dtype
132 lines
7.4 KiB
Python
132 lines
7.4 KiB
Python
from typing import Dict, List, Final, Callable, DefaultDict
|
|
from collections import defaultdict
|
|
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Op
|
|
from tinygrad.helpers import DType, PtrDType, dtypes, ImageDType, DEBUG, getenv
|
|
from tinygrad.codegen.linearizer import UOp, UOps
|
|
from triton.compiler import compile as triton_compile
|
|
import linecache
|
|
import math
|
|
import re
|
|
|
|
triton_dtypes = {dtypes.double: "tl.float64", dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.bool: "tl.int1", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64", dtypes.uint32: "tl.uint32", dtypes.uint64: "tl.uint64", dtypes.int16: "tl.int16", dtypes.uint16: "tl.uint16"}
|
|
signature_dtypes = {dtypes.double: "fp64",dtypes.float32: "fp32", dtypes.float16: "fp16", dtypes.bool: "i8", dtypes.int8: "i1", dtypes.uint8: "u8", dtypes.int32: "i32", dtypes.int64: "i64", dtypes.uint32: "u32", dtypes.uint64: "u64", dtypes.int16: "i16", dtypes.uint16: "u16"}
|
|
|
|
def next_power_of_2(x):
|
|
return 1 << (x - 1).bit_length()
|
|
|
|
def render_valid(valid):
|
|
return '(' * (len(valid) -1) + ') and '.join(valid) if len(valid) else 'True'
|
|
|
|
#NOTE Triton requires matching dimensions for load/store, disable this and see TestOps::test_output_padded_conv_transpose2d fail to compile
|
|
def fill_dims_for_idx(idx, dims):
|
|
return "(" + idx + "+ (" + (f"0*({'+'.join(d for d in dims)})))") if len(dims) else idx
|
|
|
|
def get_max(var):
|
|
if isinstance(var, int): return var
|
|
return re.sub(r'\[(.*?)\]', '', str(var))[1:-1]
|
|
|
|
#NOTE can be removed after https://github.com/gpuocelot/gpuocelot/issues/8 gets resolved
|
|
def remove_single_scalar_curly_braces(ptx_code):
|
|
return '\n'.join([re.sub(r'\{\s*(%\w+)\s*\}', r'\1', line) for line in ptx_code.split('\n')])
|
|
|
|
def render_const(args,dtype:DType):
|
|
return (('-' if args<0 else '') + 'tl.where(1,float("inf"),0)') if math.isinf(args) else ('tl.where(1,float("nan"),0)' if math.isnan(args) else f"{int(args)}" if dtypes.is_int(dtype) else str(args))
|
|
|
|
def render_cast(x:str, dtype:DType, bitcast=False):
|
|
return f"{x}.to({triton_dtypes[dtype]}, bitcast={bitcast})"
|
|
|
|
def define_scalar(local_size, dtype, args):
|
|
if len(local_size) > 0: return f"tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},),{render_const(args,dtype)}, dtype={triton_dtypes[dtype]})"
|
|
return render_const(args,dtype)
|
|
|
|
def uops_to_triton(function_name:str, uops:List[UOp]):
|
|
local_size: List[int] = []
|
|
depth = 1
|
|
signatures, dims, bufs, kernel, valid = [], [], [], [], [] #type: ignore
|
|
|
|
c: DefaultDict[str, int] = defaultdict(int)
|
|
r: Dict[UOp, str] = {}
|
|
def ssa(u, prefix="t"):
|
|
nonlocal c, r
|
|
c[prefix] += 1
|
|
r[u]=f"{prefix}{c[prefix]-1}"
|
|
return r[u]
|
|
|
|
child_count: DefaultDict[UOp, int] = defaultdict(int)
|
|
for ru in uops:
|
|
for v in ru.vin:
|
|
child_count[v] += 1
|
|
|
|
def kk(s): kernel.append(" "*depth+s)
|
|
code_for_op: Final[Dict[Op, Callable]] = {
|
|
UnaryOps.EXP2: lambda x,dtype,: f"tl.math.exp2({x})",
|
|
UnaryOps.LOG2: lambda x,dtype,: f"tl.math.log2({x})",
|
|
UnaryOps.SIN: lambda x,dtype: f"tl.sin({x})",
|
|
UnaryOps.SQRT: lambda x,dtype: f"tl.sqrt({x})",
|
|
UnaryOps.NEG: lambda x,dtype: f"-{x}" if dtype != dtypes.bool else f"tl.where({x}, 0, 1)",
|
|
BinaryOps.ADD: lambda x,y,dtype: f"({x}+{y})", BinaryOps.SUB: lambda x,y,: f"({x}-{y})",
|
|
BinaryOps.MUL: lambda x,y,dtype: f"({x}*{y})", BinaryOps.DIV: lambda x,y,: f"({x}/{y})" if y != '0.0' else f"{x}*tl.where({x}==0.0, float('nan'), float('inf'))",
|
|
BinaryOps.MAX: lambda x,y,dtype: f"tl.maximum({x},{y})",
|
|
BinaryOps.CMPLT: lambda x,y,dtype: f"({x}<{y})",
|
|
BinaryOps.MOD: lambda x,y,dtype: f"tl.abs({x})%tl.abs({y})*tl.where({x}<0,-1,1)",
|
|
TernaryOps.MULACC: lambda x,y,z,dtype: f"(({x}*{y})+{z})",
|
|
TernaryOps.WHERE: lambda x,y,z,dtype: f"tl.where({x},{y},{z})",
|
|
}
|
|
def int_div(x,y): return f"({x}//{y})" if y != '0' else f"{x}*tl.where({x}==0, float('nan'), float('inf'))"
|
|
for u in uops:
|
|
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
|
if uop == UOps.LOOP:
|
|
kk(f"for {ssa(u, 'ridx')} in range({vin[0].arg}, {r[vin[1]]}):")
|
|
depth += 1
|
|
elif uop == UOps.END: depth -= 1
|
|
elif uop == UOps.ALU:
|
|
assert dtype is not None
|
|
val = code_for_op[args](*[r[x] for x in vin])
|
|
if child_count[u] <=1 or dtypes.is_int(dtype): r[u] = int_div(*[r[x] for x in vin]) if args == BinaryOps.DIV and dtypes.is_int(dtype) else val
|
|
else: kk(f"{ssa(u, 'alu')} = ({val})")
|
|
elif uop == UOps.LOAD:
|
|
assert dtype is not None
|
|
if len(vin) == 2: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.load({r[vin[0]]} + { fill_dims_for_idx(r[vin[1]], dims)}, mask = {render_valid(valid)})', dtype)}")
|
|
else: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.where({r[vin[2]]}, tl.load({r[vin[0]]}+{fill_dims_for_idx(r[vin[1]],dims)} , mask={render_valid(valid+[r[vin[2]]])}), 0.0)', dtype)}")
|
|
elif uop == UOps.DEFINE_ACC: kk(f"{ssa(u, 'acc')} = {define_scalar(local_size, dtype, args).replace('//', '/')}")
|
|
elif uop == UOps.CONST: r[u] = define_scalar([], dtype, args)
|
|
elif uop == UOps.PHI:
|
|
kk(f"{r[vin[0]]} = {r[vin[1]].replace('//', '/')}")
|
|
r[u] = r[vin[0]]
|
|
elif uop == UOps.STORE:
|
|
assert not isinstance(dtype, ImageDType), "unimplemented: image store"
|
|
kk(f"{'if '+r[vin[3]]+': ' if len(vin)>3 else ''}tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ")
|
|
elif uop == UOps.DEFINE_GLOBAL:
|
|
bufs.append(args)
|
|
signatures.append("*" if isinstance(dtype, PtrDType) else "" + signature_dtypes[dtype])
|
|
r[u] = args
|
|
elif uop == UOps.SPECIAL:
|
|
dims.append(args[1])
|
|
valid.append(f"{args[1]}<{get_max(args[2])}")
|
|
if args[1].startswith("g"): kk(f"{args[1]} = tl.program_id({args[0]}) # {args[2]}")
|
|
elif args[1].startswith("l"):
|
|
kk(f"{args[1]} = tl.arange({0}, {next_power_of_2(args[2])})")
|
|
local_size.append(args[2])
|
|
r[u] = args[1]
|
|
elif uop == UOps.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype, isinstance(args, tuple) and args[1])
|
|
else: raise NotImplementedError(f"unimplemented: {uop}")
|
|
|
|
prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(bufs)+"):\n"
|
|
for i, line in enumerate(list(filter(lambda line: "tl.arange" in line, kernel))): kernel[kernel.index(line)] += f"[{', '.join([':' if i == j else 'None' for j in range(len(local_size))])}]"
|
|
prg += "\n".join(kernel)
|
|
|
|
acc_local_size = 1
|
|
for x in local_size: acc_local_size *= next_power_of_2(x)
|
|
local_size = [acc_local_size] + [1] * (len(local_size) - 1)
|
|
|
|
if DEBUG >= 4: print(prg)
|
|
getlines = linecache.getlines
|
|
linecache.getlines = lambda filename, module_globals=None: prg.splitlines(keepends=True) if "<triton>" == filename else getlines(filename, module_globals)
|
|
exec(compile(prg, "<triton>", "exec"), globals()) # pylint: disable=W0122\
|
|
compiled = triton_compile(globals()[function_name], signature=",".join(signatures), device_type="cuda", debug=False, cc=(35 if getenv("CUDACPU", 0) else None))
|
|
prg = remove_single_scalar_curly_braces(compiled.asm["ptx"].split(".file")[0].split(".visible .func")[0])
|
|
max_local_size = [int(x) for x in prg.split(".maxntid ")[1].split("\n")[0].split(", ")]
|
|
for i in range(len(local_size)): local_size[i] = min(local_size[i], max_local_size[i])
|
|
|
|
return prg, {"shared":compiled.metadata["shared"], "local_size":local_size + [1]*(3-len(local_size))}
|