mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
compute ScheduleItem writable bufs [pr] (#7214)
* compute ScheduleItem writable bufs [pr] * don't cache Buffer
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import sys, atexit
|
||||
import sys, atexit, functools
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast
|
||||
@@ -28,11 +28,13 @@ class ScheduleItem:
|
||||
@property
|
||||
def outputs(self) -> Tuple[Buffer, ...]:
|
||||
"""Read/write or write only buffers in the schedule."""
|
||||
return self.bufs[:len(self.ast.src)] if self.ast.op is UOps.SINK else self.bufs[0:1]
|
||||
return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs)
|
||||
@property
|
||||
def inputs(self) -> Tuple[Buffer, ...]:
|
||||
"""Read only buffers in the schedule."""
|
||||
return self.bufs[len(self.ast.src):] if self.ast.op is UOps.SINK else self.bufs[1:]
|
||||
return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs)
|
||||
@functools.cached_property
|
||||
def output_idxs(self) -> Tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is UOps.SINK else (0,)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LBScheduleItem:
|
||||
|
||||
Reference in New Issue
Block a user