Inline barrier (#2255)

* put barrier inline for locals

* fix pre-commit on m3

* gate if through barrier
This commit is contained in:
George Hotz
2023-11-10 08:17:10 -08:00
committed by GitHub
parent 75f6e9ab54
commit c0f447d6f7
5 changed files with 20 additions and 18 deletions

View File

@@ -27,7 +27,7 @@ repos:
pass_filenames: false pass_filenames: false
- id: tests - id: tests
name: subset of (CPU) tests name: subset of (CPU) tests
entry: env CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py entry: env PYTHONPATH="." CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py
language: system language: system
always_run: true always_run: true
pass_filenames: false pass_filenames: false

View File

@@ -62,7 +62,7 @@ class Linearizer(Kernel):
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)), SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) } AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
def global_load(self, i:int, idxs:Sequence[Node], acc=None) -> List[UOp]: def global_load(self, i:int, idxs:Sequence[Node], acc=None, barrier:Optional[UOp]=None) -> List[UOp]:
buf = self.bufs[i] buf = self.bufs[i]
const = buf.val if isinstance(buf, ConstBuffer) else acc const = buf.val if isinstance(buf, ConstBuffer) else acc
@@ -110,13 +110,13 @@ class Linearizer(Kernel):
if valid.min == 0: if valid.min == 0:
valid_rendered = valid.render(self.render_ops, self) valid_rendered = valid.render(self.render_ops, self)
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype))) self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype)) + ((barrier,) if barrier else ()))
else: else:
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx)) self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx) + ((barrier,) if barrier else ()))
ret.append(self.uop(UOps.GEP, dtypes.float32, (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key]) ret.append(self.uop(UOps.GEP, dtypes.float32, (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
return ret return ret
def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> None: def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]:
buf = self.bufs[i] buf = self.bufs[i]
buf_uop = self.buf_uops[i] buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped" assert buf_uop is not None, f"buffer {i} wasn't UOped"
@@ -141,6 +141,7 @@ class Linearizer(Kernel):
store_offset_new[k] = self.uop(UOps.CAST, dtypes._float4 if amt == 4 else dtypes._float2, tuple(out_tokens)) store_offset_new[k] = self.uop(UOps.CAST, dtypes._float4 if amt == 4 else dtypes._float2, tuple(out_tokens))
store_offset = store_offset_new store_offset = store_offset_new
stores = []
for idx, var in store_offset.items(): for idx, var in store_offset.items():
idx, valid = self.sts[i].expr_idxs(idx) idx, valid = self.sts[i].expr_idxs(idx)
if isinstance(buf.dtype, ImageDType): if isinstance(buf.dtype, ImageDType):
@@ -148,7 +149,8 @@ class Linearizer(Kernel):
rendered_idx = self.uop(UOps.CAST, dtypes._int2, tuple(x.render(self.render_ops, self) for x in idx)) rendered_idx = self.uop(UOps.CAST, dtypes._int2, tuple(x.render(self.render_ops, self) for x in idx))
else: else:
rendered_idx = idx.render(self.render_ops, self) rendered_idx = idx.render(self.render_ops, self)
self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var)) stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var)))
return stores
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int) kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
def linearize(self): def linearize(self):
@@ -230,7 +232,6 @@ class Linearizer(Kernel):
loaded_buffers = {} loaded_buffers = {}
acc = [] acc = []
self.load_cache: Dict[str, UOp] = {} self.load_cache: Dict[str, UOp] = {}
if_gate: Optional[UOp] = None
# reduce op # reduce op
fake_reduce_idxs: List[Variable] = [] fake_reduce_idxs: List[Variable] = []
@@ -321,13 +322,13 @@ class Linearizer(Kernel):
# end the local loop, do the local reduce # end the local loop, do the local reduce
if self.group_for_reduce: if self.group_for_reduce:
fake_global_idxs = [x*0 for x in global_idxs] fake_global_idxs = [x*0 for x in global_idxs]
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
self.uop(UOps.BARRIER, None, (), cachable=False) barrier = self.uop(UOps.BARRIER, None, tuple(stores), cachable=False)
if self.opts.has_local: if self.opts.has_local:
fake_idxs = [Variable.num(0)]*len(self.sts[-1].shape) fake_idxs = [Variable.num(0)]*len(self.sts[-1].shape)
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:] fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
if_cond: UOp = (self.sts[-1].expr_idxs(fake_idxs)[0]<1).render(self.render_ops, self) if_cond: UOp = (self.sts[-1].expr_idxs(fake_idxs)[0]<1).render(self.render_ops, self)
if_gate = self.uop(UOps.IF, None, (if_cond,), cachable=False) barrier = self.uop(UOps.IF, None, (if_cond, barrier), cachable=False)
# create new late reduce local loops and replace local_idxs that have been used # create new late reduce local loops and replace local_idxs that have been used
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))] end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
@@ -352,7 +353,7 @@ class Linearizer(Kernel):
loop_ctx = render_loop(end_local_idxs) loop_ctx = render_loop(end_local_idxs)
# load localbufs # load localbufs
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
# there's no AST here (and there's no shape for the reduce LazyOp) # there's no AST here (and there's no shape for the reduce LazyOp)
self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # type: ignore self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # type: ignore
@@ -369,12 +370,9 @@ class Linearizer(Kernel):
# store # store
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val) self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
# end the if statement if we used it
if if_gate: self.uop(UOps.END, None, (if_gate,))
# (recursively) remove childless uops # (recursively) remove childless uops
# NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that # NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.END, UOps.BARRIER, UOps.DEFINE_GLOBAL} UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_GLOBAL}
while 1: while 1:
has_child: Set[UOp] = set() has_child: Set[UOp] = set()
for ru in self.uops: for ru in self.uops:
@@ -396,6 +394,7 @@ class Linearizer(Kernel):
return sorted(list(deps), key=lambda x: x.num) return sorted(list(deps), key=lambda x: x.num)
# add END of loops after the last thing that (recursively) depends on them # add END of loops after the last thing that (recursively) depends on them
# and END any if statements
for u in self.uops: for u in self.uops:
if u.uop == UOps.LOOP: if u.uop == UOps.LOOP:
last_phi = self.uops.index(get_recursive_deps(u)[-1]) last_phi = self.uops.index(get_recursive_deps(u)[-1])
@@ -403,6 +402,8 @@ class Linearizer(Kernel):
self.uops = self.uops[:last_phi+1] self.uops = self.uops[:last_phi+1]
self.uop(UOps.END, None, (u,), cachable=False) self.uop(UOps.END, None, (u,), cachable=False)
self.uops += at_end self.uops += at_end
elif u.uop == UOps.IF:
self.uop(UOps.END, None, (u,), cachable=False)
# maybe graph the uops # maybe graph the uops
if DEBUG >= 5: if DEBUG >= 5:

View File

@@ -110,9 +110,10 @@ def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i
def graph_uops(uops): def graph_uops(uops):
colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0", colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0", UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0",
UOps.LOOP: "#c8a0e0", UOps.PHI: "#e0ffc0"} UOps.LOOP: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}
G = nx.DiGraph() G = nx.DiGraph()
for u in uops: for u in uops:
if u.uop == UOps.END: continue
G.add_node(u.num, label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff")) G.add_node(u.num, label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff"))
for v in u.vin: G.add_edge(v.num, u.num) for v in u.vin: G.add_edge(v.num, u.num)
GRAPHPATH = "/tmp/uops" GRAPHPATH = "/tmp/uops"

View File

@@ -183,7 +183,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
elif uop == UOps.LOAD: elif uop == UOps.LOAD:
assert dtype is not None assert dtype is not None
val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL) val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL)
if len(vin) > 2: val = lang.render_conditional(r[vin[2]], val, r[vin[3]]) if len(vin) > 3: val = lang.render_conditional(r[vin[2]], val, r[vin[3]])
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'val')} = {val};") kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'val')} = {val};")
elif uop == UOps.PHI: elif uop == UOps.PHI:
kk(f"{r[vin[0]]} = {r[vin[1]]};") kk(f"{r[vin[0]]} = {r[vin[1]]};")

View File

@@ -56,7 +56,7 @@ class MetalProgram:
data = libdispatch.dispatch_data_create(lib, len(lib), None, None) data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
self.library = unwrap(METAL.device.newLibraryWithData_error_(data, None)) self.library = unwrap(METAL.device.newLibraryWithData_error_(data, None))
self.fxn = self.library.newFunctionWithName_(name) self.fxn = self.library.newFunctionWithName_(name)
if DEBUG >= 5: if DEBUG >= 6:
with tempfile.NamedTemporaryFile(delete=True) as shader: with tempfile.NamedTemporaryFile(delete=True) as shader:
shader.write(lib) shader.write(lib)
shader.flush() shader.flush()