add renderer class (#4524)

* add renderer class

* tests pass

* fix pylint

* fix tensor cores
This commit is contained in:
George Hotz
2024-05-10 21:40:02 -07:00
committed by GitHub
parent b00b6b16f0
commit 347a3acb37
31 changed files with 536 additions and 527 deletions

View File

@@ -0,0 +1,49 @@
from typing import Optional, List, Tuple, Dict
import functools
from dataclasses import dataclass
from tinygrad.helpers import to_function_name
from tinygrad.codegen.uops import UOpGraph
from tinygrad.shape.symbolic import sym_infer, sint, Variable
@dataclass(frozen=True)
class Program:
name:str
src:str
dname:str
global_size:Optional[List[int]]=None
local_size:Optional[List[int]]=None
uops:Optional[UOpGraph]=None
op_estimate:sint=0
mem_estimate:sint=0
@functools.cached_property
def vars(self) -> List[Variable]: return [] if self.uops is None else self.uops.vars()
@functools.cached_property
def globals(self) -> List[Tuple[int, bool]]: return [] if self.uops is None else self.uops.globals()
@functools.cached_property
def outcount(self) -> int: return sum(x[1] for x in self.globals)
@functools.cached_property
def function_name(self) -> str: return to_function_name(self.name)
def launch_dims(self, var_vals:Dict[Variable, int]):
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
return global_size, local_size
class Renderer:
device: str = ""
suffix: str = ""
# TODO: make this generic with a list of supported types
supports_float4: bool = True
has_local: bool = True
has_shared: bool = True
has_tensor_cores: bool = False
# NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
global_max: Optional[List[int]] = None
local_max: Optional[List[int]] = None
shared_max: int = 32768
def render(self, name:str, uops:UOpGraph) -> str: raise NotImplementedError("needs a renderer")

View File

@@ -1,15 +1,16 @@
from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple, Optional, cast
import functools, struct, copy
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
import struct, copy
from collections import defaultdict
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType, INVERSE_DTYPES_DICT
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
from tinygrad.codegen.uops import UOpGraph, PatternMatcher
from tinygrad.renderer import Renderer
def render_val(x, dtype):
if dtypes.is_float(dtype):
if dtype == dtypes.double: return "0d%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
elif dtype == dtypes.half: return "0x%02X%02X" % tuple(struct.pack("e",x)[::-1])
if dtype == dtypes.half: return "0x%02X%02X" % tuple(struct.pack("e",x)[::-1])
return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
@@ -33,191 +34,16 @@ def ptr_ar(root, uops):
fptr = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], bptr), arg=BinaryOps.ADD, insert_before=uops.uops.index(root))
root.vin = (fptr, zero) + root.vin[2:]
class AssemblyLanguage(NamedTuple):
kernel_prefix: str = ""
barrier: str = ""
load_global: bool = False
label_prefix: str = ""
gid: List[str] = []
gdim: List[str] = []
lid: List[str] = []
const_requires_mov: List[DType] = [] # list of dtypes for which creating a const requires a move
asm_for_op: Dict[Op, Callable[...,str]] = {}
types: Dict[DType, str] = INVERSE_DTYPES_DICT
supports_half: List[Op] = []
class PTXRenderer(Renderer):
device = "CUDA"
suffix = "PTX"
global_max=[65535, 65535, 2147483647]
local_max=[64, 1024, 1024]
shared_max=49152
has_tensor_cores = False
def __init__(self, arch:str): self.has_tensor_cores=int(arch[3:]) >= 80
def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]: raise NotImplementedError()
def render_local(self, dest, name, size, dtype) -> List[str]: raise NotImplementedError()
def render_loop(self, idx, start, label, acc=None) -> List[str]: raise NotImplementedError()
def render_bra(self, b1, pred=None, b2=None) -> List[str]: raise NotImplementedError()
def render_gep(self, loc, base, offset, dtype, gate=None) -> List[str]: raise NotImplementedError()
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]: raise NotImplementedError()
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]: raise NotImplementedError()
def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]: raise NotImplementedError()
def render_kernel(self, kernel, function_name, bufs, regs) -> str: raise NotImplementedError()
def mem_type(self, dtype) -> str: raise NotImplementedError()
def uops_to_asm(lang:AssemblyLanguage, function_name:str, _uops:UOpGraph) -> str:
# editing the uops breaks beam search
uops = copy.deepcopy(_uops)
kernel:List[str] = []
bufs = []
matcher = PatternMatcher([
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"dtype": dtypes.bool},{})},
lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})},
lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.ADD, "dtype": set([dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]),
"vin": [{"__name__": "non_muls"}, {"__name__": "muls", "uop": UOps.ALU, "arg": BinaryOps.MUL}]},
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)),
*[({"__name__": "x", "uop": UOps.ALU, "dtype": dtypes.half, "arg": op},
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
for op in lang.asm_for_op.keys() if op not in lang.supports_half],
({"__name__": "root", "uop": UOps.LOAD, "dtype": dtypes.bool,
"vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})},
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({},{})},
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})},
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})},
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{},{"__name__": "g"})},
lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,), root.arg),))),
])
# here we do a pretransform on UOps to fix some shortcomings of PTX
# all uops must be a register
matcher.rewrite_graph(uops)
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
uops.remove_childless(set(x for x in uops if x.uop in {UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
uops.optimize_loops()
def kk(*s: str): kernel.append("\n".join(s))
c: DefaultDict[str, int] = defaultdict(int)
r: Dict[UOp, Union[List[str], str]] = {}
def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
nonlocal c, r
prefix += f"_{dtype if dtype is not None else lang.types[cast(DType, cast(UOp, u).dtype)]}_"
c[prefix] += 1
if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
return f"%{prefix}{c[prefix]-1}"
c_label: DefaultDict[str, int] = defaultdict(int)
r_label: Dict[UOp, str] = {}
def ssa_label(prefix:str, u:UOp):
nonlocal c_label, r_label
c_label[prefix] += 1
r_label[u] = f"{lang.label_prefix}{prefix}_{c_label[prefix]-1}"
return r_label[u]
def const(x:ConstType, dtype:DType, mov=False):
if mov or dtype in lang.const_requires_mov:
kk(*lang.render_const(x, dtype, mov=(out:=ssa('const', dtype=lang.types[dtype]))))
return out
return lang.render_const(x, dtype)
def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
if atype == dtype:
if u: r[u] = a
return a
kk(*lang.render_cast((ret:=ssa('cast', u, lang.types[dtype])), a, dtype, atype, bitcast))
return ret
for u in uops:
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
if uop is UOps.IF:
assert vin[0].dtype is not None
kk(*lang.render_bra(lb:=ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
elif uop is UOps.BARRIER and lang.barrier: kk(lang.barrier)
elif uop is UOps.ENDLOOP:
kk(lang.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, lang.types[dtypes.int]),
lang.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, lang.types[dtypes.int]))
kk(*lang.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
elif uop is UOps.ENDIF:
kk(f"{r_label[vin[0]]}:")
elif uop is UOps.STORE:
assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None
if vin[2].dtype.count > 1:
kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \
f"st{u.arg}.v{vin[2].dtype.count}.{lang.mem_type(vin[2].dtype.scalar())} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};")
else:
kk(*lang.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=u.arg, offset=vin[1].arg))
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.LOOP: kk(*lang.render_loop(ssa('ridx', u), r[vin[0]], ssa_label('loop', u)))
elif uop is UOps.ALU:
assert vin[0].dtype is not None
if args is BinaryOps.CMPLT or args is BinaryOps.CMPEQ:
# pass in the other dtype here
kk(lang.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], vin[0].dtype, lang.types[vin[0].dtype]))
else:
kk(lang.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], dtype, lang.types[dtype]))
elif uop is UOps.DEFINE_ACC:
if dtype.count > 1:
r[u] = [ssa('acc', dtype=lang.types[dtype.scalar()]) for _ in range(dtype.count)]
for uu in r[u]: kk(f"mov.b{lang.types[dtype.scalar()][1:]} {uu}, {const(args, dtype.scalar())};")
else: kk(f"mov.b{lang.types[dtype][1:]} {ssa('acc', u)}, {const(args, dtype)};")
elif uop is UOps.SPECIAL:
assert args[1][0] != "i", "idx not supported"
kk(f"mov.u32 %{args[1]}, {(lang.gid if args[1][0] == 'g' else lang.lid)[args[0]]};")
r[u] = "%" + args[1]
kernel = [f".reg .u32 %{args[1]};"] + kernel
elif uop is UOps.CONST:
if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
else: r[u] = const(args, dtype, mov=True)
elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg]
elif uop is UOps.LOAD:
assert vin[1].dtype is not None
if dtype.count > 1:
r[u] = [ssa('val', dtype=lang.types[dtype.scalar()]) for _ in range(dtype.count)]
if(len(vin)>3):
for v in r[u]: kk(f"mov.{lang.mem_type(dtype.scalar())} {v}, {render_val(0, dtype.scalar())};")
kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
+ f" ld{u.arg}.v{dtype.count}.{lang.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
else:
kk(*lang.render_load(r[vin[0]], ssa('val', u), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
alt=r[vin[3]] if len(vin) > 3 else None, ss=u.arg, offset=vin[1].arg))
elif uop is UOps.PHI:
kk(f"mov.b{lang.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
r[u] = r[vin[0]]
elif uop in {UOps.CAST, UOps.BITCAST}:
assert vin[0].dtype is not None
if dtype.count>1: r[u] = [r[x] for x in vin] # type: ignore
else: _cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
elif uop is UOps.DEFINE_LOCAL:
# TODO: we should sum these, and fetch 0xC000 from somewhere
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
kk(*lang.render_local(ssa('local', u, lang.types[dtypes.ulong]), args[0], args[1], dtype))
elif uop is UOps.DEFINE_VAR:
bufs.append((args.expr, dtype))
r[u] = f"%{args.expr}"
if lang.load_global: kk(*lang.render_load(args.expr, ssa('dat', u, lang.types[dtype]), dtype, ss=".param"))
elif uop is UOps.DEFINE_GLOBAL:
bufs.append((nm:=f"data{args[0]}", dtype))
r[u] = f"%{nm}"
if lang.load_global:
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
kk(*lang.render_load(nm, ssa('dat', u, lang.types[dt]), dt, ss=".param"))
elif uop is UOps.WMMA:
wmma = []
for vv in vin[:2]:
for i in range(0, len(r[vv]), 2):
wmma.append(ssa("wmma", dtype="b32"))
kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
r[u] = r[vin[2]]
kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\
{{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[u])}}};')
else: raise NotImplementedError(f"no code for {uop}")
return lang.render_kernel(kernel, function_name, bufs, c.items())
class PTXLanguage(AssemblyLanguage):
# language options
kernel_prefix = """.version VERSION
.target TARGET
.address_size 64
@@ -229,7 +55,7 @@ class PTXLanguage(AssemblyLanguage):
gid = [f'%ctaid.{chr(120+i)}' for i in range(3)]
gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)]
lid = [f'%tid.{chr(120+i)}' for i in range(3)]
asm_for_op = {
asm_for_op: Dict[Op, Callable] = {
UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"neg.{name} {d}, {a};",
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
@@ -245,13 +71,14 @@ class PTXLanguage(AssemblyLanguage):
TernaryOps.WHERE: lambda d,a,b,c,dt,name:
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
}
supports_half = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
supports_half: List[Op] = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT,
TernaryOps.WHERE]
# HACK: Use s16 and u16 for int8 and uint8 buffers. This can be wrong in cast.
types = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", dtypes.bool: "pred" }
types: Dict[DType, str] = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", dtypes.bool: "pred" }
const_requires_mov = [dtypes.half, dtypes.bool]
const_requires_mov: List[DType] = [dtypes.half, dtypes.bool]
def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]:
val = render_val(x, dtype)
@@ -270,7 +97,7 @@ class PTXLanguage(AssemblyLanguage):
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]:
assert dtype is not dtypes.bool
if gate: return [f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"]
else: return [f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];"]
return [f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];"]
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]:
return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}+{offset}], {val};"]
@@ -291,4 +118,160 @@ class PTXLanguage(AssemblyLanguage):
'\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) +
"\n}")
PTXRenderer = functools.partial(uops_to_asm, PTXLanguage())
def render(self, name:str, _uops:UOpGraph) -> str:
# editing the uops breaks beam search
uops = copy.deepcopy(_uops)
kernel:List[str] = []
bufs = []
matcher = PatternMatcher([
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"dtype": dtypes.bool},{})},
lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})},
lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.ADD, "dtype": set([dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]),
"vin": [{"__name__": "non_muls"}, {"__name__": "muls", "uop": UOps.ALU, "arg": BinaryOps.MUL}]},
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)),
*[({"__name__": "x", "uop": UOps.ALU, "dtype": dtypes.half, "arg": op},
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
for op in self.asm_for_op.keys() if op not in self.supports_half],
({"__name__": "root", "uop": UOps.LOAD, "dtype": dtypes.bool,
"vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})},
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({},{})},
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})},
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})},
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{},{"__name__": "g"})},
lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,), root.arg),))),
])
# here we do a pretransform on UOps to fix some shortcomings of PTX
# all uops must be a register
matcher.rewrite_graph(uops)
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
uops.remove_childless(set(x for x in uops if x.uop in {UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
uops.optimize_loops()
def kk(*s: str): kernel.append("\n".join(s))
c: DefaultDict[str, int] = defaultdict(int)
r: Dict[UOp, Union[List[str], str]] = {}
def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
nonlocal c, r
prefix += f"_{dtype if dtype is not None else self.types[cast(DType, cast(UOp, u).dtype)]}_"
c[prefix] += 1
if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
return f"%{prefix}{c[prefix]-1}"
c_label: DefaultDict[str, int] = defaultdict(int)
r_label: Dict[UOp, str] = {}
def ssa_label(prefix:str, u:UOp):
nonlocal c_label, r_label
c_label[prefix] += 1
r_label[u] = f"{self.label_prefix}{prefix}_{c_label[prefix]-1}"
return r_label[u]
def const(x:ConstType, dtype:DType, mov=False):
if mov or dtype in self.const_requires_mov:
kk(*self.render_const(x, dtype, mov=(out:=ssa('const', dtype=self.types[dtype]))))
return out
return self.render_const(x, dtype)
def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
if atype == dtype:
if u: r[u] = a
return a
kk(*self.render_cast((ret:=ssa('cast', u, self.types[dtype])), a, dtype, atype, bitcast))
return ret
for u in uops:
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
if uop is UOps.IF:
assert vin[0].dtype is not None
kk(*self.render_bra(lb:=ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
elif uop is UOps.ENDLOOP:
kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]),
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, self.types[dtypes.int]))
kk(*self.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
elif uop is UOps.ENDIF:
kk(f"{r_label[vin[0]]}:")
elif uop is UOps.STORE:
assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None
if vin[2].dtype.count > 1:
kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \
f"st{u.arg}.v{vin[2].dtype.count}.{self.mem_type(vin[2].dtype.scalar())} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};")
else:
kk(*self.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=u.arg, offset=vin[1].arg))
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.LOOP: kk(*self.render_loop(ssa('ridx', u), r[vin[0]], ssa_label('loop', u)))
elif uop is UOps.ALU:
assert vin[0].dtype is not None
if args is BinaryOps.CMPLT or args is BinaryOps.CMPEQ:
# pass in the other dtype here
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], vin[0].dtype, self.types[vin[0].dtype]))
else:
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], dtype, self.types[dtype]))
elif uop is UOps.DEFINE_ACC:
if dtype.count > 1:
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(args, dtype.scalar())};")
else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(args, dtype)};")
elif uop is UOps.SPECIAL:
assert args[1][0] != "i", "idx not supported"
kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};")
r[u] = "%" + args[1]
kernel = [f".reg .u32 %{args[1]};"] + kernel
elif uop is UOps.CONST:
if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
else: r[u] = const(args, dtype, mov=True)
elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg]
elif uop is UOps.LOAD:
assert vin[1].dtype is not None
if dtype.count > 1:
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
if(len(vin)>3):
for v in r[u]: kk(f"mov.{self.mem_type(dtype.scalar())} {v}, {render_val(0, dtype.scalar())};")
kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
+ f" ld{u.arg}.v{dtype.count}.{self.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
else:
kk(*self.render_load(r[vin[0]], ssa('val', u), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
alt=r[vin[3]] if len(vin) > 3 else None, ss=u.arg, offset=vin[1].arg))
elif uop is UOps.PHI:
kk(f"mov.b{self.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
r[u] = r[vin[0]]
elif uop in {UOps.CAST, UOps.BITCAST}:
assert vin[0].dtype is not None
if dtype.count>1: r[u] = [r[x] for x in vin] # type: ignore
else: _cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
elif uop is UOps.DEFINE_LOCAL:
# TODO: we should sum these, and fetch 0xC000 from somewhere
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype))
elif uop is UOps.DEFINE_VAR:
bufs.append((args.expr, dtype))
r[u] = f"%{args.expr}"
if self.load_global: kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
elif uop is UOps.DEFINE_GLOBAL:
bufs.append((nm:=f"data{args[0]}", dtype))
r[u] = f"%{nm}"
if self.load_global:
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
elif uop is UOps.WMMA:
wmma = []
for vv in vin[:2]:
for i in range(0, len(r[vv]), 2):
wmma.append(ssa("wmma", dtype="b32"))
kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
r[u] = r[vin[2]]
kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\
{{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[u])}}};')
else: raise NotImplementedError(f"no code for {uop}")
return self.render_kernel(kernel, name, bufs, c.items())

View File

@@ -1,13 +1,14 @@
from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict, cast, Literal, Callable
import math
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, cast, Literal, Callable
import os, math
from collections import defaultdict, Counter
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.helpers import strip_parens, getenv, prod
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
from tinygrad.codegen.uops import UOpGraph
from tinygrad.renderer import Renderer
class CStyleLanguage(NamedTuple):
class CStyleLanguage(Renderer):
kernel_prefix: str = ""
buffer_prefix: str = ""
buffer_suffix: str = ""
@@ -17,8 +18,6 @@ class CStyleLanguage(NamedTuple):
arg_int_prefix: str = "const int"
barrier: str = ""
code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
global_max: List[int] = []
local_max: List[int] = []
extra_args: List[str] = []
float4: Optional[str] = None
uses_vload: bool = False
@@ -88,100 +87,107 @@ class CStyleLanguage(NamedTuple):
def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{self.render_dtype(dtype)} {name}[{size}];"
def render_dtype(self, var_dtype:DType) -> str: return self.type_map.get(var_dtype, var_dtype.name)
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str:
kernel = []
bufs: List[Tuple[str, Tuple[DType, bool]]] = []
depth = 1
def kk(s): kernel.append(" "*depth+s)
def render(self, name:str, uops:UOpGraph) -> str:
kernel = []
bufs: List[Tuple[str, Tuple[DType, bool]]] = []
depth = 1
def kk(s): kernel.append(" "*depth+s)
c: DefaultDict[str, int] = defaultdict(int)
r: Dict[UOp, str] = {}
c: DefaultDict[str, int] = defaultdict(int)
r: Dict[UOp, str] = {}
def ssa(prefix:str, u:Optional[UOp]=None):
nonlocal c, r
ret = f"{prefix}{c[prefix]}"
if u is not None: r[u] = ret
c[prefix] += 1
return ret
def ssa(prefix:str, u:Optional[UOp]=None):
nonlocal c, r
ret = f"{prefix}{c[prefix]}"
if u is not None: r[u] = ret
c[prefix] += 1
return ret
child_count = Counter(v for ru in uops for v in ru.vin)
child_count = Counter(v for ru in uops for v in ru.vin)
for u in uops:
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
# these four uops don't have output dtypes
if uop is UOps.IF:
kk(f"if ({r[vin[0]]}) {{")
depth += 1
elif uop is UOps.BARRIER: kk(lang.barrier)
elif uop in {UOps.ENDLOOP, UOps.ENDIF}:
depth -= 1
kk("}")
elif uop is UOps.STORE:
assert vin[0].dtype is not None and vin[2].dtype is not None
rendered_store = lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
kk(f"if ({r[vin[3]]}) {{ {rendered_store} }}" if len(vin) > 3 else rendered_store)
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.LOOP:
kk(f"for (int {(expr := ssa('ridx',u))} = {r[vin[0]]}; {expr} < {r[vin[1]]}; {expr}++) {{")
for u in uops:
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
# these four uops don't have output dtypes
if uop is UOps.IF:
kk(f"if ({r[vin[0]]}) {{")
depth += 1
elif uop is UOps.ALU:
# remove parens if ALU types are the same. TODO: can do more here
if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in vin]
else: operands = [r[v] for v in vin]
val = lang.code_for_op[args](*operands, dtype)
assert child_count[u] != 0, f"childless ALU op found {u}"
# TODO: fix index rendering issue. fix clang nested max macro issue
if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
else: kk(f"{lang.render_dtype(dtype)} {ssa('alu',u)} = {val};")
elif uop is UOps.SPECIAL:
kk(f"int {args[1]} = {lang.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
r[u] = args[1]
elif uop is UOps.LOAD:
val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
# NOTE: this relies on the load not happening if it's in the unselected branch
if len(vin) > 3: val = lang.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype)
kk(f"{lang.render_dtype(dtype)} {ssa('val',u)} = {val};")
elif uop is UOps.PHI:
kk(f"{r[vin[0]]} = {r[vin[1]]};")
r[u] = r[vin[0]]
elif uop in {UOps.CAST, UOps.BITCAST}:
if uop is UOps.BITCAST:
assert len(vin) == 1
precast = ssa('precast')
kk(f"{lang.render_dtype(cast(DType, vin[0].dtype))} {precast} = {r[vin[0]]};")
val = lang.render_cast([precast], dtype, bitcast=True)
else:
val = lang.render_cast([r[x] for x in vin], dtype, bitcast=False)
if child_count[u] <= 1: r[u] = val
else: kk(f"{lang.render_dtype(dtype)} {ssa('cast',u)} = {val};")
elif uop is UOps.DEFINE_LOCAL:
kk(lang.render_local(args[0], dtype, args[1]))
r[u] = args[0]
elif uop is UOps.DEFINE_VAR:
bufs.append((args.expr, (dtype,False)))
r[u] = args.expr
elif uop is UOps.DEFINE_GLOBAL:
bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
r[u] = nm
elif uop is UOps.WMMA: kk(f"{lang.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
elif uop is UOps.DEFINE_ACC: kk(f"{lang.render_dtype(dtype)} {ssa('acc',u)} = {lang.render_const(args, dtype)};")
elif uop is UOps.CONST: r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"
elif uop is UOps.GEP:
assert vin[0].dtype is not None
from_ssa = vin[0].uop in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
r[u] = (r[vin[0]] if from_ssa else f"{(r[vin[0]])}") + (f"[{args}]" if vin[0].dtype.count > 4 else f".{'xyzw'[args]}")
else: raise RuntimeError(f"failed to render {uop}")
elif uop is UOps.BARRIER: kk(self.barrier)
elif uop in {UOps.ENDLOOP, UOps.ENDIF}:
depth -= 1
kk("}")
elif uop is UOps.STORE:
assert vin[0].dtype is not None and vin[2].dtype is not None
rendered_store = self.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
kk(f"if ({r[vin[3]]}) {{ {rendered_store} }}" if len(vin) > 3 else rendered_store)
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.LOOP:
kk(f"for (int {(expr := ssa('ridx',u))} = {r[vin[0]]}; {expr} < {r[vin[1]]}; {expr}++) {{")
depth += 1
elif uop is UOps.ALU:
# remove parens if ALU types are the same. TODO: can do more here
if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in vin]
else: operands = [r[v] for v in vin]
val = self.code_for_op[args](*operands, dtype)
assert child_count[u] != 0, f"childless ALU op found {u}"
# TODO: fix index rendering issue. fix clang nested max macro issue
if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
else: kk(f"{self.render_dtype(dtype)} {ssa('alu',u)} = {val};")
elif uop is UOps.SPECIAL:
kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
r[u] = args[1]
elif uop is UOps.LOAD:
val = self.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
# NOTE: this relies on the load not happening if it's in the unselected branch
if len(vin) > 3: val = self.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype)
kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
elif uop is UOps.PHI:
kk(f"{r[vin[0]]} = {r[vin[1]]};")
r[u] = r[vin[0]]
elif uop in {UOps.CAST, UOps.BITCAST}:
if uop is UOps.BITCAST:
assert len(vin) == 1
precast = ssa('precast')
kk(f"{self.render_dtype(cast(DType, vin[0].dtype))} {precast} = {r[vin[0]]};")
val = self.render_cast([precast], dtype, bitcast=True)
else:
val = self.render_cast([r[x] for x in vin], dtype, bitcast=False)
if child_count[u] <= 1: r[u] = val
else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};")
elif uop is UOps.DEFINE_LOCAL:
kk(self.render_local(args[0], dtype, args[1]))
r[u] = args[0]
elif uop is UOps.DEFINE_VAR:
bufs.append((args.expr, (dtype,False)))
r[u] = args.expr
elif uop is UOps.DEFINE_GLOBAL:
bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
r[u] = nm
elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(args, dtype)};")
elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
elif uop is UOps.GEP:
assert vin[0].dtype is not None
from_ssa = vin[0].uop in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
r[u] = (r[vin[0]] if from_ssa else f"{(r[vin[0]])}") + (f"[{args}]" if vin[0].dtype.count > 4 else f".{'xyzw'[args]}")
else: raise RuntimeError(f"failed to render {uop}")
return lang.render_kernel(function_name, kernel, bufs, uops)
return self.render_kernel(name, kernel, bufs, uops)
class ClangLanguage(CStyleLanguage):
class ClangRenderer(CStyleLanguage):
device = "CLANG"
supports_float4 = False
has_local = False
# language options
buffer_suffix = " restrict"
type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"(({a}>{b})?{a}:{b})"}
def ClangRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(ClangLanguage(), name, uops)
class OpenCLLanguage(CStyleLanguage):
class OpenCLRenderer(CStyleLanguage):
device = "GPU"
# language options
kernel_prefix = "__kernel "
buffer_prefix = "__global "
smem_align = "__attribute__ ((aligned (16))) "
@@ -197,9 +203,13 @@ class OpenCLLanguage(CStyleLanguage):
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
if any(uop.dtype == dtypes.half for uop in uops): prefix = ["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"]
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
def OpenCLRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(OpenCLLanguage(), name, uops)
class MetalLanguage(CStyleLanguage):
class MetalRenderer(CStyleLanguage):
device = "METAL"
has_tensor_cores=os.uname().machine == "arm64"
shared_max=32768
# language options
kernel_prefix = "kernel "
buffer_prefix = "device "
smem_prefix = "threadgroup "
@@ -227,7 +237,6 @@ class MetalLanguage(CStyleLanguage):
b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""")
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
def MetalRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(MetalLanguage(), name, uops)
code_for_op_half = {BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype in (dtypes.half, dtypes.bfloat16) else f"max({a},{b})",
UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
@@ -240,7 +249,15 @@ def _make_cuda_dtype(base_type, name, cnt):
vec, elems, header = f"{name}{cnt}", ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
return f"struct {vec} {{ {base_type} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
class CUDALanguage(CStyleLanguage):
class CUDARenderer(CStyleLanguage):
device = "CUDA"
global_max=[65535, 65535, 2147483647]
local_max=[64, 1024, 1024]
shared_max=49152
has_tensor_cores = False
def __init__(self, arch:str): self.has_tensor_cores=int(arch[3:]) >= 80
# language options
kernel_prefix = "extern \"C\" __global__ "
smem_prefix = "__shared__ "
smem_prefix_for_cast = False
@@ -271,7 +288,6 @@ asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}
return c;}}""")
return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
def CUDARenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(CUDALanguage(), name, uops)
code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
@@ -295,7 +311,12 @@ def _make_hip_dtype(base_type, name, cnt):
return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}"
class HIPLanguage(CStyleLanguage):
class HIPRenderer(CStyleLanguage):
device = "HSA"
has_tensor_cores = True
shared_max = 65536
# language options
kernel_prefix = """extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
@@ -357,5 +378,3 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
def HIPRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(HIPLanguage(), name, uops)

View File

@@ -4,6 +4,7 @@ from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.dtype import DType, PtrDType, dtypes
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
from tinygrad.codegen.uops import UOpGraph
from tinygrad.renderer import Renderer
MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
@@ -65,88 +66,95 @@ def cast(bb, val, input_type, output_type, bitcast=False):
def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], args)
def uops_to_llvm_ir(function_name:str, uops:UOpGraph) -> str:
# all llvm stuff goes into a module
module = ir.Module(name=__file__)
class LLVMRenderer(Renderer):
device = "LLVM"
supports_float4=False
has_local=False
has_shared=False
# extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
def render(self, name:str, uops:UOpGraph) -> str:
# all llvm stuff goes into a module
module = ir.Module(name=__file__)
# create llvm function
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values() if dtype is not None]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=function_name) # noqa: E501
for a in func.args:
if a.type.is_pointer: a.add_attribute("noalias")
# extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
# add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
func.attributes.add('"no-nans-fp-math"="true"')
# create llvm function
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values() if dtype is not None]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=name)
for a in func.args:
if a.type.is_pointer: a.add_attribute("noalias")
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
loop_blocks: List = []
reduce_phis: List = []
# TODO: newvar probably shouldn't be optional
lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
# add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
func.attributes.add('"no-nans-fp-math"="true"')
for bufname,dtype in buf_to_dtype.items():
if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
loop_blocks: List = []
reduce_phis: List = []
# TODO: newvar probably shouldn't be optional
lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
for u in uops:
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
if uop is UOps.STORE:
element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype)
def store_op(): bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
if len(vin) > 3:
with bb[-1].if_then(lvars[vin[3]]): store_op()
else: store_op()
elif uop is UOps.ENDLOOP:
loop_entry_bb, phis = loop_blocks.pop()
idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1))
lvars[vin[0]].add_incoming(idx_p1, bb[-1].block)
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), loop_entry_bb, bb[-1].block)
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.LOOP:
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
bb[-2].branch(bb[-1].block)
for bufname,dtype in buf_to_dtype.items():
if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
phis = []
for rp in reduce_phis:
incoming = lvars[rp]
lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
lvars[rp].add_incoming(incoming, bb[-2].block)
phis.append((rp, lvars[rp]))
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
lvars[u].add_incoming(lvars[vin[0]], bb[-2].block)
loop_blocks.append((bb[-1].block, phis))
elif uop is UOps.DEFINE_ACC:
lvars[u] = const(args, dtype)
reduce_phis.append(u)
elif uop is UOps.LOAD:
if len(vin) > 2:
aug_idx = bb[-1].select(lvars[vin[2]], lvars[vin[1]], ir.Constant(ir.IntType(32), 0))
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True))
val = bb[-1].select(lvars[vin[2]], val, lvars[vin[3]])
for u in uops:
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
if uop is UOps.STORE:
element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype)
if len(vin) > 3:
with bb[-1].if_then(lvars[vin[3]]):
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
else:
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
lvars[u] = val
elif uop is UOps.PHI:
lvars[u] = lvars[vin[1]]
# PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
backward = vin[0]
while backward.uop is UOps.PHI: backward = backward.vin[0]
lvars[backward] = lvars[u]
elif uop is UOps.ALU:
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else vin[0].dtype)
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
elif uop is UOps.CONST: lvars[u] = const(args, dtype)
else: raise RuntimeError(f"failed to render {uop}")
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
elif uop is UOps.ENDLOOP:
loop_entry_bb, phis = loop_blocks.pop()
idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1))
lvars[vin[0]].add_incoming(idx_p1, bb[-1].block)
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), loop_entry_bb, bb[-1].block)
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.LOOP:
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
bb[-2].branch(bb[-1].block)
bb[-1].ret_void()
return str(module)
phis = []
for rp in reduce_phis:
incoming = lvars[rp]
lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
lvars[rp].add_incoming(incoming, bb[-2].block)
phis.append((rp, lvars[rp]))
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
lvars[u].add_incoming(lvars[vin[0]], bb[-2].block)
loop_blocks.append((bb[-1].block, phis))
elif uop is UOps.DEFINE_ACC:
lvars[u] = const(args, dtype)
reduce_phis.append(u)
elif uop is UOps.LOAD:
if len(vin) > 2:
aug_idx = bb[-1].select(lvars[vin[2]], lvars[vin[1]], ir.Constant(ir.IntType(32), 0))
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True))
val = bb[-1].select(lvars[vin[2]], val, lvars[vin[3]])
else:
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
lvars[u] = val
elif uop is UOps.PHI:
lvars[u] = lvars[vin[1]]
# PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
backward = vin[0]
while backward.uop is UOps.PHI: backward = backward.vin[0]
lvars[backward] = lvars[u]
elif uop is UOps.ALU:
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else vin[0].dtype)
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
elif uop is UOps.CONST: lvars[u] = const(args, dtype)
else: raise RuntimeError(f"failed to render {uop}")
bb[-1].ret_void()
return str(module)