from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple import functools, struct from collections import defaultdict from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op from tinygrad.dtype import dtypes, DType, PtrDType, INVERSE_DTYPES_DICT def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1]) def double_to_hex(x): return "%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1]) def trunc_float(x, fmt): return struct.unpack(fmt, struct.pack(fmt, x))[0] def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype) class AssemblyLanguage(NamedTuple): kernel_prefix: str = "" barrier: str = "" load_global: bool = False label_prefix: str = "" gid: List[str] = [] gdim: List[str] = [] lid: List[str] = [] const_requires_mov: List[DType] = [] # list of dtypes for which creating a const requires a move asm_for_op: Dict[Op, Callable[...,str]] = {} types: Dict[DType, str] = INVERSE_DTYPES_DICT def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]: raise NotImplementedError() def render_local(self, dest, name, size, dtype) -> List[str]: raise NotImplementedError() def render_loop(self, idx, start, label, acc=None) -> List[str]: raise NotImplementedError() def render_bra(self, b1, pred=None, b2=None) -> List[str]: raise NotImplementedError() def render_gep(self, loc, base, offset, dtype, gate=None) -> List[str]: raise NotImplementedError() def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="") -> List[str]: raise NotImplementedError() def render_store(self, loc, val, dtype, gate=None, ss="") -> List[str]: raise NotImplementedError() def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]: raise NotImplementedError() def render_kernel(self, kernel, function_name, bufs, regs) -> str: raise NotImplementedError() def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:List[UOp]) -> str: local_size: List[int] = [] kernel:List[str] = [] bufs = [] def kk(*s: str): kernel.append("\n".join(s)) c: DefaultDict[str, int] = defaultdict(int) r: Dict[UOp, str] = {} def ssa(u, prefix="t", dtype=None) -> str: nonlocal c, r prefix += f"_{dtype if dtype else lang.types[u.dtype]}_" c[prefix] += 1 if u: r[u] = f"%{prefix}{c[prefix]-1}" return f"%{prefix}{c[prefix]-1}" c_label: DefaultDict[str, int] = defaultdict(int) r_label: Dict[UOp, str] = {} def ssa_label(u, prefix): nonlocal c_label, r_label c_label[prefix] += 1 r_label[u] = f"{lang.label_prefix}{prefix}_{c_label[prefix]-1}" return r_label[u] def const(x:Union[float,int,bool], dtype, mov=False): if mov or dtype in lang.const_requires_mov: kk(*lang.render_const(x, dtype, mov=(out:=ssa(None, 'const', lang.types[dtype])))) return out return lang.render_const(x, dtype) def cast(a:str, dtype:DType, atype:DType, bitcast=False, u=None, pred=False): if atype == dtype: if u: r[u] = a return a kk(*lang.render_cast((ret:=ssa(u, 'cast', lang.types[dtype])), a, dtype, atype, bitcast)) return ret for u in uops: uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg if uop == UOps.IF: assert vin[0].dtype is not None kk(*lang.render_bra(lb:=ssa_label(u, 'if'), cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:") elif uop == UOps.BARRIER and lang.barrier: kk(lang.barrier) elif uop == UOps.END: if vin[0].uop == UOps.LOOP: kk(lang.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, lang.types[dtypes.int]), lang.asm_for_op[BinaryOps.CMPLT](pred:=ssa(None, "pred", "pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, lang.types[dtypes.int])) kk(*lang.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:") else: kk(f"{r_label[vin[0]]}:") elif uop == UOps.STORE: assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None kk(*lang.render_gep(loc:=ssa(None,'loc','u64'), r[vin[0]], r[vin[1]], vin[0].dtype)) if len(vin) > 3: assert vin[3].dtype is not None pred = cast(r[vin[3]], dtypes.bool, vin[3].dtype, pred=True) kk(*lang.render_store(loc, r[vin[2]], vin[0].dtype, gate=pred if len(vin)>3 else None, ss='.shared' if vin[0].uop == UOps.DEFINE_LOCAL else '')) else: assert dtype is not None, f"None dtype for uop {uop}" if uop == UOps.LOOP: kk(*lang.render_loop(ssa(u, 'ridx'), r[vin[0]], ssa_label(u, 'loop'))) elif uop == UOps.ALU: assert vin[0].dtype is not None if args == BinaryOps.CMPLT or args == BinaryOps.CMPEQ: regs = [cast(r[x], dtypes.int16, dtypes.bool) if x.dtype == dtypes.bool else r[x] for x in vin] dt = dtypes.int16 if vin[0].dtype == dtypes.bool else vin[0].dtype kk(lang.asm_for_op[args](pred:=ssa(u,'lt','pred'), *regs, dt, lang.types[dt])) elif args == TernaryOps.MULACC: assert vin[1].dtype is not None kk(lang.asm_for_op[args](ssa(u, 'alu'), *[r[x] for x in vin], dtype, lang.types[vin[1].dtype])) else: kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], dtype, lang.types[dtype])) elif uop == UOps.DEFINE_ACC: kk(f"mov.b{lang.types[dtype][1:]} {ssa(u, 'acc')}, {const(args, dtype)};") elif uop == UOps.SPECIAL: if args[1][0] == "i": kk(f"mov.u32 %{args[1]}, {lang.gid[args[0]]};", f"mov.u32 {(gdim:=ssa(None,'tmp','u32'))}, {lang.gdim[args[0]]};", f"mov.u32 {(lid:=ssa(None,'tmp','u32'))}, {lang.lid[args[0]]};", f"mad.lo.u32 {(tmp:=ssa(None, 'tmp', 'u32'))}, %{args[1]}, {gdim}, {lid};") else: kk(f"mov.u32 {(tmp:=ssa(None, 'tmp', 'u32'))}, {(lang.gid if args[1][0] == 'g' else lang.lid)[args[0]]};") kk(*lang.render_cast(f"%{args[1]}", tmp, dtypes.uint, dtypes.int)) if args[1][0] == "l": local_size.append(args[2]) r[u] = "%" + args[1] kernel = [f".reg .u32 %{args[1]};"] + kernel elif uop == UOps.CONST: r[u] = const(args, dtype, mov=True) elif uop == UOps.LOAD: assert vin[1].dtype is not None val = ssa(u, 'val') if len(vin) > 3: assert vin[2].dtype is not None pred = cast(r[vin[2]], dtypes.bool, vin[2].dtype, pred=True) off = cast(r[vin[1]], dtypes.uint, vin[1].dtype) kk(*lang.render_gep(loc:=ssa(None,'loc',lang.types[dtypes.ulong]), r[vin[0]], off if len(vin)>3 else cast(r[vin[1]], dtypes.uint, vin[1].dtype), dtype), *lang.render_load(loc, val, dtype, gate=pred if len(vin) > 3 else None, alt=r[vin[3]] if len(vin) > 3 else None, ss='.shared' if vin[0].uop == UOps.DEFINE_LOCAL else '')) elif uop == UOps.PHI: kk(f"mov.b{lang.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};") r[u] = r[vin[0]] elif uop == UOps.CAST: assert vin[0].dtype is not None cast(r[vin[0]], dtype, vin[0].dtype, bitcast=isinstance(args, tuple) and args[1], u=u) elif uop == UOps.DEFINE_LOCAL: kk(*lang.render_local(ssa(u, 'local', lang.types[dtypes.ulong]), args[0], args[1], dtype)) elif uop == UOps.DEFINE_GLOBAL: bufs.append((args, dtype)) r[u] = f"%{args}" if lang.load_global: dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype kk(*lang.render_load(args, ssa(u, 'dat', dtype=lang.types[dt]), dt, ss=".param")) else: raise NotImplementedError(f"no code for {uop}") return lang.render_kernel(kernel, function_name, bufs, c.items()) class PTXLanguage(AssemblyLanguage): kernel_prefix = """.version 7.8 .target TARGET .address_size 64 .visible .entry""" barrier = "bar.sync\t0;" has_pred = True load_global = True label_prefix = "$" gid = [f'%ctaid.{chr(120+i)}' for i in range(3)] gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)] lid = [f'%tid.{chr(120+i)}' for i in range(3)] asm_for_op = { UnaryOps.NEG: lambda d,a,dt,name: f"neg.{name} {d}, {a};", UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};", UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};", BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};", BinaryOps.SUB: lambda d,a,b,dt,name: f"sub.{name} {d}, {a}, {b};", BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};", BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.b{name[1:]} {d}, {a}, {b};", BinaryOps.DIV: lambda d,a,b,dt,name: f"div{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a}, {b};", BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};", BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", BinaryOps.CMPEQ: lambda d,a,b,dt,name: f"setp.eq.{name} {d}, {a}, {b};", TernaryOps.MULACC: lambda d,a,b,c,dt,name: (('fma.rn' if dtypes.is_float(dt) else 'mad.lo' if a.split('_')[1]==c.split('_')[1] else 'mad.wide') + f".{name} {d}, {a}, {b}, {c};"), TernaryOps.WHERE: lambda d,a,b,c,dt,name: f"selp.{name} {d}, {b}, {c}, {a};" } supports_half = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.MULACC, TernaryOps.WHERE] types = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64", dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64", dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", dtypes.bool: "pred" } const_requires_mov = [dtypes.half, dtypes.bool] def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]: if dtypes.is_float(dtype): val = f"0f{float_to_hex(x)}" if dtype != dtypes.float64 else f"0d{double_to_hex(x)}" else: val = str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "") if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"] if dtype == dtypes.half: return [f".reg .f32 {mov}_tmp;", f"mov.f32 {mov}_tmp, {val};", f"cvt.rn.f16.f32 {mov}, {mov}_tmp;"] return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val def render_local(self, dest, name, size, dtype) -> List[str]: return [f".shared .align 4 .b8 {name}[{size*dtype.itemsize}];", f"mov.u64 {dest}, {name}[0];"] def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"] def render_bra(self, b1, pred=None, b2=None) -> List[str]: return [f"@{pred} bra {b1};", f"@!{pred} bra {b2};"] if pred else [f"bra {b1};"] def render_gep(self, loc, base, offset, dtype, gate=None) -> List[str]: # this cast is only required because of ocelot if "s32" in offset: return [f".reg .u32 {offset}_cast;", f"cvt.u32.s32 {offset}_cast, {offset};", f"mad.wide.u32 {loc}, {offset}_cast, {dtype.itemsize}, {base};"] else: return [f"mad.wide.u32 {loc}, {offset}, {dtype.itemsize}, {base};"] def mem_type(self, dtype): return 's8' if dtype.itemsize == 1 else 'b16' if dtype == dtypes.float16 else self.types[dtype] def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="") -> List[str]: ret = [] if (byte:=dtype.itemsize == 1): ret.append(f".reg .s8 {dest}_tmp;") if (isbool:= dtype == dtypes.bool): ret.append(f".reg .s16 {dest}_bool;") if gate: ret.extend([f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}];", f"@!{gate} mov.b{'8' if byte else self.types[dtype][1:]} {dest + ('_tmp' if byte else '')}, {alt};"]) else: ret.append(f"ld{ss}.{'s8' if byte else 'b16' if dtype==dtypes.float16 else self.types[dtype]} {dest + ('_tmp' if byte else '')}, [{loc}];") if byte: ret.append(f"cvt.{'s16' if isbool else self.types[dtype]}.s8 {dest + ('_bool' if isbool else '')}, {dest}_tmp;") if isbool: ret.append(f"setp.ne.s16 {dest}, {dest}_bool, {self.render_const(0, dtypes.int16)};") return ret def render_store(self, loc, val, dtype, gate=None, ss="") -> List[str]: if dtype == dtypes.bool: return [f".reg .s16 {val}_cast;", *self.render_cast(f"{val}_cast", val, dtypes.int16, dtype), (f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}], {val}_cast;"] return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}], {val};"] def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]: if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"] if atype == dtypes.bool: return [f"selp.{self.types[dtype]} {d}, {self.render_const(1, dtype)}, {self.render_const(0, dtype)}, {a};"] if dtype == dtypes.bool: return [f"setp.ne.{self.types[atype]} {d}, {a}, {self.render_const(0, atype)};"] rnd = ('.rzi' if dtypes.is_int(dtype) and dtypes.is_float(atype) else '.rn' if dtypes.is_float(dtype) and (dtype.itemsize < atype.itemsize or dtypes.is_int(atype) or atype == dtypes.bool) else '') return [f"cvt{rnd}.{self.types[dtype]}.{self.types[atype]} {d}, {a};"] def render_kernel(self, kernel, function_name, bufs, regs) -> str: kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"] def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1) return (f"{self.kernel_prefix} {function_name}(\n\t" + ',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) + "\n)\n{\n" + '\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) + "\n}") PTXRenderer = functools.partial(uops_to_asm, PTXLanguage())