diff --git a/test/test_outerworld.py b/test/test_outerworld.py index 1d3f39cf53..795b77ee84 100644 --- a/test/test_outerworld.py +++ b/test/test_outerworld.py @@ -38,7 +38,7 @@ class TestOuterRange(unittest.TestCase): vi = UOp.variable("i", i.vmin, i.vmax).bind(i) out = Tensor(acc.uop.after(acc_i.store(acc_i + a[:, vi].uop).end(i))) out.realize() - assert all(x == 10.0 for x in out.tolist()) + self.assertEqual(out.tolist(), [10.0]*10) def test_range_matmul(self): vec = Tensor.randn(1, 10).realize() diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 8616e1ce15..d5607fff0f 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -9,8 +9,8 @@ from tinygrad.engine.realize import ExecItem # **** schedule linearizer -# ScheduleItem = tuple[AST, buffer UOps, metadata, fixedvars, bound_ranges] -ScheduleItem = tuple[UOp, tuple[UOp, ...], tuple[Metadata, ...], dict[str, int], tuple[UOp, ...]] +# ScheduleItem = tuple[AST, buffer UOps, metadata, bound_ranges] +ScheduleItem = tuple[UOp, tuple[UOp, ...], tuple[Metadata, ...], tuple[UOp, ...]] # unwrap VIEW/CAST/etc to find the actual data source (kernel output, buffer, or multi-device op) def _unwrap_src(s: UOp) -> UOp: @@ -51,7 +51,8 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]: for k,v in in_degree.items(): if v == 0: queue.append(k) - schedule: list[ScheduleItem|UOp] = [] # ScheduleItem for kernels, UOp for RANGE/END + schedule: list[UOp] = [] # RANGE, KERNEL, or END UOps + sched_item: dict[UOp, ScheduleItem] = {} while len(queue): k = rk = queue.popleft() if k.op is Ops.END: k = k.src[0] @@ -61,32 +62,34 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]: ast = (kernel:=cast(Kernel, k.arg)).ast buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src if s.op is not Ops.BIND) bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE) - schedule.append((ast, buf_uops, kernel.metadata, {}, bound_ranges)) + sched_item[k] = (ast, buf_uops, kernel.metadata, bound_ranges) + schedule.append(k) if rk.op is Ops.END: schedule.append(rk) for x in children.get(rk, []): in_degree[x] -= 1 if in_degree[x] == 0: queue.append(x) with cpu_profile(TracingKey("unroll outer ranges")): - pre_schedule, buf_uops_list = unroll_outer_ranges(schedule) + pre_schedule, buf_uops_list = unroll_outer_ranges(schedule, sched_item) return pre_schedule, UOp.sink(*buf_uops_list) -def unroll_outer_ranges(schedule:list[ScheduleItem|UOp]) -> tuple[list[ExecItem], list[UOp]]: +def unroll_outer_ranges(schedule:list[UOp], sched_item:dict[UOp, ScheduleItem]) -> tuple[list[ExecItem], list[UOp]]: pre_schedule: list[ExecItem] = [] buf_uops_list: list[UOp] = [] sched_ptr, in_ranges, range_ptrs = 0, dict[UOp, int](), dict[UOp, int]() while sched_ptr < len(schedule): - if isinstance(si := schedule[sched_ptr], UOp): - if si.op is Ops.RANGE: - in_ranges[si] = 0 - range_ptrs[si] = sched_ptr + 1 - elif si.op is Ops.END: - if in_ranges[si.src[1]] < si.src[1].vmax: - in_ranges[si.src[1]] += 1 - sched_ptr = range_ptrs[si.src[1]] - continue + si = schedule[sched_ptr] + if si.op is Ops.RANGE: + in_ranges[si] = 0 + range_ptrs[si] = sched_ptr + 1 + elif si.op is Ops.END: + if in_ranges[si.src[1]] < si.src[1].vmax: + in_ranges[si.src[1]] += 1 + sched_ptr = range_ptrs[si.src[1]] + continue else: - ast, buf_uops, metadata, _, bound_ranges = si + assert si.op is Ops.KERNEL, f"unexpected op in schedule: {si.op}" + ast, buf_uops, metadata, bound_ranges = sched_item[si] fixedvars = {s.src[0].arg[0]:in_ranges[s.src[1]] for s in bound_ranges} pre_schedule.append(ExecItem(ast, [], metadata, fixedvars)) buf_uops_list.append(UOp.sink(*buf_uops))