mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
preschedule all (#3875)
This commit is contained in:
@@ -100,6 +100,14 @@ class TestAssign(unittest.TestCase):
|
||||
new = a + times_a
|
||||
np.testing.assert_allclose(new.numpy(), 5)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_assign_diamond_possible(self):
|
||||
a = Tensor.ones(4).contiguous().realize()
|
||||
times_a = a*3
|
||||
a.assign(Tensor.full((4,), 2.).contiguous())
|
||||
new = a + (times_a+1).contiguous()
|
||||
np.testing.assert_allclose(new.numpy(), 6)
|
||||
|
||||
def test_assign_diamond_alt(self):
|
||||
a = Tensor.ones(4).contiguous().realize()
|
||||
a.assign(Tensor.full((4,), 2.).contiguous())
|
||||
|
||||
@@ -240,6 +240,10 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
|
||||
assert len(realized_children) == 1
|
||||
reduce_for_op[next(iter(realized_children.keys()))] = r
|
||||
|
||||
# preschedule all buffers in realizes
|
||||
prescheduled = {x:_schedule_one(x, realizes, reduce_for_op) for x in realizes if x not in seen and x.realized is None and x.op is not LoadOps.CONST}
|
||||
|
||||
# breadth first ordering
|
||||
graph: DefaultDict[LazyBuffer,List[LazyBuffer]] = defaultdict(list)
|
||||
in_degree: DefaultDict[LazyBuffer,int] = defaultdict(int)
|
||||
queue: Deque[LazyBuffer] = deque()
|
||||
@@ -251,16 +255,17 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
|
||||
in_degree[buf] += 1
|
||||
if in_degree[buf] == 0: queue.append(buf)
|
||||
|
||||
sorted_realizes: List[LazyBuffer] = []
|
||||
schedule: List[ScheduleItem] = []
|
||||
while queue:
|
||||
buf = queue.popleft()
|
||||
if buf in realizes and buf not in seen:
|
||||
sorted_realizes.append(buf)
|
||||
schedule.append(prescheduled[buf])
|
||||
seen.add(buf)
|
||||
for x in graph[buf]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
||||
sched:List[ScheduleItem] = []
|
||||
for x in sorted_realizes: sched.append(_schedule_one(x, realizes, reduce_for_op))
|
||||
return sched
|
||||
# confirm everything was scheduled
|
||||
assert len(prescheduled) == len(schedule), f"prescheduled {len(prescheduled)} but only scheduled {len(schedule)}"
|
||||
return schedule
|
||||
|
||||
|
||||
Reference in New Issue
Block a user