[pr] Have PTX share code with LLVM (#7635)

* integrate into ops_cuda

* remove debugging stuff

* lint fix

* mypy fixes

* swap ptx.py

* edit

* simplify wmma

* wip

* space

* refactor

* sync the ops removal changes

* refactor

* rename variables

---------

Co-authored-by: judy <mesozoic.egg@proton.mail>
This commit is contained in:
mesozoic-egg
2024-11-17 10:53:56 +08:00
committed by GitHub
parent f2f7384b67
commit 1a5e896bd4

View File

@@ -2,10 +2,10 @@ from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable, Tup
import struct
from collections import defaultdict
from tinygrad.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer
from tinygrad.helpers import prod
from tinygrad.helpers import prod, flatten
def render_val(x, dtype):
if dtypes.is_float(dtype):
@@ -54,6 +54,84 @@ ptx_matcher = PatternMatcher([
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
])
def mem_type(x: UOp): return 'shared' if x.src[0].op is Ops.DEFINE_LOCAL or any(_x.op is Ops.DEFINE_LOCAL for _x in x.src[0].parents) else 'global'
def render_store(ctx: "PTXRenderer", x: UOp, bidx: UOp, var: UOp, pred: Optional[UOp]=None):
gate = f"@{ctx.r[pred]} " if pred is not None and pred.op is not Ops.IF else ""
return [f"{gate}st.{mem_type(bidx)}.v{var.dtype.count}.{ctx.mem_types[var.dtype.scalar()]} [{ctx.r[bidx]}+0], {{{', '.join(ctx.r[var])}}};"] \
if var.dtype.count > 1 else [f"{gate}st.{mem_type(bidx)}.{ctx.mem_types[var.dtype]} [{ctx.r[bidx]}+0], {ctx.r[var]};"]
def render_wmma(ctx: "PTXRenderer", x: UOp):
assert ctx.wmma_r, "registry values for wmma must be populated"
_, (N, M, K), dtype_in, _, _, _, upcast_axes, _ = x.arg
n_operands = tuple(prod(sz for _, sz in upc)*dtype_in.itemsize//4 for upc in upcast_axes[:2])
dt_map = { dtypes.half: "f16" }
_i = 0
for vv in x.src[:2]:
for i in range(0, len(ctx.r[vv]), 2):
yield f"mov.b32 {ctx.wmma_r[_i]}, {{{', '.join(ctx.r[vv][i:i+2])}}};"
_i += 1
yield f'mma.sync.aligned.m{M}n{N}k{K}.row.col.f32.{dt_map[dtype_in]}.{dt_map[dtype_in]}.f32{" "*12}' +\
f'{{{", ".join(ctx.r[x])}}}, {{{", ".join(ctx.wmma_r[:n_operands[0]])}}}, {{{", ".join(ctx.wmma_r[-n_operands[1]:])}}}, ' + \
f'{{{", ".join(ctx.r[x.src[2]])}}};'
def modifier(a: DType, b: DType): return '.rzi' if dtypes.is_int(a) and dtypes.is_float(b) else '.rn' if dtypes.is_float(a) and \
(a.itemsize < b.itemsize or dtypes.is_int(b) or b == dtypes.bool) else ''
string_rewrite = PatternMatcher([
(UPat(Ops.CONST, name="x", dtype=dtypes.bool), lambda ctx, x: f"setp.ne.s16 {ctx.r[x]}, {render_val(x.arg, x.dtype)}, 0;"),
(UPat(Ops.CONST, name="x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(x.arg, x.dtype)};"),
(UPat(Ops.STORE, name="x", src=(UPat.var('bidx'), UPat.var("var"), UPat.var("pred")), allow_any_len=True), render_store),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"mov.u32 %{x.arg[0]}, %{'ctaid' if x.arg[0][0] == 'g' else 'tid'}.{chr(120+int(x.arg[0][-1]))};"),
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx, x: f"ld.param.{ctx.types[dtypes.ulong]} {ctx.r[x]}, [data{x.arg}+0];"),
(UPat((Ops.CMPLT, Ops.CMPNE), name="x"),
lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.src[0].dtype, ctx.types[x.src[0].dtype])),
(UPat(GroupOp.ALU, name="x"), lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.dtype, ctx.types[x.dtype])),
(UPat(Ops.BITCAST, name="x", src=(UPat.var("a")), allow_any_len=True), lambda ctx, x, a: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {ctx.r[a]};"),
(UPat(Ops.CAST, name="x", src=(UPat(dtype=dtypes.bool, name="a"))),
lambda ctx, x, a: f"selp.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(1, x.dtype)}, {render_val(0, x.dtype)}, {ctx.r[a]};"),
(UPat(Ops.CAST, name="x", dtype=dtypes.bool),
lambda ctx, x: f"setp.ne.b{ctx.types[x.src[0].dtype][1:]} {ctx.r[x]}, {ctx.r[x.src[0]]}, {render_val(0, x.src[0].dtype)};"),
(UPat(Ops.CAST, name="x", src=(UPat.var("a"))),
lambda ctx, x, a: f"cvt{modifier(x.dtype, a.dtype)}.{ctx.types[x.dtype]}.{ctx.types[x.src[0].dtype]} {ctx.r[x]}, {ctx.r[x.src[0]]};"),
(UPat(Ops.LOAD, name="x", src=(UPat.var('loc'), UPat(name='alt'), UPat(name="gate", op=GroupOp.ALU))), lambda ctx, x, loc, alt, gate: flatten([
[f"mov.{ctx.mem_types[x.dtype.scalar()]} {v}, {render_val(0, x.dtype.scalar())};" for v in ctx.r[x]],
[f"@{ctx.r[gate]} ld.{mem_type(x)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"]
]) if alt.dtype.count > 1 else [
f"@{ctx.r[gate]} ld.{mem_type(x)}.{ctx.mem_types[x.dtype.scalar()]} {ctx.r[x]}, [{ctx.r[loc]}+0];",
f"@!{ctx.r[gate]} mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[x]}, {ctx.r[alt]};"]),
(UPat(Ops.LOAD, name="x", src=(UPat.var('loc'),), allow_any_len=True),
lambda ctx, x, loc: f" ld.{mem_type(x)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
if x.dtype.count > 1 else f"ld.{mem_type(x)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE, dtype=dtypes.bool),), allow_any_len=True),
lambda ctx, x, pred: flatten([
[f"setp.ne.s16 {ctx.r[pred][i]}, {render_val(pred.src[0].arg, x.dtype.scalar())}, 0;",
f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {ctx.r[pred][i]};"] for i, uu in enumerate(ctx.r[x])])),
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE, dtype=dtypes.half),), allow_any_len=True),
lambda ctx, x, pred: flatten([[f"mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[pred][i]}, {render_val(pred.src[0].arg, x.dtype.scalar())};",
f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {ctx.r[pred][i]};"] for i, uu in enumerate(ctx.r[x])])),
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE),), allow_any_len=True), lambda ctx, x, pred: [
f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {render_val(pred.src[0].arg, x.dtype.scalar())};" for i, uu in enumerate(ctx.r[x])]),
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.CONST, dtype=dtypes.bool), ), allow_any_len=True), lambda ctx, x, pred: [
f"setp.ne.s16 {ctx.r[pred]}, {render_val(pred.arg, pred.dtype)}, 0;", f"mov.pred {ctx.r[x]}, {ctx.r[pred]};"]),
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.CONST), ), allow_any_len=True),
lambda ctx, x, pred: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(pred.arg, x.dtype)};"),
(UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, {ctx.r[x.src[0]]};", "LOOP_" + f"{ctx.r[x][1:]}:"]),
(UPat(Ops.ASSIGN, name="x", dtype=dtypes.bool), lambda ctx, x: [f"mov.pred {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"]),
(UPat(Ops.ASSIGN, name="x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"),
(UPat(Ops.ENDRANGE, name="x", src=(UPat.var("src0"),)), lambda ctx, x, src0: [
ctx.code_for_op[Ops.ADD](ctx.r[src0], ctx.r[src0], "1", dtypes.int, ctx.types[dtypes.int]),
ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[1]], dtypes.int, ctx.types[dtypes.int]),
f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]),
(UPat(Ops.DEFINE_LOCAL, name="x"),
lambda ctx, x: [f".shared .align 4 .b8 {x.arg[0]}[{x.arg[1]*x.dtype.itemsize}];", f"mov.u64 {ctx.r[x]}, {x.arg[0]}[0];"]),
(UPat(Ops.IF, name="x"), lambda ctx, x: f"@!{ctx.r[x.src[0]]} bra IF_{ctx.r[x.src[0]][1:]}_{ctx.uops.index(x)};"),
(UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"),
(UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))),
(UPat(Ops.BARRIER, name="x"), lambda ctx, x: ctx.barrier),
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx, x: f"ld.param.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{x.arg[0]}+0];"),
])
class PTXRenderer(Renderer):
device = "CUDA"
suffix = "PTX"
@@ -80,34 +158,6 @@ class PTXRenderer(Renderer):
mem_types: Dict[DType, str] = types.copy()
mem_types.update({dtypes.int8: "s8", dtypes.uint8: "u8", dtypes.bool: "u8", dtypes.float16: "b16"})
const_requires_mov: List[DType] = [dtypes.half, dtypes.bool]
def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]:
val = render_val(x, dtype)
if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"]
return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val
def render_local(self, dest, name, size, dtype) -> List[str]:
return [f".shared .align 4 .b8 {name}[{size*dtype.itemsize}];", f"mov.u64 {dest}, {name}[0];"]
def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"]
def render_bra(self, b1, pred=None, invert=False) -> List[str]:
return [f"@{'!' if invert else ''}{pred} bra {b1};"] if pred else [f"bra {b1};"]
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]:
assert dtype != dtypes.bool
if gate: return [f"@{gate} ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"]
return [f"ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];"]
def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]:
if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"]
if atype == dtypes.bool: return [f"selp.b{self.types[dtype][1:]} {d}, {render_val(1, dtype)}, {render_val(0, dtype)}, {a};"]
if dtype == dtypes.bool: return [f"setp.ne.b{self.types[atype][1:]} {d}, {a}, {self.render_const(0, atype)};"]
rnd = ('.rzi' if dtypes.is_int(dtype) and dtypes.is_float(atype) else
'.rn' if dtypes.is_float(dtype) and (dtype.itemsize < atype.itemsize or dtypes.is_int(atype) or atype == dtypes.bool) else '')
return [f"cvt{rnd}.{self.types[dtype]}.{self.types[atype]} {d}, {a};"]
def render_kernel(self, kernel, function_name, bufs, regs) -> str:
kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]
def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1)
@@ -120,116 +170,56 @@ class PTXRenderer(Renderer):
kernel:List[str] = []
bufs = []
def kk(*s: str): kernel.append("\n".join(s))
c: DefaultDict[str, int] = defaultdict(int)
r: Dict[UOp, Union[List[str], str]] = {}
self.r = r
self.uops = uops
def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
nonlocal c, r
prefix += f"_{dtype if dtype is not None else self.types[cast(UOp, u).dtype]}_"
c[prefix] += 1
if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
return f"%{prefix}{c[prefix]-1}"
def const(x:ConstType, dtype:DType, mov=False):
if mov or dtype in self.const_requires_mov:
kk(*self.render_const(x, dtype, mov=(out:=ssa('const', dtype=self.types[dtype]))))
return out
return self.render_const(x, dtype)
def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
if atype == dtype or isinstance(atype, PtrDType):
if u is not None: r[u] = a
return a
kk(*self.render_cast((ret:=ssa('cast', u, self.types[dtype])), a, dtype, atype, bitcast))
return ret
for u in uops:
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
if uop is Ops.IF:
pred_reg = _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True)
kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{uops.index(u)}", pred_reg, invert=True))
elif uop is Ops.BARRIER and self.barrier: kk(self.barrier)
elif uop is Ops.ENDRANGE:
kk(self.code_for_op[Ops.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]),
self.code_for_op[Ops.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int]))
kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred))
elif uop is Ops.ENDIF:
kk(f"IF_{r[src[0].src[0]][1:]}_{uops.index(src[0])}:")
elif uop is Ops.STORE:
assert src[0].dtype == dtypes.int64, "store isn't int64"
mem_type = '.shared' if src[0].op is Ops.DEFINE_LOCAL or any(x.op is Ops.DEFINE_LOCAL for x in src[0].parents) else '.global'
gate = f"@{r[src[2]]} " if len(src)>2 and src[2].op is not Ops.IF else ""
if src[1].dtype.count > 1:
kk(gate + f"st{mem_type}.v{src[1].dtype.count}.{self.mem_types[src[1].dtype.scalar()]} [{r[src[0]]}+0], {{{', '.join(r[src[1]])}}};")
else:
kk(gate + f"st{mem_type}.{self.mem_types[src[1].dtype]} [{r[src[0]]}+0], {r[src[1]]};")
else:
if uop is Ops.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
elif uop in GroupOp.ALU:
src_dtype = src[0].dtype if uop in {Ops.CMPLT, Ops.CMPNE} else dtype
kk(self.code_for_op[uop](ssa("alu", u), *[r[x] for x in src], src_dtype, self.types[src_dtype]))
elif uop is Ops.DEFINE_ACC:
if dtype.count > 1:
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].src[0].arg, dtype.scalar())};")
else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
elif uop is Ops.SPECIAL:
assert args[0][0] != "i", "idx not supported"
kk(f"mov.u32 %{args[0]}, %{'ctaid' if args[0][0] == 'g' else 'tid'}.{chr(120+int(args[0][-1]))};")
r[u] = "%" + args[0]
kernel = [f".reg .u32 %{args[0]};"] + kernel
elif uop is Ops.DEFINE_VAR:
bufs.append((args[0], dtype))
r[u] = f"%{args[0]}"
kk(*self.render_load(args[0], ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
elif uop is Ops.CONST: r[u] = const(args, dtype, mov=True)
elif uop is Ops.GEP:
assert len(u.arg) == 1
r[u] = r[src[0]][u.arg[0]]
elif uop is Ops.LOAD:
assert src[0].dtype == dtypes.int64, "load isn't int64"
mem_type = '.shared' if src[0].op is Ops.DEFINE_LOCAL or any(x.op is Ops.DEFINE_LOCAL for x in src[0].parents) else '.global'
has_gate = len(src) > 2 and src[2].op in GroupOp.ALU
if dtype.count > 1:
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
if has_gate:
for v in r[u]: kk(f"mov.{self.mem_types[dtype.scalar()]} {v}, {render_val(0, dtype.scalar())};")
kk((f"@{r[src[2]]}" if has_gate else "")
+ f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[src[0]]}+0];")
else:
kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if has_gate else None,
alt=r[src[1]] if has_gate else None, ss=mem_type, offset=0))
elif uop is Ops.ASSIGN:
if dtype.count > 1:
for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {r[src[0]]}, {r[src[1]]};")
r[u] = r[src[0]]
# NOTE: casting to str is fine because you can't vectorize a vectorize
elif uop is Ops.VECTORIZE: r[u] = [cast(str,r[x]) for x in src]
elif uop in {Ops.CAST, Ops.BITCAST}:
_cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is Ops.BITCAST, u=u)
elif uop is Ops.DEFINE_LOCAL:
# TODO: we should sum these, and fetch 0xC000 from somewhere
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype))
elif uop is Ops.DEFINE_GLOBAL:
bufs.append((nm:=f"data{args}", dtype))
r[u] = f"%{nm}"
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
elif uop is Ops.WMMA:
_, (N, M, K), dtype_in, _, _, _, upcast_axes, _ = args
wmma, n_operands = [], tuple(prod(sz for _, sz in upc)*dtype_in.itemsize//4 for upc in upcast_axes[:2])
dt_map = { dtypes.half: "f16" }
for vv in src[:2]:
for i in range(0, len(r[vv]), 2):
wmma.append(ssa("wmma", dtype="b32"))
kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
r[u] = [ssa("wmma", dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
kk(f'mma.sync.aligned.m{M}n{N}k{K}.row.col.f32.{dt_map[dtype_in]}.{dt_map[dtype_in]}.f32\
{{{", ".join(r[u])}}}, {{{", ".join(wmma[:n_operands[0]])}}}, {{{", ".join(wmma[-n_operands[1]:])}}}, {{{", ".join(r[src[2]])}}};')
else: raise NotImplementedError(f"no code for {uop}")
if u.op is Ops.VECTORIZE:
r[u] = [cast(str,r[x]) for x in u.src]
continue
if u.op is Ops.GEP:
assert len(u.arg) == 1
r[u] = r[u.src[0]][u.arg[0]]
continue
if u.op in {Ops.CAST, Ops.BITCAST}:
if u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType):
r[u] = r[u.src[0]]
continue
r[u] = ssa('cast', u, self.types[u.dtype])
elif u.op is Ops.ENDRANGE: r[u] = ssa("pred", u, dtype="pred")
elif u.op is Ops.RANGE: r[u] = ssa("ridx", u)
elif u.op in GroupOp.ALU: r[u] = ssa("alu", u)
elif u.op is Ops.DEFINE_ACC:
if u.dtype.scalar() in [dtypes.half, dtypes.bool]:
r[u.src[0]] = [ssa("const", u.src[0].src[0]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("const", u.src[0])
r[u] = [ssa('acc', u, dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("acc", u)
elif u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0]
elif u.op is Ops.DEFINE_VAR:
bufs.append((u.arg[0], u.dtype))
r[u] = ssa("dat", u, self.types[u.dtype])
elif u.op is Ops.CONST: r[u] = ssa("const", u, dtype=self.types[u.dtype])
elif u.op is Ops.LOAD:
assert u.src[0].dtype == dtypes.int64, "load isn't int64"
r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u)
elif u.op is Ops.DEFINE_LOCAL: r[u] = ssa('local', u, self.types[dtypes.ulong])
elif u.op is Ops.DEFINE_GLOBAL:
bufs.append((f"data{u.arg}", u.dtype))
r[u] = ssa('dat', u, self.types[dtypes.ulong if u.dtype.__class__ == PtrDType else u.dtype])
elif u.op is Ops.WMMA:
self.wmma_r = [ssa("wmma", dtype="b32") for vv in u.src[:2] for i in range(0, len(r[vv]), 2)]
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
if (l:=cast(Union[str, List[str]], string_rewrite.rewrite(u, ctx=self))) is None:
raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.u.src]}")
kernel.extend([l] if isinstance(l, str) else l)
if u.op is Ops.ASSIGN: r[u] = r[u.src[0]]
elif u.op is Ops.SPECIAL: kernel = [f".reg .u32 %{u.arg[0]};"] + kernel
return self.render_kernel(kernel, name, bufs, c.items())