mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
UPat for has_valid in load/store [run_process_replay] (#5052)
* fold gated load/store [run_process_replay] * handle temp loads * direct store
This commit is contained in:
@@ -156,7 +156,7 @@ class Linearizer(Kernel):
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
||||
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx))
|
||||
valid_tuple = (valid.render(self.render_ops, self), UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple()
|
||||
valid_tuple = (valid.render(self.render_ops, self), UOp.const(buf.dtype.base.vec(4), invalid_value))
|
||||
self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4),
|
||||
(buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
||||
if localtype == localtype.scalar():
|
||||
@@ -172,7 +172,7 @@ class Linearizer(Kernel):
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
valid_tuple = (valid.render(self.render_ops, self), UOp.const(localtype, invalid_value)) if valid.min == 0 else tuple()
|
||||
valid_tuple = (valid.render(self.render_ops, self), UOp.const(localtype, invalid_value))
|
||||
self.load_cache[key] = UOp(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
||||
ret.append(UOp(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
|
||||
return ret
|
||||
@@ -209,8 +209,7 @@ class Linearizer(Kernel):
|
||||
tuple(x.render(self.render_ops, self) for x in image_idx))
|
||||
else:
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
if valid.min == 1: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var)))
|
||||
else: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
|
||||
stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
|
||||
return stores
|
||||
|
||||
# render loop
|
||||
|
||||
@@ -258,6 +258,11 @@ constant_folder = PatternMatcher([
|
||||
(UOp.lt(-UOp.var('x'), UOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(UOp.const(c.dtype, -c.arg), x)),
|
||||
# cast NOOP (NOTE: it's str to deal with PtrDType)
|
||||
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
||||
# fold gated LOAD/STORE
|
||||
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.int, 1), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
|
||||
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.int, 1), UOp.cvar("var"), UOp.var("barrier")),
|
||||
lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)),
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("val"), UOp.const(dtypes.int, 1)), UOp.store),
|
||||
])
|
||||
|
||||
# *** uop graph ***
|
||||
|
||||
Reference in New Issue
Block a user