From 49160951245ef1344856d3d5b696ebfbf0d81fa2 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 22 Oct 2024 19:02:29 +0300 Subject: [PATCH] compute ScheduleItem writable bufs [pr] (#7214) * compute ScheduleItem writable bufs [pr] * don't cache Buffer --- tinygrad/engine/schedule.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 31306c0a70..22446f6fd6 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,4 +1,4 @@ -import sys, atexit +import sys, atexit, functools from collections import defaultdict, deque from dataclasses import dataclass from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast @@ -28,11 +28,13 @@ class ScheduleItem: @property def outputs(self) -> Tuple[Buffer, ...]: """Read/write or write only buffers in the schedule.""" - return self.bufs[:len(self.ast.src)] if self.ast.op is UOps.SINK else self.bufs[0:1] + return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs) @property def inputs(self) -> Tuple[Buffer, ...]: """Read only buffers in the schedule.""" - return self.bufs[len(self.ast.src):] if self.ast.op is UOps.SINK else self.bufs[1:] + return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs) + @functools.cached_property + def output_idxs(self) -> Tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is UOps.SINK else (0,) @dataclass(frozen=True) class LBScheduleItem: