mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
add renderer class (#4524)
* add renderer class * tests pass * fix pylint * fix tensor cores
This commit is contained in:
49
tinygrad/renderer/__init__.py
Normal file
49
tinygrad/renderer/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from typing import Optional, List, Tuple, Dict
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import to_function_name
|
||||
from tinygrad.codegen.uops import UOpGraph
|
||||
from tinygrad.shape.symbolic import sym_infer, sint, Variable
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Program:
|
||||
name:str
|
||||
src:str
|
||||
dname:str
|
||||
global_size:Optional[List[int]]=None
|
||||
local_size:Optional[List[int]]=None
|
||||
uops:Optional[UOpGraph]=None
|
||||
op_estimate:sint=0
|
||||
mem_estimate:sint=0
|
||||
|
||||
@functools.cached_property
|
||||
def vars(self) -> List[Variable]: return [] if self.uops is None else self.uops.vars()
|
||||
|
||||
@functools.cached_property
|
||||
def globals(self) -> List[Tuple[int, bool]]: return [] if self.uops is None else self.uops.globals()
|
||||
|
||||
@functools.cached_property
|
||||
def outcount(self) -> int: return sum(x[1] for x in self.globals)
|
||||
|
||||
@functools.cached_property
|
||||
def function_name(self) -> str: return to_function_name(self.name)
|
||||
|
||||
def launch_dims(self, var_vals:Dict[Variable, int]):
|
||||
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
|
||||
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
|
||||
return global_size, local_size
|
||||
|
||||
class Renderer:
|
||||
device: str = ""
|
||||
suffix: str = ""
|
||||
# TODO: make this generic with a list of supported types
|
||||
supports_float4: bool = True
|
||||
has_local: bool = True
|
||||
has_shared: bool = True
|
||||
has_tensor_cores: bool = False
|
||||
# NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
|
||||
global_max: Optional[List[int]] = None
|
||||
local_max: Optional[List[int]] = None
|
||||
shared_max: int = 32768
|
||||
|
||||
def render(self, name:str, uops:UOpGraph) -> str: raise NotImplementedError("needs a renderer")
|
||||
@@ -1,15 +1,16 @@
|
||||
from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple, Optional, cast
|
||||
import functools, struct, copy
|
||||
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
|
||||
import struct, copy
|
||||
from collections import defaultdict
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType, INVERSE_DTYPES_DICT
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.codegen.uops import UOpGraph, PatternMatcher
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
def render_val(x, dtype):
|
||||
if dtypes.is_float(dtype):
|
||||
if dtype == dtypes.double: return "0d%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
|
||||
elif dtype == dtypes.half: return "0x%02X%02X" % tuple(struct.pack("e",x)[::-1])
|
||||
if dtype == dtypes.half: return "0x%02X%02X" % tuple(struct.pack("e",x)[::-1])
|
||||
return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
||||
return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
|
||||
|
||||
@@ -33,191 +34,16 @@ def ptr_ar(root, uops):
|
||||
fptr = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], bptr), arg=BinaryOps.ADD, insert_before=uops.uops.index(root))
|
||||
root.vin = (fptr, zero) + root.vin[2:]
|
||||
|
||||
class AssemblyLanguage(NamedTuple):
|
||||
kernel_prefix: str = ""
|
||||
barrier: str = ""
|
||||
load_global: bool = False
|
||||
label_prefix: str = ""
|
||||
gid: List[str] = []
|
||||
gdim: List[str] = []
|
||||
lid: List[str] = []
|
||||
const_requires_mov: List[DType] = [] # list of dtypes for which creating a const requires a move
|
||||
asm_for_op: Dict[Op, Callable[...,str]] = {}
|
||||
types: Dict[DType, str] = INVERSE_DTYPES_DICT
|
||||
supports_half: List[Op] = []
|
||||
class PTXRenderer(Renderer):
|
||||
device = "CUDA"
|
||||
suffix = "PTX"
|
||||
global_max=[65535, 65535, 2147483647]
|
||||
local_max=[64, 1024, 1024]
|
||||
shared_max=49152
|
||||
has_tensor_cores = False
|
||||
def __init__(self, arch:str): self.has_tensor_cores=int(arch[3:]) >= 80
|
||||
|
||||
def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]: raise NotImplementedError()
|
||||
def render_local(self, dest, name, size, dtype) -> List[str]: raise NotImplementedError()
|
||||
|
||||
def render_loop(self, idx, start, label, acc=None) -> List[str]: raise NotImplementedError()
|
||||
def render_bra(self, b1, pred=None, b2=None) -> List[str]: raise NotImplementedError()
|
||||
def render_gep(self, loc, base, offset, dtype, gate=None) -> List[str]: raise NotImplementedError()
|
||||
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]: raise NotImplementedError()
|
||||
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]: raise NotImplementedError()
|
||||
def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]: raise NotImplementedError()
|
||||
|
||||
def render_kernel(self, kernel, function_name, bufs, regs) -> str: raise NotImplementedError()
|
||||
def mem_type(self, dtype) -> str: raise NotImplementedError()
|
||||
|
||||
def uops_to_asm(lang:AssemblyLanguage, function_name:str, _uops:UOpGraph) -> str:
|
||||
# editing the uops breaks beam search
|
||||
uops = copy.deepcopy(_uops)
|
||||
kernel:List[str] = []
|
||||
bufs = []
|
||||
|
||||
matcher = PatternMatcher([
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"dtype": dtypes.bool},{})},
|
||||
lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)),
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})},
|
||||
lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.ADD, "dtype": set([dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]),
|
||||
"vin": [{"__name__": "non_muls"}, {"__name__": "muls", "uop": UOps.ALU, "arg": BinaryOps.MUL}]},
|
||||
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)),
|
||||
*[({"__name__": "x", "uop": UOps.ALU, "dtype": dtypes.half, "arg": op},
|
||||
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
|
||||
for op in lang.asm_for_op.keys() if op not in lang.supports_half],
|
||||
({"__name__": "root", "uop": UOps.LOAD, "dtype": dtypes.bool,
|
||||
"vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})},
|
||||
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
|
||||
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({},{})},
|
||||
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})},
|
||||
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})},
|
||||
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{},{"__name__": "g"})},
|
||||
lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,), root.arg),))),
|
||||
])
|
||||
|
||||
# here we do a pretransform on UOps to fix some shortcomings of PTX
|
||||
# all uops must be a register
|
||||
matcher.rewrite_graph(uops)
|
||||
|
||||
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
|
||||
uops.remove_childless(set(x for x in uops if x.uop in {UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
|
||||
uops.optimize_loops()
|
||||
|
||||
def kk(*s: str): kernel.append("\n".join(s))
|
||||
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
r: Dict[UOp, Union[List[str], str]] = {}
|
||||
def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
|
||||
nonlocal c, r
|
||||
prefix += f"_{dtype if dtype is not None else lang.types[cast(DType, cast(UOp, u).dtype)]}_"
|
||||
c[prefix] += 1
|
||||
if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
|
||||
return f"%{prefix}{c[prefix]-1}"
|
||||
|
||||
c_label: DefaultDict[str, int] = defaultdict(int)
|
||||
r_label: Dict[UOp, str] = {}
|
||||
def ssa_label(prefix:str, u:UOp):
|
||||
nonlocal c_label, r_label
|
||||
c_label[prefix] += 1
|
||||
r_label[u] = f"{lang.label_prefix}{prefix}_{c_label[prefix]-1}"
|
||||
return r_label[u]
|
||||
|
||||
def const(x:ConstType, dtype:DType, mov=False):
|
||||
if mov or dtype in lang.const_requires_mov:
|
||||
kk(*lang.render_const(x, dtype, mov=(out:=ssa('const', dtype=lang.types[dtype]))))
|
||||
return out
|
||||
return lang.render_const(x, dtype)
|
||||
|
||||
def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
|
||||
if atype == dtype:
|
||||
if u: r[u] = a
|
||||
return a
|
||||
kk(*lang.render_cast((ret:=ssa('cast', u, lang.types[dtype])), a, dtype, atype, bitcast))
|
||||
return ret
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
if uop is UOps.IF:
|
||||
assert vin[0].dtype is not None
|
||||
kk(*lang.render_bra(lb:=ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
|
||||
elif uop is UOps.BARRIER and lang.barrier: kk(lang.barrier)
|
||||
elif uop is UOps.ENDLOOP:
|
||||
kk(lang.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, lang.types[dtypes.int]),
|
||||
lang.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, lang.types[dtypes.int]))
|
||||
kk(*lang.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
|
||||
elif uop is UOps.ENDIF:
|
||||
kk(f"{r_label[vin[0]]}:")
|
||||
elif uop is UOps.STORE:
|
||||
assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None
|
||||
if vin[2].dtype.count > 1:
|
||||
kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \
|
||||
f"st{u.arg}.v{vin[2].dtype.count}.{lang.mem_type(vin[2].dtype.scalar())} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};")
|
||||
else:
|
||||
kk(*lang.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=u.arg, offset=vin[1].arg))
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.LOOP: kk(*lang.render_loop(ssa('ridx', u), r[vin[0]], ssa_label('loop', u)))
|
||||
elif uop is UOps.ALU:
|
||||
assert vin[0].dtype is not None
|
||||
if args is BinaryOps.CMPLT or args is BinaryOps.CMPEQ:
|
||||
# pass in the other dtype here
|
||||
kk(lang.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], vin[0].dtype, lang.types[vin[0].dtype]))
|
||||
else:
|
||||
kk(lang.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], dtype, lang.types[dtype]))
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('acc', dtype=lang.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
for uu in r[u]: kk(f"mov.b{lang.types[dtype.scalar()][1:]} {uu}, {const(args, dtype.scalar())};")
|
||||
else: kk(f"mov.b{lang.types[dtype][1:]} {ssa('acc', u)}, {const(args, dtype)};")
|
||||
elif uop is UOps.SPECIAL:
|
||||
assert args[1][0] != "i", "idx not supported"
|
||||
kk(f"mov.u32 %{args[1]}, {(lang.gid if args[1][0] == 'g' else lang.lid)[args[0]]};")
|
||||
r[u] = "%" + args[1]
|
||||
kernel = [f".reg .u32 %{args[1]};"] + kernel
|
||||
elif uop is UOps.CONST:
|
||||
if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
|
||||
else: r[u] = const(args, dtype, mov=True)
|
||||
elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg]
|
||||
elif uop is UOps.LOAD:
|
||||
assert vin[1].dtype is not None
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('val', dtype=lang.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
if(len(vin)>3):
|
||||
for v in r[u]: kk(f"mov.{lang.mem_type(dtype.scalar())} {v}, {render_val(0, dtype.scalar())};")
|
||||
kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
|
||||
+ f" ld{u.arg}.v{dtype.count}.{lang.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
|
||||
else:
|
||||
kk(*lang.render_load(r[vin[0]], ssa('val', u), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
|
||||
alt=r[vin[3]] if len(vin) > 3 else None, ss=u.arg, offset=vin[1].arg))
|
||||
elif uop is UOps.PHI:
|
||||
kk(f"mov.b{lang.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
|
||||
r[u] = r[vin[0]]
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
assert vin[0].dtype is not None
|
||||
if dtype.count>1: r[u] = [r[x] for x in vin] # type: ignore
|
||||
else: _cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
|
||||
elif uop is UOps.DEFINE_LOCAL:
|
||||
# TODO: we should sum these, and fetch 0xC000 from somewhere
|
||||
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
|
||||
kk(*lang.render_local(ssa('local', u, lang.types[dtypes.ulong]), args[0], args[1], dtype))
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
bufs.append((args.expr, dtype))
|
||||
r[u] = f"%{args.expr}"
|
||||
if lang.load_global: kk(*lang.render_load(args.expr, ssa('dat', u, lang.types[dtype]), dtype, ss=".param"))
|
||||
elif uop is UOps.DEFINE_GLOBAL:
|
||||
bufs.append((nm:=f"data{args[0]}", dtype))
|
||||
r[u] = f"%{nm}"
|
||||
if lang.load_global:
|
||||
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
|
||||
kk(*lang.render_load(nm, ssa('dat', u, lang.types[dt]), dt, ss=".param"))
|
||||
elif uop is UOps.WMMA:
|
||||
wmma = []
|
||||
for vv in vin[:2]:
|
||||
for i in range(0, len(r[vv]), 2):
|
||||
wmma.append(ssa("wmma", dtype="b32"))
|
||||
kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
|
||||
r[u] = r[vin[2]]
|
||||
kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\
|
||||
{{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[u])}}};')
|
||||
else: raise NotImplementedError(f"no code for {uop}")
|
||||
|
||||
return lang.render_kernel(kernel, function_name, bufs, c.items())
|
||||
|
||||
class PTXLanguage(AssemblyLanguage):
|
||||
# language options
|
||||
kernel_prefix = """.version VERSION
|
||||
.target TARGET
|
||||
.address_size 64
|
||||
@@ -229,7 +55,7 @@ class PTXLanguage(AssemblyLanguage):
|
||||
gid = [f'%ctaid.{chr(120+i)}' for i in range(3)]
|
||||
gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)]
|
||||
lid = [f'%tid.{chr(120+i)}' for i in range(3)]
|
||||
asm_for_op = {
|
||||
asm_for_op: Dict[Op, Callable] = {
|
||||
UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"neg.{name} {d}, {a};",
|
||||
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
|
||||
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
|
||||
@@ -245,13 +71,14 @@ class PTXLanguage(AssemblyLanguage):
|
||||
TernaryOps.WHERE: lambda d,a,b,c,dt,name:
|
||||
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
|
||||
}
|
||||
supports_half = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
|
||||
supports_half: List[Op] = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT,
|
||||
TernaryOps.WHERE]
|
||||
# HACK: Use s16 and u16 for int8 and uint8 buffers. This can be wrong in cast.
|
||||
types = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
|
||||
dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
|
||||
dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", dtypes.bool: "pred" }
|
||||
types: Dict[DType, str] = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
|
||||
dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
|
||||
dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", dtypes.bool: "pred" }
|
||||
|
||||
const_requires_mov = [dtypes.half, dtypes.bool]
|
||||
const_requires_mov: List[DType] = [dtypes.half, dtypes.bool]
|
||||
|
||||
def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]:
|
||||
val = render_val(x, dtype)
|
||||
@@ -270,7 +97,7 @@ class PTXLanguage(AssemblyLanguage):
|
||||
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]:
|
||||
assert dtype is not dtypes.bool
|
||||
if gate: return [f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"]
|
||||
else: return [f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];"]
|
||||
return [f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];"]
|
||||
|
||||
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]:
|
||||
return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}+{offset}], {val};"]
|
||||
@@ -291,4 +118,160 @@ class PTXLanguage(AssemblyLanguage):
|
||||
'\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) +
|
||||
"\n}")
|
||||
|
||||
PTXRenderer = functools.partial(uops_to_asm, PTXLanguage())
|
||||
def render(self, name:str, _uops:UOpGraph) -> str:
|
||||
# editing the uops breaks beam search
|
||||
uops = copy.deepcopy(_uops)
|
||||
kernel:List[str] = []
|
||||
bufs = []
|
||||
|
||||
matcher = PatternMatcher([
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"dtype": dtypes.bool},{})},
|
||||
lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)),
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})},
|
||||
lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.ADD, "dtype": set([dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]),
|
||||
"vin": [{"__name__": "non_muls"}, {"__name__": "muls", "uop": UOps.ALU, "arg": BinaryOps.MUL}]},
|
||||
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)),
|
||||
*[({"__name__": "x", "uop": UOps.ALU, "dtype": dtypes.half, "arg": op},
|
||||
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
|
||||
for op in self.asm_for_op.keys() if op not in self.supports_half],
|
||||
({"__name__": "root", "uop": UOps.LOAD, "dtype": dtypes.bool,
|
||||
"vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})},
|
||||
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
|
||||
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({},{})},
|
||||
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})},
|
||||
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})},
|
||||
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{},{"__name__": "g"})},
|
||||
lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,), root.arg),))),
|
||||
])
|
||||
|
||||
# here we do a pretransform on UOps to fix some shortcomings of PTX
|
||||
# all uops must be a register
|
||||
matcher.rewrite_graph(uops)
|
||||
|
||||
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
|
||||
uops.remove_childless(set(x for x in uops if x.uop in {UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
|
||||
uops.optimize_loops()
|
||||
|
||||
def kk(*s: str): kernel.append("\n".join(s))
|
||||
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
r: Dict[UOp, Union[List[str], str]] = {}
|
||||
def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
|
||||
nonlocal c, r
|
||||
prefix += f"_{dtype if dtype is not None else self.types[cast(DType, cast(UOp, u).dtype)]}_"
|
||||
c[prefix] += 1
|
||||
if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
|
||||
return f"%{prefix}{c[prefix]-1}"
|
||||
|
||||
c_label: DefaultDict[str, int] = defaultdict(int)
|
||||
r_label: Dict[UOp, str] = {}
|
||||
def ssa_label(prefix:str, u:UOp):
|
||||
nonlocal c_label, r_label
|
||||
c_label[prefix] += 1
|
||||
r_label[u] = f"{self.label_prefix}{prefix}_{c_label[prefix]-1}"
|
||||
return r_label[u]
|
||||
|
||||
def const(x:ConstType, dtype:DType, mov=False):
|
||||
if mov or dtype in self.const_requires_mov:
|
||||
kk(*self.render_const(x, dtype, mov=(out:=ssa('const', dtype=self.types[dtype]))))
|
||||
return out
|
||||
return self.render_const(x, dtype)
|
||||
|
||||
def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
|
||||
if atype == dtype:
|
||||
if u: r[u] = a
|
||||
return a
|
||||
kk(*self.render_cast((ret:=ssa('cast', u, self.types[dtype])), a, dtype, atype, bitcast))
|
||||
return ret
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
if uop is UOps.IF:
|
||||
assert vin[0].dtype is not None
|
||||
kk(*self.render_bra(lb:=ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
|
||||
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
|
||||
elif uop is UOps.ENDLOOP:
|
||||
kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]),
|
||||
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, self.types[dtypes.int]))
|
||||
kk(*self.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
|
||||
elif uop is UOps.ENDIF:
|
||||
kk(f"{r_label[vin[0]]}:")
|
||||
elif uop is UOps.STORE:
|
||||
assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None
|
||||
if vin[2].dtype.count > 1:
|
||||
kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \
|
||||
f"st{u.arg}.v{vin[2].dtype.count}.{self.mem_type(vin[2].dtype.scalar())} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};")
|
||||
else:
|
||||
kk(*self.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=u.arg, offset=vin[1].arg))
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.LOOP: kk(*self.render_loop(ssa('ridx', u), r[vin[0]], ssa_label('loop', u)))
|
||||
elif uop is UOps.ALU:
|
||||
assert vin[0].dtype is not None
|
||||
if args is BinaryOps.CMPLT or args is BinaryOps.CMPEQ:
|
||||
# pass in the other dtype here
|
||||
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], vin[0].dtype, self.types[vin[0].dtype]))
|
||||
else:
|
||||
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], dtype, self.types[dtype]))
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(args, dtype.scalar())};")
|
||||
else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(args, dtype)};")
|
||||
elif uop is UOps.SPECIAL:
|
||||
assert args[1][0] != "i", "idx not supported"
|
||||
kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};")
|
||||
r[u] = "%" + args[1]
|
||||
kernel = [f".reg .u32 %{args[1]};"] + kernel
|
||||
elif uop is UOps.CONST:
|
||||
if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
|
||||
else: r[u] = const(args, dtype, mov=True)
|
||||
elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg]
|
||||
elif uop is UOps.LOAD:
|
||||
assert vin[1].dtype is not None
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
if(len(vin)>3):
|
||||
for v in r[u]: kk(f"mov.{self.mem_type(dtype.scalar())} {v}, {render_val(0, dtype.scalar())};")
|
||||
kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
|
||||
+ f" ld{u.arg}.v{dtype.count}.{self.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
|
||||
else:
|
||||
kk(*self.render_load(r[vin[0]], ssa('val', u), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
|
||||
alt=r[vin[3]] if len(vin) > 3 else None, ss=u.arg, offset=vin[1].arg))
|
||||
elif uop is UOps.PHI:
|
||||
kk(f"mov.b{self.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
|
||||
r[u] = r[vin[0]]
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
assert vin[0].dtype is not None
|
||||
if dtype.count>1: r[u] = [r[x] for x in vin] # type: ignore
|
||||
else: _cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
|
||||
elif uop is UOps.DEFINE_LOCAL:
|
||||
# TODO: we should sum these, and fetch 0xC000 from somewhere
|
||||
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
|
||||
kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype))
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
bufs.append((args.expr, dtype))
|
||||
r[u] = f"%{args.expr}"
|
||||
if self.load_global: kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
|
||||
elif uop is UOps.DEFINE_GLOBAL:
|
||||
bufs.append((nm:=f"data{args[0]}", dtype))
|
||||
r[u] = f"%{nm}"
|
||||
if self.load_global:
|
||||
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
|
||||
kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
|
||||
elif uop is UOps.WMMA:
|
||||
wmma = []
|
||||
for vv in vin[:2]:
|
||||
for i in range(0, len(r[vv]), 2):
|
||||
wmma.append(ssa("wmma", dtype="b32"))
|
||||
kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
|
||||
r[u] = r[vin[2]]
|
||||
kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\
|
||||
{{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[u])}}};')
|
||||
else: raise NotImplementedError(f"no code for {uop}")
|
||||
|
||||
return self.render_kernel(kernel, name, bufs, c.items())
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict, cast, Literal, Callable
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, cast, Literal, Callable
|
||||
import os, math
|
||||
from collections import defaultdict, Counter
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.helpers import strip_parens, getenv, prod
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.codegen.uops import UOpGraph
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
class CStyleLanguage(NamedTuple):
|
||||
class CStyleLanguage(Renderer):
|
||||
kernel_prefix: str = ""
|
||||
buffer_prefix: str = ""
|
||||
buffer_suffix: str = ""
|
||||
@@ -17,8 +18,6 @@ class CStyleLanguage(NamedTuple):
|
||||
arg_int_prefix: str = "const int"
|
||||
barrier: str = ""
|
||||
code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
|
||||
global_max: List[int] = []
|
||||
local_max: List[int] = []
|
||||
extra_args: List[str] = []
|
||||
float4: Optional[str] = None
|
||||
uses_vload: bool = False
|
||||
@@ -88,100 +87,107 @@ class CStyleLanguage(NamedTuple):
|
||||
def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{self.render_dtype(dtype)} {name}[{size}];"
|
||||
def render_dtype(self, var_dtype:DType) -> str: return self.type_map.get(var_dtype, var_dtype.name)
|
||||
|
||||
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
kernel = []
|
||||
bufs: List[Tuple[str, Tuple[DType, bool]]] = []
|
||||
depth = 1
|
||||
def kk(s): kernel.append(" "*depth+s)
|
||||
def render(self, name:str, uops:UOpGraph) -> str:
|
||||
kernel = []
|
||||
bufs: List[Tuple[str, Tuple[DType, bool]]] = []
|
||||
depth = 1
|
||||
def kk(s): kernel.append(" "*depth+s)
|
||||
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
r: Dict[UOp, str] = {}
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
r: Dict[UOp, str] = {}
|
||||
|
||||
def ssa(prefix:str, u:Optional[UOp]=None):
|
||||
nonlocal c, r
|
||||
ret = f"{prefix}{c[prefix]}"
|
||||
if u is not None: r[u] = ret
|
||||
c[prefix] += 1
|
||||
return ret
|
||||
def ssa(prefix:str, u:Optional[UOp]=None):
|
||||
nonlocal c, r
|
||||
ret = f"{prefix}{c[prefix]}"
|
||||
if u is not None: r[u] = ret
|
||||
c[prefix] += 1
|
||||
return ret
|
||||
|
||||
child_count = Counter(v for ru in uops for v in ru.vin)
|
||||
child_count = Counter(v for ru in uops for v in ru.vin)
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
# these four uops don't have output dtypes
|
||||
if uop is UOps.IF:
|
||||
kk(f"if ({r[vin[0]]}) {{")
|
||||
depth += 1
|
||||
elif uop is UOps.BARRIER: kk(lang.barrier)
|
||||
elif uop in {UOps.ENDLOOP, UOps.ENDIF}:
|
||||
depth -= 1
|
||||
kk("}")
|
||||
elif uop is UOps.STORE:
|
||||
assert vin[0].dtype is not None and vin[2].dtype is not None
|
||||
rendered_store = lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
|
||||
kk(f"if ({r[vin[3]]}) {{ {rendered_store} }}" if len(vin) > 3 else rendered_store)
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.LOOP:
|
||||
kk(f"for (int {(expr := ssa('ridx',u))} = {r[vin[0]]}; {expr} < {r[vin[1]]}; {expr}++) {{")
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
# these four uops don't have output dtypes
|
||||
if uop is UOps.IF:
|
||||
kk(f"if ({r[vin[0]]}) {{")
|
||||
depth += 1
|
||||
elif uop is UOps.ALU:
|
||||
# remove parens if ALU types are the same. TODO: can do more here
|
||||
if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in vin]
|
||||
else: operands = [r[v] for v in vin]
|
||||
val = lang.code_for_op[args](*operands, dtype)
|
||||
assert child_count[u] != 0, f"childless ALU op found {u}"
|
||||
# TODO: fix index rendering issue. fix clang nested max macro issue
|
||||
if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
|
||||
else: kk(f"{lang.render_dtype(dtype)} {ssa('alu',u)} = {val};")
|
||||
elif uop is UOps.SPECIAL:
|
||||
kk(f"int {args[1]} = {lang.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
|
||||
r[u] = args[1]
|
||||
elif uop is UOps.LOAD:
|
||||
val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
|
||||
# NOTE: this relies on the load not happening if it's in the unselected branch
|
||||
if len(vin) > 3: val = lang.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype)
|
||||
kk(f"{lang.render_dtype(dtype)} {ssa('val',u)} = {val};")
|
||||
elif uop is UOps.PHI:
|
||||
kk(f"{r[vin[0]]} = {r[vin[1]]};")
|
||||
r[u] = r[vin[0]]
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
if uop is UOps.BITCAST:
|
||||
assert len(vin) == 1
|
||||
precast = ssa('precast')
|
||||
kk(f"{lang.render_dtype(cast(DType, vin[0].dtype))} {precast} = {r[vin[0]]};")
|
||||
val = lang.render_cast([precast], dtype, bitcast=True)
|
||||
else:
|
||||
val = lang.render_cast([r[x] for x in vin], dtype, bitcast=False)
|
||||
if child_count[u] <= 1: r[u] = val
|
||||
else: kk(f"{lang.render_dtype(dtype)} {ssa('cast',u)} = {val};")
|
||||
elif uop is UOps.DEFINE_LOCAL:
|
||||
kk(lang.render_local(args[0], dtype, args[1]))
|
||||
r[u] = args[0]
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
bufs.append((args.expr, (dtype,False)))
|
||||
r[u] = args.expr
|
||||
elif uop is UOps.DEFINE_GLOBAL:
|
||||
bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
|
||||
r[u] = nm
|
||||
elif uop is UOps.WMMA: kk(f"{lang.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
|
||||
elif uop is UOps.DEFINE_ACC: kk(f"{lang.render_dtype(dtype)} {ssa('acc',u)} = {lang.render_const(args, dtype)};")
|
||||
elif uop is UOps.CONST: r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"
|
||||
elif uop is UOps.GEP:
|
||||
assert vin[0].dtype is not None
|
||||
from_ssa = vin[0].uop in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
|
||||
r[u] = (r[vin[0]] if from_ssa else f"{(r[vin[0]])}") + (f"[{args}]" if vin[0].dtype.count > 4 else f".{'xyzw'[args]}")
|
||||
else: raise RuntimeError(f"failed to render {uop}")
|
||||
elif uop is UOps.BARRIER: kk(self.barrier)
|
||||
elif uop in {UOps.ENDLOOP, UOps.ENDIF}:
|
||||
depth -= 1
|
||||
kk("}")
|
||||
elif uop is UOps.STORE:
|
||||
assert vin[0].dtype is not None and vin[2].dtype is not None
|
||||
rendered_store = self.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
|
||||
kk(f"if ({r[vin[3]]}) {{ {rendered_store} }}" if len(vin) > 3 else rendered_store)
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.LOOP:
|
||||
kk(f"for (int {(expr := ssa('ridx',u))} = {r[vin[0]]}; {expr} < {r[vin[1]]}; {expr}++) {{")
|
||||
depth += 1
|
||||
elif uop is UOps.ALU:
|
||||
# remove parens if ALU types are the same. TODO: can do more here
|
||||
if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in vin]
|
||||
else: operands = [r[v] for v in vin]
|
||||
val = self.code_for_op[args](*operands, dtype)
|
||||
assert child_count[u] != 0, f"childless ALU op found {u}"
|
||||
# TODO: fix index rendering issue. fix clang nested max macro issue
|
||||
if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
|
||||
else: kk(f"{self.render_dtype(dtype)} {ssa('alu',u)} = {val};")
|
||||
elif uop is UOps.SPECIAL:
|
||||
kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
|
||||
r[u] = args[1]
|
||||
elif uop is UOps.LOAD:
|
||||
val = self.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
|
||||
# NOTE: this relies on the load not happening if it's in the unselected branch
|
||||
if len(vin) > 3: val = self.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype)
|
||||
kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
|
||||
elif uop is UOps.PHI:
|
||||
kk(f"{r[vin[0]]} = {r[vin[1]]};")
|
||||
r[u] = r[vin[0]]
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
if uop is UOps.BITCAST:
|
||||
assert len(vin) == 1
|
||||
precast = ssa('precast')
|
||||
kk(f"{self.render_dtype(cast(DType, vin[0].dtype))} {precast} = {r[vin[0]]};")
|
||||
val = self.render_cast([precast], dtype, bitcast=True)
|
||||
else:
|
||||
val = self.render_cast([r[x] for x in vin], dtype, bitcast=False)
|
||||
if child_count[u] <= 1: r[u] = val
|
||||
else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};")
|
||||
elif uop is UOps.DEFINE_LOCAL:
|
||||
kk(self.render_local(args[0], dtype, args[1]))
|
||||
r[u] = args[0]
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
bufs.append((args.expr, (dtype,False)))
|
||||
r[u] = args.expr
|
||||
elif uop is UOps.DEFINE_GLOBAL:
|
||||
bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
|
||||
r[u] = nm
|
||||
elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
|
||||
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(args, dtype)};")
|
||||
elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
|
||||
elif uop is UOps.GEP:
|
||||
assert vin[0].dtype is not None
|
||||
from_ssa = vin[0].uop in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
|
||||
r[u] = (r[vin[0]] if from_ssa else f"{(r[vin[0]])}") + (f"[{args}]" if vin[0].dtype.count > 4 else f".{'xyzw'[args]}")
|
||||
else: raise RuntimeError(f"failed to render {uop}")
|
||||
|
||||
return lang.render_kernel(function_name, kernel, bufs, uops)
|
||||
return self.render_kernel(name, kernel, bufs, uops)
|
||||
|
||||
class ClangLanguage(CStyleLanguage):
|
||||
class ClangRenderer(CStyleLanguage):
|
||||
device = "CLANG"
|
||||
supports_float4 = False
|
||||
has_local = False
|
||||
|
||||
# language options
|
||||
buffer_suffix = " restrict"
|
||||
type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
|
||||
code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"(({a}>{b})?{a}:{b})"}
|
||||
def ClangRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(ClangLanguage(), name, uops)
|
||||
|
||||
class OpenCLLanguage(CStyleLanguage):
|
||||
class OpenCLRenderer(CStyleLanguage):
|
||||
device = "GPU"
|
||||
|
||||
# language options
|
||||
kernel_prefix = "__kernel "
|
||||
buffer_prefix = "__global "
|
||||
smem_align = "__attribute__ ((aligned (16))) "
|
||||
@@ -197,9 +203,13 @@ class OpenCLLanguage(CStyleLanguage):
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
if any(uop.dtype == dtypes.half for uop in uops): prefix = ["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"]
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
def OpenCLRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(OpenCLLanguage(), name, uops)
|
||||
|
||||
class MetalLanguage(CStyleLanguage):
|
||||
class MetalRenderer(CStyleLanguage):
|
||||
device = "METAL"
|
||||
has_tensor_cores=os.uname().machine == "arm64"
|
||||
shared_max=32768
|
||||
|
||||
# language options
|
||||
kernel_prefix = "kernel "
|
||||
buffer_prefix = "device "
|
||||
smem_prefix = "threadgroup "
|
||||
@@ -227,7 +237,6 @@ class MetalLanguage(CStyleLanguage):
|
||||
b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
|
||||
return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""")
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
def MetalRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(MetalLanguage(), name, uops)
|
||||
|
||||
code_for_op_half = {BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype in (dtypes.half, dtypes.bfloat16) else f"max({a},{b})",
|
||||
UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
|
||||
@@ -240,7 +249,15 @@ def _make_cuda_dtype(base_type, name, cnt):
|
||||
vec, elems, header = f"{name}{cnt}", ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
|
||||
return f"struct {vec} {{ {base_type} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
|
||||
|
||||
class CUDALanguage(CStyleLanguage):
|
||||
class CUDARenderer(CStyleLanguage):
|
||||
device = "CUDA"
|
||||
global_max=[65535, 65535, 2147483647]
|
||||
local_max=[64, 1024, 1024]
|
||||
shared_max=49152
|
||||
has_tensor_cores = False
|
||||
def __init__(self, arch:str): self.has_tensor_cores=int(arch[3:]) >= 80
|
||||
|
||||
# language options
|
||||
kernel_prefix = "extern \"C\" __global__ "
|
||||
smem_prefix = "__shared__ "
|
||||
smem_prefix_for_cast = False
|
||||
@@ -271,7 +288,6 @@ asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}
|
||||
return c;}}""")
|
||||
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
|
||||
def CUDARenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(CUDALanguage(), name, uops)
|
||||
|
||||
code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
@@ -295,7 +311,12 @@ def _make_hip_dtype(base_type, name, cnt):
|
||||
return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
|
||||
f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}"
|
||||
|
||||
class HIPLanguage(CStyleLanguage):
|
||||
class HIPRenderer(CStyleLanguage):
|
||||
device = "HSA"
|
||||
has_tensor_cores = True
|
||||
shared_max = 65536
|
||||
|
||||
# language options
|
||||
kernel_prefix = """extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
|
||||
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
|
||||
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
|
||||
@@ -357,5 +378,3 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
|
||||
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
|
||||
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
|
||||
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
|
||||
|
||||
def HIPRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(HIPLanguage(), name, uops)
|
||||
|
||||
@@ -4,6 +4,7 @@ from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.dtype import DType, PtrDType, dtypes
|
||||
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.codegen.uops import UOpGraph
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
|
||||
|
||||
@@ -65,88 +66,95 @@ def cast(bb, val, input_type, output_type, bitcast=False):
|
||||
|
||||
def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], args)
|
||||
|
||||
def uops_to_llvm_ir(function_name:str, uops:UOpGraph) -> str:
|
||||
# all llvm stuff goes into a module
|
||||
module = ir.Module(name=__file__)
|
||||
class LLVMRenderer(Renderer):
|
||||
device = "LLVM"
|
||||
supports_float4=False
|
||||
has_local=False
|
||||
has_shared=False
|
||||
|
||||
# extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
|
||||
buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
|
||||
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
|
||||
def render(self, name:str, uops:UOpGraph) -> str:
|
||||
# all llvm stuff goes into a module
|
||||
module = ir.Module(name=__file__)
|
||||
|
||||
# create llvm function
|
||||
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values() if dtype is not None]
|
||||
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=function_name) # noqa: E501
|
||||
for a in func.args:
|
||||
if a.type.is_pointer: a.add_attribute("noalias")
|
||||
# extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
|
||||
buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
|
||||
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
|
||||
|
||||
# add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations
|
||||
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
|
||||
func.attributes.add('"no-nans-fp-math"="true"')
|
||||
# create llvm function
|
||||
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values() if dtype is not None]
|
||||
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=name)
|
||||
for a in func.args:
|
||||
if a.type.is_pointer: a.add_attribute("noalias")
|
||||
|
||||
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
|
||||
loop_blocks: List = []
|
||||
reduce_phis: List = []
|
||||
# TODO: newvar probably shouldn't be optional
|
||||
lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
|
||||
# add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations
|
||||
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
|
||||
func.attributes.add('"no-nans-fp-math"="true"')
|
||||
|
||||
for bufname,dtype in buf_to_dtype.items():
|
||||
if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
|
||||
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
|
||||
loop_blocks: List = []
|
||||
reduce_phis: List = []
|
||||
# TODO: newvar probably shouldn't be optional
|
||||
lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
if uop is UOps.STORE:
|
||||
element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype)
|
||||
def store_op(): bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
||||
if len(vin) > 3:
|
||||
with bb[-1].if_then(lvars[vin[3]]): store_op()
|
||||
else: store_op()
|
||||
elif uop is UOps.ENDLOOP:
|
||||
loop_entry_bb, phis = loop_blocks.pop()
|
||||
idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1))
|
||||
lvars[vin[0]].add_incoming(idx_p1, bb[-1].block)
|
||||
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
|
||||
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), loop_entry_bb, bb[-1].block)
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.LOOP:
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
|
||||
bb[-2].branch(bb[-1].block)
|
||||
for bufname,dtype in buf_to_dtype.items():
|
||||
if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
|
||||
|
||||
phis = []
|
||||
for rp in reduce_phis:
|
||||
incoming = lvars[rp]
|
||||
lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
|
||||
lvars[rp].add_incoming(incoming, bb[-2].block)
|
||||
phis.append((rp, lvars[rp]))
|
||||
|
||||
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
|
||||
lvars[u].add_incoming(lvars[vin[0]], bb[-2].block)
|
||||
loop_blocks.append((bb[-1].block, phis))
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
lvars[u] = const(args, dtype)
|
||||
reduce_phis.append(u)
|
||||
elif uop is UOps.LOAD:
|
||||
if len(vin) > 2:
|
||||
aug_idx = bb[-1].select(lvars[vin[2]], lvars[vin[1]], ir.Constant(ir.IntType(32), 0))
|
||||
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True))
|
||||
val = bb[-1].select(lvars[vin[2]], val, lvars[vin[3]])
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
if uop is UOps.STORE:
|
||||
element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype)
|
||||
if len(vin) > 3:
|
||||
with bb[-1].if_then(lvars[vin[3]]):
|
||||
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
||||
else:
|
||||
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
||||
lvars[u] = val
|
||||
elif uop is UOps.PHI:
|
||||
lvars[u] = lvars[vin[1]]
|
||||
# PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
|
||||
backward = vin[0]
|
||||
while backward.uop is UOps.PHI: backward = backward.vin[0]
|
||||
lvars[backward] = lvars[u]
|
||||
elif uop is UOps.ALU:
|
||||
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else vin[0].dtype)
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
|
||||
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
|
||||
elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
|
||||
elif uop is UOps.CONST: lvars[u] = const(args, dtype)
|
||||
else: raise RuntimeError(f"failed to render {uop}")
|
||||
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
||||
elif uop is UOps.ENDLOOP:
|
||||
loop_entry_bb, phis = loop_blocks.pop()
|
||||
idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1))
|
||||
lvars[vin[0]].add_incoming(idx_p1, bb[-1].block)
|
||||
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
|
||||
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), loop_entry_bb, bb[-1].block)
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.LOOP:
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
|
||||
bb[-2].branch(bb[-1].block)
|
||||
|
||||
bb[-1].ret_void()
|
||||
return str(module)
|
||||
phis = []
|
||||
for rp in reduce_phis:
|
||||
incoming = lvars[rp]
|
||||
lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
|
||||
lvars[rp].add_incoming(incoming, bb[-2].block)
|
||||
phis.append((rp, lvars[rp]))
|
||||
|
||||
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
|
||||
lvars[u].add_incoming(lvars[vin[0]], bb[-2].block)
|
||||
loop_blocks.append((bb[-1].block, phis))
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
lvars[u] = const(args, dtype)
|
||||
reduce_phis.append(u)
|
||||
elif uop is UOps.LOAD:
|
||||
if len(vin) > 2:
|
||||
aug_idx = bb[-1].select(lvars[vin[2]], lvars[vin[1]], ir.Constant(ir.IntType(32), 0))
|
||||
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True))
|
||||
val = bb[-1].select(lvars[vin[2]], val, lvars[vin[3]])
|
||||
else:
|
||||
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
||||
lvars[u] = val
|
||||
elif uop is UOps.PHI:
|
||||
lvars[u] = lvars[vin[1]]
|
||||
# PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
|
||||
backward = vin[0]
|
||||
while backward.uop is UOps.PHI: backward = backward.vin[0]
|
||||
lvars[backward] = lvars[u]
|
||||
elif uop is UOps.ALU:
|
||||
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else vin[0].dtype)
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
|
||||
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
|
||||
elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
|
||||
elif uop is UOps.CONST: lvars[u] = const(args, dtype)
|
||||
else: raise RuntimeError(f"failed to render {uop}")
|
||||
|
||||
bb[-1].ret_void()
|
||||
return str(module)
|
||||
|
||||
Reference in New Issue
Block a user