compute ScheduleItem writable bufs [pr] (#7214)

* compute ScheduleItem writable bufs [pr]

* don't cache Buffer
This commit is contained in:
qazal
2024-10-22 19:02:29 +03:00
committed by GitHub
parent 24ed2ed6c8
commit 4916095124

View File

@@ -1,4 +1,4 @@
import sys, atexit
import sys, atexit, functools
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast
@@ -28,11 +28,13 @@ class ScheduleItem:
@property
def outputs(self) -> Tuple[Buffer, ...]:
"""Read/write or write only buffers in the schedule."""
return self.bufs[:len(self.ast.src)] if self.ast.op is UOps.SINK else self.bufs[0:1]
return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs)
@property
def inputs(self) -> Tuple[Buffer, ...]:
"""Read only buffers in the schedule."""
return self.bufs[len(self.ast.src):] if self.ast.op is UOps.SINK else self.bufs[1:]
return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs)
@functools.cached_property
def output_idxs(self) -> Tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is UOps.SINK else (0,)
@dataclass(frozen=True)
class LBScheduleItem: