mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 14:43:57 -05:00
This reverts commit 664b563c91.
This commit is contained in:
@@ -46,14 +46,5 @@ class TestUOpGraph(unittest.TestCase):
|
||||
self.assertEqual(out.uop, UOps.CONST)
|
||||
self.assertEqual(out.arg, 0)
|
||||
|
||||
def test_insert_before(self):
|
||||
g = UOpGraph()
|
||||
g.add(UOps.CONST, dtypes.int, arg=0)
|
||||
three = g.add(UOps.CONST, dtypes.int, arg=3)
|
||||
g.add(UOps.CONST, dtypes.int, arg=1, insert_before=three)
|
||||
g.add(UOps.CONST, dtypes.int, arg=2, insert_before=three)
|
||||
g.add(UOps.CONST, dtypes.int, arg=4)
|
||||
for i,uop in enumerate(g.uops): self.assertEqual(i, uop.arg)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -44,12 +44,11 @@ class Linearizer(Kernel):
|
||||
return self.uops.add(UOps.ALU, dtype, (a, render_b), op)
|
||||
|
||||
# NOTE: the consts have to be cached for deduping of downstream uops to work
|
||||
def const(self, b:ConstType, dtype:DType=dtypes.int32, insert_before:Optional[UOp|int]=None) -> UOp:
|
||||
def const(self, b:ConstType, dtype:DType=dtypes.int32, insert_before=None) -> UOp:
|
||||
if isinstance(b, Variable): return self.uops.add(UOps.DEFINE_VAR, dtype, tuple(), b.unbind()[0], insert_before=insert_before)
|
||||
else: return self.uops.add(UOps.CONST, dtype, tuple(), b, insert_before=insert_before)
|
||||
|
||||
def cast(self, val: Tuple[UOp], dtype:DType, insert_before:Optional[UOp|int]=None) -> UOp:
|
||||
return self.uops.add(UOps.CAST, dtype, val, insert_before=insert_before)
|
||||
def cast(self, val: UOp, dtype) -> UOp: return self.uops.add(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val
|
||||
|
||||
def get_reduce_acc(self, reduceop:LazyOp):
|
||||
dtype = reduceop.dtype
|
||||
@@ -71,7 +70,7 @@ class Linearizer(Kernel):
|
||||
AndNode: lambda self,ops,ctx:
|
||||
functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
|
||||
def global_load(self, i:int, idxs:List[Node], acc=None, barrier:Optional[UOp]=None, insert_before:Optional[UOp]=None) -> List[UOp]:
|
||||
def global_load(self, i:int, idxs:List[Node], acc=None, barrier:Optional[UOp]=None) -> List[UOp]:
|
||||
buf = self.bufs[i]
|
||||
localtype = self.get_base_dtype(buf.dtype if acc is None else self.reduceop.dtype)
|
||||
const = buf.val if isinstance(buf, ConstBuffer) else acc
|
||||
@@ -98,44 +97,39 @@ class Linearizer(Kernel):
|
||||
key = f"{acc}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
|
||||
if key not in self.load_cache:
|
||||
if acc is not None:
|
||||
self.load_cache[key] = self.uops.add(UOps.DEFINE_ACC, localtype, (), dtypes.as_const(this_const, localtype), False, insert_before)
|
||||
self.load_cache[key] = self.uops.add(UOps.DEFINE_ACC, localtype, (), dtypes.as_const(this_const, localtype), cachable=False)
|
||||
elif this_const is not None:
|
||||
self.load_cache[key] = self.const(this_const, localtype, insert_before)
|
||||
self.load_cache[key] = self.const(this_const, localtype)
|
||||
if valid.min == 0 and valid.max == 1:
|
||||
valid_rendered = valid.render(self.render_ops, self)
|
||||
self.load_cache[key] = self.uops.add(UOps.ALU, localtype, (valid_rendered, self.load_cache[key], \
|
||||
self.const(invalid_value, localtype, insert_before)), TernaryOps.WHERE, insert_before=insert_before)
|
||||
self.load_cache[key] = self.uops.add(UOps.ALU, localtype, (valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)), TernaryOps.WHERE) # noqa: E501
|
||||
elif isinstance(buf.dtype, ImageDType):
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
||||
rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx), \
|
||||
insert_before=insert_before)
|
||||
valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, buf.dtype.base.vec(4), insert_before)) \
|
||||
if valid.min == 0 else tuple()
|
||||
self.load_cache[key] = self.uops.add(UOps.LOAD, buf.dtype.base.vec(4), \
|
||||
(buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()), insert_before=insert_before)
|
||||
rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx))
|
||||
valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, buf.dtype.base.vec(4))) if valid.min == 0 else tuple()
|
||||
self.load_cache[key] = self.uops.add(UOps.LOAD, buf.dtype.base.vec(4),
|
||||
(buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
||||
if localtype == localtype.scalar():
|
||||
idx_small = idx%4
|
||||
res = idx_small.render(self.render_ops, self)
|
||||
out = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max, insert_before=insert_before)
|
||||
out = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
|
||||
for ix in range(idx_small.max, idx_small.min, -1):
|
||||
rvv = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), ix-1, insert_before=insert_before)
|
||||
sel = self.uops.add(UOps.ALU, dtypes.bool, (res, self.const(ix)), BinaryOps.CMPLT, insert_before=insert_before)
|
||||
out = self.uops.add(UOps.ALU, localtype, (sel, rvv, out), TernaryOps.WHERE, insert_before=insert_before)
|
||||
rvv = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
|
||||
sel = self.uops.add(UOps.ALU, dtypes.bool, (res, self.const(ix)), BinaryOps.CMPLT)
|
||||
out = self.uops.add(UOps.ALU, localtype, (sel, rvv, out), TernaryOps.WHERE)
|
||||
self.load_cache[key] = out
|
||||
else:
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, localtype, insert_before)) if valid.min == 0 else tuple()
|
||||
self.load_cache[key] = \
|
||||
self.uops.add(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()), insert_before=insert_before)
|
||||
ret.append(self.uops.add(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim], insert_before=insert_before) \
|
||||
if dim is not None else self.load_cache[key])
|
||||
valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, localtype)) if valid.min == 0 else tuple()
|
||||
self.load_cache[key] = self.uops.add(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
||||
ret.append(self.uops.add(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
|
||||
return ret
|
||||
|
||||
def global_store(self, i:int, idxs:List[Node], store:List[UOp], insert_before:Optional[UOp]=None) -> List[UOp]:
|
||||
def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]:
|
||||
buf = self.bufs[i]
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
@@ -155,7 +149,7 @@ class Linearizer(Kernel):
|
||||
amt = len(grouped)
|
||||
idx, valid = self.sts[i].expr_idxs(k)
|
||||
assert idx == ((idx//amt)*amt), "float4 stores are always aligned"
|
||||
store_offset_new[k] = self.uops.add(UOps.CAST, buf.dtype.vec(amt), tuple(grouped), insert_before=insert_before)
|
||||
store_offset_new[k] = self.uops.add(UOps.CAST, buf.dtype.vec(amt), tuple(grouped))
|
||||
store_offset = store_offset_new
|
||||
|
||||
stores = []
|
||||
@@ -164,25 +158,23 @@ class Linearizer(Kernel):
|
||||
if isinstance(buf.dtype, ImageDType):
|
||||
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
||||
rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), \
|
||||
tuple(x.render(self.render_ops, self) for x in image_idx), insert_before=insert_before)
|
||||
tuple(x.render(self.render_ops, self) for x in image_idx))
|
||||
else:
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
if valid.min == 1: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var), insert_before=insert_before))
|
||||
else:
|
||||
stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self)), insert_before=insert_before))
|
||||
if valid.min == 1: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var)))
|
||||
else: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
|
||||
return stores
|
||||
|
||||
# render loop
|
||||
def render_loop(self, xx:List[Variable], insert_before:Optional[UOp]=None) -> Tuple[UOp, ...]:
|
||||
def render_loop(self, xx:List[Variable]) -> Tuple[UOp, ...]:
|
||||
new_loops = {x.expr:self.uops.add(UOps.LOOP, dtypes.int32, (
|
||||
self.const(x.min, dtypes.int32, insert_before) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
|
||||
self.const(x.max+1, dtypes.int32, insert_before) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)
|
||||
), cachable=False, insert_before=insert_before) for x in xx if not isinstance(x, NumNode) and x.expr is not None}
|
||||
self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
|
||||
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
|
||||
self.loop_uops.update(new_loops)
|
||||
return tuple(new_loops.values())
|
||||
|
||||
def render_reduceop(self, reduceop: LazyOp, loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], \
|
||||
global_idxs, local_idxs, upcast_idxs, insert_before:Optional[UOp]=None):
|
||||
global_idxs, local_idxs, upcast_idxs):
|
||||
# define indicies
|
||||
full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])]
|
||||
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
|
||||
@@ -227,14 +219,13 @@ class Linearizer(Kernel):
|
||||
for n in range(len(replace_acc_idxs)-len(tc.threads)):
|
||||
upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
|
||||
if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs}")
|
||||
acc = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(reduceop), insert_before=insert_before)
|
||||
acc = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(reduceop))
|
||||
|
||||
# reduce loop
|
||||
loop_ctx = self.render_loop(reduce_idxs, insert_before=insert_before)
|
||||
loop_ctx = self.render_loop(reduce_idxs)
|
||||
|
||||
# store local aliases
|
||||
locals_to_store = [(localbuf_idx, buf_idxs, self.global_load(i, buf_idxs, insert_before=insert_before)) \
|
||||
for i, localbuf_idx, buf_idxs in alias_buf_idxs]
|
||||
locals_to_store = [(localbuf_idx, buf_idxs, self.global_load(i, buf_idxs)) for i, localbuf_idx, buf_idxs in alias_buf_idxs]
|
||||
|
||||
if (tc:=self.tensor_core):
|
||||
# run tensor cores AST
|
||||
@@ -247,26 +238,22 @@ class Linearizer(Kernel):
|
||||
return strides
|
||||
upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device
|
||||
for iter in [x[::-1] for x in itertools.product(*[x for x in [range(sz) for _,sz in upcasts[0]][::-1]])]:
|
||||
offs: List[int] = [x*y for (x,y) in zip([sum([prod(x) for x in zip(iter, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
|
||||
ops = (
|
||||
self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]]), None, True, insert_before),
|
||||
self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]]), None, True, insert_before),
|
||||
self.uops.add(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(op3:=acc[offs[2]:offs[2]+wmma_sz[2]]), None, True, insert_before))
|
||||
ret = self.uops.add(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, \
|
||||
tuple(map(prod, tc.thread_local_sizes)), dev), insert_before=insert_before)
|
||||
offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(iter, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
|
||||
ops = (self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
|
||||
self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
|
||||
self.uops.add(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(op3:=acc[offs[2]:offs[2]+wmma_sz[2]])))
|
||||
ret = self.uops.add(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(map(prod, tc.thread_local_sizes)), dev))
|
||||
for z in range(wmma_sz[2]): # TODO: don't need to DEFINE_ACC, pass to WMMA in op3, or PHI accs that are not valid
|
||||
acc[offs[2]+z] = self.uops.add(UOps.PHI, tc.dtype_out, (op3[z], \
|
||||
self.uops.add(UOps.GEP, tc.dtype_out, (ret,), z, insert_before=insert_before)) + loop_ctx, insert_before=insert_before)
|
||||
acc[offs[2]+z] = self.uops.add(UOps.PHI, tc.dtype_out, (op3[z], self.uops.add(UOps.GEP, tc.dtype_out, (ret,), z)) + loop_ctx)
|
||||
else:
|
||||
assert not locals_to_store, "storing locals isn't supported here"
|
||||
|
||||
# load earlybufs
|
||||
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i,
|
||||
global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,insert_before=insert_before) for i,b in enumerate(self.bufs) if b in self.earlybufs})
|
||||
global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs})
|
||||
|
||||
# run early AST (with reduce)
|
||||
self.ast_parse(reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, \
|
||||
do_reduce=True, loop_ctx=loop_ctx, insert_before=insert_before)
|
||||
self.ast_parse(reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)
|
||||
|
||||
# end the reduce loop
|
||||
self.load_cache.clear()
|
||||
@@ -274,14 +261,13 @@ class Linearizer(Kernel):
|
||||
# end the local loop, do the local reduce
|
||||
if self.group_for_reduces:
|
||||
fake_global_idxs = [x*0 for x in global_idxs]
|
||||
# store accumulators
|
||||
stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc, insert_before=insert_before)
|
||||
barrier = self.uops.add(UOps.BARRIER, None, tuple(stores), cachable=False, insert_before=insert_before)
|
||||
stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
|
||||
barrier = self.uops.add(UOps.BARRIER, None, tuple(stores), cachable=False)
|
||||
if self.opts.has_local:
|
||||
fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
|
||||
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
|
||||
if_cond: UOp = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1).render(self.render_ops, self)
|
||||
barrier = self.uops.add(UOps.IF, None, (if_cond, barrier), cachable=False, insert_before=insert_before)
|
||||
barrier = self.uops.add(UOps.IF, None, (if_cond, barrier), cachable=False)
|
||||
|
||||
# create new late reduce local loops and replace local_idxs that have been used
|
||||
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+self.group_for_reduces)] # noqa: E501
|
||||
@@ -300,17 +286,16 @@ class Linearizer(Kernel):
|
||||
# NOTE: this structure is the same as the reduce op above
|
||||
|
||||
# define late accumulator
|
||||
acc = self.global_load(0, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(reduceop), insert_before=insert_before)
|
||||
acc = self.global_load(0, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(reduceop))
|
||||
|
||||
# late reduce loop
|
||||
loop_ctx = self.render_loop(end_local_idxs, insert_before=insert_before)
|
||||
loop_ctx = self.render_loop(end_local_idxs)
|
||||
|
||||
# load localbufs
|
||||
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, None, barrier, insert_before)
|
||||
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
|
||||
|
||||
# there's no AST here (and there's no shape for the reduce LazyOp)
|
||||
self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),)), \
|
||||
acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx, insert_before=insert_before)
|
||||
self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # noqa: E501
|
||||
|
||||
# end the late reduce loop
|
||||
self.load_cache.clear()
|
||||
@@ -427,30 +412,28 @@ class Linearizer(Kernel):
|
||||
self.applied_opts_cache = self.applied_opts[:]
|
||||
return self
|
||||
|
||||
def ast_parse(self, x:LazyOp, acc: List[UOp], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]],\
|
||||
do_reduce=False, loop_ctx=tuple(), cache=None, insert_before:Optional[UOp]=None) -> List[UOp]:
|
||||
def ast_parse(self, x:LazyOp, acc: List[UOp], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], do_reduce=False, loop_ctx=tuple(), cache=None) -> List[UOp]: # noqa: E501
|
||||
if cache is None: cache = {}
|
||||
if x in cache: return cache[x]
|
||||
if x.op in BufferOps: return loaded_buffers[x.arg]
|
||||
if x.op is UnaryOps.CAST: return [self.uops.add(UOps.BITCAST if x.arg[1] else UOps.CAST, self.get_base_dtype(x.arg[0]), (u,), x.arg[0], \
|
||||
insert_before=insert_before) for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers, insert_before=insert_before)]
|
||||
if x.op is UnaryOps.CAST: return [self.uops.add(UOps.BITCAST if x.arg[1] else UOps.CAST, self.get_base_dtype(x.arg[0]), (u,), x.arg[0]) \
|
||||
for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)]
|
||||
if x.op in ReduceOps and not do_reduce:
|
||||
assert offs is None, "not available if we aren't doing reduce"
|
||||
return acc
|
||||
|
||||
values = [self.ast_parse(v, acc, offs, loaded_buffers, loop_ctx=loop_ctx, cache=cache, insert_before=insert_before) for v in x.src]
|
||||
values = [self.ast_parse(v, acc, offs, loaded_buffers, loop_ctx=loop_ctx, cache=cache) for v in x.src]
|
||||
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}
|
||||
if x.op in ops:
|
||||
ret: List[UOp] = []
|
||||
input_acc = acc[:]
|
||||
for val, off in zip(zip(*values), cast(List[int], offs)):
|
||||
acc[off] = self.uops.add(UOps.ALU, acc[off].dtype, vin=val+(acc[off],), arg=ops[cast(ReduceOps, x.op)], insert_before=insert_before)
|
||||
acc[off] = self.uops.add(UOps.ALU, acc[off].dtype, vin=val+(acc[off],), arg=ops[cast(ReduceOps, x.op)])
|
||||
ret.append(acc[off])
|
||||
for off in range(len(acc)):
|
||||
if input_acc[off] != acc[off]:
|
||||
acc[off] = self.uops.add(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]) + tuple(loop_ctx), insert_before=insert_before)
|
||||
acc[off] = self.uops.add(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]) + tuple(loop_ctx))
|
||||
else:
|
||||
ret = [self.uops.add(UOps.ALU, dtypes.bool if x.op in {BinaryOps.CMPLT, BinaryOps.CMPEQ} else val[-1].dtype, \
|
||||
val, x.op, insert_before=insert_before) for val in zip(*values)]
|
||||
ret = [self.uops.add(UOps.ALU, dtypes.bool if x.op in {BinaryOps.CMPLT, BinaryOps.CMPEQ} else val[-1].dtype, val, x.op) for val in zip(*values)]
|
||||
cache[x] = ret
|
||||
return ret
|
||||
|
||||
@@ -157,7 +157,6 @@ class UOpGraph:
|
||||
ret = rewritten
|
||||
key = (ret.uop, ret.dtype, ret.vin, ret.arg)
|
||||
if insert_before is None: insert_before = len(self.uops)
|
||||
elif isinstance(insert_before, UOp): insert_before = self.uops.index(insert_before)
|
||||
# check if the cached expr is valid with the given insert place.
|
||||
if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and self.uops.index(expr) <= insert_before: return expr
|
||||
self.uops.insert(insert_before, ret)
|
||||
@@ -284,7 +283,7 @@ class UOpGraph:
|
||||
loop_length = loop_op.vin[1].arg - loop_op.vin[0].arg
|
||||
for u in self.uops:
|
||||
if u.arg is BinaryOps.ADD and len(wheres.intersection(get_recursive_parents(u))) and len(phis.intersection(self.get_recursive_children(u))):
|
||||
u.vin = tuple([const(vin.arg*loop_length, insert_before=u) if vin.uop is UOps.CONST else vin for vin in list(u.vin)])
|
||||
u.vin = tuple([const(vin.arg*loop_length, insert_before=self.uops.index(u)) if vin.uop is UOps.CONST else vin for vin in list(u.vin)])
|
||||
for where in sorted(wheres, key=lambda x: self.uops.index(x)):
|
||||
comp_lt, comp_gt = where.vin[0].vin[0], where.vin[0].vin[1]
|
||||
factored = loop_factor(comp_lt, NumNode(int(comp_gt.arg)), loop_op, round_up=(comp_gt.arg > 0))
|
||||
@@ -333,10 +332,10 @@ class UOpGraph:
|
||||
del self.saved_exprs[(u.uop, u.dtype, u.vin, u.arg)]
|
||||
# NOTE: assuming u.vin[2].vin[1] and u.vin[2].vin[0] have the same dtype
|
||||
loop_len = self.add(UOps.ALU, u.vin[2].vin[1].dtype, (u.vin[2].vin[1], u.vin[2].vin[0]), BinaryOps.SUB,
|
||||
insert_before=u)
|
||||
insert_before=self.uops.index(u))
|
||||
if loop_len.dtype != u.dtype: loop_len = self.add(UOps.CAST, u.dtype, (loop_len,),
|
||||
insert_before=u)
|
||||
new = self.add(UOps.ALU, u.dtype, (u.vin[1], loop_len,), BinaryOps.MUL, insert_before=u)
|
||||
insert_before=self.uops.index(u))
|
||||
new = self.add(UOps.ALU, u.dtype, (u.vin[1], loop_len,), BinaryOps.MUL, insert_before=self.uops.index(u))
|
||||
self.replace_op(u, new)
|
||||
return True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user