From 267bbb57f9e767cd39804acbf0e0334031c12d35 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sat, 4 May 2024 22:50:21 +0800 Subject: [PATCH] Revert "Add `insert_before` to Linearizer Functions (#4320)" (#4421) This reverts commit 664b563c91ced21eff115d26ae296e39133329db. --- test/test_uop_graph.py | 9 --- tinygrad/codegen/linearizer.py | 119 ++++++++++++++------------------- tinygrad/codegen/uops.py | 9 ++- 3 files changed, 55 insertions(+), 82 deletions(-) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index cb65300ae7..0996e87068 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -46,14 +46,5 @@ class TestUOpGraph(unittest.TestCase): self.assertEqual(out.uop, UOps.CONST) self.assertEqual(out.arg, 0) - def test_insert_before(self): - g = UOpGraph() - g.add(UOps.CONST, dtypes.int, arg=0) - three = g.add(UOps.CONST, dtypes.int, arg=3) - g.add(UOps.CONST, dtypes.int, arg=1, insert_before=three) - g.add(UOps.CONST, dtypes.int, arg=2, insert_before=three) - g.add(UOps.CONST, dtypes.int, arg=4) - for i,uop in enumerate(g.uops): self.assertEqual(i, uop.arg) - if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 88e73d5d77..34a0bbf0d8 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -44,12 +44,11 @@ class Linearizer(Kernel): return self.uops.add(UOps.ALU, dtype, (a, render_b), op) # NOTE: the consts have to be cached for deduping of downstream uops to work - def const(self, b:ConstType, dtype:DType=dtypes.int32, insert_before:Optional[UOp|int]=None) -> UOp: + def const(self, b:ConstType, dtype:DType=dtypes.int32, insert_before=None) -> UOp: if isinstance(b, Variable): return self.uops.add(UOps.DEFINE_VAR, dtype, tuple(), b.unbind()[0], insert_before=insert_before) else: return self.uops.add(UOps.CONST, dtype, tuple(), b, insert_before=insert_before) - def cast(self, val: Tuple[UOp], dtype:DType, insert_before:Optional[UOp|int]=None) -> UOp: - return self.uops.add(UOps.CAST, dtype, val, insert_before=insert_before) + def cast(self, val: UOp, dtype) -> UOp: return self.uops.add(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val def get_reduce_acc(self, reduceop:LazyOp): dtype = reduceop.dtype @@ -71,7 +70,7 @@ class Linearizer(Kernel): AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) } - def global_load(self, i:int, idxs:List[Node], acc=None, barrier:Optional[UOp]=None, insert_before:Optional[UOp]=None) -> List[UOp]: + def global_load(self, i:int, idxs:List[Node], acc=None, barrier:Optional[UOp]=None) -> List[UOp]: buf = self.bufs[i] localtype = self.get_base_dtype(buf.dtype if acc is None else self.reduceop.dtype) const = buf.val if isinstance(buf, ConstBuffer) else acc @@ -98,44 +97,39 @@ class Linearizer(Kernel): key = f"{acc}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501 if key not in self.load_cache: if acc is not None: - self.load_cache[key] = self.uops.add(UOps.DEFINE_ACC, localtype, (), dtypes.as_const(this_const, localtype), False, insert_before) + self.load_cache[key] = self.uops.add(UOps.DEFINE_ACC, localtype, (), dtypes.as_const(this_const, localtype), cachable=False) elif this_const is not None: - self.load_cache[key] = self.const(this_const, localtype, insert_before) + self.load_cache[key] = self.const(this_const, localtype) if valid.min == 0 and valid.max == 1: valid_rendered = valid.render(self.render_ops, self) - self.load_cache[key] = self.uops.add(UOps.ALU, localtype, (valid_rendered, self.load_cache[key], \ - self.const(invalid_value, localtype, insert_before)), TernaryOps.WHERE, insert_before=insert_before) + self.load_cache[key] = self.uops.add(UOps.ALU, localtype, (valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)), TernaryOps.WHERE) # noqa: E501 elif isinstance(buf.dtype, ImageDType): buf_uop = self.buf_uops[i] assert buf_uop is not None, f"buffer {i} wasn't UOped" image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid) - rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx), \ - insert_before=insert_before) - valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, buf.dtype.base.vec(4), insert_before)) \ - if valid.min == 0 else tuple() - self.load_cache[key] = self.uops.add(UOps.LOAD, buf.dtype.base.vec(4), \ - (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()), insert_before=insert_before) + rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx)) + valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, buf.dtype.base.vec(4))) if valid.min == 0 else tuple() + self.load_cache[key] = self.uops.add(UOps.LOAD, buf.dtype.base.vec(4), + (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ())) if localtype == localtype.scalar(): idx_small = idx%4 res = idx_small.render(self.render_ops, self) - out = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max, insert_before=insert_before) + out = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max) for ix in range(idx_small.max, idx_small.min, -1): - rvv = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), ix-1, insert_before=insert_before) - sel = self.uops.add(UOps.ALU, dtypes.bool, (res, self.const(ix)), BinaryOps.CMPLT, insert_before=insert_before) - out = self.uops.add(UOps.ALU, localtype, (sel, rvv, out), TernaryOps.WHERE, insert_before=insert_before) + rvv = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), ix-1) + sel = self.uops.add(UOps.ALU, dtypes.bool, (res, self.const(ix)), BinaryOps.CMPLT) + out = self.uops.add(UOps.ALU, localtype, (sel, rvv, out), TernaryOps.WHERE) self.load_cache[key] = out else: buf_uop = self.buf_uops[i] assert buf_uop is not None, f"buffer {i} wasn't UOped" rendered_idx = idx.render(self.render_ops, self) - valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, localtype, insert_before)) if valid.min == 0 else tuple() - self.load_cache[key] = \ - self.uops.add(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()), insert_before=insert_before) - ret.append(self.uops.add(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim], insert_before=insert_before) \ - if dim is not None else self.load_cache[key]) + valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, localtype)) if valid.min == 0 else tuple() + self.load_cache[key] = self.uops.add(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ())) + ret.append(self.uops.add(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key]) return ret - def global_store(self, i:int, idxs:List[Node], store:List[UOp], insert_before:Optional[UOp]=None) -> List[UOp]: + def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]: buf = self.bufs[i] buf_uop = self.buf_uops[i] assert buf_uop is not None, f"buffer {i} wasn't UOped" @@ -155,7 +149,7 @@ class Linearizer(Kernel): amt = len(grouped) idx, valid = self.sts[i].expr_idxs(k) assert idx == ((idx//amt)*amt), "float4 stores are always aligned" - store_offset_new[k] = self.uops.add(UOps.CAST, buf.dtype.vec(amt), tuple(grouped), insert_before=insert_before) + store_offset_new[k] = self.uops.add(UOps.CAST, buf.dtype.vec(amt), tuple(grouped)) store_offset = store_offset_new stores = [] @@ -164,25 +158,23 @@ class Linearizer(Kernel): if isinstance(buf.dtype, ImageDType): image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid) rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), \ - tuple(x.render(self.render_ops, self) for x in image_idx), insert_before=insert_before) + tuple(x.render(self.render_ops, self) for x in image_idx)) else: rendered_idx = idx.render(self.render_ops, self) - if valid.min == 1: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var), insert_before=insert_before)) - else: - stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self)), insert_before=insert_before)) + if valid.min == 1: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var))) + else: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self)))) return stores # render loop - def render_loop(self, xx:List[Variable], insert_before:Optional[UOp]=None) -> Tuple[UOp, ...]: + def render_loop(self, xx:List[Variable]) -> Tuple[UOp, ...]: new_loops = {x.expr:self.uops.add(UOps.LOOP, dtypes.int32, ( - self.const(x.min, dtypes.int32, insert_before) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self), - self.const(x.max+1, dtypes.int32, insert_before) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self) - ), cachable=False, insert_before=insert_before) for x in xx if not isinstance(x, NumNode) and x.expr is not None} + self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self), + self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501 self.loop_uops.update(new_loops) return tuple(new_loops.values()) def render_reduceop(self, reduceop: LazyOp, loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], \ - global_idxs, local_idxs, upcast_idxs, insert_before:Optional[UOp]=None): + global_idxs, local_idxs, upcast_idxs): # define indicies full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])] reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501 @@ -227,14 +219,13 @@ class Linearizer(Kernel): for n in range(len(replace_acc_idxs)-len(tc.threads)): upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs}") - acc = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(reduceop), insert_before=insert_before) + acc = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(reduceop)) # reduce loop - loop_ctx = self.render_loop(reduce_idxs, insert_before=insert_before) + loop_ctx = self.render_loop(reduce_idxs) # store local aliases - locals_to_store = [(localbuf_idx, buf_idxs, self.global_load(i, buf_idxs, insert_before=insert_before)) \ - for i, localbuf_idx, buf_idxs in alias_buf_idxs] + locals_to_store = [(localbuf_idx, buf_idxs, self.global_load(i, buf_idxs)) for i, localbuf_idx, buf_idxs in alias_buf_idxs] if (tc:=self.tensor_core): # run tensor cores AST @@ -247,26 +238,22 @@ class Linearizer(Kernel): return strides upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device for iter in [x[::-1] for x in itertools.product(*[x for x in [range(sz) for _,sz in upcasts[0]][::-1]])]: - offs: List[int] = [x*y for (x,y) in zip([sum([prod(x) for x in zip(iter, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)] - ops = ( - self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]]), None, True, insert_before), - self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]]), None, True, insert_before), - self.uops.add(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(op3:=acc[offs[2]:offs[2]+wmma_sz[2]]), None, True, insert_before)) - ret = self.uops.add(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, \ - tuple(map(prod, tc.thread_local_sizes)), dev), insert_before=insert_before) + offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(iter, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)] + ops = (self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])), + self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])), + self.uops.add(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(op3:=acc[offs[2]:offs[2]+wmma_sz[2]]))) + ret = self.uops.add(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(map(prod, tc.thread_local_sizes)), dev)) for z in range(wmma_sz[2]): # TODO: don't need to DEFINE_ACC, pass to WMMA in op3, or PHI accs that are not valid - acc[offs[2]+z] = self.uops.add(UOps.PHI, tc.dtype_out, (op3[z], \ - self.uops.add(UOps.GEP, tc.dtype_out, (ret,), z, insert_before=insert_before)) + loop_ctx, insert_before=insert_before) + acc[offs[2]+z] = self.uops.add(UOps.PHI, tc.dtype_out, (op3[z], self.uops.add(UOps.GEP, tc.dtype_out, (ret,), z)) + loop_ctx) else: assert not locals_to_store, "storing locals isn't supported here" # load earlybufs loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, - global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,insert_before=insert_before) for i,b in enumerate(self.bufs) if b in self.earlybufs}) + global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs}) # run early AST (with reduce) - self.ast_parse(reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, \ - do_reduce=True, loop_ctx=loop_ctx, insert_before=insert_before) + self.ast_parse(reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # end the reduce loop self.load_cache.clear() @@ -274,14 +261,13 @@ class Linearizer(Kernel): # end the local loop, do the local reduce if self.group_for_reduces: fake_global_idxs = [x*0 for x in global_idxs] - # store accumulators - stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc, insert_before=insert_before) - barrier = self.uops.add(UOps.BARRIER, None, tuple(stores), cachable=False, insert_before=insert_before) + stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators + barrier = self.uops.add(UOps.BARRIER, None, tuple(stores), cachable=False) if self.opts.has_local: fake_idxs = [NumNode(0)]*len(self.sts[-1].shape) fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:] if_cond: UOp = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1).render(self.render_ops, self) - barrier = self.uops.add(UOps.IF, None, (if_cond, barrier), cachable=False, insert_before=insert_before) + barrier = self.uops.add(UOps.IF, None, (if_cond, barrier), cachable=False) # create new late reduce local loops and replace local_idxs that have been used end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+self.group_for_reduces)] # noqa: E501 @@ -300,17 +286,16 @@ class Linearizer(Kernel): # NOTE: this structure is the same as the reduce op above # define late accumulator - acc = self.global_load(0, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(reduceop), insert_before=insert_before) + acc = self.global_load(0, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(reduceop)) # late reduce loop - loop_ctx = self.render_loop(end_local_idxs, insert_before=insert_before) + loop_ctx = self.render_loop(end_local_idxs) # load localbufs - loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, None, barrier, insert_before) + loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier) # there's no AST here (and there's no shape for the reduce LazyOp) - self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),)), \ - acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx, insert_before=insert_before) + self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # noqa: E501 # end the late reduce loop self.load_cache.clear() @@ -427,30 +412,28 @@ class Linearizer(Kernel): self.applied_opts_cache = self.applied_opts[:] return self - def ast_parse(self, x:LazyOp, acc: List[UOp], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]],\ - do_reduce=False, loop_ctx=tuple(), cache=None, insert_before:Optional[UOp]=None) -> List[UOp]: + def ast_parse(self, x:LazyOp, acc: List[UOp], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], do_reduce=False, loop_ctx=tuple(), cache=None) -> List[UOp]: # noqa: E501 if cache is None: cache = {} if x in cache: return cache[x] if x.op in BufferOps: return loaded_buffers[x.arg] - if x.op is UnaryOps.CAST: return [self.uops.add(UOps.BITCAST if x.arg[1] else UOps.CAST, self.get_base_dtype(x.arg[0]), (u,), x.arg[0], \ - insert_before=insert_before) for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers, insert_before=insert_before)] + if x.op is UnaryOps.CAST: return [self.uops.add(UOps.BITCAST if x.arg[1] else UOps.CAST, self.get_base_dtype(x.arg[0]), (u,), x.arg[0]) \ + for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)] if x.op in ReduceOps and not do_reduce: assert offs is None, "not available if we aren't doing reduce" return acc - values = [self.ast_parse(v, acc, offs, loaded_buffers, loop_ctx=loop_ctx, cache=cache, insert_before=insert_before) for v in x.src] + values = [self.ast_parse(v, acc, offs, loaded_buffers, loop_ctx=loop_ctx, cache=cache) for v in x.src] ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX} if x.op in ops: ret: List[UOp] = [] input_acc = acc[:] for val, off in zip(zip(*values), cast(List[int], offs)): - acc[off] = self.uops.add(UOps.ALU, acc[off].dtype, vin=val+(acc[off],), arg=ops[cast(ReduceOps, x.op)], insert_before=insert_before) + acc[off] = self.uops.add(UOps.ALU, acc[off].dtype, vin=val+(acc[off],), arg=ops[cast(ReduceOps, x.op)]) ret.append(acc[off]) for off in range(len(acc)): if input_acc[off] != acc[off]: - acc[off] = self.uops.add(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]) + tuple(loop_ctx), insert_before=insert_before) + acc[off] = self.uops.add(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]) + tuple(loop_ctx)) else: - ret = [self.uops.add(UOps.ALU, dtypes.bool if x.op in {BinaryOps.CMPLT, BinaryOps.CMPEQ} else val[-1].dtype, \ - val, x.op, insert_before=insert_before) for val in zip(*values)] + ret = [self.uops.add(UOps.ALU, dtypes.bool if x.op in {BinaryOps.CMPLT, BinaryOps.CMPEQ} else val[-1].dtype, val, x.op) for val in zip(*values)] cache[x] = ret return ret diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index d3ac933cb9..f965a3cf80 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -157,7 +157,6 @@ class UOpGraph: ret = rewritten key = (ret.uop, ret.dtype, ret.vin, ret.arg) if insert_before is None: insert_before = len(self.uops) - elif isinstance(insert_before, UOp): insert_before = self.uops.index(insert_before) # check if the cached expr is valid with the given insert place. if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and self.uops.index(expr) <= insert_before: return expr self.uops.insert(insert_before, ret) @@ -284,7 +283,7 @@ class UOpGraph: loop_length = loop_op.vin[1].arg - loop_op.vin[0].arg for u in self.uops: if u.arg is BinaryOps.ADD and len(wheres.intersection(get_recursive_parents(u))) and len(phis.intersection(self.get_recursive_children(u))): - u.vin = tuple([const(vin.arg*loop_length, insert_before=u) if vin.uop is UOps.CONST else vin for vin in list(u.vin)]) + u.vin = tuple([const(vin.arg*loop_length, insert_before=self.uops.index(u)) if vin.uop is UOps.CONST else vin for vin in list(u.vin)]) for where in sorted(wheres, key=lambda x: self.uops.index(x)): comp_lt, comp_gt = where.vin[0].vin[0], where.vin[0].vin[1] factored = loop_factor(comp_lt, NumNode(int(comp_gt.arg)), loop_op, round_up=(comp_gt.arg > 0)) @@ -333,10 +332,10 @@ class UOpGraph: del self.saved_exprs[(u.uop, u.dtype, u.vin, u.arg)] # NOTE: assuming u.vin[2].vin[1] and u.vin[2].vin[0] have the same dtype loop_len = self.add(UOps.ALU, u.vin[2].vin[1].dtype, (u.vin[2].vin[1], u.vin[2].vin[0]), BinaryOps.SUB, - insert_before=u) + insert_before=self.uops.index(u)) if loop_len.dtype != u.dtype: loop_len = self.add(UOps.CAST, u.dtype, (loop_len,), - insert_before=u) - new = self.add(UOps.ALU, u.dtype, (u.vin[1], loop_len,), BinaryOps.MUL, insert_before=u) + insert_before=self.uops.index(u)) + new = self.add(UOps.ALU, u.dtype, (u.vin[1], loop_len,), BinaryOps.MUL, insert_before=self.uops.index(u)) self.replace_op(u, new) return True