refactor UOps.LOAD barrier [run_process_replay] (#5258)

This commit is contained in:
qazal
2024-07-02 13:48:47 +03:00
committed by GitHub
parent a1044e6063
commit d3cfb6c2e3

View File

@@ -116,7 +116,7 @@ class Linearizer(Kernel):
# NOTE: once images are loaded, we uop them as their base float
def get_base_dtype(self, dt:DType) -> DType: return dt.base if isinstance(dt, ImageDType) else dt
def global_load(self, i:int, idxs:List[Node], acc:Optional[LazyOp]=None, barrier:Optional[UOp]=None, loop_ctx:Tuple[UOp, ...]=()) -> List[UOp]:
def global_load(self, i:int, idxs:List[Node], acc:Optional[LazyOp]=None, barrier:Tuple[UOp, ...]=(), loop_ctx:Tuple[UOp, ...]=()) -> List[UOp]:
buf = self.bufs[i]
localtype = self.get_base_dtype(buf.dtype if acc is None else acc.dtype)
const = buf.val if isinstance(buf, ConstBuffer) else None
@@ -158,8 +158,7 @@ class Linearizer(Kernel):
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
valid_tuple = (valid_uop, UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple()
self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4),
(buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4), (buf_uop, rendered_idx) + valid_tuple + barrier)
if localtype == localtype.scalar():
idx_small = idx%4
res = idx_small.render(render_ops, self.loop_uops)
@@ -174,7 +173,7 @@ class Linearizer(Kernel):
assert buf_uop is not None, f"buffer {i} wasn't UOped"
rendered_idx = idx.render(render_ops, self.loop_uops)
valid_tuple = (valid_uop, UOp.const(localtype, invalid_value)) if valid.min == 0 else tuple()
self.load_cache[key] = UOp(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
self.load_cache[key] = UOp(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + barrier)
ret.append(UOp(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
return ret
@@ -352,7 +351,7 @@ class Linearizer(Kernel):
accs[reduceop] = self.global_load(0, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
# load localbufs
loaded_buffers[self.bufs[out_buf]] = self.global_load(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
loaded_buffers[self.bufs[out_buf]] = self.global_load(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=(barrier,))
# there's no AST here (and there's no shape for the reduce LazyOp)
self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[out_buf]),)),\
@@ -369,7 +368,7 @@ class Linearizer(Kernel):
fake_local_idxs = local_idxs[:self.local_dims] + [x*0 for x in local_idxs[self.local_dims:]]
stores = self.global_store(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop])
barrier = UOp(UOps.BARRIER, None, tuple(stores))
accs[reduceop] = self.global_load(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
accs[reduceop] = self.global_load(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, barrier=(barrier,))
return local_idxs[:self.local_dims] + [NumNode(0) for _ in range(self.group_for_reduces)], upcast_idxs
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)