mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
PTX - implement float 4, ptr arithmetics and other speed improvements (#3775)
* ptx float4 implementation * remove from cache when trimming uops * Gate for float4 * Linting fix * disable test reasonable time for ptx * import getenv * Update uops.py * linter * Add div test for half * upcast if op does not support operation * fix offset * Run only if dtype supported * zero out registers when accessing by pred + cleanup * Remove trailing whitespace * revert * spacing fix * move cache clearing outside loop * did this suddenly start working? * unused import removed * Remove cast * Use pattern matching * linting --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -170,6 +170,7 @@ class UOpGraph:
|
||||
if len(nu) == len(self.uops): break
|
||||
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
|
||||
self.uops = nu
|
||||
self.saved_exprs = {k:v for k,v in self.saved_exprs.items() if v in nu}
|
||||
|
||||
# optional
|
||||
def type_verify(self):
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple
|
||||
from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple, Set
|
||||
import functools, struct
|
||||
from collections import defaultdict
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, INVERSE_DTYPES_DICT
|
||||
from tinygrad.codegen.uops import UOpGraph
|
||||
from tinygrad.codegen.uops import UOpGraph, PatternMatcher
|
||||
|
||||
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
||||
def double_to_hex(x): return "%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
|
||||
def trunc_float(x, fmt): return struct.unpack(fmt, struct.pack(fmt, x))[0]
|
||||
def render_val(x, dtype):
|
||||
if dtypes.is_float(dtype):
|
||||
if dtype == dtypes.double: return "0d%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
|
||||
elif dtype == dtypes.half: return "0x%02X%02X" % tuple(struct.pack("e",x)[::-1])
|
||||
return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
||||
return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
|
||||
|
||||
def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
|
||||
|
||||
@@ -23,6 +26,7 @@ class AssemblyLanguage(NamedTuple):
|
||||
const_requires_mov: List[DType] = [] # list of dtypes for which creating a const requires a move
|
||||
asm_for_op: Dict[Op, Callable[...,str]] = {}
|
||||
types: Dict[DType, str] = INVERSE_DTYPES_DICT
|
||||
supports_half: List[Op] = []
|
||||
|
||||
def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]: raise NotImplementedError()
|
||||
def render_local(self, dest, name, size, dtype) -> List[str]: raise NotImplementedError()
|
||||
@@ -30,47 +34,77 @@ class AssemblyLanguage(NamedTuple):
|
||||
def render_loop(self, idx, start, label, acc=None) -> List[str]: raise NotImplementedError()
|
||||
def render_bra(self, b1, pred=None, b2=None) -> List[str]: raise NotImplementedError()
|
||||
def render_gep(self, loc, base, offset, dtype, gate=None) -> List[str]: raise NotImplementedError()
|
||||
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="") -> List[str]: raise NotImplementedError()
|
||||
def render_store(self, loc, val, dtype, gate=None, ss="") -> List[str]: raise NotImplementedError()
|
||||
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]: raise NotImplementedError()
|
||||
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]: raise NotImplementedError()
|
||||
def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]: raise NotImplementedError()
|
||||
|
||||
def render_kernel(self, kernel, function_name, bufs, regs) -> str: raise NotImplementedError()
|
||||
def mem_type(self, dtype) -> str: raise NotImplementedError()
|
||||
|
||||
def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
kernel:List[str] = []
|
||||
bufs = []
|
||||
|
||||
def eq_rep(root, x, y):
|
||||
root.arg = BinaryOps.XOR
|
||||
new = uops.add(UOps.ALU, dtypes.bool, (root,), arg=UnaryOps.NEG, insert_before=uops.uops.index(root)+1)
|
||||
return new
|
||||
|
||||
def lt_rep(x, y):
|
||||
new = uops.add(UOps.ALU, dtypes.bool, (u.vin[0],), arg=UnaryOps.NEG, insert_before=uops.uops.index(u))
|
||||
u.vin = (new, u.vin[1])
|
||||
u.arg = BinaryOps.MUL
|
||||
|
||||
def ld_rep(root, x, y):
|
||||
root.dtype = dtypes.uint8
|
||||
new = uops.add(UOps.CAST, dtypes.bool, (root,), insert_before=uops.uops.index(root)+1)
|
||||
ptr_ar(root)
|
||||
return new
|
||||
|
||||
def gate_rep(root, x, y, z, k):
|
||||
new = uops.add(UOps.CAST, dtypes.uint8, (k,), insert_before=uops.uops.index(root))
|
||||
root.vin = (x,y,z,new)
|
||||
return ld_rep(root,x,y)
|
||||
|
||||
def ptr_ar(root):
|
||||
root.arg = '.shared' if root.vin[0].uop == UOps.DEFINE_LOCAL else '.global' # move this to the argL
|
||||
if root.vin[0].dtype.itemsize > 1:
|
||||
val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=root.vin[0].dtype.itemsize, insert_before=uops.uops.index(root))
|
||||
ptr = uops.add(UOps.ALU, dtypes.int, (root.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root))
|
||||
else: ptr = root.vin[1]
|
||||
if ptr.uop == UOps.CONST: root.vin = (root.vin[0], ptr) + root.vin[2:]
|
||||
else:
|
||||
zero = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0, cachable=False, insert_before=uops.uops.index(root))
|
||||
bptr = uops.add(UOps.CAST, dtypes.uint64, (ptr,), insert_before=uops.uops.index(root))
|
||||
fptr = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], bptr), arg=BinaryOps.ADD, insert_before=uops.uops.index(root))
|
||||
root.vin = (fptr, zero) + root.vin[2:]
|
||||
|
||||
matcher = PatternMatcher([
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})}, eq_rep),
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})}, lt_rep),
|
||||
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool,
|
||||
"vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})}, gate_rep),
|
||||
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({"__name__": "x"},{"__name__": "y"})}, ld_rep),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": {}}, ptr_ar),
|
||||
({"__name__": "root", "uop": UOps.LOAD, "vin": {}}, ptr_ar),
|
||||
])
|
||||
|
||||
# here we do a pretransform on UOps to fix some shortcomings of PTX
|
||||
# all uops must be a register
|
||||
# TODO: uops class should make these rewrites easier
|
||||
replace: Dict[UOp, UOp] = {}
|
||||
seen: Set[UOp] = set()
|
||||
for u in uops:
|
||||
if u in seen: continue
|
||||
seen.add(u)
|
||||
for o,n in replace.items():
|
||||
if o in u.vin and u is not n:
|
||||
u.vin = tuple(n if x == o else x for x in u.vin)
|
||||
if u.uop is UOps.LOAD and u.dtype is dtypes.bool:
|
||||
# rewrite load bool
|
||||
if len(u.vin) == 4:
|
||||
new = uops.add(UOps.CAST, dtypes.uint8, (u.vin[3],), insert_before=uops.uops.index(u))
|
||||
u.vin = u.vin[0:3] + (new,)
|
||||
u.dtype = dtypes.uint8
|
||||
new = uops.add(UOps.CAST, dtypes.bool, (u,), insert_before=uops.uops.index(u)+1)
|
||||
replace[u] = new
|
||||
if u.uop is UOps.ALU and u.arg in {BinaryOps.CMPEQ, BinaryOps.CMPLT} and u.vin[0].dtype is dtypes.bool:
|
||||
if u.arg == BinaryOps.CMPEQ:
|
||||
u.arg = BinaryOps.XOR
|
||||
new = uops.add(UOps.ALU, dtypes.bool, (u,), arg=UnaryOps.NEG, insert_before=uops.uops.index(u)+1)
|
||||
replace[u] = new
|
||||
if u.arg == BinaryOps.CMPLT:
|
||||
new = uops.add(UOps.ALU, dtypes.bool, (u.vin[0],), arg=UnaryOps.NEG, insert_before=uops.uops.index(u))
|
||||
u.vin = (new, u.vin[1])
|
||||
u.arg = BinaryOps.MUL
|
||||
#uops.print()
|
||||
if rew := matcher.rewrite(u): replace[u] = rew
|
||||
|
||||
def kk(*s: str): kernel.append("\n".join(s))
|
||||
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
r: Dict[UOp, str] = {}
|
||||
r: Dict[UOp, Union[List[str], str]] = {}
|
||||
def ssa(u, prefix="t", dtype=None) -> str:
|
||||
nonlocal c, r
|
||||
prefix += f"_{dtype if dtype else lang.types[u.dtype]}_"
|
||||
@@ -92,7 +126,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
return out
|
||||
return lang.render_const(x, dtype)
|
||||
|
||||
def cast(a:str, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
|
||||
def cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
|
||||
if atype == dtype:
|
||||
if u: r[u] = a
|
||||
return a
|
||||
@@ -113,45 +147,70 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
kk(f"{r_label[vin[0]]}:")
|
||||
elif uop == UOps.STORE:
|
||||
assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None
|
||||
kk(*lang.render_gep(loc:=ssa(None,'loc','u64'), r[vin[0]], r[vin[1]], vin[0].dtype))
|
||||
if len(vin) > 3:
|
||||
assert vin[3].dtype is not None
|
||||
pred = cast(r[vin[3]], dtypes.bool, vin[3].dtype, pred=True)
|
||||
kk(*lang.render_store(loc, r[vin[2]], vin[0].dtype, gate=pred if len(vin)>3 else None, ss='.shared' if vin[0].uop == UOps.DEFINE_LOCAL else ''))
|
||||
if vin[2].dtype.count > 1:
|
||||
kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \
|
||||
f"st{u.arg}.v{vin[2].dtype.count}.{lang.mem_type(vin[2].dtype.scalar())} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};")
|
||||
else:
|
||||
kk(*lang.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=u.arg, offset=vin[1].arg))
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop == UOps.LOOP: kk(*lang.render_loop(ssa(u, 'ridx'), r[vin[0]], ssa_label(u, 'loop')))
|
||||
elif uop == UOps.ALU:
|
||||
assert vin[0].dtype is not None
|
||||
operands = [r[x] for x in vin]
|
||||
lab = ssa(u, "alu")
|
||||
if needs_upcast := dtype == dtypes.half and args not in lang.supports_half:
|
||||
dtype = dtypes.float32
|
||||
out_lab, lab = lab, ssa(None, "alu_cast", lang.types[dtype])
|
||||
for i, op in enumerate(operands):
|
||||
operands[i] = ssa(None, "alu_cast", lang.types[dtype])
|
||||
kk(*lang.render_cast(operands[i], op, dtype, dtypes.half)) # type: ignore
|
||||
if args == BinaryOps.CMPLT or args == BinaryOps.CMPEQ:
|
||||
# pass in the other dtype here
|
||||
kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], vin[0].dtype, lang.types[vin[0].dtype]))
|
||||
kk(lang.asm_for_op[args](lab, *operands, vin[0].dtype, lang.types[vin[0].dtype]))
|
||||
else:
|
||||
kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], dtype, lang.types[dtype]))
|
||||
elif uop == UOps.DEFINE_ACC: kk(f"mov.b{lang.types[dtype][1:]} {ssa(u, 'acc')}, {const(args, dtype)};")
|
||||
kk(lang.asm_for_op[args](lab, *operands, dtype, lang.types[dtype]))
|
||||
if needs_upcast:
|
||||
kk(*lang.render_cast(out_lab, lab, dtypes.half, dtype))
|
||||
elif uop == UOps.DEFINE_ACC:
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa(None, 'acc', lang.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
for uu in r[u]:
|
||||
kk(f"mov.b{lang.types[dtype.scalar()][1:]} {uu}, {const(args, dtype.scalar())};")
|
||||
else:
|
||||
kk(f"mov.b{lang.types[dtype][1:]} {ssa(u, 'acc')}, {const(args, dtype)};")
|
||||
elif uop == UOps.SPECIAL:
|
||||
assert args[1][0] != "i", "idx not supported"
|
||||
kk(f"mov.u32 %{args[1]}, {(lang.gid if args[1][0] == 'g' else lang.lid)[args[0]]};")
|
||||
r[u] = "%" + args[1]
|
||||
kernel = [f".reg .u32 %{args[1]};"] + kernel
|
||||
elif uop == UOps.CONST: r[u] = const(args, dtype, mov=True)
|
||||
elif uop == UOps.CONST:
|
||||
if dtype.count > 1:
|
||||
r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
|
||||
else:
|
||||
r[u] = const(args, dtype, mov=True)
|
||||
elif uop == UOps.GEP:
|
||||
r[u] = r[vin[0]][u.arg]
|
||||
elif uop == UOps.LOAD:
|
||||
assert vin[1].dtype is not None
|
||||
val = ssa(u, 'val')
|
||||
if len(vin) > 3:
|
||||
assert vin[2].dtype is not None
|
||||
pred = cast(r[vin[2]], dtypes.bool, vin[2].dtype, pred=True)
|
||||
off = cast(r[vin[1]], dtypes.uint, vin[1].dtype)
|
||||
kk(*lang.render_gep(loc:=ssa(None,'loc',lang.types[dtypes.ulong]), r[vin[0]], off if len(vin)>3 else cast(r[vin[1]],
|
||||
dtypes.uint, vin[1].dtype), dtype),
|
||||
*lang.render_load(loc, val, dtype, gate=pred if len(vin) > 3 else None,
|
||||
alt=r[vin[3]] if len(vin) > 3 else None, ss='.shared' if vin[0].uop == UOps.DEFINE_LOCAL else ''))
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa(None, 'val', lang.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
if(len(vin)>3):
|
||||
for v in r[u]: kk(f"mov.{lang.mem_type(dtype.scalar())} {v}, {render_val(0, dtype.scalar())};")
|
||||
kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
|
||||
+ f" ld{u.arg}.v{dtype.count}.{lang.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
|
||||
else:
|
||||
kk(*lang.render_load(r[vin[0]], ssa(u, 'val'), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
|
||||
alt=r[vin[3]] if len(vin) > 3 else None, ss=u.arg, offset=vin[1].arg))
|
||||
elif uop == UOps.PHI:
|
||||
kk(f"mov.b{lang.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
|
||||
r[u] = r[vin[0]]
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
assert vin[0].dtype is not None
|
||||
cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
|
||||
if dtype.count>1:
|
||||
r[u] = [r[x] for x in vin] # type: ignore
|
||||
else:
|
||||
cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
# TODO: we should sum these, and fetch 0xC000 from somewhere
|
||||
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
|
||||
@@ -172,7 +231,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
return lang.render_kernel(kernel, function_name, bufs, c.items())
|
||||
|
||||
class PTXLanguage(AssemblyLanguage):
|
||||
kernel_prefix = """.version 7.8
|
||||
kernel_prefix = """.version 7.5
|
||||
.target TARGET
|
||||
.address_size 64
|
||||
.visible .entry"""
|
||||
@@ -210,10 +269,8 @@ class PTXLanguage(AssemblyLanguage):
|
||||
const_requires_mov = [dtypes.half, dtypes.bool]
|
||||
|
||||
def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]:
|
||||
if dtypes.is_float(dtype): val = f"0f{float_to_hex(x)}" if dtype != dtypes.float64 else f"0d{double_to_hex(x)}"
|
||||
else: val = str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
|
||||
val = render_val(x, dtype)
|
||||
if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"]
|
||||
if dtype == dtypes.half: return [f".reg .f32 {mov}_tmp;", f"mov.f32 {mov}_tmp, {val};", f"cvt.rn.f16.f32 {mov}, {mov}_tmp;"]
|
||||
return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val
|
||||
|
||||
def render_local(self, dest, name, size, dtype) -> List[str]:
|
||||
@@ -223,31 +280,25 @@ class PTXLanguage(AssemblyLanguage):
|
||||
|
||||
def render_bra(self, b1, pred=None, b2=None) -> List[str]: return [f"@{pred} bra {b1};", f"@!{pred} bra {b2};"] if pred else [f"bra {b1};"]
|
||||
|
||||
def render_gep(self, loc, base, offset, dtype, gate=None) -> List[str]:
|
||||
# this cast is only required because of ocelot
|
||||
if "s32" in offset:
|
||||
return [f".reg .u32 {offset}_cast;", f"cvt.u32.s32 {offset}_cast, {offset};", f"mad.wide.u32 {loc}, {offset}_cast, {dtype.itemsize}, {base};"]
|
||||
else: return [f"mad.wide.u32 {loc}, {offset}, {dtype.itemsize}, {base};"]
|
||||
|
||||
def mem_type(self, dtype): return 's8' if dtype.itemsize == 1 else 'b16' if dtype == dtypes.float16 else self.types[dtype]
|
||||
|
||||
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="") -> List[str]:
|
||||
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]:
|
||||
assert dtype is not dtypes.bool
|
||||
ret = []
|
||||
if gate: ret.extend([f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}];",
|
||||
if gate: ret.extend([f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];",
|
||||
f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"])
|
||||
else: ret.append(f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}];")
|
||||
else: ret.append(f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];")
|
||||
return ret
|
||||
|
||||
def render_store(self, loc, val, dtype, gate=None, ss="") -> List[str]:
|
||||
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]:
|
||||
if dtype == dtypes.bool: return [f".reg .s16 {val}_cast;", *self.render_cast(f"{val}_cast", val, dtypes.int16, dtype),
|
||||
(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}], {val}_cast;"]
|
||||
return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}], {val};"]
|
||||
(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}+{offset}], {val}_cast;"]
|
||||
return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}+{offset}], {val};"]
|
||||
|
||||
def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]:
|
||||
if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"]
|
||||
if atype == dtypes.bool: return [f"selp.{self.types[dtype]} {d}, {self.render_const(1, dtype)}, {self.render_const(0, dtype)}, {a};"]
|
||||
if dtype == dtypes.bool: return [f"setp.ne.{self.types[atype]} {d}, {a}, {self.render_const(0, atype)};"]
|
||||
if atype == dtypes.bool: return[f"selp.b{self.types[dtype][1:]} {d}, {render_val(1, dtype)}, {render_val(0, dtype)}, {a};"]
|
||||
if dtype == dtypes.bool: return [f"setp.ne.b{self.types[atype][1:]} {d}, {a}, {self.render_const(0, atype)};"]
|
||||
rnd = ('.rzi' if dtypes.is_int(dtype) and dtypes.is_float(atype) else
|
||||
'.rn' if dtypes.is_float(dtype) and (dtype.itemsize < atype.itemsize or dtypes.is_int(atype) or atype == dtypes.bool) else '')
|
||||
return [f"cvt{rnd}.{self.types[dtype]}.{self.types[atype]} {d}, {a};"]
|
||||
|
||||
@@ -52,8 +52,7 @@ def _get_bytes(arg, get_str, get_sz, check) -> bytes:
|
||||
return ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value)
|
||||
|
||||
class PTXCompiler(Compiler):
|
||||
linearizer_opts = LinearizerOptions("CUDA", suffix="PTX", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024],
|
||||
supports_float4=False, shared_max=49152)
|
||||
linearizer_opts = LinearizerOptions("CUDA", suffix="PTX", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], shared_max=49152)
|
||||
def __init__(self, arch:str):
|
||||
self.arch = arch
|
||||
PTXCompiler.linearizer_opts = PTXCompiler.linearizer_opts._replace(has_tensor_cores=int(arch[3:]) >= 80)
|
||||
|
||||
Reference in New Issue
Block a user