src->vin [run_process_replay] (#5036)

This commit is contained in:
kormann
2024-06-18 21:23:49 +02:00
committed by GitHub
parent f171006ded
commit fe332464d2
6 changed files with 89 additions and 89 deletions

View File

@@ -20,8 +20,8 @@ def _uops_to_prg(uops_list, print=False):
has_local = Device[Device.DEFAULT].renderer.has_local
return CompiledRunner(Program("test", src, Device.DEFAULT, [1,1,1] if has_local else None, [1,1,1] if has_local else None, uops=uops))
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(vin), arg))
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], src:Tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(src), arg))
return uops[-1]
def _test_single_value(vals, op, dts):

View File

@@ -517,7 +517,7 @@ class Linearizer(Kernel):
for off in range(len(acc)):
if input_acc[off] != acc[off]:
acc[off] = UOp(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]))
else: ret = [UOp.alu(x.op, *vin) for vin in zip(*values)]
else: ret = [UOp.alu(x.op, *src) for src in zip(*values)]
cache[x] = ret
return ret

View File

@@ -65,11 +65,11 @@ class UOp:
if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (), b)
return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
@staticmethod
def alu(arg, *vin:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else vin[-1].dtype, vin, arg)
def alu(arg, *src:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else src[-1].dtype, src, arg)
@staticmethod
def load(*vin:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.LOAD, dtype, tuple(vin)+tuple(kwargs.values()))
def load(*src:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.LOAD, dtype, tuple(src)+tuple(kwargs.values()))
@staticmethod
def store(*vin:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.STORE, dtype, tuple(vin)+tuple(kwargs.values()))
def store(*src:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.STORE, dtype, tuple(src)+tuple(kwargs.values()))
@staticmethod
def var(name: Optional[str]=None, dtype: Optional[DType]=None): return UOp(UOps.VAR, dtype=dtype, arg=name)
@staticmethod
@@ -103,7 +103,7 @@ class UPat:
@staticmethod
def compile(u: UOp, name:Optional[str]=None) -> UPat:
if u.op is UOps.VAR: return UPat(name=name or u.arg, dtype=u.dtype) if len(u.src) == 0 else UPat.compile(u.src[0], name or u.arg)
return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(vin) for vin in u.src]) if u.src != () else None, name, u.dtype)
return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(src) for src in u.src]) if u.src != () else None, name, u.dtype)
T = TypeVar("T")
def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool:
@@ -458,23 +458,23 @@ class UOpGraph:
def type_verify(self):
for u in self.uops:
uop, arg, vin, dtype = u.op, u.arg, u.src, u.dtype
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
if uop is UOps.DEFINE_ACC: arg = arg[0]
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
if uop is UOps.ALU:
if arg in UnaryOps:
assert dtype == vin[0].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=}"
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
elif arg in (BinaryOps.CMPLT, BinaryOps.CMPNE):
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
assert vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
elif arg is BinaryOps.IDIV:
assert dtypes.is_int(vin[0].dtype) and dtypes.is_int(vin[1].dtype), \
f"input dtype mismatch {dtypes.int} != {vin[0].dtype=} != {vin[1].dtype=}"
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), \
f"input dtype mismatch {dtypes.int} != {src[0].dtype=} != {src[1].dtype=}"
assert dtypes.is_int(dtype), f"{arg} output dtype mismatch {dtype=} != {dtypes.int}"
elif arg in BinaryOps:
assert dtype == vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
elif arg == TernaryOps.WHERE:
assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}"
assert dtype == vin[1].dtype == vin[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {vin[1].dtype=} != {vin[2].dtype=}"
assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"

View File

@@ -131,37 +131,37 @@ class PTXRenderer(Renderer):
return ret
for u in uops:
uop,dtype,vin,args = u.op,u.dtype,u.src,u.arg
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
if uop is UOps.IF:
assert vin[0].dtype is not None
kk(*self.render_bra(f"IF_{r[vin[0]][1:]}_{cast(List, uops._uops).index(u)}", _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True)))
assert src[0].dtype is not None
kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{cast(List, uops._uops).index(u)}", _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True)))
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
elif uop is UOps.ENDRANGE:
kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]),
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].src[1]], dtypes.int, self.types[dtypes.int]))
kk(*self.render_bra(f"LOOP_{r[vin[0]][1:]}", pred))
kk(self.asm_for_op[BinaryOps.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]),
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int]))
kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred))
elif uop is UOps.ENDIF:
kk(f"IF_{r[vin[0].src[0]][1:]}_{cast(List, uops._uops).index(vin[0])}:")
kk(f"IF_{r[src[0].src[0]][1:]}_{cast(List, uops._uops).index(src[0])}:")
elif uop is UOps.STORE:
assert vin[0].dtype is not None and vin[2].dtype is not None
assert vin[0].dtype == dtypes.int64, "store isn't int64"
assert vin[1].op is UOps.CONST, f"store isn't const {u}"
mem_type = '.shared' if vin[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global'
if vin[2].dtype.count > 1:
kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \
f"st{mem_type}.v{vin[2].dtype.count}.{self.mem_types[vin[2].dtype.scalar()]} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};")
assert src[0].dtype is not None and src[2].dtype is not None
assert src[0].dtype == dtypes.int64, "store isn't int64"
assert src[1].op is UOps.CONST, f"store isn't const {u}"
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
if src[2].dtype.count > 1:
kk((f"@{r[src[3]]} " if len(src)>3 else "") + \
f"st{mem_type}.v{src[2].dtype.count}.{self.mem_types[src[2].dtype.scalar()]} [{r[src[0]]}+{src[1].arg}], {{{', '.join(r[src[2]])}}};")
else:
kk(*self.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=mem_type, offset=vin[1].arg))
kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype, gate=r[src[3]] if len(src)>3 else None, ss=mem_type, offset=src[1].arg))
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[vin[0]], "LOOP_"+loop[1:]))
if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
elif uop is UOps.ALU:
assert vin[0].dtype is not None
assert src[0].dtype is not None
if args is BinaryOps.CMPLT or args is BinaryOps.CMPNE:
# pass in the other dtype here
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], vin[0].dtype, self.types[vin[0].dtype]))
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], src[0].dtype, self.types[src[0].dtype]))
else:
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], dtype, self.types[dtype]))
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], dtype, self.types[dtype]))
elif uop is UOps.DEFINE_ACC:
if dtype.count > 1:
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
@@ -175,30 +175,30 @@ class PTXRenderer(Renderer):
elif uop is UOps.CONST:
if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
else: r[u] = const(args, dtype, mov=True)
elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg]
elif uop is UOps.GEP: r[u] = r[src[0]][u.arg]
elif uop is UOps.LOAD:
assert vin[0].dtype == dtypes.int64, "load isn't int64"
assert vin[1].op is UOps.CONST, f"load isn't const {u}"
mem_type = '.shared' if vin[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global'
assert src[0].dtype == dtypes.int64, "load isn't int64"
assert src[1].op is UOps.CONST, f"load isn't const {u}"
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
if dtype.count > 1:
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
if(len(vin)>3):
if(len(src)>3):
for v in r[u]: kk(f"mov.{self.mem_types[dtype.scalar()]} {v}, {render_val(0, dtype.scalar())};")
kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
+ f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
kk((f"@{r[src[2]]}"if len(src) > 3 else "")
+ f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[src[0]]}+{src[1].arg}];")
else:
kk(*self.render_load(r[vin[0]], ssa('val', u), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
alt=r[vin[3]] if len(vin) > 3 else None, ss=mem_type, offset=vin[1].arg))
kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if len(src) > 3 else None,
alt=r[src[3]] if len(src) > 3 else None, ss=mem_type, offset=src[1].arg))
elif uop is UOps.PHI:
if dtype.count > 1:
for x0, x1 in zip(r[vin[0]], r[vin[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
else:
kk(f"mov.b{self.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
r[u] = r[vin[0]]
kk(f"mov.b{self.types[dtype][1:]} {r[src[0]]}, {r[src[1]]};")
r[u] = r[src[0]]
elif uop in {UOps.CAST, UOps.BITCAST}:
assert vin[0].dtype is not None
if dtype.count>1: r[u] = [r[x] for x in vin] # type: ignore
else: _cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
assert src[0].dtype is not None
if dtype.count>1: r[u] = [r[x] for x in src] # type: ignore
else: _cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
elif uop is UOps.DEFINE_LOCAL:
# TODO: we should sum these, and fetch 0xC000 from somewhere
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
@@ -214,13 +214,13 @@ class PTXRenderer(Renderer):
kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
elif uop is UOps.WMMA:
wmma = []
for vv in vin[:2]:
for vv in src[:2]:
for i in range(0, len(r[vv]), 2):
wmma.append(ssa("wmma", dtype="b32"))
kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
r[u] = [ssa("wmma", dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\
{{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[vin[2]])}}};')
{{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[src[2]])}}};')
else: raise NotImplementedError(f"no code for {uop}")
return self.render_kernel(kernel, name, bufs, c.items())

View File

@@ -107,28 +107,28 @@ class CStyleLanguage(Renderer):
child_count = Counter(v for ru in uops for v in ru.src)
for u in uops:
uop,dtype,vin,args = u.op,u.dtype,u.src,u.arg
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
# these four uops don't have output dtypes
if uop is UOps.IF:
kk(f"if ({r[vin[0]]}) {{")
kk(f"if ({r[src[0]]}) {{")
depth += 1
elif uop is UOps.BARRIER: kk(self.barrier)
elif uop in {UOps.ENDRANGE, UOps.ENDIF}:
depth -= 1
kk("}")
elif uop is UOps.STORE:
assert vin[0].dtype is not None and vin[2].dtype is not None
rendered_store = self.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].op is UOps.DEFINE_LOCAL)
kk(f"if ({r[vin[3]]}) {{ {rendered_store} }}" if len(vin) > 3 else rendered_store)
assert src[0].dtype is not None and src[2].dtype is not None
rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 else rendered_store)
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.RANGE:
kk(f"for (int {(expr := ssa('ridx',u))} = {r[vin[0]]}; {expr} < {r[vin[1]]}; {expr}++) {{")
kk(f"for (int {(expr := ssa('ridx',u))} = {r[src[0]]}; {expr} < {r[src[1]]}; {expr}++) {{")
depth += 1
elif uop is UOps.ALU:
# remove parens if ALU types are the same. TODO: can do more here
if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in vin]
else: operands = [r[v] for v in vin]
if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in src]
else: operands = [r[v] for v in src]
val = self.code_for_op[args](*operands, dtype)
assert child_count[u] != 0, f"childless ALU op found {u}"
# TODO: fix index rendering issue. fix clang nested max macro issue
@@ -138,21 +138,21 @@ class CStyleLanguage(Renderer):
kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
r[u] = args[1]
elif uop is UOps.LOAD:
val = self.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].op is UOps.DEFINE_LOCAL)
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
# NOTE: this relies on the load not happening if it's in the unselected branch
if len(vin) > 3: val = self.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype)
if len(src) > 3: val = self.code_for_op[TernaryOps.WHERE](r[src[2]], val, r[src[3]], dtype)
kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
elif uop is UOps.PHI:
kk(f"{r[vin[0]]} = {r[vin[1]]};")
r[u] = r[vin[0]]
kk(f"{r[src[0]]} = {r[src[1]]};")
r[u] = r[src[0]]
elif uop in {UOps.CAST, UOps.BITCAST}:
if uop is UOps.BITCAST:
assert len(vin) == 1
assert len(src) == 1
precast = ssa('precast')
kk(f"{self.render_dtype(cast(DType, vin[0].dtype))} {precast} = {r[vin[0]]};")
kk(f"{self.render_dtype(cast(DType, src[0].dtype))} {precast} = {r[src[0]]};")
val = self.render_cast([precast], dtype, bitcast=True)
else:
val = self.render_cast([r[x] for x in vin], dtype, bitcast=False)
val = self.render_cast([r[x] for x in src], dtype, bitcast=False)
if child_count[u] <= 1: r[u] = val
else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};")
elif uop is UOps.DEFINE_LOCAL:
@@ -164,13 +164,13 @@ class CStyleLanguage(Renderer):
elif uop is UOps.DEFINE_GLOBAL:
bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
r[u] = nm
elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});")
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(args[0], dtype)};")
elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
elif uop is UOps.GEP:
assert vin[0].dtype is not None
from_ssa = vin[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
r[u] = (r[vin[0]] if from_ssa else f"{(r[vin[0]])}") + (f"[{args}]" if vin[0].dtype.count > 4 else f".{'xyzw'[args]}")
assert src[0].dtype is not None
from_ssa = src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + (f"[{args}]" if src[0].dtype.count > 4 else f".{'xyzw'[args]}")
else: raise RuntimeError(f"failed to render {uop}")
return self.render_kernel(name, kernel, bufs, uops)

View File

@@ -101,21 +101,21 @@ class LLVMRenderer(Renderer):
if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
for u in uops:
uop,dtype,vin,args = u.op,u.dtype,u.src,u.arg
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
if uop is UOps.STORE:
element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype)
if len(vin) > 3:
with bb[-1].if_then(lvars[vin[3]]):
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
element = cast(bb, lvars[src[2]], src[2].dtype, src[0].dtype)
if len(src) > 3:
with bb[-1].if_then(lvars[src[3]]):
bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
else:
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
elif uop is UOps.ENDRANGE:
loop_entry_bb, phis = loop_blocks.pop()
idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1))
lvars[vin[0]].add_incoming(idx_p1, bb[-1].block)
idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))
lvars[src[0]].add_incoming(idx_p1, bb[-1].block)
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].src[1]]), loop_entry_bb, bb[-1].block)
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block)
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.RANGE:
@@ -130,28 +130,28 @@ class LLVMRenderer(Renderer):
phis.append((rp, lvars[rp]))
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
lvars[u].add_incoming(lvars[vin[0]], bb[-2].block)
lvars[u].add_incoming(lvars[src[0]], bb[-2].block)
loop_blocks.append((bb[-1].block, phis))
elif uop is UOps.DEFINE_ACC:
lvars[u] = const(args[0], dtype)
reduce_phis.append(u)
elif uop is UOps.LOAD:
if len(vin) > 2:
aug_idx = bb[-1].select(lvars[vin[2]], lvars[vin[1]], ir.Constant(ir.IntType(32), 0))
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True))
val = bb[-1].select(lvars[vin[2]], val, lvars[vin[3]])
if len(src) > 2:
aug_idx = bb[-1].select(lvars[src[2]], lvars[src[1]], ir.Constant(ir.IntType(32), 0))
val = bb[-1].load(bb[-1].gep(lvars[src[0]], [aug_idx], inbounds=True))
val = bb[-1].select(lvars[src[2]], val, lvars[src[3]])
else:
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
val = bb[-1].load(bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
lvars[u] = val
elif uop is UOps.PHI:
lvars[u] = lvars[vin[1]]
lvars[u] = lvars[src[1]]
# PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
backward = vin[0]
backward = src[0]
while backward.op is UOps.PHI: backward = backward.src[0]
lvars[backward] = lvars[u]
elif uop is UOps.ALU:
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPNE) else vin[0].dtype)
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in src], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPNE) else src[0].dtype)
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
elif uop is UOps.CONST: lvars[u] = const(args, dtype)