mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 14:28:09 -05:00
refactor UOps.LOAD barrier [run_process_replay] (#5258)
This commit is contained in:
@@ -116,7 +116,7 @@ class Linearizer(Kernel):
|
||||
# NOTE: once images are loaded, we uop them as their base float
|
||||
def get_base_dtype(self, dt:DType) -> DType: return dt.base if isinstance(dt, ImageDType) else dt
|
||||
|
||||
def global_load(self, i:int, idxs:List[Node], acc:Optional[LazyOp]=None, barrier:Optional[UOp]=None, loop_ctx:Tuple[UOp, ...]=()) -> List[UOp]:
|
||||
def global_load(self, i:int, idxs:List[Node], acc:Optional[LazyOp]=None, barrier:Tuple[UOp, ...]=(), loop_ctx:Tuple[UOp, ...]=()) -> List[UOp]:
|
||||
buf = self.bufs[i]
|
||||
localtype = self.get_base_dtype(buf.dtype if acc is None else acc.dtype)
|
||||
const = buf.val if isinstance(buf, ConstBuffer) else None
|
||||
@@ -158,8 +158,7 @@ class Linearizer(Kernel):
|
||||
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
||||
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
|
||||
valid_tuple = (valid_uop, UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple()
|
||||
self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4),
|
||||
(buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
||||
self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4), (buf_uop, rendered_idx) + valid_tuple + barrier)
|
||||
if localtype == localtype.scalar():
|
||||
idx_small = idx%4
|
||||
res = idx_small.render(render_ops, self.loop_uops)
|
||||
@@ -174,7 +173,7 @@ class Linearizer(Kernel):
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
rendered_idx = idx.render(render_ops, self.loop_uops)
|
||||
valid_tuple = (valid_uop, UOp.const(localtype, invalid_value)) if valid.min == 0 else tuple()
|
||||
self.load_cache[key] = UOp(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
||||
self.load_cache[key] = UOp(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + barrier)
|
||||
ret.append(UOp(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
|
||||
return ret
|
||||
|
||||
@@ -352,7 +351,7 @@ class Linearizer(Kernel):
|
||||
accs[reduceop] = self.global_load(0, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
|
||||
|
||||
# load localbufs
|
||||
loaded_buffers[self.bufs[out_buf]] = self.global_load(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
|
||||
loaded_buffers[self.bufs[out_buf]] = self.global_load(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=(barrier,))
|
||||
|
||||
# there's no AST here (and there's no shape for the reduce LazyOp)
|
||||
self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[out_buf]),)),\
|
||||
@@ -369,7 +368,7 @@ class Linearizer(Kernel):
|
||||
fake_local_idxs = local_idxs[:self.local_dims] + [x*0 for x in local_idxs[self.local_dims:]]
|
||||
stores = self.global_store(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop])
|
||||
barrier = UOp(UOps.BARRIER, None, tuple(stores))
|
||||
accs[reduceop] = self.global_load(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
|
||||
accs[reduceop] = self.global_load(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, barrier=(barrier,))
|
||||
return local_idxs[:self.local_dims] + [NumNode(0) for _ in range(self.group_for_reduces)], upcast_idxs
|
||||
|
||||
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
|
||||
|
||||
Reference in New Issue
Block a user