refactor gated load rendering [run_process_replay] (#5259)

* refactor gated load rendering [run_process_replay]

* hotfix: extra line

* remove llvm diff
This commit is contained in:
qazal
2024-07-02 15:13:10 +03:00
committed by GitHub
parent e050603b4b
commit 59bc837ad1
3 changed files with 7 additions and 6 deletions

View File

@@ -444,7 +444,7 @@ class UOpGraph:
arg = src[0].arg
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
if uop is UOps.LOAD and len(src) > 2 and src[2].op not in {UOps.IF, UOps.BARRIER}: assert src[2].dtype == dtypes.bool
if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype
if uop is UOps.STORE and len(src) == 4: assert src[3].dtype == dtypes.bool
if uop is UOps.ALU:
if arg in UnaryOps:

View File

@@ -181,15 +181,16 @@ class PTXRenderer(Renderer):
assert src[0].dtype == dtypes.int64, "load isn't int64"
assert src[1].op is UOps.CONST, f"load isn't const {u}"
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
has_gate = len(src) > 3 and src[2].op is UOps.ALU
if dtype.count > 1:
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
if(len(src)>3):
if has_gate:
for v in r[u]: kk(f"mov.{self.mem_types[dtype.scalar()]} {v}, {render_val(0, dtype.scalar())};")
kk((f"@{r[src[2]]}"if len(src) > 3 else "")
kk((f"@{r[src[2]]}"if has_gate else "")
+ f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[src[0]]}+{src[1].arg}];")
else:
kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if len(src) > 3 else None,
alt=r[src[3]] if len(src) > 3 else None, ss=mem_type, offset=src[1].arg))
kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if has_gate else None,
alt=r[src[3]] if has_gate else None, ss=mem_type, offset=src[1].arg))
elif uop is UOps.PHI:
if dtype.count > 1:
for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")

View File

@@ -141,7 +141,7 @@ class CStyleLanguage(Renderer):
elif uop is UOps.LOAD:
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
# NOTE: this relies on the load not happening if it's in the unselected branch
if len(src) > 3: val = self.code_for_op[TernaryOps.WHERE](r[src[2]], val, r[src[3]], dtype)
if len(src) > 3 and src[2].op is UOps.ALU: val = self.code_for_op[TernaryOps.WHERE](r[src[2]], val, r[src[3]], dtype)
kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
elif uop is UOps.PHI:
kk(f"{r[src[0]]} = {r[src[1]]};")