|
|
|
@@ -62,7 +62,7 @@ class Linearizer(Kernel):
|
|
|
|
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
|
|
|
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
|
|
|
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
|
|
|
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
|
|
|
|
|
|
|
|
|
|
|
def global_load(self, i:int, idxs:Sequence[Node], acc=None) -> List[UOp]:
|
|
|
|
def global_load(self, i:int, idxs:Sequence[Node], acc=None, barrier:Optional[UOp]=None) -> List[UOp]:
|
|
|
|
buf = self.bufs[i]
|
|
|
|
buf = self.bufs[i]
|
|
|
|
const = buf.val if isinstance(buf, ConstBuffer) else acc
|
|
|
|
const = buf.val if isinstance(buf, ConstBuffer) else acc
|
|
|
|
|
|
|
|
|
|
|
|
@@ -110,13 +110,13 @@ class Linearizer(Kernel):
|
|
|
|
|
|
|
|
|
|
|
|
if valid.min == 0:
|
|
|
|
if valid.min == 0:
|
|
|
|
valid_rendered = valid.render(self.render_ops, self)
|
|
|
|
valid_rendered = valid.render(self.render_ops, self)
|
|
|
|
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype)))
|
|
|
|
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype)) + ((barrier,) if barrier else ()))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx))
|
|
|
|
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx) + ((barrier,) if barrier else ()))
|
|
|
|
ret.append(self.uop(UOps.GEP, dtypes.float32, (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
|
|
|
|
ret.append(self.uop(UOps.GEP, dtypes.float32, (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
|
|
|
|
return ret
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
|
|
def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> None:
|
|
|
|
def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]:
|
|
|
|
buf = self.bufs[i]
|
|
|
|
buf = self.bufs[i]
|
|
|
|
buf_uop = self.buf_uops[i]
|
|
|
|
buf_uop = self.buf_uops[i]
|
|
|
|
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
|
|
|
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
|
|
|
@@ -141,6 +141,7 @@ class Linearizer(Kernel):
|
|
|
|
store_offset_new[k] = self.uop(UOps.CAST, dtypes._float4 if amt == 4 else dtypes._float2, tuple(out_tokens))
|
|
|
|
store_offset_new[k] = self.uop(UOps.CAST, dtypes._float4 if amt == 4 else dtypes._float2, tuple(out_tokens))
|
|
|
|
store_offset = store_offset_new
|
|
|
|
store_offset = store_offset_new
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stores = []
|
|
|
|
for idx, var in store_offset.items():
|
|
|
|
for idx, var in store_offset.items():
|
|
|
|
idx, valid = self.sts[i].expr_idxs(idx)
|
|
|
|
idx, valid = self.sts[i].expr_idxs(idx)
|
|
|
|
if isinstance(buf.dtype, ImageDType):
|
|
|
|
if isinstance(buf.dtype, ImageDType):
|
|
|
|
@@ -148,7 +149,8 @@ class Linearizer(Kernel):
|
|
|
|
rendered_idx = self.uop(UOps.CAST, dtypes._int2, tuple(x.render(self.render_ops, self) for x in idx))
|
|
|
|
rendered_idx = self.uop(UOps.CAST, dtypes._int2, tuple(x.render(self.render_ops, self) for x in idx))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
rendered_idx = idx.render(self.render_ops, self)
|
|
|
|
rendered_idx = idx.render(self.render_ops, self)
|
|
|
|
self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var))
|
|
|
|
stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var)))
|
|
|
|
|
|
|
|
return stores
|
|
|
|
|
|
|
|
|
|
|
|
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
|
|
|
|
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
|
|
|
|
def linearize(self):
|
|
|
|
def linearize(self):
|
|
|
|
@@ -230,7 +232,6 @@ class Linearizer(Kernel):
|
|
|
|
loaded_buffers = {}
|
|
|
|
loaded_buffers = {}
|
|
|
|
acc = []
|
|
|
|
acc = []
|
|
|
|
self.load_cache: Dict[str, UOp] = {}
|
|
|
|
self.load_cache: Dict[str, UOp] = {}
|
|
|
|
if_gate: Optional[UOp] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# reduce op
|
|
|
|
# reduce op
|
|
|
|
fake_reduce_idxs: List[Variable] = []
|
|
|
|
fake_reduce_idxs: List[Variable] = []
|
|
|
|
@@ -321,13 +322,13 @@ class Linearizer(Kernel):
|
|
|
|
# end the local loop, do the local reduce
|
|
|
|
# end the local loop, do the local reduce
|
|
|
|
if self.group_for_reduce:
|
|
|
|
if self.group_for_reduce:
|
|
|
|
fake_global_idxs = [x*0 for x in global_idxs]
|
|
|
|
fake_global_idxs = [x*0 for x in global_idxs]
|
|
|
|
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
|
|
|
|
stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
|
|
|
|
self.uop(UOps.BARRIER, None, (), cachable=False)
|
|
|
|
barrier = self.uop(UOps.BARRIER, None, tuple(stores), cachable=False)
|
|
|
|
if self.opts.has_local:
|
|
|
|
if self.opts.has_local:
|
|
|
|
fake_idxs = [Variable.num(0)]*len(self.sts[-1].shape)
|
|
|
|
fake_idxs = [Variable.num(0)]*len(self.sts[-1].shape)
|
|
|
|
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
|
|
|
|
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
|
|
|
|
if_cond: UOp = (self.sts[-1].expr_idxs(fake_idxs)[0]<1).render(self.render_ops, self)
|
|
|
|
if_cond: UOp = (self.sts[-1].expr_idxs(fake_idxs)[0]<1).render(self.render_ops, self)
|
|
|
|
if_gate = self.uop(UOps.IF, None, (if_cond,), cachable=False)
|
|
|
|
barrier = self.uop(UOps.IF, None, (if_cond, barrier), cachable=False)
|
|
|
|
|
|
|
|
|
|
|
|
# create new late reduce local loops and replace local_idxs that have been used
|
|
|
|
# create new late reduce local loops and replace local_idxs that have been used
|
|
|
|
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
|
|
|
|
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
|
|
|
|
@@ -352,7 +353,7 @@ class Linearizer(Kernel):
|
|
|
|
loop_ctx = render_loop(end_local_idxs)
|
|
|
|
loop_ctx = render_loop(end_local_idxs)
|
|
|
|
|
|
|
|
|
|
|
|
# load localbufs
|
|
|
|
# load localbufs
|
|
|
|
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs)
|
|
|
|
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
|
|
|
|
|
|
|
|
|
|
|
|
# there's no AST here (and there's no shape for the reduce LazyOp)
|
|
|
|
# there's no AST here (and there's no shape for the reduce LazyOp)
|
|
|
|
self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # type: ignore
|
|
|
|
self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # type: ignore
|
|
|
|
@@ -369,12 +370,9 @@ class Linearizer(Kernel):
|
|
|
|
# store
|
|
|
|
# store
|
|
|
|
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
|
|
|
|
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
|
|
|
|
|
|
|
|
|
|
|
|
# end the if statement if we used it
|
|
|
|
|
|
|
|
if if_gate: self.uop(UOps.END, None, (if_gate,))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# (recursively) remove childless uops
|
|
|
|
# (recursively) remove childless uops
|
|
|
|
# NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that
|
|
|
|
# NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that
|
|
|
|
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.END, UOps.BARRIER, UOps.DEFINE_GLOBAL}
|
|
|
|
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_GLOBAL}
|
|
|
|
while 1:
|
|
|
|
while 1:
|
|
|
|
has_child: Set[UOp] = set()
|
|
|
|
has_child: Set[UOp] = set()
|
|
|
|
for ru in self.uops:
|
|
|
|
for ru in self.uops:
|
|
|
|
@@ -396,6 +394,7 @@ class Linearizer(Kernel):
|
|
|
|
return sorted(list(deps), key=lambda x: x.num)
|
|
|
|
return sorted(list(deps), key=lambda x: x.num)
|
|
|
|
|
|
|
|
|
|
|
|
# add END of loops after the last thing that (recursively) depends on them
|
|
|
|
# add END of loops after the last thing that (recursively) depends on them
|
|
|
|
|
|
|
|
# and END any if statements
|
|
|
|
for u in self.uops:
|
|
|
|
for u in self.uops:
|
|
|
|
if u.uop == UOps.LOOP:
|
|
|
|
if u.uop == UOps.LOOP:
|
|
|
|
last_phi = self.uops.index(get_recursive_deps(u)[-1])
|
|
|
|
last_phi = self.uops.index(get_recursive_deps(u)[-1])
|
|
|
|
@@ -403,6 +402,8 @@ class Linearizer(Kernel):
|
|
|
|
self.uops = self.uops[:last_phi+1]
|
|
|
|
self.uops = self.uops[:last_phi+1]
|
|
|
|
self.uop(UOps.END, None, (u,), cachable=False)
|
|
|
|
self.uop(UOps.END, None, (u,), cachable=False)
|
|
|
|
self.uops += at_end
|
|
|
|
self.uops += at_end
|
|
|
|
|
|
|
|
elif u.uop == UOps.IF:
|
|
|
|
|
|
|
|
self.uop(UOps.END, None, (u,), cachable=False)
|
|
|
|
|
|
|
|
|
|
|
|
# maybe graph the uops
|
|
|
|
# maybe graph the uops
|
|
|
|
if DEBUG >= 5:
|
|
|
|
if DEBUG >= 5:
|
|
|
|
|