mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
[pr] Have PTX share code with LLVM (#7635)
* integrate into ops_cuda * remove debugging stuff * lint fix * mypy fixes * swap ptx.py * edit * simplify wmma * wip * space * refactor * sync the ops removal changes * refactor * rename variables --------- Co-authored-by: judy <mesozoic.egg@proton.mail>
This commit is contained in:
@@ -2,10 +2,10 @@ from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable, Tup
|
||||
import struct
|
||||
from collections import defaultdict
|
||||
from tinygrad.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.helpers import prod, flatten
|
||||
|
||||
def render_val(x, dtype):
|
||||
if dtypes.is_float(dtype):
|
||||
@@ -54,6 +54,84 @@ ptx_matcher = PatternMatcher([
|
||||
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
|
||||
])
|
||||
|
||||
def mem_type(x: UOp): return 'shared' if x.src[0].op is Ops.DEFINE_LOCAL or any(_x.op is Ops.DEFINE_LOCAL for _x in x.src[0].parents) else 'global'
|
||||
|
||||
def render_store(ctx: "PTXRenderer", x: UOp, bidx: UOp, var: UOp, pred: Optional[UOp]=None):
|
||||
gate = f"@{ctx.r[pred]} " if pred is not None and pred.op is not Ops.IF else ""
|
||||
return [f"{gate}st.{mem_type(bidx)}.v{var.dtype.count}.{ctx.mem_types[var.dtype.scalar()]} [{ctx.r[bidx]}+0], {{{', '.join(ctx.r[var])}}};"] \
|
||||
if var.dtype.count > 1 else [f"{gate}st.{mem_type(bidx)}.{ctx.mem_types[var.dtype]} [{ctx.r[bidx]}+0], {ctx.r[var]};"]
|
||||
|
||||
def render_wmma(ctx: "PTXRenderer", x: UOp):
|
||||
assert ctx.wmma_r, "registry values for wmma must be populated"
|
||||
_, (N, M, K), dtype_in, _, _, _, upcast_axes, _ = x.arg
|
||||
n_operands = tuple(prod(sz for _, sz in upc)*dtype_in.itemsize//4 for upc in upcast_axes[:2])
|
||||
dt_map = { dtypes.half: "f16" }
|
||||
_i = 0
|
||||
for vv in x.src[:2]:
|
||||
for i in range(0, len(ctx.r[vv]), 2):
|
||||
yield f"mov.b32 {ctx.wmma_r[_i]}, {{{', '.join(ctx.r[vv][i:i+2])}}};"
|
||||
_i += 1
|
||||
yield f'mma.sync.aligned.m{M}n{N}k{K}.row.col.f32.{dt_map[dtype_in]}.{dt_map[dtype_in]}.f32{" "*12}' +\
|
||||
f'{{{", ".join(ctx.r[x])}}}, {{{", ".join(ctx.wmma_r[:n_operands[0]])}}}, {{{", ".join(ctx.wmma_r[-n_operands[1]:])}}}, ' + \
|
||||
f'{{{", ".join(ctx.r[x.src[2]])}}};'
|
||||
|
||||
def modifier(a: DType, b: DType): return '.rzi' if dtypes.is_int(a) and dtypes.is_float(b) else '.rn' if dtypes.is_float(a) and \
|
||||
(a.itemsize < b.itemsize or dtypes.is_int(b) or b == dtypes.bool) else ''
|
||||
|
||||
string_rewrite = PatternMatcher([
|
||||
(UPat(Ops.CONST, name="x", dtype=dtypes.bool), lambda ctx, x: f"setp.ne.s16 {ctx.r[x]}, {render_val(x.arg, x.dtype)}, 0;"),
|
||||
(UPat(Ops.CONST, name="x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(x.arg, x.dtype)};"),
|
||||
(UPat(Ops.STORE, name="x", src=(UPat.var('bidx'), UPat.var("var"), UPat.var("pred")), allow_any_len=True), render_store),
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"mov.u32 %{x.arg[0]}, %{'ctaid' if x.arg[0][0] == 'g' else 'tid'}.{chr(120+int(x.arg[0][-1]))};"),
|
||||
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx, x: f"ld.param.{ctx.types[dtypes.ulong]} {ctx.r[x]}, [data{x.arg}+0];"),
|
||||
(UPat((Ops.CMPLT, Ops.CMPNE), name="x"),
|
||||
lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.src[0].dtype, ctx.types[x.src[0].dtype])),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.dtype, ctx.types[x.dtype])),
|
||||
(UPat(Ops.BITCAST, name="x", src=(UPat.var("a")), allow_any_len=True), lambda ctx, x, a: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {ctx.r[a]};"),
|
||||
(UPat(Ops.CAST, name="x", src=(UPat(dtype=dtypes.bool, name="a"))),
|
||||
lambda ctx, x, a: f"selp.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(1, x.dtype)}, {render_val(0, x.dtype)}, {ctx.r[a]};"),
|
||||
(UPat(Ops.CAST, name="x", dtype=dtypes.bool),
|
||||
lambda ctx, x: f"setp.ne.b{ctx.types[x.src[0].dtype][1:]} {ctx.r[x]}, {ctx.r[x.src[0]]}, {render_val(0, x.src[0].dtype)};"),
|
||||
(UPat(Ops.CAST, name="x", src=(UPat.var("a"))),
|
||||
lambda ctx, x, a: f"cvt{modifier(x.dtype, a.dtype)}.{ctx.types[x.dtype]}.{ctx.types[x.src[0].dtype]} {ctx.r[x]}, {ctx.r[x.src[0]]};"),
|
||||
(UPat(Ops.LOAD, name="x", src=(UPat.var('loc'), UPat(name='alt'), UPat(name="gate", op=GroupOp.ALU))), lambda ctx, x, loc, alt, gate: flatten([
|
||||
[f"mov.{ctx.mem_types[x.dtype.scalar()]} {v}, {render_val(0, x.dtype.scalar())};" for v in ctx.r[x]],
|
||||
[f"@{ctx.r[gate]} ld.{mem_type(x)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"]
|
||||
]) if alt.dtype.count > 1 else [
|
||||
f"@{ctx.r[gate]} ld.{mem_type(x)}.{ctx.mem_types[x.dtype.scalar()]} {ctx.r[x]}, [{ctx.r[loc]}+0];",
|
||||
f"@!{ctx.r[gate]} mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[x]}, {ctx.r[alt]};"]),
|
||||
(UPat(Ops.LOAD, name="x", src=(UPat.var('loc'),), allow_any_len=True),
|
||||
lambda ctx, x, loc: f" ld.{mem_type(x)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
|
||||
if x.dtype.count > 1 else f"ld.{mem_type(x)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
|
||||
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE, dtype=dtypes.bool),), allow_any_len=True),
|
||||
lambda ctx, x, pred: flatten([
|
||||
[f"setp.ne.s16 {ctx.r[pred][i]}, {render_val(pred.src[0].arg, x.dtype.scalar())}, 0;",
|
||||
f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {ctx.r[pred][i]};"] for i, uu in enumerate(ctx.r[x])])),
|
||||
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE, dtype=dtypes.half),), allow_any_len=True),
|
||||
lambda ctx, x, pred: flatten([[f"mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[pred][i]}, {render_val(pred.src[0].arg, x.dtype.scalar())};",
|
||||
f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {ctx.r[pred][i]};"] for i, uu in enumerate(ctx.r[x])])),
|
||||
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE),), allow_any_len=True), lambda ctx, x, pred: [
|
||||
f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {render_val(pred.src[0].arg, x.dtype.scalar())};" for i, uu in enumerate(ctx.r[x])]),
|
||||
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.CONST, dtype=dtypes.bool), ), allow_any_len=True), lambda ctx, x, pred: [
|
||||
f"setp.ne.s16 {ctx.r[pred]}, {render_val(pred.arg, pred.dtype)}, 0;", f"mov.pred {ctx.r[x]}, {ctx.r[pred]};"]),
|
||||
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.CONST), ), allow_any_len=True),
|
||||
lambda ctx, x, pred: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(pred.arg, x.dtype)};"),
|
||||
(UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, {ctx.r[x.src[0]]};", "LOOP_" + f"{ctx.r[x][1:]}:"]),
|
||||
(UPat(Ops.ASSIGN, name="x", dtype=dtypes.bool), lambda ctx, x: [f"mov.pred {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"]),
|
||||
(UPat(Ops.ASSIGN, name="x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"),
|
||||
(UPat(Ops.ENDRANGE, name="x", src=(UPat.var("src0"),)), lambda ctx, x, src0: [
|
||||
ctx.code_for_op[Ops.ADD](ctx.r[src0], ctx.r[src0], "1", dtypes.int, ctx.types[dtypes.int]),
|
||||
ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[1]], dtypes.int, ctx.types[dtypes.int]),
|
||||
f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"),
|
||||
lambda ctx, x: [f".shared .align 4 .b8 {x.arg[0]}[{x.arg[1]*x.dtype.itemsize}];", f"mov.u64 {ctx.r[x]}, {x.arg[0]}[0];"]),
|
||||
(UPat(Ops.IF, name="x"), lambda ctx, x: f"@!{ctx.r[x.src[0]]} bra IF_{ctx.r[x.src[0]][1:]}_{ctx.uops.index(x)};"),
|
||||
(UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"),
|
||||
(UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))),
|
||||
(UPat(Ops.BARRIER, name="x"), lambda ctx, x: ctx.barrier),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx, x: f"ld.param.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{x.arg[0]}+0];"),
|
||||
])
|
||||
|
||||
class PTXRenderer(Renderer):
|
||||
device = "CUDA"
|
||||
suffix = "PTX"
|
||||
@@ -80,34 +158,6 @@ class PTXRenderer(Renderer):
|
||||
mem_types: Dict[DType, str] = types.copy()
|
||||
mem_types.update({dtypes.int8: "s8", dtypes.uint8: "u8", dtypes.bool: "u8", dtypes.float16: "b16"})
|
||||
|
||||
const_requires_mov: List[DType] = [dtypes.half, dtypes.bool]
|
||||
|
||||
def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]:
|
||||
val = render_val(x, dtype)
|
||||
if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"]
|
||||
return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val
|
||||
|
||||
def render_local(self, dest, name, size, dtype) -> List[str]:
|
||||
return [f".shared .align 4 .b8 {name}[{size*dtype.itemsize}];", f"mov.u64 {dest}, {name}[0];"]
|
||||
|
||||
def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"]
|
||||
|
||||
def render_bra(self, b1, pred=None, invert=False) -> List[str]:
|
||||
return [f"@{'!' if invert else ''}{pred} bra {b1};"] if pred else [f"bra {b1};"]
|
||||
|
||||
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]:
|
||||
assert dtype != dtypes.bool
|
||||
if gate: return [f"@{gate} ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"]
|
||||
return [f"ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];"]
|
||||
|
||||
def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]:
|
||||
if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"]
|
||||
if atype == dtypes.bool: return [f"selp.b{self.types[dtype][1:]} {d}, {render_val(1, dtype)}, {render_val(0, dtype)}, {a};"]
|
||||
if dtype == dtypes.bool: return [f"setp.ne.b{self.types[atype][1:]} {d}, {a}, {self.render_const(0, atype)};"]
|
||||
rnd = ('.rzi' if dtypes.is_int(dtype) and dtypes.is_float(atype) else
|
||||
'.rn' if dtypes.is_float(dtype) and (dtype.itemsize < atype.itemsize or dtypes.is_int(atype) or atype == dtypes.bool) else '')
|
||||
return [f"cvt{rnd}.{self.types[dtype]}.{self.types[atype]} {d}, {a};"]
|
||||
|
||||
def render_kernel(self, kernel, function_name, bufs, regs) -> str:
|
||||
kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]
|
||||
def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1)
|
||||
@@ -120,116 +170,56 @@ class PTXRenderer(Renderer):
|
||||
kernel:List[str] = []
|
||||
bufs = []
|
||||
|
||||
def kk(*s: str): kernel.append("\n".join(s))
|
||||
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
r: Dict[UOp, Union[List[str], str]] = {}
|
||||
self.r = r
|
||||
self.uops = uops
|
||||
|
||||
def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
|
||||
nonlocal c, r
|
||||
prefix += f"_{dtype if dtype is not None else self.types[cast(UOp, u).dtype]}_"
|
||||
c[prefix] += 1
|
||||
if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
|
||||
return f"%{prefix}{c[prefix]-1}"
|
||||
|
||||
def const(x:ConstType, dtype:DType, mov=False):
|
||||
if mov or dtype in self.const_requires_mov:
|
||||
kk(*self.render_const(x, dtype, mov=(out:=ssa('const', dtype=self.types[dtype]))))
|
||||
return out
|
||||
return self.render_const(x, dtype)
|
||||
|
||||
def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
|
||||
if atype == dtype or isinstance(atype, PtrDType):
|
||||
if u is not None: r[u] = a
|
||||
return a
|
||||
kk(*self.render_cast((ret:=ssa('cast', u, self.types[dtype])), a, dtype, atype, bitcast))
|
||||
return ret
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
|
||||
if uop is Ops.IF:
|
||||
pred_reg = _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True)
|
||||
kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{uops.index(u)}", pred_reg, invert=True))
|
||||
elif uop is Ops.BARRIER and self.barrier: kk(self.barrier)
|
||||
elif uop is Ops.ENDRANGE:
|
||||
kk(self.code_for_op[Ops.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]),
|
||||
self.code_for_op[Ops.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int]))
|
||||
kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred))
|
||||
elif uop is Ops.ENDIF:
|
||||
kk(f"IF_{r[src[0].src[0]][1:]}_{uops.index(src[0])}:")
|
||||
elif uop is Ops.STORE:
|
||||
assert src[0].dtype == dtypes.int64, "store isn't int64"
|
||||
mem_type = '.shared' if src[0].op is Ops.DEFINE_LOCAL or any(x.op is Ops.DEFINE_LOCAL for x in src[0].parents) else '.global'
|
||||
gate = f"@{r[src[2]]} " if len(src)>2 and src[2].op is not Ops.IF else ""
|
||||
if src[1].dtype.count > 1:
|
||||
kk(gate + f"st{mem_type}.v{src[1].dtype.count}.{self.mem_types[src[1].dtype.scalar()]} [{r[src[0]]}+0], {{{', '.join(r[src[1]])}}};")
|
||||
else:
|
||||
kk(gate + f"st{mem_type}.{self.mem_types[src[1].dtype]} [{r[src[0]]}+0], {r[src[1]]};")
|
||||
else:
|
||||
if uop is Ops.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
|
||||
elif uop in GroupOp.ALU:
|
||||
src_dtype = src[0].dtype if uop in {Ops.CMPLT, Ops.CMPNE} else dtype
|
||||
kk(self.code_for_op[uop](ssa("alu", u), *[r[x] for x in src], src_dtype, self.types[src_dtype]))
|
||||
elif uop is Ops.DEFINE_ACC:
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].src[0].arg, dtype.scalar())};")
|
||||
else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
|
||||
elif uop is Ops.SPECIAL:
|
||||
assert args[0][0] != "i", "idx not supported"
|
||||
kk(f"mov.u32 %{args[0]}, %{'ctaid' if args[0][0] == 'g' else 'tid'}.{chr(120+int(args[0][-1]))};")
|
||||
r[u] = "%" + args[0]
|
||||
kernel = [f".reg .u32 %{args[0]};"] + kernel
|
||||
elif uop is Ops.DEFINE_VAR:
|
||||
bufs.append((args[0], dtype))
|
||||
r[u] = f"%{args[0]}"
|
||||
kk(*self.render_load(args[0], ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
|
||||
elif uop is Ops.CONST: r[u] = const(args, dtype, mov=True)
|
||||
elif uop is Ops.GEP:
|
||||
assert len(u.arg) == 1
|
||||
r[u] = r[src[0]][u.arg[0]]
|
||||
elif uop is Ops.LOAD:
|
||||
assert src[0].dtype == dtypes.int64, "load isn't int64"
|
||||
mem_type = '.shared' if src[0].op is Ops.DEFINE_LOCAL or any(x.op is Ops.DEFINE_LOCAL for x in src[0].parents) else '.global'
|
||||
has_gate = len(src) > 2 and src[2].op in GroupOp.ALU
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
if has_gate:
|
||||
for v in r[u]: kk(f"mov.{self.mem_types[dtype.scalar()]} {v}, {render_val(0, dtype.scalar())};")
|
||||
kk((f"@{r[src[2]]}" if has_gate else "")
|
||||
+ f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[src[0]]}+0];")
|
||||
else:
|
||||
kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if has_gate else None,
|
||||
alt=r[src[1]] if has_gate else None, ss=mem_type, offset=0))
|
||||
elif uop is Ops.ASSIGN:
|
||||
if dtype.count > 1:
|
||||
for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
|
||||
else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {r[src[0]]}, {r[src[1]]};")
|
||||
r[u] = r[src[0]]
|
||||
# NOTE: casting to str is fine because you can't vectorize a vectorize
|
||||
elif uop is Ops.VECTORIZE: r[u] = [cast(str,r[x]) for x in src]
|
||||
elif uop in {Ops.CAST, Ops.BITCAST}:
|
||||
_cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is Ops.BITCAST, u=u)
|
||||
elif uop is Ops.DEFINE_LOCAL:
|
||||
# TODO: we should sum these, and fetch 0xC000 from somewhere
|
||||
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
|
||||
kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype))
|
||||
elif uop is Ops.DEFINE_GLOBAL:
|
||||
bufs.append((nm:=f"data{args}", dtype))
|
||||
r[u] = f"%{nm}"
|
||||
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
|
||||
kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
|
||||
elif uop is Ops.WMMA:
|
||||
_, (N, M, K), dtype_in, _, _, _, upcast_axes, _ = args
|
||||
wmma, n_operands = [], tuple(prod(sz for _, sz in upc)*dtype_in.itemsize//4 for upc in upcast_axes[:2])
|
||||
dt_map = { dtypes.half: "f16" }
|
||||
for vv in src[:2]:
|
||||
for i in range(0, len(r[vv]), 2):
|
||||
wmma.append(ssa("wmma", dtype="b32"))
|
||||
kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
|
||||
r[u] = [ssa("wmma", dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
kk(f'mma.sync.aligned.m{M}n{N}k{K}.row.col.f32.{dt_map[dtype_in]}.{dt_map[dtype_in]}.f32\
|
||||
{{{", ".join(r[u])}}}, {{{", ".join(wmma[:n_operands[0]])}}}, {{{", ".join(wmma[-n_operands[1]:])}}}, {{{", ".join(r[src[2]])}}};')
|
||||
else: raise NotImplementedError(f"no code for {uop}")
|
||||
if u.op is Ops.VECTORIZE:
|
||||
r[u] = [cast(str,r[x]) for x in u.src]
|
||||
continue
|
||||
if u.op is Ops.GEP:
|
||||
assert len(u.arg) == 1
|
||||
r[u] = r[u.src[0]][u.arg[0]]
|
||||
continue
|
||||
if u.op in {Ops.CAST, Ops.BITCAST}:
|
||||
if u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType):
|
||||
r[u] = r[u.src[0]]
|
||||
continue
|
||||
r[u] = ssa('cast', u, self.types[u.dtype])
|
||||
elif u.op is Ops.ENDRANGE: r[u] = ssa("pred", u, dtype="pred")
|
||||
elif u.op is Ops.RANGE: r[u] = ssa("ridx", u)
|
||||
elif u.op in GroupOp.ALU: r[u] = ssa("alu", u)
|
||||
elif u.op is Ops.DEFINE_ACC:
|
||||
if u.dtype.scalar() in [dtypes.half, dtypes.bool]:
|
||||
r[u.src[0]] = [ssa("const", u.src[0].src[0]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("const", u.src[0])
|
||||
r[u] = [ssa('acc', u, dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("acc", u)
|
||||
elif u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0]
|
||||
elif u.op is Ops.DEFINE_VAR:
|
||||
bufs.append((u.arg[0], u.dtype))
|
||||
r[u] = ssa("dat", u, self.types[u.dtype])
|
||||
elif u.op is Ops.CONST: r[u] = ssa("const", u, dtype=self.types[u.dtype])
|
||||
elif u.op is Ops.LOAD:
|
||||
assert u.src[0].dtype == dtypes.int64, "load isn't int64"
|
||||
r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u)
|
||||
elif u.op is Ops.DEFINE_LOCAL: r[u] = ssa('local', u, self.types[dtypes.ulong])
|
||||
elif u.op is Ops.DEFINE_GLOBAL:
|
||||
bufs.append((f"data{u.arg}", u.dtype))
|
||||
r[u] = ssa('dat', u, self.types[dtypes.ulong if u.dtype.__class__ == PtrDType else u.dtype])
|
||||
elif u.op is Ops.WMMA:
|
||||
self.wmma_r = [ssa("wmma", dtype="b32") for vv in u.src[:2] for i in range(0, len(r[vv]), 2)]
|
||||
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
|
||||
if (l:=cast(Union[str, List[str]], string_rewrite.rewrite(u, ctx=self))) is None:
|
||||
raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.u.src]}")
|
||||
kernel.extend([l] if isinstance(l, str) else l)
|
||||
|
||||
if u.op is Ops.ASSIGN: r[u] = r[u.src[0]]
|
||||
elif u.op is Ops.SPECIAL: kernel = [f".reg .u32 %{u.arg[0]};"] + kernel
|
||||
return self.render_kernel(kernel, name, bufs, c.items())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user