diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 877e46a824..44b9c2a6ec 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -444,7 +444,7 @@ class UOpGraph: arg = src[0].arg assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}" if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg - if uop is UOps.LOAD and len(src) > 2 and src[2].op not in {UOps.IF, UOps.BARRIER}: assert src[2].dtype == dtypes.bool + if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype if uop is UOps.STORE and len(src) == 4: assert src[3].dtype == dtypes.bool if uop is UOps.ALU: if arg in UnaryOps: diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index d0155c00fb..ab7aadbb99 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -181,15 +181,16 @@ class PTXRenderer(Renderer): assert src[0].dtype == dtypes.int64, "load isn't int64" assert src[1].op is UOps.CONST, f"load isn't const {u}" mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global' + has_gate = len(src) > 3 and src[2].op is UOps.ALU if dtype.count > 1: r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)] - if(len(src)>3): + if has_gate: for v in r[u]: kk(f"mov.{self.mem_types[dtype.scalar()]} {v}, {render_val(0, dtype.scalar())};") - kk((f"@{r[src[2]]}"if len(src) > 3 else "") + kk((f"@{r[src[2]]}"if has_gate else "") + f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[src[0]]}+{src[1].arg}];") else: - kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if len(src) > 3 else None, - alt=r[src[3]] if len(src) > 3 else None, ss=mem_type, offset=src[1].arg)) + kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if has_gate else None, + alt=r[src[3]] if has_gate else None, ss=mem_type, offset=src[1].arg)) elif uop is UOps.PHI: if dtype.count > 1: for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};") diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 120a9a8aa7..6c5926b607 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -141,7 +141,7 @@ class CStyleLanguage(Renderer): elif uop is UOps.LOAD: val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL) # NOTE: this relies on the load not happening if it's in the unselected branch - if len(src) > 3: val = self.code_for_op[TernaryOps.WHERE](r[src[2]], val, r[src[3]], dtype) + if len(src) > 3 and src[2].op is UOps.ALU: val = self.code_for_op[TernaryOps.WHERE](r[src[2]], val, r[src[3]], dtype) kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};") elif uop is UOps.PHI: kk(f"{r[src[0]]} = {r[src[1]]};")