UPat for has_valid in load/store [run_process_replay] (#5052)

* fold gated load/store [run_process_replay]

* handle temp loads

* direct store
This commit is contained in:
qazal
2024-06-19 12:20:22 +03:00
committed by GitHub
parent 996788358d
commit 71194df1da
2 changed files with 8 additions and 4 deletions

View File

@@ -156,7 +156,7 @@ class Linearizer(Kernel):
assert buf_uop is not None, f"buffer {i} wasn't UOped"
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx))
valid_tuple = (valid.render(self.render_ops, self), UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple()
valid_tuple = (valid.render(self.render_ops, self), UOp.const(buf.dtype.base.vec(4), invalid_value))
self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4),
(buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
if localtype == localtype.scalar():
@@ -172,7 +172,7 @@ class Linearizer(Kernel):
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
rendered_idx = idx.render(self.render_ops, self)
valid_tuple = (valid.render(self.render_ops, self), UOp.const(localtype, invalid_value)) if valid.min == 0 else tuple()
valid_tuple = (valid.render(self.render_ops, self), UOp.const(localtype, invalid_value))
self.load_cache[key] = UOp(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
ret.append(UOp(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
return ret
@@ -209,8 +209,7 @@ class Linearizer(Kernel):
tuple(x.render(self.render_ops, self) for x in image_idx))
else:
rendered_idx = idx.render(self.render_ops, self)
if valid.min == 1: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var)))
else: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
return stores
# render loop

View File

@@ -258,6 +258,11 @@ constant_folder = PatternMatcher([
(UOp.lt(-UOp.var('x'), UOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(UOp.const(c.dtype, -c.arg), x)),
# cast NOOP (NOTE: it's str to deal with PtrDType)
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
# fold gated LOAD/STORE
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.int, 1), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.int, 1), UOp.cvar("var"), UOp.var("barrier")),
lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)),
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("val"), UOp.const(dtypes.int, 1)), UOp.store),
])
# *** uop graph ***