diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 0c78a226f3..46e77e74e0 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -2,10 +2,10 @@ from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable, Tup import struct from collections import defaultdict from tinygrad.ops import Ops, UOp, PatternMatcher, UPat, GroupOp -from tinygrad.dtype import dtypes, DType, PtrDType, ConstType +from tinygrad.dtype import dtypes, DType, PtrDType from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import CUDARenderer -from tinygrad.helpers import prod +from tinygrad.helpers import prod, flatten def render_val(x, dtype): if dtypes.is_float(dtype): @@ -54,6 +54,84 @@ ptx_matcher = PatternMatcher([ (UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None), ]) +def mem_type(x: UOp): return 'shared' if x.src[0].op is Ops.DEFINE_LOCAL or any(_x.op is Ops.DEFINE_LOCAL for _x in x.src[0].parents) else 'global' + +def render_store(ctx: "PTXRenderer", x: UOp, bidx: UOp, var: UOp, pred: Optional[UOp]=None): + gate = f"@{ctx.r[pred]} " if pred is not None and pred.op is not Ops.IF else "" + return [f"{gate}st.{mem_type(bidx)}.v{var.dtype.count}.{ctx.mem_types[var.dtype.scalar()]} [{ctx.r[bidx]}+0], {{{', '.join(ctx.r[var])}}};"] \ + if var.dtype.count > 1 else [f"{gate}st.{mem_type(bidx)}.{ctx.mem_types[var.dtype]} [{ctx.r[bidx]}+0], {ctx.r[var]};"] + +def render_wmma(ctx: "PTXRenderer", x: UOp): + assert ctx.wmma_r, "registry values for wmma must be populated" + _, (N, M, K), dtype_in, _, _, _, upcast_axes, _ = x.arg + n_operands = tuple(prod(sz for _, sz in upc)*dtype_in.itemsize//4 for upc in upcast_axes[:2]) + dt_map = { dtypes.half: "f16" } + _i = 0 + for vv in x.src[:2]: + for i in range(0, len(ctx.r[vv]), 2): + yield f"mov.b32 {ctx.wmma_r[_i]}, {{{', '.join(ctx.r[vv][i:i+2])}}};" + _i += 1 + yield f'mma.sync.aligned.m{M}n{N}k{K}.row.col.f32.{dt_map[dtype_in]}.{dt_map[dtype_in]}.f32{" "*12}' +\ + f'{{{", ".join(ctx.r[x])}}}, {{{", ".join(ctx.wmma_r[:n_operands[0]])}}}, {{{", ".join(ctx.wmma_r[-n_operands[1]:])}}}, ' + \ + f'{{{", ".join(ctx.r[x.src[2]])}}};' + +def modifier(a: DType, b: DType): return '.rzi' if dtypes.is_int(a) and dtypes.is_float(b) else '.rn' if dtypes.is_float(a) and \ + (a.itemsize < b.itemsize or dtypes.is_int(b) or b == dtypes.bool) else '' + +string_rewrite = PatternMatcher([ + (UPat(Ops.CONST, name="x", dtype=dtypes.bool), lambda ctx, x: f"setp.ne.s16 {ctx.r[x]}, {render_val(x.arg, x.dtype)}, 0;"), + (UPat(Ops.CONST, name="x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(x.arg, x.dtype)};"), + (UPat(Ops.STORE, name="x", src=(UPat.var('bidx'), UPat.var("var"), UPat.var("pred")), allow_any_len=True), render_store), + (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"mov.u32 %{x.arg[0]}, %{'ctaid' if x.arg[0][0] == 'g' else 'tid'}.{chr(120+int(x.arg[0][-1]))};"), + (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx, x: f"ld.param.{ctx.types[dtypes.ulong]} {ctx.r[x]}, [data{x.arg}+0];"), + (UPat((Ops.CMPLT, Ops.CMPNE), name="x"), + lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.src[0].dtype, ctx.types[x.src[0].dtype])), + (UPat(GroupOp.ALU, name="x"), lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.dtype, ctx.types[x.dtype])), + (UPat(Ops.BITCAST, name="x", src=(UPat.var("a")), allow_any_len=True), lambda ctx, x, a: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {ctx.r[a]};"), + (UPat(Ops.CAST, name="x", src=(UPat(dtype=dtypes.bool, name="a"))), + lambda ctx, x, a: f"selp.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(1, x.dtype)}, {render_val(0, x.dtype)}, {ctx.r[a]};"), + (UPat(Ops.CAST, name="x", dtype=dtypes.bool), + lambda ctx, x: f"setp.ne.b{ctx.types[x.src[0].dtype][1:]} {ctx.r[x]}, {ctx.r[x.src[0]]}, {render_val(0, x.src[0].dtype)};"), + (UPat(Ops.CAST, name="x", src=(UPat.var("a"))), + lambda ctx, x, a: f"cvt{modifier(x.dtype, a.dtype)}.{ctx.types[x.dtype]}.{ctx.types[x.src[0].dtype]} {ctx.r[x]}, {ctx.r[x.src[0]]};"), + (UPat(Ops.LOAD, name="x", src=(UPat.var('loc'), UPat(name='alt'), UPat(name="gate", op=GroupOp.ALU))), lambda ctx, x, loc, alt, gate: flatten([ + [f"mov.{ctx.mem_types[x.dtype.scalar()]} {v}, {render_val(0, x.dtype.scalar())};" for v in ctx.r[x]], + [f"@{ctx.r[gate]} ld.{mem_type(x)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"] + ]) if alt.dtype.count > 1 else [ + f"@{ctx.r[gate]} ld.{mem_type(x)}.{ctx.mem_types[x.dtype.scalar()]} {ctx.r[x]}, [{ctx.r[loc]}+0];", + f"@!{ctx.r[gate]} mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[x]}, {ctx.r[alt]};"]), + (UPat(Ops.LOAD, name="x", src=(UPat.var('loc'),), allow_any_len=True), + lambda ctx, x, loc: f" ld.{mem_type(x)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \ + if x.dtype.count > 1 else f"ld.{mem_type(x)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"), + (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE, dtype=dtypes.bool),), allow_any_len=True), + lambda ctx, x, pred: flatten([ + [f"setp.ne.s16 {ctx.r[pred][i]}, {render_val(pred.src[0].arg, x.dtype.scalar())}, 0;", + f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {ctx.r[pred][i]};"] for i, uu in enumerate(ctx.r[x])])), + (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE, dtype=dtypes.half),), allow_any_len=True), + lambda ctx, x, pred: flatten([[f"mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[pred][i]}, {render_val(pred.src[0].arg, x.dtype.scalar())};", + f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {ctx.r[pred][i]};"] for i, uu in enumerate(ctx.r[x])])), + (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE),), allow_any_len=True), lambda ctx, x, pred: [ + f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {render_val(pred.src[0].arg, x.dtype.scalar())};" for i, uu in enumerate(ctx.r[x])]), + (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.CONST, dtype=dtypes.bool), ), allow_any_len=True), lambda ctx, x, pred: [ + f"setp.ne.s16 {ctx.r[pred]}, {render_val(pred.arg, pred.dtype)}, 0;", f"mov.pred {ctx.r[x]}, {ctx.r[pred]};"]), + (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.CONST), ), allow_any_len=True), + lambda ctx, x, pred: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(pred.arg, x.dtype)};"), + (UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, {ctx.r[x.src[0]]};", "LOOP_" + f"{ctx.r[x][1:]}:"]), + (UPat(Ops.ASSIGN, name="x", dtype=dtypes.bool), lambda ctx, x: [f"mov.pred {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"]), + (UPat(Ops.ASSIGN, name="x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"), + (UPat(Ops.ENDRANGE, name="x", src=(UPat.var("src0"),)), lambda ctx, x, src0: [ + ctx.code_for_op[Ops.ADD](ctx.r[src0], ctx.r[src0], "1", dtypes.int, ctx.types[dtypes.int]), + ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[1]], dtypes.int, ctx.types[dtypes.int]), + f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]), + (UPat(Ops.DEFINE_LOCAL, name="x"), + lambda ctx, x: [f".shared .align 4 .b8 {x.arg[0]}[{x.arg[1]*x.dtype.itemsize}];", f"mov.u64 {ctx.r[x]}, {x.arg[0]}[0];"]), + (UPat(Ops.IF, name="x"), lambda ctx, x: f"@!{ctx.r[x.src[0]]} bra IF_{ctx.r[x.src[0]][1:]}_{ctx.uops.index(x)};"), + (UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"), + (UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))), + (UPat(Ops.BARRIER, name="x"), lambda ctx, x: ctx.barrier), + (UPat(Ops.DEFINE_VAR, name="x"), lambda ctx, x: f"ld.param.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{x.arg[0]}+0];"), +]) + class PTXRenderer(Renderer): device = "CUDA" suffix = "PTX" @@ -80,34 +158,6 @@ class PTXRenderer(Renderer): mem_types: Dict[DType, str] = types.copy() mem_types.update({dtypes.int8: "s8", dtypes.uint8: "u8", dtypes.bool: "u8", dtypes.float16: "b16"}) - const_requires_mov: List[DType] = [dtypes.half, dtypes.bool] - - def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]: - val = render_val(x, dtype) - if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"] - return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val - - def render_local(self, dest, name, size, dtype) -> List[str]: - return [f".shared .align 4 .b8 {name}[{size*dtype.itemsize}];", f"mov.u64 {dest}, {name}[0];"] - - def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"] - - def render_bra(self, b1, pred=None, invert=False) -> List[str]: - return [f"@{'!' if invert else ''}{pred} bra {b1};"] if pred else [f"bra {b1};"] - - def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]: - assert dtype != dtypes.bool - if gate: return [f"@{gate} ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"] - return [f"ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];"] - - def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]: - if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"] - if atype == dtypes.bool: return [f"selp.b{self.types[dtype][1:]} {d}, {render_val(1, dtype)}, {render_val(0, dtype)}, {a};"] - if dtype == dtypes.bool: return [f"setp.ne.b{self.types[atype][1:]} {d}, {a}, {self.render_const(0, atype)};"] - rnd = ('.rzi' if dtypes.is_int(dtype) and dtypes.is_float(atype) else - '.rn' if dtypes.is_float(dtype) and (dtype.itemsize < atype.itemsize or dtypes.is_int(atype) or atype == dtypes.bool) else '') - return [f"cvt{rnd}.{self.types[dtype]}.{self.types[atype]} {d}, {a};"] - def render_kernel(self, kernel, function_name, bufs, regs) -> str: kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"] def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1) @@ -120,116 +170,56 @@ class PTXRenderer(Renderer): kernel:List[str] = [] bufs = [] - def kk(*s: str): kernel.append("\n".join(s)) - c: DefaultDict[str, int] = defaultdict(int) r: Dict[UOp, Union[List[str], str]] = {} + self.r = r + self.uops = uops + def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str: nonlocal c, r prefix += f"_{dtype if dtype is not None else self.types[cast(UOp, u).dtype]}_" c[prefix] += 1 - if u is not None: r[u] = f"%{prefix}{c[prefix]-1}" return f"%{prefix}{c[prefix]-1}" - def const(x:ConstType, dtype:DType, mov=False): - if mov or dtype in self.const_requires_mov: - kk(*self.render_const(x, dtype, mov=(out:=ssa('const', dtype=self.types[dtype])))) - return out - return self.render_const(x, dtype) - - def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False): - if atype == dtype or isinstance(atype, PtrDType): - if u is not None: r[u] = a - return a - kk(*self.render_cast((ret:=ssa('cast', u, self.types[dtype])), a, dtype, atype, bitcast)) - return ret - for u in uops: - uop,dtype,src,args = u.op,u.dtype,u.src,u.arg - if uop is Ops.IF: - pred_reg = _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True) - kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{uops.index(u)}", pred_reg, invert=True)) - elif uop is Ops.BARRIER and self.barrier: kk(self.barrier) - elif uop is Ops.ENDRANGE: - kk(self.code_for_op[Ops.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]), - self.code_for_op[Ops.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int])) - kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred)) - elif uop is Ops.ENDIF: - kk(f"IF_{r[src[0].src[0]][1:]}_{uops.index(src[0])}:") - elif uop is Ops.STORE: - assert src[0].dtype == dtypes.int64, "store isn't int64" - mem_type = '.shared' if src[0].op is Ops.DEFINE_LOCAL or any(x.op is Ops.DEFINE_LOCAL for x in src[0].parents) else '.global' - gate = f"@{r[src[2]]} " if len(src)>2 and src[2].op is not Ops.IF else "" - if src[1].dtype.count > 1: - kk(gate + f"st{mem_type}.v{src[1].dtype.count}.{self.mem_types[src[1].dtype.scalar()]} [{r[src[0]]}+0], {{{', '.join(r[src[1]])}}};") - else: - kk(gate + f"st{mem_type}.{self.mem_types[src[1].dtype]} [{r[src[0]]}+0], {r[src[1]]};") - else: - if uop is Ops.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:])) - elif uop in GroupOp.ALU: - src_dtype = src[0].dtype if uop in {Ops.CMPLT, Ops.CMPNE} else dtype - kk(self.code_for_op[uop](ssa("alu", u), *[r[x] for x in src], src_dtype, self.types[src_dtype])) - elif uop is Ops.DEFINE_ACC: - if dtype.count > 1: - r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)] - for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].src[0].arg, dtype.scalar())};") - else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {ssa('acc', u)}, {const(src[0].arg, dtype)};") - elif uop is Ops.SPECIAL: - assert args[0][0] != "i", "idx not supported" - kk(f"mov.u32 %{args[0]}, %{'ctaid' if args[0][0] == 'g' else 'tid'}.{chr(120+int(args[0][-1]))};") - r[u] = "%" + args[0] - kernel = [f".reg .u32 %{args[0]};"] + kernel - elif uop is Ops.DEFINE_VAR: - bufs.append((args[0], dtype)) - r[u] = f"%{args[0]}" - kk(*self.render_load(args[0], ssa('dat', u, self.types[dtype]), dtype, ss=".param")) - elif uop is Ops.CONST: r[u] = const(args, dtype, mov=True) - elif uop is Ops.GEP: - assert len(u.arg) == 1 - r[u] = r[src[0]][u.arg[0]] - elif uop is Ops.LOAD: - assert src[0].dtype == dtypes.int64, "load isn't int64" - mem_type = '.shared' if src[0].op is Ops.DEFINE_LOCAL or any(x.op is Ops.DEFINE_LOCAL for x in src[0].parents) else '.global' - has_gate = len(src) > 2 and src[2].op in GroupOp.ALU - if dtype.count > 1: - r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)] - if has_gate: - for v in r[u]: kk(f"mov.{self.mem_types[dtype.scalar()]} {v}, {render_val(0, dtype.scalar())};") - kk((f"@{r[src[2]]}" if has_gate else "") - + f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[src[0]]}+0];") - else: - kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if has_gate else None, - alt=r[src[1]] if has_gate else None, ss=mem_type, offset=0)) - elif uop is Ops.ASSIGN: - if dtype.count > 1: - for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};") - else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {r[src[0]]}, {r[src[1]]};") - r[u] = r[src[0]] - # NOTE: casting to str is fine because you can't vectorize a vectorize - elif uop is Ops.VECTORIZE: r[u] = [cast(str,r[x]) for x in src] - elif uop in {Ops.CAST, Ops.BITCAST}: - _cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is Ops.BITCAST, u=u) - elif uop is Ops.DEFINE_LOCAL: - # TODO: we should sum these, and fetch 0xC000 from somewhere - assert args[1]*dtype.itemsize <= 0xC000, "too large local" - kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype)) - elif uop is Ops.DEFINE_GLOBAL: - bufs.append((nm:=f"data{args}", dtype)) - r[u] = f"%{nm}" - dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype - kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param")) - elif uop is Ops.WMMA: - _, (N, M, K), dtype_in, _, _, _, upcast_axes, _ = args - wmma, n_operands = [], tuple(prod(sz for _, sz in upc)*dtype_in.itemsize//4 for upc in upcast_axes[:2]) - dt_map = { dtypes.half: "f16" } - for vv in src[:2]: - for i in range(0, len(r[vv]), 2): - wmma.append(ssa("wmma", dtype="b32")) - kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};') - r[u] = [ssa("wmma", dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)] - kk(f'mma.sync.aligned.m{M}n{N}k{K}.row.col.f32.{dt_map[dtype_in]}.{dt_map[dtype_in]}.f32\ - {{{", ".join(r[u])}}}, {{{", ".join(wmma[:n_operands[0]])}}}, {{{", ".join(wmma[-n_operands[1]:])}}}, {{{", ".join(r[src[2]])}}};') - else: raise NotImplementedError(f"no code for {uop}") + if u.op is Ops.VECTORIZE: + r[u] = [cast(str,r[x]) for x in u.src] + continue + if u.op is Ops.GEP: + assert len(u.arg) == 1 + r[u] = r[u.src[0]][u.arg[0]] + continue + if u.op in {Ops.CAST, Ops.BITCAST}: + if u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType): + r[u] = r[u.src[0]] + continue + r[u] = ssa('cast', u, self.types[u.dtype]) + elif u.op is Ops.ENDRANGE: r[u] = ssa("pred", u, dtype="pred") + elif u.op is Ops.RANGE: r[u] = ssa("ridx", u) + elif u.op in GroupOp.ALU: r[u] = ssa("alu", u) + elif u.op is Ops.DEFINE_ACC: + if u.dtype.scalar() in [dtypes.half, dtypes.bool]: + r[u.src[0]] = [ssa("const", u.src[0].src[0]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("const", u.src[0]) + r[u] = [ssa('acc', u, dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("acc", u) + elif u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0] + elif u.op is Ops.DEFINE_VAR: + bufs.append((u.arg[0], u.dtype)) + r[u] = ssa("dat", u, self.types[u.dtype]) + elif u.op is Ops.CONST: r[u] = ssa("const", u, dtype=self.types[u.dtype]) + elif u.op is Ops.LOAD: + assert u.src[0].dtype == dtypes.int64, "load isn't int64" + r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u) + elif u.op is Ops.DEFINE_LOCAL: r[u] = ssa('local', u, self.types[dtypes.ulong]) + elif u.op is Ops.DEFINE_GLOBAL: + bufs.append((f"data{u.arg}", u.dtype)) + r[u] = ssa('dat', u, self.types[dtypes.ulong if u.dtype.__class__ == PtrDType else u.dtype]) + elif u.op is Ops.WMMA: + self.wmma_r = [ssa("wmma", dtype="b32") for vv in u.src[:2] for i in range(0, len(r[vv]), 2)] + r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] + if (l:=cast(Union[str, List[str]], string_rewrite.rewrite(u, ctx=self))) is None: + raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.u.src]}") + kernel.extend([l] if isinstance(l, str) else l) + if u.op is Ops.ASSIGN: r[u] = r[u.src[0]] + elif u.op is Ops.SPECIAL: kernel = [f".reg .u32 %{u.arg[0]};"] + kernel return self.render_kernel(kernel, name, bufs, c.items()) -