PTX - implement float 4, ptr arithmetics and other speed improvements (#3775)

* ptx float4 implementation

* remove from cache when trimming uops

* Gate for float4

* Linting fix

* disable test reasonable time for ptx

* import getenv

* Update uops.py

* linter

* Add div test for half

* upcast if op does not support operation

* fix offset

* Run only if dtype supported

* zero out registers when accessing by pred + cleanup

* Remove trailing whitespace

* revert

* spacing fix

* move cache clearing outside loop

* did this suddenly start working?

* unused import removed

* Remove cast

* Use pattern matching

* linting

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Szymon Ożóg
2024-03-22 16:54:02 +01:00
committed by GitHub
parent f4055439dc
commit 624bc89910
3 changed files with 118 additions and 67 deletions

View File

@@ -170,6 +170,7 @@ class UOpGraph:
if len(nu) == len(self.uops): break
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
self.uops = nu
self.saved_exprs = {k:v for k,v in self.saved_exprs.items() if v in nu}
# optional
def type_verify(self):

View File

@@ -1,14 +1,17 @@
from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple
from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple, Set
import functools, struct
from collections import defaultdict
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
from tinygrad.dtype import dtypes, DType, PtrDType, INVERSE_DTYPES_DICT
from tinygrad.codegen.uops import UOpGraph
from tinygrad.codegen.uops import UOpGraph, PatternMatcher
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
def double_to_hex(x): return "%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
def trunc_float(x, fmt): return struct.unpack(fmt, struct.pack(fmt, x))[0]
def render_val(x, dtype):
if dtypes.is_float(dtype):
if dtype == dtypes.double: return "0d%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
elif dtype == dtypes.half: return "0x%02X%02X" % tuple(struct.pack("e",x)[::-1])
return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
@@ -23,6 +26,7 @@ class AssemblyLanguage(NamedTuple):
const_requires_mov: List[DType] = [] # list of dtypes for which creating a const requires a move
asm_for_op: Dict[Op, Callable[...,str]] = {}
types: Dict[DType, str] = INVERSE_DTYPES_DICT
supports_half: List[Op] = []
def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]: raise NotImplementedError()
def render_local(self, dest, name, size, dtype) -> List[str]: raise NotImplementedError()
@@ -30,47 +34,77 @@ class AssemblyLanguage(NamedTuple):
def render_loop(self, idx, start, label, acc=None) -> List[str]: raise NotImplementedError()
def render_bra(self, b1, pred=None, b2=None) -> List[str]: raise NotImplementedError()
def render_gep(self, loc, base, offset, dtype, gate=None) -> List[str]: raise NotImplementedError()
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="") -> List[str]: raise NotImplementedError()
def render_store(self, loc, val, dtype, gate=None, ss="") -> List[str]: raise NotImplementedError()
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]: raise NotImplementedError()
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]: raise NotImplementedError()
def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]: raise NotImplementedError()
def render_kernel(self, kernel, function_name, bufs, regs) -> str: raise NotImplementedError()
def mem_type(self, dtype) -> str: raise NotImplementedError()
def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
kernel:List[str] = []
bufs = []
def eq_rep(root, x, y):
root.arg = BinaryOps.XOR
new = uops.add(UOps.ALU, dtypes.bool, (root,), arg=UnaryOps.NEG, insert_before=uops.uops.index(root)+1)
return new
def lt_rep(x, y):
new = uops.add(UOps.ALU, dtypes.bool, (u.vin[0],), arg=UnaryOps.NEG, insert_before=uops.uops.index(u))
u.vin = (new, u.vin[1])
u.arg = BinaryOps.MUL
def ld_rep(root, x, y):
root.dtype = dtypes.uint8
new = uops.add(UOps.CAST, dtypes.bool, (root,), insert_before=uops.uops.index(root)+1)
ptr_ar(root)
return new
def gate_rep(root, x, y, z, k):
new = uops.add(UOps.CAST, dtypes.uint8, (k,), insert_before=uops.uops.index(root))
root.vin = (x,y,z,new)
return ld_rep(root,x,y)
def ptr_ar(root):
root.arg = '.shared' if root.vin[0].uop == UOps.DEFINE_LOCAL else '.global' # move this to the argL
if root.vin[0].dtype.itemsize > 1:
val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=root.vin[0].dtype.itemsize, insert_before=uops.uops.index(root))
ptr = uops.add(UOps.ALU, dtypes.int, (root.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root))
else: ptr = root.vin[1]
if ptr.uop == UOps.CONST: root.vin = (root.vin[0], ptr) + root.vin[2:]
else:
zero = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0, cachable=False, insert_before=uops.uops.index(root))
bptr = uops.add(UOps.CAST, dtypes.uint64, (ptr,), insert_before=uops.uops.index(root))
fptr = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], bptr), arg=BinaryOps.ADD, insert_before=uops.uops.index(root))
root.vin = (fptr, zero) + root.vin[2:]
matcher = PatternMatcher([
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})}, eq_rep),
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})}, lt_rep),
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool,
"vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})}, gate_rep),
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({"__name__": "x"},{"__name__": "y"})}, ld_rep),
({"__name__": "root", "uop": UOps.STORE, "vin": {}}, ptr_ar),
({"__name__": "root", "uop": UOps.LOAD, "vin": {}}, ptr_ar),
])
# here we do a pretransform on UOps to fix some shortcomings of PTX
# all uops must be a register
# TODO: uops class should make these rewrites easier
replace: Dict[UOp, UOp] = {}
seen: Set[UOp] = set()
for u in uops:
if u in seen: continue
seen.add(u)
for o,n in replace.items():
if o in u.vin and u is not n:
u.vin = tuple(n if x == o else x for x in u.vin)
if u.uop is UOps.LOAD and u.dtype is dtypes.bool:
# rewrite load bool
if len(u.vin) == 4:
new = uops.add(UOps.CAST, dtypes.uint8, (u.vin[3],), insert_before=uops.uops.index(u))
u.vin = u.vin[0:3] + (new,)
u.dtype = dtypes.uint8
new = uops.add(UOps.CAST, dtypes.bool, (u,), insert_before=uops.uops.index(u)+1)
replace[u] = new
if u.uop is UOps.ALU and u.arg in {BinaryOps.CMPEQ, BinaryOps.CMPLT} and u.vin[0].dtype is dtypes.bool:
if u.arg == BinaryOps.CMPEQ:
u.arg = BinaryOps.XOR
new = uops.add(UOps.ALU, dtypes.bool, (u,), arg=UnaryOps.NEG, insert_before=uops.uops.index(u)+1)
replace[u] = new
if u.arg == BinaryOps.CMPLT:
new = uops.add(UOps.ALU, dtypes.bool, (u.vin[0],), arg=UnaryOps.NEG, insert_before=uops.uops.index(u))
u.vin = (new, u.vin[1])
u.arg = BinaryOps.MUL
#uops.print()
if rew := matcher.rewrite(u): replace[u] = rew
def kk(*s: str): kernel.append("\n".join(s))
c: DefaultDict[str, int] = defaultdict(int)
r: Dict[UOp, str] = {}
r: Dict[UOp, Union[List[str], str]] = {}
def ssa(u, prefix="t", dtype=None) -> str:
nonlocal c, r
prefix += f"_{dtype if dtype else lang.types[u.dtype]}_"
@@ -92,7 +126,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
return out
return lang.render_const(x, dtype)
def cast(a:str, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
def cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
if atype == dtype:
if u: r[u] = a
return a
@@ -113,45 +147,70 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
kk(f"{r_label[vin[0]]}:")
elif uop == UOps.STORE:
assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None
kk(*lang.render_gep(loc:=ssa(None,'loc','u64'), r[vin[0]], r[vin[1]], vin[0].dtype))
if len(vin) > 3:
assert vin[3].dtype is not None
pred = cast(r[vin[3]], dtypes.bool, vin[3].dtype, pred=True)
kk(*lang.render_store(loc, r[vin[2]], vin[0].dtype, gate=pred if len(vin)>3 else None, ss='.shared' if vin[0].uop == UOps.DEFINE_LOCAL else ''))
if vin[2].dtype.count > 1:
kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \
f"st{u.arg}.v{vin[2].dtype.count}.{lang.mem_type(vin[2].dtype.scalar())} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};")
else:
kk(*lang.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=u.arg, offset=vin[1].arg))
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop == UOps.LOOP: kk(*lang.render_loop(ssa(u, 'ridx'), r[vin[0]], ssa_label(u, 'loop')))
elif uop == UOps.ALU:
assert vin[0].dtype is not None
operands = [r[x] for x in vin]
lab = ssa(u, "alu")
if needs_upcast := dtype == dtypes.half and args not in lang.supports_half:
dtype = dtypes.float32
out_lab, lab = lab, ssa(None, "alu_cast", lang.types[dtype])
for i, op in enumerate(operands):
operands[i] = ssa(None, "alu_cast", lang.types[dtype])
kk(*lang.render_cast(operands[i], op, dtype, dtypes.half)) # type: ignore
if args == BinaryOps.CMPLT or args == BinaryOps.CMPEQ:
# pass in the other dtype here
kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], vin[0].dtype, lang.types[vin[0].dtype]))
kk(lang.asm_for_op[args](lab, *operands, vin[0].dtype, lang.types[vin[0].dtype]))
else:
kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], dtype, lang.types[dtype]))
elif uop == UOps.DEFINE_ACC: kk(f"mov.b{lang.types[dtype][1:]} {ssa(u, 'acc')}, {const(args, dtype)};")
kk(lang.asm_for_op[args](lab, *operands, dtype, lang.types[dtype]))
if needs_upcast:
kk(*lang.render_cast(out_lab, lab, dtypes.half, dtype))
elif uop == UOps.DEFINE_ACC:
if dtype.count > 1:
r[u] = [ssa(None, 'acc', lang.types[dtype.scalar()]) for _ in range(dtype.count)]
for uu in r[u]:
kk(f"mov.b{lang.types[dtype.scalar()][1:]} {uu}, {const(args, dtype.scalar())};")
else:
kk(f"mov.b{lang.types[dtype][1:]} {ssa(u, 'acc')}, {const(args, dtype)};")
elif uop == UOps.SPECIAL:
assert args[1][0] != "i", "idx not supported"
kk(f"mov.u32 %{args[1]}, {(lang.gid if args[1][0] == 'g' else lang.lid)[args[0]]};")
r[u] = "%" + args[1]
kernel = [f".reg .u32 %{args[1]};"] + kernel
elif uop == UOps.CONST: r[u] = const(args, dtype, mov=True)
elif uop == UOps.CONST:
if dtype.count > 1:
r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
else:
r[u] = const(args, dtype, mov=True)
elif uop == UOps.GEP:
r[u] = r[vin[0]][u.arg]
elif uop == UOps.LOAD:
assert vin[1].dtype is not None
val = ssa(u, 'val')
if len(vin) > 3:
assert vin[2].dtype is not None
pred = cast(r[vin[2]], dtypes.bool, vin[2].dtype, pred=True)
off = cast(r[vin[1]], dtypes.uint, vin[1].dtype)
kk(*lang.render_gep(loc:=ssa(None,'loc',lang.types[dtypes.ulong]), r[vin[0]], off if len(vin)>3 else cast(r[vin[1]],
dtypes.uint, vin[1].dtype), dtype),
*lang.render_load(loc, val, dtype, gate=pred if len(vin) > 3 else None,
alt=r[vin[3]] if len(vin) > 3 else None, ss='.shared' if vin[0].uop == UOps.DEFINE_LOCAL else ''))
if dtype.count > 1:
r[u] = [ssa(None, 'val', lang.types[dtype.scalar()]) for _ in range(dtype.count)]
if(len(vin)>3):
for v in r[u]: kk(f"mov.{lang.mem_type(dtype.scalar())} {v}, {render_val(0, dtype.scalar())};")
kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
+ f" ld{u.arg}.v{dtype.count}.{lang.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
else:
kk(*lang.render_load(r[vin[0]], ssa(u, 'val'), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
alt=r[vin[3]] if len(vin) > 3 else None, ss=u.arg, offset=vin[1].arg))
elif uop == UOps.PHI:
kk(f"mov.b{lang.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
r[u] = r[vin[0]]
elif uop in {UOps.CAST, UOps.BITCAST}:
assert vin[0].dtype is not None
cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
if dtype.count>1:
r[u] = [r[x] for x in vin] # type: ignore
else:
cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
elif uop == UOps.DEFINE_LOCAL:
# TODO: we should sum these, and fetch 0xC000 from somewhere
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
@@ -172,7 +231,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
return lang.render_kernel(kernel, function_name, bufs, c.items())
class PTXLanguage(AssemblyLanguage):
kernel_prefix = """.version 7.8
kernel_prefix = """.version 7.5
.target TARGET
.address_size 64
.visible .entry"""
@@ -210,10 +269,8 @@ class PTXLanguage(AssemblyLanguage):
const_requires_mov = [dtypes.half, dtypes.bool]
def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]:
if dtypes.is_float(dtype): val = f"0f{float_to_hex(x)}" if dtype != dtypes.float64 else f"0d{double_to_hex(x)}"
else: val = str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
val = render_val(x, dtype)
if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"]
if dtype == dtypes.half: return [f".reg .f32 {mov}_tmp;", f"mov.f32 {mov}_tmp, {val};", f"cvt.rn.f16.f32 {mov}, {mov}_tmp;"]
return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val
def render_local(self, dest, name, size, dtype) -> List[str]:
@@ -223,31 +280,25 @@ class PTXLanguage(AssemblyLanguage):
def render_bra(self, b1, pred=None, b2=None) -> List[str]: return [f"@{pred} bra {b1};", f"@!{pred} bra {b2};"] if pred else [f"bra {b1};"]
def render_gep(self, loc, base, offset, dtype, gate=None) -> List[str]:
# this cast is only required because of ocelot
if "s32" in offset:
return [f".reg .u32 {offset}_cast;", f"cvt.u32.s32 {offset}_cast, {offset};", f"mad.wide.u32 {loc}, {offset}_cast, {dtype.itemsize}, {base};"]
else: return [f"mad.wide.u32 {loc}, {offset}, {dtype.itemsize}, {base};"]
def mem_type(self, dtype): return 's8' if dtype.itemsize == 1 else 'b16' if dtype == dtypes.float16 else self.types[dtype]
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="") -> List[str]:
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]:
assert dtype is not dtypes.bool
ret = []
if gate: ret.extend([f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}];",
if gate: ret.extend([f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];",
f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"])
else: ret.append(f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}];")
else: ret.append(f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];")
return ret
def render_store(self, loc, val, dtype, gate=None, ss="") -> List[str]:
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]:
if dtype == dtypes.bool: return [f".reg .s16 {val}_cast;", *self.render_cast(f"{val}_cast", val, dtypes.int16, dtype),
(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}], {val}_cast;"]
return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}], {val};"]
(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}+{offset}], {val}_cast;"]
return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}+{offset}], {val};"]
def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]:
if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"]
if atype == dtypes.bool: return [f"selp.{self.types[dtype]} {d}, {self.render_const(1, dtype)}, {self.render_const(0, dtype)}, {a};"]
if dtype == dtypes.bool: return [f"setp.ne.{self.types[atype]} {d}, {a}, {self.render_const(0, atype)};"]
if atype == dtypes.bool: return[f"selp.b{self.types[dtype][1:]} {d}, {render_val(1, dtype)}, {render_val(0, dtype)}, {a};"]
if dtype == dtypes.bool: return [f"setp.ne.b{self.types[atype][1:]} {d}, {a}, {self.render_const(0, atype)};"]
rnd = ('.rzi' if dtypes.is_int(dtype) and dtypes.is_float(atype) else
'.rn' if dtypes.is_float(dtype) and (dtype.itemsize < atype.itemsize or dtypes.is_int(atype) or atype == dtypes.bool) else '')
return [f"cvt{rnd}.{self.types[dtype]}.{self.types[atype]} {d}, {a};"]

View File

@@ -52,8 +52,7 @@ def _get_bytes(arg, get_str, get_sz, check) -> bytes:
return ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value)
class PTXCompiler(Compiler):
linearizer_opts = LinearizerOptions("CUDA", suffix="PTX", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024],
supports_float4=False, shared_max=49152)
linearizer_opts = LinearizerOptions("CUDA", suffix="PTX", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], shared_max=49152)
def __init__(self, arch:str):
self.arch = arch
PTXCompiler.linearizer_opts = PTXCompiler.linearizer_opts._replace(has_tensor_cores=int(arch[3:]) >= 80)