mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
src->vin [run_process_replay] (#5036)
This commit is contained in:
@@ -20,8 +20,8 @@ def _uops_to_prg(uops_list, print=False):
|
||||
has_local = Device[Device.DEFAULT].renderer.has_local
|
||||
return CompiledRunner(Program("test", src, Device.DEFAULT, [1,1,1] if has_local else None, [1,1,1] if has_local else None, uops=uops))
|
||||
|
||||
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
|
||||
uops.append(UOp(uop, dtype, tuple(vin), arg))
|
||||
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], src:Tuple[UOp, ...], arg:Any=None) -> UOp:
|
||||
uops.append(UOp(uop, dtype, tuple(src), arg))
|
||||
return uops[-1]
|
||||
|
||||
def _test_single_value(vals, op, dts):
|
||||
|
||||
@@ -517,7 +517,7 @@ class Linearizer(Kernel):
|
||||
for off in range(len(acc)):
|
||||
if input_acc[off] != acc[off]:
|
||||
acc[off] = UOp(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]))
|
||||
else: ret = [UOp.alu(x.op, *vin) for vin in zip(*values)]
|
||||
else: ret = [UOp.alu(x.op, *src) for src in zip(*values)]
|
||||
cache[x] = ret
|
||||
return ret
|
||||
|
||||
|
||||
@@ -65,11 +65,11 @@ class UOp:
|
||||
if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (), b)
|
||||
return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
|
||||
@staticmethod
|
||||
def alu(arg, *vin:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else vin[-1].dtype, vin, arg)
|
||||
def alu(arg, *src:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else src[-1].dtype, src, arg)
|
||||
@staticmethod
|
||||
def load(*vin:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.LOAD, dtype, tuple(vin)+tuple(kwargs.values()))
|
||||
def load(*src:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.LOAD, dtype, tuple(src)+tuple(kwargs.values()))
|
||||
@staticmethod
|
||||
def store(*vin:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.STORE, dtype, tuple(vin)+tuple(kwargs.values()))
|
||||
def store(*src:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.STORE, dtype, tuple(src)+tuple(kwargs.values()))
|
||||
@staticmethod
|
||||
def var(name: Optional[str]=None, dtype: Optional[DType]=None): return UOp(UOps.VAR, dtype=dtype, arg=name)
|
||||
@staticmethod
|
||||
@@ -103,7 +103,7 @@ class UPat:
|
||||
@staticmethod
|
||||
def compile(u: UOp, name:Optional[str]=None) -> UPat:
|
||||
if u.op is UOps.VAR: return UPat(name=name or u.arg, dtype=u.dtype) if len(u.src) == 0 else UPat.compile(u.src[0], name or u.arg)
|
||||
return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(vin) for vin in u.src]) if u.src != () else None, name, u.dtype)
|
||||
return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(src) for src in u.src]) if u.src != () else None, name, u.dtype)
|
||||
|
||||
T = TypeVar("T")
|
||||
def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool:
|
||||
@@ -458,23 +458,23 @@ class UOpGraph:
|
||||
|
||||
def type_verify(self):
|
||||
for u in self.uops:
|
||||
uop, arg, vin, dtype = u.op, u.arg, u.src, u.dtype
|
||||
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
|
||||
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
|
||||
if uop is UOps.DEFINE_ACC: arg = arg[0]
|
||||
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
||||
if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
|
||||
if uop is UOps.ALU:
|
||||
if arg in UnaryOps:
|
||||
assert dtype == vin[0].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=}"
|
||||
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
elif arg in (BinaryOps.CMPLT, BinaryOps.CMPNE):
|
||||
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
|
||||
assert vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
|
||||
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
elif arg is BinaryOps.IDIV:
|
||||
assert dtypes.is_int(vin[0].dtype) and dtypes.is_int(vin[1].dtype), \
|
||||
f"input dtype mismatch {dtypes.int} != {vin[0].dtype=} != {vin[1].dtype=}"
|
||||
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), \
|
||||
f"input dtype mismatch {dtypes.int} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
assert dtypes.is_int(dtype), f"{arg} output dtype mismatch {dtype=} != {dtypes.int}"
|
||||
elif arg in BinaryOps:
|
||||
assert dtype == vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
|
||||
assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
elif arg == TernaryOps.WHERE:
|
||||
assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}"
|
||||
assert dtype == vin[1].dtype == vin[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {vin[1].dtype=} != {vin[2].dtype=}"
|
||||
assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
|
||||
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
|
||||
|
||||
@@ -131,37 +131,37 @@ class PTXRenderer(Renderer):
|
||||
return ret
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.op,u.dtype,u.src,u.arg
|
||||
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
|
||||
if uop is UOps.IF:
|
||||
assert vin[0].dtype is not None
|
||||
kk(*self.render_bra(f"IF_{r[vin[0]][1:]}_{cast(List, uops._uops).index(u)}", _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True)))
|
||||
assert src[0].dtype is not None
|
||||
kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{cast(List, uops._uops).index(u)}", _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True)))
|
||||
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
|
||||
elif uop is UOps.ENDRANGE:
|
||||
kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]),
|
||||
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].src[1]], dtypes.int, self.types[dtypes.int]))
|
||||
kk(*self.render_bra(f"LOOP_{r[vin[0]][1:]}", pred))
|
||||
kk(self.asm_for_op[BinaryOps.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]),
|
||||
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int]))
|
||||
kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred))
|
||||
elif uop is UOps.ENDIF:
|
||||
kk(f"IF_{r[vin[0].src[0]][1:]}_{cast(List, uops._uops).index(vin[0])}:")
|
||||
kk(f"IF_{r[src[0].src[0]][1:]}_{cast(List, uops._uops).index(src[0])}:")
|
||||
elif uop is UOps.STORE:
|
||||
assert vin[0].dtype is not None and vin[2].dtype is not None
|
||||
assert vin[0].dtype == dtypes.int64, "store isn't int64"
|
||||
assert vin[1].op is UOps.CONST, f"store isn't const {u}"
|
||||
mem_type = '.shared' if vin[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global'
|
||||
if vin[2].dtype.count > 1:
|
||||
kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \
|
||||
f"st{mem_type}.v{vin[2].dtype.count}.{self.mem_types[vin[2].dtype.scalar()]} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};")
|
||||
assert src[0].dtype is not None and src[2].dtype is not None
|
||||
assert src[0].dtype == dtypes.int64, "store isn't int64"
|
||||
assert src[1].op is UOps.CONST, f"store isn't const {u}"
|
||||
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
|
||||
if src[2].dtype.count > 1:
|
||||
kk((f"@{r[src[3]]} " if len(src)>3 else "") + \
|
||||
f"st{mem_type}.v{src[2].dtype.count}.{self.mem_types[src[2].dtype.scalar()]} [{r[src[0]]}+{src[1].arg}], {{{', '.join(r[src[2]])}}};")
|
||||
else:
|
||||
kk(*self.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=mem_type, offset=vin[1].arg))
|
||||
kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype, gate=r[src[3]] if len(src)>3 else None, ss=mem_type, offset=src[1].arg))
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[vin[0]], "LOOP_"+loop[1:]))
|
||||
if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
|
||||
elif uop is UOps.ALU:
|
||||
assert vin[0].dtype is not None
|
||||
assert src[0].dtype is not None
|
||||
if args is BinaryOps.CMPLT or args is BinaryOps.CMPNE:
|
||||
# pass in the other dtype here
|
||||
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], vin[0].dtype, self.types[vin[0].dtype]))
|
||||
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], src[0].dtype, self.types[src[0].dtype]))
|
||||
else:
|
||||
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], dtype, self.types[dtype]))
|
||||
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], dtype, self.types[dtype]))
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
@@ -175,30 +175,30 @@ class PTXRenderer(Renderer):
|
||||
elif uop is UOps.CONST:
|
||||
if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
|
||||
else: r[u] = const(args, dtype, mov=True)
|
||||
elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg]
|
||||
elif uop is UOps.GEP: r[u] = r[src[0]][u.arg]
|
||||
elif uop is UOps.LOAD:
|
||||
assert vin[0].dtype == dtypes.int64, "load isn't int64"
|
||||
assert vin[1].op is UOps.CONST, f"load isn't const {u}"
|
||||
mem_type = '.shared' if vin[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global'
|
||||
assert src[0].dtype == dtypes.int64, "load isn't int64"
|
||||
assert src[1].op is UOps.CONST, f"load isn't const {u}"
|
||||
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
if(len(vin)>3):
|
||||
if(len(src)>3):
|
||||
for v in r[u]: kk(f"mov.{self.mem_types[dtype.scalar()]} {v}, {render_val(0, dtype.scalar())};")
|
||||
kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
|
||||
+ f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
|
||||
kk((f"@{r[src[2]]}"if len(src) > 3 else "")
|
||||
+ f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[src[0]]}+{src[1].arg}];")
|
||||
else:
|
||||
kk(*self.render_load(r[vin[0]], ssa('val', u), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
|
||||
alt=r[vin[3]] if len(vin) > 3 else None, ss=mem_type, offset=vin[1].arg))
|
||||
kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if len(src) > 3 else None,
|
||||
alt=r[src[3]] if len(src) > 3 else None, ss=mem_type, offset=src[1].arg))
|
||||
elif uop is UOps.PHI:
|
||||
if dtype.count > 1:
|
||||
for x0, x1 in zip(r[vin[0]], r[vin[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
|
||||
for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
|
||||
else:
|
||||
kk(f"mov.b{self.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
|
||||
r[u] = r[vin[0]]
|
||||
kk(f"mov.b{self.types[dtype][1:]} {r[src[0]]}, {r[src[1]]};")
|
||||
r[u] = r[src[0]]
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
assert vin[0].dtype is not None
|
||||
if dtype.count>1: r[u] = [r[x] for x in vin] # type: ignore
|
||||
else: _cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
|
||||
assert src[0].dtype is not None
|
||||
if dtype.count>1: r[u] = [r[x] for x in src] # type: ignore
|
||||
else: _cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
|
||||
elif uop is UOps.DEFINE_LOCAL:
|
||||
# TODO: we should sum these, and fetch 0xC000 from somewhere
|
||||
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
|
||||
@@ -214,13 +214,13 @@ class PTXRenderer(Renderer):
|
||||
kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
|
||||
elif uop is UOps.WMMA:
|
||||
wmma = []
|
||||
for vv in vin[:2]:
|
||||
for vv in src[:2]:
|
||||
for i in range(0, len(r[vv]), 2):
|
||||
wmma.append(ssa("wmma", dtype="b32"))
|
||||
kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
|
||||
r[u] = [ssa("wmma", dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\
|
||||
{{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[vin[2]])}}};')
|
||||
{{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[src[2]])}}};')
|
||||
else: raise NotImplementedError(f"no code for {uop}")
|
||||
|
||||
return self.render_kernel(kernel, name, bufs, c.items())
|
||||
|
||||
@@ -107,28 +107,28 @@ class CStyleLanguage(Renderer):
|
||||
child_count = Counter(v for ru in uops for v in ru.src)
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.op,u.dtype,u.src,u.arg
|
||||
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
|
||||
# these four uops don't have output dtypes
|
||||
if uop is UOps.IF:
|
||||
kk(f"if ({r[vin[0]]}) {{")
|
||||
kk(f"if ({r[src[0]]}) {{")
|
||||
depth += 1
|
||||
elif uop is UOps.BARRIER: kk(self.barrier)
|
||||
elif uop in {UOps.ENDRANGE, UOps.ENDIF}:
|
||||
depth -= 1
|
||||
kk("}")
|
||||
elif uop is UOps.STORE:
|
||||
assert vin[0].dtype is not None and vin[2].dtype is not None
|
||||
rendered_store = self.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].op is UOps.DEFINE_LOCAL)
|
||||
kk(f"if ({r[vin[3]]}) {{ {rendered_store} }}" if len(vin) > 3 else rendered_store)
|
||||
assert src[0].dtype is not None and src[2].dtype is not None
|
||||
rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
|
||||
kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 else rendered_store)
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.RANGE:
|
||||
kk(f"for (int {(expr := ssa('ridx',u))} = {r[vin[0]]}; {expr} < {r[vin[1]]}; {expr}++) {{")
|
||||
kk(f"for (int {(expr := ssa('ridx',u))} = {r[src[0]]}; {expr} < {r[src[1]]}; {expr}++) {{")
|
||||
depth += 1
|
||||
elif uop is UOps.ALU:
|
||||
# remove parens if ALU types are the same. TODO: can do more here
|
||||
if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in vin]
|
||||
else: operands = [r[v] for v in vin]
|
||||
if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in src]
|
||||
else: operands = [r[v] for v in src]
|
||||
val = self.code_for_op[args](*operands, dtype)
|
||||
assert child_count[u] != 0, f"childless ALU op found {u}"
|
||||
# TODO: fix index rendering issue. fix clang nested max macro issue
|
||||
@@ -138,21 +138,21 @@ class CStyleLanguage(Renderer):
|
||||
kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
|
||||
r[u] = args[1]
|
||||
elif uop is UOps.LOAD:
|
||||
val = self.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].op is UOps.DEFINE_LOCAL)
|
||||
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
|
||||
# NOTE: this relies on the load not happening if it's in the unselected branch
|
||||
if len(vin) > 3: val = self.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype)
|
||||
if len(src) > 3: val = self.code_for_op[TernaryOps.WHERE](r[src[2]], val, r[src[3]], dtype)
|
||||
kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
|
||||
elif uop is UOps.PHI:
|
||||
kk(f"{r[vin[0]]} = {r[vin[1]]};")
|
||||
r[u] = r[vin[0]]
|
||||
kk(f"{r[src[0]]} = {r[src[1]]};")
|
||||
r[u] = r[src[0]]
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
if uop is UOps.BITCAST:
|
||||
assert len(vin) == 1
|
||||
assert len(src) == 1
|
||||
precast = ssa('precast')
|
||||
kk(f"{self.render_dtype(cast(DType, vin[0].dtype))} {precast} = {r[vin[0]]};")
|
||||
kk(f"{self.render_dtype(cast(DType, src[0].dtype))} {precast} = {r[src[0]]};")
|
||||
val = self.render_cast([precast], dtype, bitcast=True)
|
||||
else:
|
||||
val = self.render_cast([r[x] for x in vin], dtype, bitcast=False)
|
||||
val = self.render_cast([r[x] for x in src], dtype, bitcast=False)
|
||||
if child_count[u] <= 1: r[u] = val
|
||||
else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};")
|
||||
elif uop is UOps.DEFINE_LOCAL:
|
||||
@@ -164,13 +164,13 @@ class CStyleLanguage(Renderer):
|
||||
elif uop is UOps.DEFINE_GLOBAL:
|
||||
bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
|
||||
r[u] = nm
|
||||
elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
|
||||
elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});")
|
||||
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(args[0], dtype)};")
|
||||
elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
|
||||
elif uop is UOps.GEP:
|
||||
assert vin[0].dtype is not None
|
||||
from_ssa = vin[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
|
||||
r[u] = (r[vin[0]] if from_ssa else f"{(r[vin[0]])}") + (f"[{args}]" if vin[0].dtype.count > 4 else f".{'xyzw'[args]}")
|
||||
assert src[0].dtype is not None
|
||||
from_ssa = src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
|
||||
r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + (f"[{args}]" if src[0].dtype.count > 4 else f".{'xyzw'[args]}")
|
||||
else: raise RuntimeError(f"failed to render {uop}")
|
||||
|
||||
return self.render_kernel(name, kernel, bufs, uops)
|
||||
|
||||
@@ -101,21 +101,21 @@ class LLVMRenderer(Renderer):
|
||||
if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.op,u.dtype,u.src,u.arg
|
||||
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
|
||||
if uop is UOps.STORE:
|
||||
element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype)
|
||||
if len(vin) > 3:
|
||||
with bb[-1].if_then(lvars[vin[3]]):
|
||||
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
||||
element = cast(bb, lvars[src[2]], src[2].dtype, src[0].dtype)
|
||||
if len(src) > 3:
|
||||
with bb[-1].if_then(lvars[src[3]]):
|
||||
bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
|
||||
else:
|
||||
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
||||
bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
|
||||
elif uop is UOps.ENDRANGE:
|
||||
loop_entry_bb, phis = loop_blocks.pop()
|
||||
idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1))
|
||||
lvars[vin[0]].add_incoming(idx_p1, bb[-1].block)
|
||||
idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))
|
||||
lvars[src[0]].add_incoming(idx_p1, bb[-1].block)
|
||||
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
|
||||
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].src[1]]), loop_entry_bb, bb[-1].block)
|
||||
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block)
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.RANGE:
|
||||
@@ -130,28 +130,28 @@ class LLVMRenderer(Renderer):
|
||||
phis.append((rp, lvars[rp]))
|
||||
|
||||
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
|
||||
lvars[u].add_incoming(lvars[vin[0]], bb[-2].block)
|
||||
lvars[u].add_incoming(lvars[src[0]], bb[-2].block)
|
||||
loop_blocks.append((bb[-1].block, phis))
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
lvars[u] = const(args[0], dtype)
|
||||
reduce_phis.append(u)
|
||||
elif uop is UOps.LOAD:
|
||||
if len(vin) > 2:
|
||||
aug_idx = bb[-1].select(lvars[vin[2]], lvars[vin[1]], ir.Constant(ir.IntType(32), 0))
|
||||
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True))
|
||||
val = bb[-1].select(lvars[vin[2]], val, lvars[vin[3]])
|
||||
if len(src) > 2:
|
||||
aug_idx = bb[-1].select(lvars[src[2]], lvars[src[1]], ir.Constant(ir.IntType(32), 0))
|
||||
val = bb[-1].load(bb[-1].gep(lvars[src[0]], [aug_idx], inbounds=True))
|
||||
val = bb[-1].select(lvars[src[2]], val, lvars[src[3]])
|
||||
else:
|
||||
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
||||
val = bb[-1].load(bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
|
||||
lvars[u] = val
|
||||
elif uop is UOps.PHI:
|
||||
lvars[u] = lvars[vin[1]]
|
||||
lvars[u] = lvars[src[1]]
|
||||
# PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
|
||||
backward = vin[0]
|
||||
backward = src[0]
|
||||
while backward.op is UOps.PHI: backward = backward.src[0]
|
||||
lvars[backward] = lvars[u]
|
||||
elif uop is UOps.ALU:
|
||||
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPNE) else vin[0].dtype)
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
|
||||
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in src], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPNE) else src[0].dtype)
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
|
||||
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
|
||||
elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
|
||||
elif uop is UOps.CONST: lvars[u] = const(args, dtype)
|
||||
|
||||
Reference in New Issue
Block a user