mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
refactor gated load rendering [run_process_replay] (#5259)
* refactor gated load rendering [run_process_replay] * hotfix: extra line * remove llvm diff
This commit is contained in:
@@ -444,7 +444,7 @@ class UOpGraph:
|
||||
arg = src[0].arg
|
||||
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
||||
if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
|
||||
if uop is UOps.LOAD and len(src) > 2 and src[2].op not in {UOps.IF, UOps.BARRIER}: assert src[2].dtype == dtypes.bool
|
||||
if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype
|
||||
if uop is UOps.STORE and len(src) == 4: assert src[3].dtype == dtypes.bool
|
||||
if uop is UOps.ALU:
|
||||
if arg in UnaryOps:
|
||||
|
||||
@@ -181,15 +181,16 @@ class PTXRenderer(Renderer):
|
||||
assert src[0].dtype == dtypes.int64, "load isn't int64"
|
||||
assert src[1].op is UOps.CONST, f"load isn't const {u}"
|
||||
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
|
||||
has_gate = len(src) > 3 and src[2].op is UOps.ALU
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
if(len(src)>3):
|
||||
if has_gate:
|
||||
for v in r[u]: kk(f"mov.{self.mem_types[dtype.scalar()]} {v}, {render_val(0, dtype.scalar())};")
|
||||
kk((f"@{r[src[2]]}"if len(src) > 3 else "")
|
||||
kk((f"@{r[src[2]]}"if has_gate else "")
|
||||
+ f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[src[0]]}+{src[1].arg}];")
|
||||
else:
|
||||
kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if len(src) > 3 else None,
|
||||
alt=r[src[3]] if len(src) > 3 else None, ss=mem_type, offset=src[1].arg))
|
||||
kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if has_gate else None,
|
||||
alt=r[src[3]] if has_gate else None, ss=mem_type, offset=src[1].arg))
|
||||
elif uop is UOps.PHI:
|
||||
if dtype.count > 1:
|
||||
for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
|
||||
|
||||
@@ -141,7 +141,7 @@ class CStyleLanguage(Renderer):
|
||||
elif uop is UOps.LOAD:
|
||||
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
|
||||
# NOTE: this relies on the load not happening if it's in the unselected branch
|
||||
if len(src) > 3: val = self.code_for_op[TernaryOps.WHERE](r[src[2]], val, r[src[3]], dtype)
|
||||
if len(src) > 3 and src[2].op is UOps.ALU: val = self.code_for_op[TernaryOps.WHERE](r[src[2]], val, r[src[3]], dtype)
|
||||
kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
|
||||
elif uop is UOps.PHI:
|
||||
kk(f"{r[src[0]]} = {r[src[1]]};")
|
||||
|
||||
Reference in New Issue
Block a user