mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
use is to compare with enum (#3993)
* use is to compare with enum currently it's mixed between `==` and `is`, moved all to `is` * more
This commit is contained in:
@@ -66,10 +66,10 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
|
||||
def ptr_ar(root):
|
||||
assert root.arg in {'.shared', '.global', None}
|
||||
if root.arg is None: root.arg = '.shared' if root.vin[0].uop == UOps.DEFINE_LOCAL else '.global' # move this to the argL
|
||||
if root.arg is None: root.arg = '.shared' if root.vin[0].uop is UOps.DEFINE_LOCAL else '.global' # move this to the argL
|
||||
val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=root.vin[0].dtype.itemsize, insert_before=uops.uops.index(root))
|
||||
ptr = uops.add(UOps.ALU, dtypes.int, (root.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root))
|
||||
if ptr.uop == UOps.CONST: root.vin = (root.vin[0], ptr) + root.vin[2:]
|
||||
if ptr.uop is UOps.CONST: root.vin = (root.vin[0], ptr) + root.vin[2:]
|
||||
else:
|
||||
zero = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0, cachable=False, insert_before=uops.uops.index(root))
|
||||
bptr = uops.add(UOps.CAST, dtypes.uint64, (ptr,), insert_before=uops.uops.index(root))
|
||||
@@ -133,17 +133,17 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
if uop == UOps.IF:
|
||||
if uop is UOps.IF:
|
||||
assert vin[0].dtype is not None
|
||||
kk(*lang.render_bra(lb:=ssa_label(u, 'if'), cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
|
||||
elif uop == UOps.BARRIER and lang.barrier: kk(lang.barrier)
|
||||
elif uop == UOps.ENDLOOP:
|
||||
elif uop is UOps.BARRIER and lang.barrier: kk(lang.barrier)
|
||||
elif uop is UOps.ENDLOOP:
|
||||
kk(lang.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, lang.types[dtypes.int]),
|
||||
lang.asm_for_op[BinaryOps.CMPLT](pred:=ssa(None, "pred", "pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, lang.types[dtypes.int]))
|
||||
kk(*lang.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
|
||||
elif uop == UOps.ENDIF:
|
||||
elif uop is UOps.ENDIF:
|
||||
kk(f"{r_label[vin[0]]}:")
|
||||
elif uop == UOps.STORE:
|
||||
elif uop is UOps.STORE:
|
||||
assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None
|
||||
if vin[2].dtype.count > 1:
|
||||
kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \
|
||||
@@ -152,8 +152,8 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
kk(*lang.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=u.arg, offset=vin[1].arg))
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop == UOps.LOOP: kk(*lang.render_loop(ssa(u, 'ridx'), r[vin[0]], ssa_label(u, 'loop')))
|
||||
elif uop == UOps.ALU:
|
||||
if uop is UOps.LOOP: kk(*lang.render_loop(ssa(u, 'ridx'), r[vin[0]], ssa_label(u, 'loop')))
|
||||
elif uop is UOps.ALU:
|
||||
assert vin[0].dtype is not None
|
||||
operands = [r[x] for x in vin]
|
||||
lab = ssa(u, "alu")
|
||||
@@ -163,28 +163,28 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
for i, op in enumerate(operands):
|
||||
operands[i] = ssa(None, "alu_cast", lang.types[dtype])
|
||||
kk(*lang.render_cast(operands[i], op, dtype, dtypes.half)) # type: ignore
|
||||
if args == BinaryOps.CMPLT or args == BinaryOps.CMPEQ:
|
||||
if args is BinaryOps.CMPLT or args is BinaryOps.CMPEQ:
|
||||
# pass in the other dtype here
|
||||
kk(lang.asm_for_op[args](lab, *operands, vin[0].dtype, lang.types[vin[0].dtype]))
|
||||
else:
|
||||
kk(lang.asm_for_op[args](lab, *operands, dtype, lang.types[dtype]))
|
||||
if needs_upcast:
|
||||
kk(*lang.render_cast(out_lab, lab, dtypes.half, dtype))
|
||||
elif uop == UOps.DEFINE_ACC:
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa(None, 'acc', lang.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
for uu in r[u]: kk(f"mov.b{lang.types[dtype.scalar()][1:]} {uu}, {const(args, dtype.scalar())};")
|
||||
else: kk(f"mov.b{lang.types[dtype][1:]} {ssa(u, 'acc')}, {const(args, dtype)};")
|
||||
elif uop == UOps.SPECIAL:
|
||||
elif uop is UOps.SPECIAL:
|
||||
assert args[1][0] != "i", "idx not supported"
|
||||
kk(f"mov.u32 %{args[1]}, {(lang.gid if args[1][0] == 'g' else lang.lid)[args[0]]};")
|
||||
r[u] = "%" + args[1]
|
||||
kernel = [f".reg .u32 %{args[1]};"] + kernel
|
||||
elif uop == UOps.CONST:
|
||||
elif uop is UOps.CONST:
|
||||
if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
|
||||
else: r[u] = const(args, dtype, mov=True)
|
||||
elif uop == UOps.GEP: r[u] = r[vin[0]][u.arg]
|
||||
elif uop == UOps.LOAD:
|
||||
elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg]
|
||||
elif uop is UOps.LOAD:
|
||||
assert vin[1].dtype is not None
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa(None, 'val', lang.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
@@ -195,14 +195,14 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
else:
|
||||
kk(*lang.render_load(r[vin[0]], ssa(u, 'val'), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
|
||||
alt=r[vin[3]] if len(vin) > 3 else None, ss=u.arg, offset=vin[1].arg))
|
||||
elif uop == UOps.PHI:
|
||||
elif uop is UOps.PHI:
|
||||
kk(f"mov.b{lang.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
|
||||
r[u] = r[vin[0]]
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
assert vin[0].dtype is not None
|
||||
if dtype.count>1: r[u] = [r[x] for x in vin] # type: ignore
|
||||
else: cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
elif uop is UOps.DEFINE_LOCAL:
|
||||
# TODO: we should sum these, and fetch 0xC000 from somewhere
|
||||
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
|
||||
kk(*lang.render_local(ssa(u, 'local', lang.types[dtypes.ulong]), args[0], args[1], dtype))
|
||||
|
||||
@@ -131,7 +131,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
|
||||
val = lang.code_for_op[args](*operands, dtype)
|
||||
assert child_count[u] != 0, f"childless ALU op found {u}"
|
||||
# TODO: fix index rendering issue. fix clang nested max macro issue
|
||||
if child_count[u] <= 1 and args != BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
|
||||
if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
|
||||
else: kk(f"{lang.render_dtype(dtype)} {ssa(u,'alu')} = {val};")
|
||||
elif uop is UOps.SPECIAL:
|
||||
kk(f"int {args[1]} = {lang.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
|
||||
@@ -215,7 +215,7 @@ class MetalLanguage(CStyleLanguage):
|
||||
return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
||||
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.uop == UOps.WMMA])
|
||||
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.uop is UOps.WMMA])
|
||||
for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{
|
||||
simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x;
|
||||
b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
|
||||
@@ -257,7 +257,7 @@ class CUDALanguage(CStyleLanguage):
|
||||
prefix += ["#include <cuda_bf16.h>"] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x) for x in [4, 8]]
|
||||
|
||||
# TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing
|
||||
for arg in set([uop.arg for uop in uops if uop.uop == UOps.WMMA]):
|
||||
for arg in set([uop.arg for uop in uops if uop.uop is UOps.WMMA]):
|
||||
fn, ti, to, ci, co = arg[0], dt_map[arg[2]][0], dt_map[arg[3]][0], dt_map[arg[2]][1], dt_map[arg[3]][1]
|
||||
prefix.append(f"""__device__ {to}4 __{fn}({ti}8 a, {ti}4 b, {to}4 c) {{ int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
|
||||
asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}, {{ %4, %5, %6, %7 }}, {{ %8, %9 }}, {{ %0, %1, %2, %3 }};"
|
||||
@@ -338,7 +338,7 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
|
||||
|
||||
prefix += [_make_hip_dtype(*x) for x in vec_dts]
|
||||
|
||||
for arg in set([uop.arg for uop in uops if uop.uop == UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
|
||||
for arg in set([uop.arg for uop in uops if uop.uop is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
|
||||
if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")
|
||||
else: prefix.append(f"static __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
|
||||
half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
|
||||
@@ -347,7 +347,7 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
|
||||
def get_kernel_modifier(self, uops:UOpGraph) -> str:
|
||||
requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.uop == UOps.SPECIAL and u.arg[1][0] == "l")
|
||||
requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.uop is UOps.SPECIAL and u.arg[1][0] == "l")
|
||||
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
|
||||
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
|
||||
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
|
||||
|
||||
@@ -109,7 +109,7 @@ def uops_to_llvm_ir(function_name:str, uops:UOpGraph) -> str:
|
||||
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), loop_entry_bb, bb[-1].block)
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop == UOps.LOOP:
|
||||
if uop is UOps.LOOP:
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
|
||||
bb[-2].branch(bb[-1].block)
|
||||
|
||||
@@ -138,7 +138,7 @@ def uops_to_llvm_ir(function_name:str, uops:UOpGraph) -> str:
|
||||
lvars[u] = lvars[vin[1]]
|
||||
# PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
|
||||
backward = vin[0]
|
||||
while backward.uop == UOps.PHI: backward = backward.vin[0]
|
||||
while backward.uop is UOps.PHI: backward = backward.vin[0]
|
||||
lvars[backward] = lvars[u]
|
||||
elif uop is UOps.ALU:
|
||||
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else vin[0].dtype)
|
||||
|
||||
Reference in New Issue
Block a user