From c0f447d6f76cee2ff3a32fc008e4caa363e0dca9 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 10 Nov 2023 08:17:10 -0800 Subject: [PATCH] Inline barrier (#2255) * put barrier inline for locals * fix pre-commit on m3 * gate if through barrier --- .pre-commit-config.yaml | 2 +- tinygrad/codegen/linearizer.py | 29 +++++++++++++++-------------- tinygrad/graph.py | 3 ++- tinygrad/renderer/cstyle.py | 2 +- tinygrad/runtime/ops_metal.py | 2 +- 5 files changed, 20 insertions(+), 18 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 77d40565a7..2cc71e2905 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: pass_filenames: false - id: tests name: subset of (CPU) tests - entry: env CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py + entry: env PYTHONPATH="." CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py language: system always_run: true pass_filenames: false diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index f25058bf2c..110f338f2c 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -62,7 +62,7 @@ class Linearizer(Kernel): SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)), AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) } - def global_load(self, i:int, idxs:Sequence[Node], acc=None) -> List[UOp]: + def global_load(self, i:int, idxs:Sequence[Node], acc=None, barrier:Optional[UOp]=None) -> List[UOp]: buf = self.bufs[i] const = buf.val if isinstance(buf, ConstBuffer) else acc @@ -110,13 +110,13 @@ class Linearizer(Kernel): if valid.min == 0: valid_rendered = valid.render(self.render_ops, self) - self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype))) + self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype)) + ((barrier,) if barrier else ())) else: - self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx)) + self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx) + ((barrier,) if barrier else ())) ret.append(self.uop(UOps.GEP, dtypes.float32, (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key]) return ret - def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> None: + def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]: buf = self.bufs[i] buf_uop = self.buf_uops[i] assert buf_uop is not None, f"buffer {i} wasn't UOped" @@ -141,6 +141,7 @@ class Linearizer(Kernel): store_offset_new[k] = self.uop(UOps.CAST, dtypes._float4 if amt == 4 else dtypes._float2, tuple(out_tokens)) store_offset = store_offset_new + stores = [] for idx, var in store_offset.items(): idx, valid = self.sts[i].expr_idxs(idx) if isinstance(buf.dtype, ImageDType): @@ -148,7 +149,8 @@ class Linearizer(Kernel): rendered_idx = self.uop(UOps.CAST, dtypes._int2, tuple(x.render(self.render_ops, self) for x in idx)) else: rendered_idx = idx.render(self.render_ops, self) - self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var)) + stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var))) + return stores kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int) def linearize(self): @@ -230,7 +232,6 @@ class Linearizer(Kernel): loaded_buffers = {} acc = [] self.load_cache: Dict[str, UOp] = {} - if_gate: Optional[UOp] = None # reduce op fake_reduce_idxs: List[Variable] = [] @@ -321,13 +322,13 @@ class Linearizer(Kernel): # end the local loop, do the local reduce if self.group_for_reduce: fake_global_idxs = [x*0 for x in global_idxs] - self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators - self.uop(UOps.BARRIER, None, (), cachable=False) + stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators + barrier = self.uop(UOps.BARRIER, None, tuple(stores), cachable=False) if self.opts.has_local: fake_idxs = [Variable.num(0)]*len(self.sts[-1].shape) fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:] if_cond: UOp = (self.sts[-1].expr_idxs(fake_idxs)[0]<1).render(self.render_ops, self) - if_gate = self.uop(UOps.IF, None, (if_cond,), cachable=False) + barrier = self.uop(UOps.IF, None, (if_cond, barrier), cachable=False) # create new late reduce local loops and replace local_idxs that have been used end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))] @@ -352,7 +353,7 @@ class Linearizer(Kernel): loop_ctx = render_loop(end_local_idxs) # load localbufs - loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) + loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier) # there's no AST here (and there's no shape for the reduce LazyOp) self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # type: ignore @@ -369,12 +370,9 @@ class Linearizer(Kernel): # store self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val) - # end the if statement if we used it - if if_gate: self.uop(UOps.END, None, (if_gate,)) - # (recursively) remove childless uops # NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that - UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.END, UOps.BARRIER, UOps.DEFINE_GLOBAL} + UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_GLOBAL} while 1: has_child: Set[UOp] = set() for ru in self.uops: @@ -396,6 +394,7 @@ class Linearizer(Kernel): return sorted(list(deps), key=lambda x: x.num) # add END of loops after the last thing that (recursively) depends on them + # and END any if statements for u in self.uops: if u.uop == UOps.LOOP: last_phi = self.uops.index(get_recursive_deps(u)[-1]) @@ -403,6 +402,8 @@ class Linearizer(Kernel): self.uops = self.uops[:last_phi+1] self.uop(UOps.END, None, (u,), cachable=False) self.uops += at_end + elif u.uop == UOps.IF: + self.uop(UOps.END, None, (u,), cachable=False) # maybe graph the uops if DEBUG >= 5: diff --git a/tinygrad/graph.py b/tinygrad/graph.py index 145a5adc2b..0ca9d2ec37 100644 --- a/tinygrad/graph.py +++ b/tinygrad/graph.py @@ -110,9 +110,10 @@ def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i def graph_uops(uops): colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0", UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0", - UOps.LOOP: "#c8a0e0", UOps.PHI: "#e0ffc0"} + UOps.LOOP: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"} G = nx.DiGraph() for u in uops: + if u.uop == UOps.END: continue G.add_node(u.num, label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff")) for v in u.vin: G.add_edge(v.num, u.num) GRAPHPATH = "/tmp/uops" diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 727e005449..4a9ef73f44 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -183,7 +183,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu elif uop == UOps.LOAD: assert dtype is not None val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL) - if len(vin) > 2: val = lang.render_conditional(r[vin[2]], val, r[vin[3]]) + if len(vin) > 3: val = lang.render_conditional(r[vin[2]], val, r[vin[3]]) kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'val')} = {val};") elif uop == UOps.PHI: kk(f"{r[vin[0]]} = {r[vin[1]]};") diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 584576b6cb..03f03b0ad1 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -56,7 +56,7 @@ class MetalProgram: data = libdispatch.dispatch_data_create(lib, len(lib), None, None) self.library = unwrap(METAL.device.newLibraryWithData_error_(data, None)) self.fxn = self.library.newFunctionWithName_(name) - if DEBUG >= 5: + if DEBUG >= 6: with tempfile.NamedTemporaryFile(delete=True) as shader: shader.write(lib) shader.flush()