From fe332464d260ec5cded61691443e43df12716717 Mon Sep 17 00:00:00 2001 From: kormann <49917710+DKormann@users.noreply.github.com> Date: Tue, 18 Jun 2024 21:23:49 +0200 Subject: [PATCH] src->vin [run_process_replay] (#5036) --- test/test_uops.py | 4 +- tinygrad/codegen/linearizer.py | 2 +- tinygrad/codegen/uops.py | 24 ++++++------ tinygrad/renderer/assembly.py | 72 +++++++++++++++++----------------- tinygrad/renderer/cstyle.py | 38 +++++++++--------- tinygrad/renderer/llvmir.py | 38 +++++++++--------- 6 files changed, 89 insertions(+), 89 deletions(-) diff --git a/test/test_uops.py b/test/test_uops.py index 9a5b8e5d18..3ab6e274ac 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -20,8 +20,8 @@ def _uops_to_prg(uops_list, print=False): has_local = Device[Device.DEFAULT].renderer.has_local return CompiledRunner(Program("test", src, Device.DEFAULT, [1,1,1] if has_local else None, [1,1,1] if has_local else None, uops=uops)) -def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp: - uops.append(UOp(uop, dtype, tuple(vin), arg)) +def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], src:Tuple[UOp, ...], arg:Any=None) -> UOp: + uops.append(UOp(uop, dtype, tuple(src), arg)) return uops[-1] def _test_single_value(vals, op, dts): diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 5e1051dd61..be0f9ad1b8 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -517,7 +517,7 @@ class Linearizer(Kernel): for off in range(len(acc)): if input_acc[off] != acc[off]: acc[off] = UOp(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off])) - else: ret = [UOp.alu(x.op, *vin) for vin in zip(*values)] + else: ret = [UOp.alu(x.op, *src) for src in zip(*values)] cache[x] = ret return ret diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index fddfe2a31d..66e081718b 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -65,11 +65,11 @@ class UOp: if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (), b) return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) @staticmethod - def alu(arg, *vin:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else vin[-1].dtype, vin, arg) + def alu(arg, *src:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else src[-1].dtype, src, arg) @staticmethod - def load(*vin:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.LOAD, dtype, tuple(vin)+tuple(kwargs.values())) + def load(*src:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.LOAD, dtype, tuple(src)+tuple(kwargs.values())) @staticmethod - def store(*vin:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.STORE, dtype, tuple(vin)+tuple(kwargs.values())) + def store(*src:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.STORE, dtype, tuple(src)+tuple(kwargs.values())) @staticmethod def var(name: Optional[str]=None, dtype: Optional[DType]=None): return UOp(UOps.VAR, dtype=dtype, arg=name) @staticmethod @@ -103,7 +103,7 @@ class UPat: @staticmethod def compile(u: UOp, name:Optional[str]=None) -> UPat: if u.op is UOps.VAR: return UPat(name=name or u.arg, dtype=u.dtype) if len(u.src) == 0 else UPat.compile(u.src[0], name or u.arg) - return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(vin) for vin in u.src]) if u.src != () else None, name, u.dtype) + return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(src) for src in u.src]) if u.src != () else None, name, u.dtype) T = TypeVar("T") def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool: @@ -458,23 +458,23 @@ class UOpGraph: def type_verify(self): for u in self.uops: - uop, arg, vin, dtype = u.op, u.arg, u.src, u.dtype + uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype if uop in {UOps.CONST, UOps.DEFINE_ACC}: if uop is UOps.DEFINE_ACC: arg = arg[0] assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}" if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg if uop is UOps.ALU: if arg in UnaryOps: - assert dtype == vin[0].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=}" + assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}" elif arg in (BinaryOps.CMPLT, BinaryOps.CMPNE): assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}" - assert vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}" + assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}" elif arg is BinaryOps.IDIV: - assert dtypes.is_int(vin[0].dtype) and dtypes.is_int(vin[1].dtype), \ - f"input dtype mismatch {dtypes.int} != {vin[0].dtype=} != {vin[1].dtype=}" + assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), \ + f"input dtype mismatch {dtypes.int} != {src[0].dtype=} != {src[1].dtype=}" assert dtypes.is_int(dtype), f"{arg} output dtype mismatch {dtype=} != {dtypes.int}" elif arg in BinaryOps: - assert dtype == vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}" + assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}" elif arg == TernaryOps.WHERE: - assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}" - assert dtype == vin[1].dtype == vin[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {vin[1].dtype=} != {vin[2].dtype=}" + assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}" + assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}" diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index c73f5bc720..437fbef509 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -131,37 +131,37 @@ class PTXRenderer(Renderer): return ret for u in uops: - uop,dtype,vin,args = u.op,u.dtype,u.src,u.arg + uop,dtype,src,args = u.op,u.dtype,u.src,u.arg if uop is UOps.IF: - assert vin[0].dtype is not None - kk(*self.render_bra(f"IF_{r[vin[0]][1:]}_{cast(List, uops._uops).index(u)}", _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True))) + assert src[0].dtype is not None + kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{cast(List, uops._uops).index(u)}", _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True))) elif uop is UOps.BARRIER and self.barrier: kk(self.barrier) elif uop is UOps.ENDRANGE: - kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]), - self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].src[1]], dtypes.int, self.types[dtypes.int])) - kk(*self.render_bra(f"LOOP_{r[vin[0]][1:]}", pred)) + kk(self.asm_for_op[BinaryOps.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]), + self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int])) + kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred)) elif uop is UOps.ENDIF: - kk(f"IF_{r[vin[0].src[0]][1:]}_{cast(List, uops._uops).index(vin[0])}:") + kk(f"IF_{r[src[0].src[0]][1:]}_{cast(List, uops._uops).index(src[0])}:") elif uop is UOps.STORE: - assert vin[0].dtype is not None and vin[2].dtype is not None - assert vin[0].dtype == dtypes.int64, "store isn't int64" - assert vin[1].op is UOps.CONST, f"store isn't const {u}" - mem_type = '.shared' if vin[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global' - if vin[2].dtype.count > 1: - kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \ - f"st{mem_type}.v{vin[2].dtype.count}.{self.mem_types[vin[2].dtype.scalar()]} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};") + assert src[0].dtype is not None and src[2].dtype is not None + assert src[0].dtype == dtypes.int64, "store isn't int64" + assert src[1].op is UOps.CONST, f"store isn't const {u}" + mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global' + if src[2].dtype.count > 1: + kk((f"@{r[src[3]]} " if len(src)>3 else "") + \ + f"st{mem_type}.v{src[2].dtype.count}.{self.mem_types[src[2].dtype.scalar()]} [{r[src[0]]}+{src[1].arg}], {{{', '.join(r[src[2]])}}};") else: - kk(*self.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=mem_type, offset=vin[1].arg)) + kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype, gate=r[src[3]] if len(src)>3 else None, ss=mem_type, offset=src[1].arg)) else: assert dtype is not None, f"None dtype for uop {uop}" - if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[vin[0]], "LOOP_"+loop[1:])) + if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:])) elif uop is UOps.ALU: - assert vin[0].dtype is not None + assert src[0].dtype is not None if args is BinaryOps.CMPLT or args is BinaryOps.CMPNE: # pass in the other dtype here - kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], vin[0].dtype, self.types[vin[0].dtype])) + kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], src[0].dtype, self.types[src[0].dtype])) else: - kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], dtype, self.types[dtype])) + kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], dtype, self.types[dtype])) elif uop is UOps.DEFINE_ACC: if dtype.count > 1: r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)] @@ -175,30 +175,30 @@ class PTXRenderer(Renderer): elif uop is UOps.CONST: if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)] else: r[u] = const(args, dtype, mov=True) - elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg] + elif uop is UOps.GEP: r[u] = r[src[0]][u.arg] elif uop is UOps.LOAD: - assert vin[0].dtype == dtypes.int64, "load isn't int64" - assert vin[1].op is UOps.CONST, f"load isn't const {u}" - mem_type = '.shared' if vin[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global' + assert src[0].dtype == dtypes.int64, "load isn't int64" + assert src[1].op is UOps.CONST, f"load isn't const {u}" + mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global' if dtype.count > 1: r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)] - if(len(vin)>3): + if(len(src)>3): for v in r[u]: kk(f"mov.{self.mem_types[dtype.scalar()]} {v}, {render_val(0, dtype.scalar())};") - kk((f"@{r[vin[2]]}"if len(vin) > 3 else "") - + f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];") + kk((f"@{r[src[2]]}"if len(src) > 3 else "") + + f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[src[0]]}+{src[1].arg}];") else: - kk(*self.render_load(r[vin[0]], ssa('val', u), dtype, gate=r[vin[2]] if len(vin) > 3 else None, - alt=r[vin[3]] if len(vin) > 3 else None, ss=mem_type, offset=vin[1].arg)) + kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if len(src) > 3 else None, + alt=r[src[3]] if len(src) > 3 else None, ss=mem_type, offset=src[1].arg)) elif uop is UOps.PHI: if dtype.count > 1: - for x0, x1 in zip(r[vin[0]], r[vin[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};") + for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};") else: - kk(f"mov.b{self.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};") - r[u] = r[vin[0]] + kk(f"mov.b{self.types[dtype][1:]} {r[src[0]]}, {r[src[1]]};") + r[u] = r[src[0]] elif uop in {UOps.CAST, UOps.BITCAST}: - assert vin[0].dtype is not None - if dtype.count>1: r[u] = [r[x] for x in vin] # type: ignore - else: _cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u) + assert src[0].dtype is not None + if dtype.count>1: r[u] = [r[x] for x in src] # type: ignore + else: _cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u) elif uop is UOps.DEFINE_LOCAL: # TODO: we should sum these, and fetch 0xC000 from somewhere assert args[1]*dtype.itemsize <= 0xC000, "too large local" @@ -214,13 +214,13 @@ class PTXRenderer(Renderer): kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param")) elif uop is UOps.WMMA: wmma = [] - for vv in vin[:2]: + for vv in src[:2]: for i in range(0, len(r[vv]), 2): wmma.append(ssa("wmma", dtype="b32")) kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};') r[u] = [ssa("wmma", dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)] kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\ - {{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[vin[2]])}}};') + {{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[src[2]])}}};') else: raise NotImplementedError(f"no code for {uop}") return self.render_kernel(kernel, name, bufs, c.items()) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index a9733f8731..738e2a843c 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -107,28 +107,28 @@ class CStyleLanguage(Renderer): child_count = Counter(v for ru in uops for v in ru.src) for u in uops: - uop,dtype,vin,args = u.op,u.dtype,u.src,u.arg + uop,dtype,src,args = u.op,u.dtype,u.src,u.arg # these four uops don't have output dtypes if uop is UOps.IF: - kk(f"if ({r[vin[0]]}) {{") + kk(f"if ({r[src[0]]}) {{") depth += 1 elif uop is UOps.BARRIER: kk(self.barrier) elif uop in {UOps.ENDRANGE, UOps.ENDIF}: depth -= 1 kk("}") elif uop is UOps.STORE: - assert vin[0].dtype is not None and vin[2].dtype is not None - rendered_store = self.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].op is UOps.DEFINE_LOCAL) - kk(f"if ({r[vin[3]]}) {{ {rendered_store} }}" if len(vin) > 3 else rendered_store) + assert src[0].dtype is not None and src[2].dtype is not None + rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL) + kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 else rendered_store) else: assert dtype is not None, f"None dtype for uop {uop}" if uop is UOps.RANGE: - kk(f"for (int {(expr := ssa('ridx',u))} = {r[vin[0]]}; {expr} < {r[vin[1]]}; {expr}++) {{") + kk(f"for (int {(expr := ssa('ridx',u))} = {r[src[0]]}; {expr} < {r[src[1]]}; {expr}++) {{") depth += 1 elif uop is UOps.ALU: # remove parens if ALU types are the same. TODO: can do more here - if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in vin] - else: operands = [r[v] for v in vin] + if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in src] + else: operands = [r[v] for v in src] val = self.code_for_op[args](*operands, dtype) assert child_count[u] != 0, f"childless ALU op found {u}" # TODO: fix index rendering issue. fix clang nested max macro issue @@ -138,21 +138,21 @@ class CStyleLanguage(Renderer): kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */") r[u] = args[1] elif uop is UOps.LOAD: - val = self.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].op is UOps.DEFINE_LOCAL) + val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL) # NOTE: this relies on the load not happening if it's in the unselected branch - if len(vin) > 3: val = self.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype) + if len(src) > 3: val = self.code_for_op[TernaryOps.WHERE](r[src[2]], val, r[src[3]], dtype) kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};") elif uop is UOps.PHI: - kk(f"{r[vin[0]]} = {r[vin[1]]};") - r[u] = r[vin[0]] + kk(f"{r[src[0]]} = {r[src[1]]};") + r[u] = r[src[0]] elif uop in {UOps.CAST, UOps.BITCAST}: if uop is UOps.BITCAST: - assert len(vin) == 1 + assert len(src) == 1 precast = ssa('precast') - kk(f"{self.render_dtype(cast(DType, vin[0].dtype))} {precast} = {r[vin[0]]};") + kk(f"{self.render_dtype(cast(DType, src[0].dtype))} {precast} = {r[src[0]]};") val = self.render_cast([precast], dtype, bitcast=True) else: - val = self.render_cast([r[x] for x in vin], dtype, bitcast=False) + val = self.render_cast([r[x] for x in src], dtype, bitcast=False) if child_count[u] <= 1: r[u] = val else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};") elif uop is UOps.DEFINE_LOCAL: @@ -164,13 +164,13 @@ class CStyleLanguage(Renderer): elif uop is UOps.DEFINE_GLOBAL: bufs.append((nm:=f"data{args[0]}", (dtype,args[1]))) r[u] = nm - elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});") + elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});") elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(args[0], dtype)};") elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})" elif uop is UOps.GEP: - assert vin[0].dtype is not None - from_ssa = vin[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC} - r[u] = (r[vin[0]] if from_ssa else f"{(r[vin[0]])}") + (f"[{args}]" if vin[0].dtype.count > 4 else f".{'xyzw'[args]}") + assert src[0].dtype is not None + from_ssa = src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC} + r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + (f"[{args}]" if src[0].dtype.count > 4 else f".{'xyzw'[args]}") else: raise RuntimeError(f"failed to render {uop}") return self.render_kernel(name, kernel, bufs, uops) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index eec5706242..5a22e93631 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -101,21 +101,21 @@ class LLVMRenderer(Renderer): if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32)) for u in uops: - uop,dtype,vin,args = u.op,u.dtype,u.src,u.arg + uop,dtype,src,args = u.op,u.dtype,u.src,u.arg if uop is UOps.STORE: - element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype) - if len(vin) > 3: - with bb[-1].if_then(lvars[vin[3]]): - bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True)) + element = cast(bb, lvars[src[2]], src[2].dtype, src[0].dtype) + if len(src) > 3: + with bb[-1].if_then(lvars[src[3]]): + bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True)) else: - bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True)) + bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True)) elif uop is UOps.ENDRANGE: loop_entry_bb, phis = loop_blocks.pop() - idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1)) - lvars[vin[0]].add_incoming(idx_p1, bb[-1].block) + idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1)) + lvars[src[0]].add_incoming(idx_p1, bb[-1].block) for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block) bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}"))) - bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].src[1]]), loop_entry_bb, bb[-1].block) + bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block) else: assert dtype is not None, f"None dtype for uop {uop}" if uop is UOps.RANGE: @@ -130,28 +130,28 @@ class LLVMRenderer(Renderer): phis.append((rp, lvars[rp])) lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}") - lvars[u].add_incoming(lvars[vin[0]], bb[-2].block) + lvars[u].add_incoming(lvars[src[0]], bb[-2].block) loop_blocks.append((bb[-1].block, phis)) elif uop is UOps.DEFINE_ACC: lvars[u] = const(args[0], dtype) reduce_phis.append(u) elif uop is UOps.LOAD: - if len(vin) > 2: - aug_idx = bb[-1].select(lvars[vin[2]], lvars[vin[1]], ir.Constant(ir.IntType(32), 0)) - val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True)) - val = bb[-1].select(lvars[vin[2]], val, lvars[vin[3]]) + if len(src) > 2: + aug_idx = bb[-1].select(lvars[src[2]], lvars[src[1]], ir.Constant(ir.IntType(32), 0)) + val = bb[-1].load(bb[-1].gep(lvars[src[0]], [aug_idx], inbounds=True)) + val = bb[-1].select(lvars[src[2]], val, lvars[src[3]]) else: - val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True)) + val = bb[-1].load(bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True)) lvars[u] = val elif uop is UOps.PHI: - lvars[u] = lvars[vin[1]] + lvars[u] = lvars[src[1]] # PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC - backward = vin[0] + backward = src[0] while backward.op is UOps.PHI: backward = backward.src[0] lvars[backward] = lvars[u] elif uop is UOps.ALU: - lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPNE) else vin[0].dtype) - elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=uop is UOps.BITCAST) + lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in src], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPNE) else src[0].dtype) + elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is UOps.BITCAST) elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]] elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr] elif uop is UOps.CONST: lvars[u] = const(args, dtype)