mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-12 23:54:58 -05:00
* try * test: add logical_not tests * gah im retarded, but this doesn't match types for const() * fix: can't we jsut do this? * big change: I don't actually know what I'm doing * WOOO IM JUST CHANGING EVERYTHING WOW probably gon revert later * BYE BYE noqa: E501 * fix: less lines and add test * fix: rm 2 redundant tests * fix: eq with False so we don't unintentionally implicit upcast, but it's bool anyways so w/e
241 lines
14 KiB
Python
241 lines
14 KiB
Python
from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple
|
|
import functools, struct
|
|
from collections import defaultdict
|
|
from tinygrad.codegen.linearizer import UOps, UOp
|
|
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
|
|
from tinygrad.dtype import dtypes, DType, PtrDType, INVERSE_DTYPES_DICT
|
|
|
|
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
|
def double_to_hex(x): return "%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
|
|
def trunc_float(x, fmt): return struct.unpack(fmt, struct.pack(fmt, x))[0]
|
|
|
|
def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
|
|
|
|
class AssemblyLanguage(NamedTuple):
|
|
kernel_prefix: str = ""
|
|
barrier: str = ""
|
|
load_global: bool = False
|
|
label_prefix: str = ""
|
|
gid: List[str] = []
|
|
gdim: List[str] = []
|
|
lid: List[str] = []
|
|
const_requires_mov: List[DType] = [] # list of dtypes for which creating a const requires a move
|
|
asm_for_op: Dict[Op, Callable[...,str]] = {}
|
|
types: Dict[DType, str] = INVERSE_DTYPES_DICT
|
|
|
|
def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]: raise NotImplementedError()
|
|
def render_local(self, dest, name, size, dtype) -> List[str]: raise NotImplementedError()
|
|
|
|
def render_loop(self, idx, start, label, acc=None) -> List[str]: raise NotImplementedError()
|
|
def render_bra(self, b1, pred=None, b2=None) -> List[str]: raise NotImplementedError()
|
|
def render_gep(self, loc, base, offset, dtype, gate=None) -> List[str]: raise NotImplementedError()
|
|
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="") -> List[str]: raise NotImplementedError()
|
|
def render_store(self, loc, val, dtype, gate=None, ss="") -> List[str]: raise NotImplementedError()
|
|
def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]: raise NotImplementedError()
|
|
|
|
def render_kernel(self, kernel, function_name, bufs, regs) -> str: raise NotImplementedError()
|
|
|
|
def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:List[UOp]) -> str:
|
|
local_size: List[int] = []
|
|
kernel:List[str] = []
|
|
bufs = []
|
|
|
|
def kk(*s: str): kernel.append("\n".join(s))
|
|
|
|
c: DefaultDict[str, int] = defaultdict(int)
|
|
r: Dict[UOp, str] = {}
|
|
def ssa(u, prefix="t", dtype=None) -> str:
|
|
nonlocal c, r
|
|
prefix += f"_{dtype if dtype else lang.types[u.dtype]}_"
|
|
c[prefix] += 1
|
|
if u: r[u] = f"%{prefix}{c[prefix]-1}"
|
|
return f"%{prefix}{c[prefix]-1}"
|
|
|
|
c_label: DefaultDict[str, int] = defaultdict(int)
|
|
r_label: Dict[UOp, str] = {}
|
|
def ssa_label(u, prefix):
|
|
nonlocal c_label, r_label
|
|
c_label[prefix] += 1
|
|
r_label[u] = f"{lang.label_prefix}{prefix}_{c_label[prefix]-1}"
|
|
return r_label[u]
|
|
|
|
def const(x:Union[float,int,bool], dtype, mov=False):
|
|
if mov or dtype in lang.const_requires_mov:
|
|
kk(*lang.render_const(x, dtype, mov=(out:=ssa(None, 'const', lang.types[dtype]))))
|
|
return out
|
|
return lang.render_const(x, dtype)
|
|
|
|
def cast(a:str, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
|
|
if atype == dtype:
|
|
if u: r[u] = a
|
|
return a
|
|
kk(*lang.render_cast((ret:=ssa(u, 'cast', lang.types[dtype])), a, dtype, atype, bitcast))
|
|
return ret
|
|
|
|
for u in uops:
|
|
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
|
if uop == UOps.IF:
|
|
assert vin[0].dtype is not None
|
|
kk(*lang.render_bra(lb:=ssa_label(u, 'if'), cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
|
|
elif uop == UOps.BARRIER and lang.barrier: kk(lang.barrier)
|
|
elif uop == UOps.END:
|
|
if vin[0].uop == UOps.LOOP:
|
|
kk(lang.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, lang.types[dtypes.int]),
|
|
lang.asm_for_op[BinaryOps.CMPLT](pred:=ssa(None, "pred", "pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, lang.types[dtypes.int]))
|
|
kk(*lang.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
|
|
else: kk(f"{r_label[vin[0]]}:")
|
|
elif uop == UOps.STORE:
|
|
assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None
|
|
kk(*lang.render_gep(loc:=ssa(None,'loc','u64'), r[vin[0]], r[vin[1]], vin[0].dtype))
|
|
if len(vin) > 3:
|
|
assert vin[3].dtype is not None
|
|
pred = cast(r[vin[3]], dtypes.bool, vin[3].dtype, pred=True)
|
|
kk(*lang.render_store(loc, r[vin[2]], vin[0].dtype, gate=pred if len(vin)>3 else None, ss='.shared' if vin[0].uop == UOps.DEFINE_LOCAL else ''))
|
|
else:
|
|
assert dtype is not None, f"None dtype for uop {uop}"
|
|
if uop == UOps.LOOP: kk(*lang.render_loop(ssa(u, 'ridx'), r[vin[0]], ssa_label(u, 'loop')))
|
|
elif uop == UOps.ALU:
|
|
assert vin[0].dtype is not None
|
|
if args == BinaryOps.CMPLT or args == BinaryOps.CMPEQ:
|
|
regs = [cast(r[x], dtypes.int16, dtypes.bool) if x.dtype == dtypes.bool else r[x] for x in vin]
|
|
dt = dtypes.int16 if vin[0].dtype == dtypes.bool else vin[0].dtype
|
|
kk(lang.asm_for_op[args](pred:=ssa(u,'lt','pred'), *regs, dt, lang.types[dt]))
|
|
elif args == TernaryOps.MULACC:
|
|
assert vin[1].dtype is not None
|
|
kk(lang.asm_for_op[args](ssa(u, 'alu'), *[r[x] for x in vin], dtype, lang.types[vin[1].dtype]))
|
|
else: kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], dtype, lang.types[dtype]))
|
|
elif uop == UOps.DEFINE_ACC: kk(f"mov.b{lang.types[dtype][1:]} {ssa(u, 'acc')}, {const(args, dtype)};")
|
|
elif uop == UOps.SPECIAL:
|
|
if args[1][0] == "i": kk(f"mov.u32 %{args[1]}, {lang.gid[args[0]]};", f"mov.u32 {(gdim:=ssa(None,'tmp','u32'))}, {lang.gdim[args[0]]};",
|
|
f"mov.u32 {(lid:=ssa(None,'tmp','u32'))}, {lang.lid[args[0]]};",
|
|
f"mad.lo.u32 {(tmp:=ssa(None, 'tmp', 'u32'))}, %{args[1]}, {gdim}, {lid};")
|
|
else: kk(f"mov.u32 {(tmp:=ssa(None, 'tmp', 'u32'))}, {(lang.gid if args[1][0] == 'g' else lang.lid)[args[0]]};")
|
|
kk(*lang.render_cast(f"%{args[1]}", tmp, dtypes.uint, dtypes.int))
|
|
if args[1][0] == "l": local_size.append(args[2])
|
|
r[u] = "%" + args[1]
|
|
kernel = [f".reg .u32 %{args[1]};"] + kernel
|
|
elif uop == UOps.CONST: r[u] = const(args, dtype, mov=True)
|
|
elif uop == UOps.LOAD:
|
|
assert vin[1].dtype is not None
|
|
val = ssa(u, 'val')
|
|
if len(vin) > 3:
|
|
assert vin[2].dtype is not None
|
|
pred = cast(r[vin[2]], dtypes.bool, vin[2].dtype, pred=True)
|
|
off = cast(r[vin[1]], dtypes.uint, vin[1].dtype)
|
|
kk(*lang.render_gep(loc:=ssa(None,'loc',lang.types[dtypes.ulong]), r[vin[0]], off if len(vin)>3 else cast(r[vin[1]],
|
|
dtypes.uint, vin[1].dtype), dtype),
|
|
*lang.render_load(loc, val, dtype, gate=pred if len(vin) > 3 else None,
|
|
alt=r[vin[3]] if len(vin) > 3 else None, ss='.shared' if vin[0].uop == UOps.DEFINE_LOCAL else ''))
|
|
elif uop == UOps.PHI:
|
|
kk(f"mov.b{lang.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
|
|
r[u] = r[vin[0]]
|
|
elif uop == UOps.CAST:
|
|
assert vin[0].dtype is not None
|
|
cast(r[vin[0]], dtype, vin[0].dtype, bitcast=isinstance(args, tuple) and args[1], u=u)
|
|
elif uop == UOps.DEFINE_LOCAL: kk(*lang.render_local(ssa(u, 'local', lang.types[dtypes.ulong]), args[0], args[1], dtype))
|
|
elif uop == UOps.DEFINE_GLOBAL:
|
|
bufs.append((args, dtype))
|
|
r[u] = f"%{args}"
|
|
if lang.load_global:
|
|
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
|
|
kk(*lang.render_load(args, ssa(u, 'dat', dtype=lang.types[dt]), dt, ss=".param"))
|
|
else: raise NotImplementedError(f"no code for {uop}")
|
|
|
|
return lang.render_kernel(kernel, function_name, bufs, c.items())
|
|
|
|
class PTXLanguage(AssemblyLanguage):
|
|
kernel_prefix = """.version 7.8
|
|
.target TARGET
|
|
.address_size 64
|
|
.visible .entry"""
|
|
barrier = "bar.sync\t0;"
|
|
has_pred = True
|
|
load_global = True
|
|
label_prefix = "$"
|
|
gid = [f'%ctaid.{chr(120+i)}' for i in range(3)]
|
|
gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)]
|
|
lid = [f'%tid.{chr(120+i)}' for i in range(3)]
|
|
asm_for_op = {
|
|
UnaryOps.NEG: lambda d,a,dt,name: f"neg.{name} {d}, {a};",
|
|
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
|
|
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};",
|
|
UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
|
|
BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
|
|
BinaryOps.SUB: lambda d,a,b,dt,name: f"sub.{name} {d}, {a}, {b};",
|
|
BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
|
|
BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.b{name[1:]} {d}, {a}, {b};",
|
|
BinaryOps.DIV: lambda d,a,b,dt,name: f"div{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a}, {b};",
|
|
BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
|
|
BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};",
|
|
BinaryOps.CMPEQ: lambda d,a,b,dt,name: f"setp.eq.{name} {d}, {a}, {b};",
|
|
TernaryOps.MULACC: lambda d,a,b,c,dt,name: (('fma.rn' if dtypes.is_float(dt) else 'mad.lo' if a.split('_')[1]==c.split('_')[1] else 'mad.wide') +
|
|
f".{name} {d}, {a}, {b}, {c};"),
|
|
TernaryOps.WHERE: lambda d,a,b,c,dt,name: f"selp.{name} {d}, {b}, {c}, {a};"
|
|
}
|
|
supports_half = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT,
|
|
TernaryOps.MULACC, TernaryOps.WHERE]
|
|
types = {
|
|
dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
|
|
dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
|
|
dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64",
|
|
dtypes.bool: "pred"
|
|
}
|
|
|
|
const_requires_mov = [dtypes.half, dtypes.bool]
|
|
|
|
def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]:
|
|
if dtypes.is_float(dtype): val = f"0f{float_to_hex(x)}" if dtype != dtypes.float64 else f"0d{double_to_hex(x)}"
|
|
else: val = str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
|
|
if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"]
|
|
if dtype == dtypes.half: return [f".reg .f32 {mov}_tmp;", f"mov.f32 {mov}_tmp, {val};", f"cvt.rn.f16.f32 {mov}, {mov}_tmp;"]
|
|
return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val
|
|
|
|
def render_local(self, dest, name, size, dtype) -> List[str]:
|
|
return [f".shared .align 4 .b8 {name}[{size*dtype.itemsize}];", f"mov.u64 {dest}, {name}[0];"]
|
|
|
|
def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"]
|
|
|
|
def render_bra(self, b1, pred=None, b2=None) -> List[str]: return [f"@{pred} bra {b1};", f"@!{pred} bra {b2};"] if pred else [f"bra {b1};"]
|
|
|
|
def render_gep(self, loc, base, offset, dtype, gate=None) -> List[str]:
|
|
# this cast is only required because of ocelot
|
|
if "s32" in offset:
|
|
return [f".reg .u32 {offset}_cast;", f"cvt.u32.s32 {offset}_cast, {offset};", f"mad.wide.u32 {loc}, {offset}_cast, {dtype.itemsize}, {base};"]
|
|
else: return [f"mad.wide.u32 {loc}, {offset}, {dtype.itemsize}, {base};"]
|
|
|
|
def mem_type(self, dtype): return 's8' if dtype.itemsize == 1 else 'b16' if dtype == dtypes.float16 else self.types[dtype]
|
|
|
|
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="") -> List[str]:
|
|
ret = []
|
|
if (byte:=dtype.itemsize == 1): ret.append(f".reg .s8 {dest}_tmp;")
|
|
if (isbool:= dtype == dtypes.bool): ret.append(f".reg .s16 {dest}_bool;")
|
|
if gate: ret.extend([f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}];",
|
|
f"@!{gate} mov.b{'8' if byte else self.types[dtype][1:]} {dest + ('_tmp' if byte else '')}, {alt};"])
|
|
else: ret.append(f"ld{ss}.{'s8' if byte else 'b16' if dtype==dtypes.float16 else self.types[dtype]} {dest + ('_tmp' if byte else '')}, [{loc}];")
|
|
if byte: ret.append(f"cvt.{'s16' if isbool else self.types[dtype]}.s8 {dest + ('_bool' if isbool else '')}, {dest}_tmp;")
|
|
if isbool: ret.append(f"setp.ne.s16 {dest}, {dest}_bool, {self.render_const(0, dtypes.int16)};")
|
|
return ret
|
|
|
|
def render_store(self, loc, val, dtype, gate=None, ss="") -> List[str]:
|
|
if dtype == dtypes.bool: return [f".reg .s16 {val}_cast;", *self.render_cast(f"{val}_cast", val, dtypes.int16, dtype),
|
|
(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}], {val}_cast;"]
|
|
return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}], {val};"]
|
|
|
|
def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]:
|
|
if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"]
|
|
if atype == dtypes.bool: return [f"selp.{self.types[dtype]} {d}, {self.render_const(1, dtype)}, {self.render_const(0, dtype)}, {a};"]
|
|
if dtype == dtypes.bool: return [f"setp.ne.{self.types[atype]} {d}, {a}, {self.render_const(0, atype)};"]
|
|
rnd = ('.rzi' if dtypes.is_int(dtype) and dtypes.is_float(atype) else
|
|
'.rn' if dtypes.is_float(dtype) and (dtype.itemsize < atype.itemsize or dtypes.is_int(atype) or atype == dtypes.bool) else '')
|
|
return [f"cvt{rnd}.{self.types[dtype]}.{self.types[atype]} {d}, {a};"]
|
|
|
|
def render_kernel(self, kernel, function_name, bufs, regs) -> str:
|
|
kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]
|
|
def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1)
|
|
return (f"{self.kernel_prefix} {function_name}(\n\t" +
|
|
',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) + "\n)\n{\n" +
|
|
'\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) +
|
|
"\n}")
|
|
|
|
PTXRenderer = functools.partial(uops_to_asm, PTXLanguage())
|