diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 31306c0a70..22446f6fd6 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,4 +1,4 @@ -import sys, atexit +import sys, atexit, functools from collections import defaultdict, deque from dataclasses import dataclass from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast @@ -28,11 +28,13 @@ class ScheduleItem: @property def outputs(self) -> Tuple[Buffer, ...]: """Read/write or write only buffers in the schedule.""" - return self.bufs[:len(self.ast.src)] if self.ast.op is UOps.SINK else self.bufs[0:1] + return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs) @property def inputs(self) -> Tuple[Buffer, ...]: """Read only buffers in the schedule.""" - return self.bufs[len(self.ast.src):] if self.ast.op is UOps.SINK else self.bufs[1:] + return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs) + @functools.cached_property + def output_idxs(self) -> Tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is UOps.SINK else (0,) @dataclass(frozen=True) class LBScheduleItem: