From d3cfb6c2e35a9412e83b279ce38de7d92e9ad1bc Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 2 Jul 2024 13:48:47 +0300 Subject: [PATCH] refactor UOps.LOAD barrier [run_process_replay] (#5258) --- tinygrad/codegen/linearizer.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index c92c302ea6..cd1cc2255b 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -116,7 +116,7 @@ class Linearizer(Kernel): # NOTE: once images are loaded, we uop them as their base float def get_base_dtype(self, dt:DType) -> DType: return dt.base if isinstance(dt, ImageDType) else dt - def global_load(self, i:int, idxs:List[Node], acc:Optional[LazyOp]=None, barrier:Optional[UOp]=None, loop_ctx:Tuple[UOp, ...]=()) -> List[UOp]: + def global_load(self, i:int, idxs:List[Node], acc:Optional[LazyOp]=None, barrier:Tuple[UOp, ...]=(), loop_ctx:Tuple[UOp, ...]=()) -> List[UOp]: buf = self.bufs[i] localtype = self.get_base_dtype(buf.dtype if acc is None else acc.dtype) const = buf.val if isinstance(buf, ConstBuffer) else None @@ -158,8 +158,7 @@ class Linearizer(Kernel): image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid) rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx)) valid_tuple = (valid_uop, UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple() - self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4), - (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ())) + self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4), (buf_uop, rendered_idx) + valid_tuple + barrier) if localtype == localtype.scalar(): idx_small = idx%4 res = idx_small.render(render_ops, self.loop_uops) @@ -174,7 +173,7 @@ class Linearizer(Kernel): assert buf_uop is not None, f"buffer {i} wasn't UOped" rendered_idx = idx.render(render_ops, self.loop_uops) valid_tuple = (valid_uop, UOp.const(localtype, invalid_value)) if valid.min == 0 else tuple() - self.load_cache[key] = UOp(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ())) + self.load_cache[key] = UOp(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + barrier) ret.append(UOp(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key]) return ret @@ -352,7 +351,7 @@ class Linearizer(Kernel): accs[reduceop] = self.global_load(0, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx) # load localbufs - loaded_buffers[self.bufs[out_buf]] = self.global_load(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier) + loaded_buffers[self.bufs[out_buf]] = self.global_load(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=(barrier,)) # there's no AST here (and there's no shape for the reduce LazyOp) self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[out_buf]),)),\ @@ -369,7 +368,7 @@ class Linearizer(Kernel): fake_local_idxs = local_idxs[:self.local_dims] + [x*0 for x in local_idxs[self.local_dims:]] stores = self.global_store(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop]) barrier = UOp(UOps.BARRIER, None, tuple(stores)) - accs[reduceop] = self.global_load(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier) + accs[reduceop] = self.global_load(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, barrier=(barrier,)) return local_idxs[:self.local_dims] + [NumNode(0) for _ in range(self.group_for_reduces)], upcast_idxs kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)