Files
tinygrad/tinygrad/renderer/assembly.py
2024-04-02 14:19:51 -07:00

287 lines
17 KiB
Python

from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple, Set
import functools, struct
from collections import defaultdict
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
from tinygrad.dtype import dtypes, DType, PtrDType, INVERSE_DTYPES_DICT
from tinygrad.codegen.uops import UOpGraph, PatternMatcher
def render_val(x, dtype):
if dtypes.is_float(dtype):
if dtype == dtypes.double: return "0d%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
elif dtype == dtypes.half: return "0x%02X%02X" % tuple(struct.pack("e",x)[::-1])
return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
class AssemblyLanguage(NamedTuple):
kernel_prefix: str = ""
barrier: str = ""
load_global: bool = False
label_prefix: str = ""
gid: List[str] = []
gdim: List[str] = []
lid: List[str] = []
const_requires_mov: List[DType] = [] # list of dtypes for which creating a const requires a move
asm_for_op: Dict[Op, Callable[...,str]] = {}
types: Dict[DType, str] = INVERSE_DTYPES_DICT
supports_half: List[Op] = []
def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]: raise NotImplementedError()
def render_local(self, dest, name, size, dtype) -> List[str]: raise NotImplementedError()
def render_loop(self, idx, start, label, acc=None) -> List[str]: raise NotImplementedError()
def render_bra(self, b1, pred=None, b2=None) -> List[str]: raise NotImplementedError()
def render_gep(self, loc, base, offset, dtype, gate=None) -> List[str]: raise NotImplementedError()
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]: raise NotImplementedError()
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]: raise NotImplementedError()
def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]: raise NotImplementedError()
def render_kernel(self, kernel, function_name, bufs, regs) -> str: raise NotImplementedError()
def mem_type(self, dtype) -> str: raise NotImplementedError()
def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
kernel:List[str] = []
bufs = []
def ptr_ar(root, uops):
assert root.arg in {'.shared', '.global', None}
if root.arg is None: root.arg = '.shared' if root.vin[0].uop is UOps.DEFINE_LOCAL else '.global' # move this to the argL
val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=root.vin[0].dtype.itemsize, insert_before=uops.uops.index(root))
ptr = uops.add(UOps.ALU, dtypes.int, (root.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root))
if ptr.uop is UOps.CONST: root.vin = (root.vin[0], ptr) + root.vin[2:]
else:
zero = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0, cachable=False, insert_before=uops.uops.index(root))
bptr = uops.add(UOps.CAST, dtypes.uint64, (ptr,), insert_before=uops.uops.index(root))
fptr = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], bptr), arg=BinaryOps.ADD, insert_before=uops.uops.index(root))
root.vin = (fptr, zero) + root.vin[2:]
matcher = PatternMatcher([
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"dtype": dtypes.bool},{})},
lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})},
lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
*[({"__name__": "x", "uop": UOps.ALU, "dtype": dtypes.half, "arg": op},
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
for op in lang.asm_for_op.keys() if op not in lang.supports_half],
({"__name__": "root", "uop": UOps.LOAD, "dtype": dtypes.bool,
"vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})},
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({},{})},
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})},
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})},
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
])
# here we do a pretransform on UOps to fix some shortcomings of PTX
# all uops must be a register
replace: Dict[UOp, UOp] = {}
seen: Set[UOp] = set()
for u in uops:
if u in seen: continue
seen.add(u)
for o,n in replace.items():
if o in u.vin and u is not n:
u.vin = tuple(n if x == o else x for x in u.vin)
if rew := matcher.rewrite(u): replace[u] = rew
for o,n in replace.items():
queue = [n]
while queue:
if all([qq in uops.uops for qq in queue[-1].vin]):
q = queue.pop()
new = uops.add(q.uop, q.dtype, q.vin, q.arg, insert_before=max([uops.uops.index(vv) for vv in q.vin])+1)
for vv in uops.uops + queue: vv.vin = tuple(new if x is q else x for x in vv.vin)
else: queue.extend([qq for qq in queue[-1].vin if qq not in uops.uops])
uops.uops.remove(o)
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
uops.remove_childless(set(x for x in uops if x.uop in {UOps.DEFINE_GLOBAL, UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
def kk(*s: str): kernel.append("\n".join(s))
c: DefaultDict[str, int] = defaultdict(int)
r: Dict[UOp, Union[List[str], str]] = {}
def ssa(u, prefix="t", dtype=None) -> str:
nonlocal c, r
prefix += f"_{dtype if dtype else lang.types[u.dtype]}_"
c[prefix] += 1
if u: r[u] = f"%{prefix}{c[prefix]-1}"
return f"%{prefix}{c[prefix]-1}"
c_label: DefaultDict[str, int] = defaultdict(int)
r_label: Dict[UOp, str] = {}
def ssa_label(u, prefix):
nonlocal c_label, r_label
c_label[prefix] += 1
r_label[u] = f"{lang.label_prefix}{prefix}_{c_label[prefix]-1}"
return r_label[u]
def const(x:Union[float,int,bool], dtype, mov=False):
if mov or dtype in lang.const_requires_mov:
kk(*lang.render_const(x, dtype, mov=(out:=ssa(None, 'const', lang.types[dtype]))))
return out
return lang.render_const(x, dtype)
def cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
if atype == dtype:
if u: r[u] = a
return a
kk(*lang.render_cast((ret:=ssa(u, 'cast', lang.types[dtype])), a, dtype, atype, bitcast))
return ret
for u in uops:
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
if uop is UOps.IF:
assert vin[0].dtype is not None
kk(*lang.render_bra(lb:=ssa_label(u, 'if'), cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
elif uop is UOps.BARRIER and lang.barrier: kk(lang.barrier)
elif uop is UOps.ENDLOOP:
kk(lang.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, lang.types[dtypes.int]),
lang.asm_for_op[BinaryOps.CMPLT](pred:=ssa(None, "pred", "pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, lang.types[dtypes.int]))
kk(*lang.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
elif uop is UOps.ENDIF:
kk(f"{r_label[vin[0]]}:")
elif uop is UOps.STORE:
assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None
if vin[2].dtype.count > 1:
kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \
f"st{u.arg}.v{vin[2].dtype.count}.{lang.mem_type(vin[2].dtype.scalar())} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};")
else:
kk(*lang.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=u.arg, offset=vin[1].arg))
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.LOOP: kk(*lang.render_loop(ssa(u, 'ridx'), r[vin[0]], ssa_label(u, 'loop')))
elif uop is UOps.ALU:
assert vin[0].dtype is not None
if args is BinaryOps.CMPLT or args is BinaryOps.CMPEQ:
# pass in the other dtype here
kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], vin[0].dtype, lang.types[vin[0].dtype]))
else:
kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], dtype, lang.types[dtype]))
elif uop is UOps.DEFINE_ACC:
if dtype.count > 1:
r[u] = [ssa(None, 'acc', lang.types[dtype.scalar()]) for _ in range(dtype.count)]
for uu in r[u]: kk(f"mov.b{lang.types[dtype.scalar()][1:]} {uu}, {const(args, dtype.scalar())};")
else: kk(f"mov.b{lang.types[dtype][1:]} {ssa(u, 'acc')}, {const(args, dtype)};")
elif uop is UOps.SPECIAL:
assert args[1][0] != "i", "idx not supported"
kk(f"mov.u32 %{args[1]}, {(lang.gid if args[1][0] == 'g' else lang.lid)[args[0]]};")
r[u] = "%" + args[1]
kernel = [f".reg .u32 %{args[1]};"] + kernel
elif uop is UOps.CONST:
if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
else: r[u] = const(args, dtype, mov=True)
elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg]
elif uop is UOps.LOAD:
assert vin[1].dtype is not None
if dtype.count > 1:
r[u] = [ssa(None, 'val', lang.types[dtype.scalar()]) for _ in range(dtype.count)]
if(len(vin)>3):
for v in r[u]: kk(f"mov.{lang.mem_type(dtype.scalar())} {v}, {render_val(0, dtype.scalar())};")
kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
+ f" ld{u.arg}.v{dtype.count}.{lang.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
else:
kk(*lang.render_load(r[vin[0]], ssa(u, 'val'), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
alt=r[vin[3]] if len(vin) > 3 else None, ss=u.arg, offset=vin[1].arg))
elif uop is UOps.PHI:
kk(f"mov.b{lang.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
r[u] = r[vin[0]]
elif uop in {UOps.CAST, UOps.BITCAST}:
assert vin[0].dtype is not None
if dtype.count>1: r[u] = [r[x] for x in vin] # type: ignore
else: cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
elif uop is UOps.DEFINE_LOCAL:
# TODO: we should sum these, and fetch 0xC000 from somewhere
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
kk(*lang.render_local(ssa(u, 'local', lang.types[dtypes.ulong]), args[0], args[1], dtype))
elif uop is UOps.DEFINE_VAR:
bufs.append((args.expr, dtype))
r[u] = f"%{args.expr}"
if lang.load_global: kk(*lang.render_load(args.expr, ssa(u, 'dat', dtype=lang.types[dtype]), dtype, ss=".param"))
elif uop is UOps.DEFINE_GLOBAL:
bufs.append((args[1], dtype))
r[u] = f"%{args[1]}"
if lang.load_global:
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
kk(*lang.render_load(args[1], ssa(u, 'dat', dtype=lang.types[dt]), dt, ss=".param"))
else: raise NotImplementedError(f"no code for {uop}")
return lang.render_kernel(kernel, function_name, bufs, c.items())
class PTXLanguage(AssemblyLanguage):
kernel_prefix = """.version VERSION
.target TARGET
.address_size 64
.visible .entry"""
barrier = "bar.sync\t0;"
has_pred = True
load_global = True
label_prefix = "$"
gid = [f'%ctaid.{chr(120+i)}' for i in range(3)]
gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)]
lid = [f'%tid.{chr(120+i)}' for i in range(3)]
asm_for_op = {
UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"neg.{name} {d}, {a};",
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
BinaryOps.SUB: lambda d,a,b,dt,name: f"sub.{name} {d}, {a}, {b};",
BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
BinaryOps.DIV: lambda d,a,b,dt,name: f"div{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a}, {b};",
BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};",
BinaryOps.CMPEQ: lambda d,a,b,dt,name: f"setp.eq.{name} {d}, {a}, {b};",
TernaryOps.WHERE: lambda d,a,b,c,dt,name:
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
}
supports_half = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
# HACK: Use s16 and u16 for int8 and uint8 buffers. This can be wrong in cast.
types = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", dtypes.bool: "pred" }
const_requires_mov = [dtypes.half, dtypes.bool]
def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]:
val = render_val(x, dtype)
if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"]
return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val
def render_local(self, dest, name, size, dtype) -> List[str]:
return [f".shared .align 4 .b8 {name}[{size*dtype.itemsize}];", f"mov.u64 {dest}, {name}[0];"]
def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"]
def render_bra(self, b1, pred=None, b2=None) -> List[str]: return [f"@{pred} bra {b1};", f"@!{pred} bra {b2};"] if pred else [f"bra {b1};"]
def mem_type(self, dtype): return 's8' if dtype.itemsize == 1 else 'b16' if dtype == dtypes.float16 else self.types[dtype]
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]:
assert dtype is not dtypes.bool
if gate: return [f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"]
else: return [f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];"]
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]:
return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}+{offset}], {val};"]
def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]:
if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"]
if atype == dtypes.bool: return[f"selp.b{self.types[dtype][1:]} {d}, {render_val(1, dtype)}, {render_val(0, dtype)}, {a};"]
if dtype == dtypes.bool: return [f"setp.ne.b{self.types[atype][1:]} {d}, {a}, {self.render_const(0, atype)};"]
rnd = ('.rzi' if dtypes.is_int(dtype) and dtypes.is_float(atype) else
'.rn' if dtypes.is_float(dtype) and (dtype.itemsize < atype.itemsize or dtypes.is_int(atype) or atype == dtypes.bool) else '')
return [f"cvt{rnd}.{self.types[dtype]}.{self.types[atype]} {d}, {a};"]
def render_kernel(self, kernel, function_name, bufs, regs) -> str:
kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]
def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1)
return (f"{self.kernel_prefix} {function_name}(\n\t" +
',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) + "\n)\n{\n" +
'\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) +
"\n}")
PTXRenderer = functools.partial(uops_to_asm, PTXLanguage())