mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
good suggestions from mypy lineprecision-report for schedule.py [pr] (#7823)
* good suggestions from mypy lineprecision-report [pr] * ok if metadata doesn't exist * same for store * that's buf_uop
This commit is contained in:
@@ -35,11 +35,7 @@ class ScheduleItem:
|
||||
@functools.cached_property
|
||||
def output_idxs(self) -> Tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
|
||||
|
||||
# **** small wrapper for LazyBuffer -> UOp
|
||||
|
||||
def UPatSrc(*args, **kwargs): return UPat(Ops.VIEW, src=(UPat.var("b"), UPat(*args, **{**kwargs, "name":"to_store"})), name="base")
|
||||
@functools.lru_cache(None)
|
||||
def is_scheduled(u:UOp): return u.op is Ops.VIEW and len(u.src) == 2
|
||||
# **** Schedule context and big graph
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleContext:
|
||||
@@ -50,6 +46,11 @@ class ScheduleContext:
|
||||
allbufs: Dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
|
||||
children: DefaultDict[UOp, Dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
||||
|
||||
class UPatSrc(UPat):
|
||||
def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, src=(UPat.var("b"), UPat(*args, **{**kwargs, "name": "to_store"})), name="base")
|
||||
@functools.lru_cache(None)
|
||||
def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2
|
||||
|
||||
def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache:Dict[LazyBuffer, UOp]) -> UOp:
|
||||
if (r:=cache.get(buf)) is not None: return r
|
||||
if buf is not buf.base:
|
||||
@@ -193,7 +194,7 @@ to_si = PatternMatcher([
|
||||
# ** fusion
|
||||
|
||||
def fuse_src(ctx:ScheduleItemContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
|
||||
if (metadata:=ctx.lazybufs[b].metadata) is not None: ctx.metadata.add(metadata)
|
||||
if (lbuf:=ctx.lazybufs.get(b)) is not None and (metadata:=lbuf.metadata) is not None: ctx.metadata.add(metadata)
|
||||
return to_store
|
||||
|
||||
lazy = PatternMatcher([
|
||||
@@ -206,9 +207,9 @@ multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: c
|
||||
|
||||
def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemContext]:
|
||||
si_ctx = ScheduleItemContext(ctx.lazybufs, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src},
|
||||
metadata={mx for x in pre.src if (mx:=ctx.lazybufs[x.buf_uop].metadata) is not None})
|
||||
metadata={l.metadata for x in pre.src if (l:=ctx.lazybufs.get(x.buf_uop)) is not None and l.metadata is not None})
|
||||
# fuse and fold store -> loads
|
||||
sink = graph_rewrite(pre, lazy+multioutput if len(pre.src)>1 else lazy, si_ctx)
|
||||
sink = graph_rewrite(pre, lazy+multioutput if len(pre.src) > 1 else lazy, si_ctx)
|
||||
# assert cyclic dependency
|
||||
for b,ops in itertools.groupby((x for x in sink.sparents if x.op in {Ops.PRELOAD,Ops.LOAD} and x.buf_uop in ctx.assigns), key=lambda x:x.buf_uop):
|
||||
if not all_same([x.op for x in ops]):
|
||||
@@ -232,7 +233,7 @@ def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemCon
|
||||
PROCESS_REPLAY_CAPTURE: List[Tuple[UOp, UOp]] = []
|
||||
if getenv("RUN_PROCESS_REPLAY"):
|
||||
@atexit.register
|
||||
def save_process_replay():
|
||||
def save_process_replay() -> None:
|
||||
for x,ret in PROCESS_REPLAY_CAPTURE: diskcache_put("schedule_process_replay", str(x.key), (x, {}, ret))
|
||||
|
||||
# **** Schedule grouping
|
||||
@@ -281,7 +282,7 @@ def group_realizes(ctx:ScheduleContext, realizes:Dict[UOp, UOp]) -> List[List[UO
|
||||
double_reduces: List[UOp] = []
|
||||
for r, r_uop in ctx.allbufs.items():
|
||||
if (r_uop:=uval(r_uop)).op is not Ops.REDUCE_AXIS: continue
|
||||
if FUSE_CONV_BW and r_uop.op is Ops.REDUCE_AXIS and uval((x:=r_uop.src[0]).base).op is r_uop.op and x.base is not x: double_reduces.append(r)
|
||||
if FUSE_CONV_BW and uval((x:=r_uop.src[0]).base).op is r_uop.op and x.base is not x: double_reduces.append(r)
|
||||
if r in realizes: continue
|
||||
group: Dict[UOp, None] = {}
|
||||
recursive_group(r, unwrap(r_uop.st), r, ctx.children, ctx.allbufs, realizes, reduce_for_op, group, cache={})
|
||||
@@ -355,7 +356,7 @@ def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) ->
|
||||
|
||||
do_realize = PatternMatcher([
|
||||
# always realize meta ops
|
||||
(UPatSrc((Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta)), realize),
|
||||
(UPatSrc({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}), realize),
|
||||
# don't realize image to image casts
|
||||
(UPatSrc(Ops.CAST, src=(UPat(Ops.LOAD, name="x"),), dtype=dtypes.float).view(name="v"), lambda ctx,x,v,**kwargs: r.src[2].view(v.st)
|
||||
if (r:=ctx.get(b:=x.buf_uop)) is not None and r.op is Ops.STORE and isinstance(b.dtype, ImageDType) and r.src[2].op not in GroupOp.Meta else None),
|
||||
|
||||
@@ -334,7 +334,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
|
||||
return UOp(Ops.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
|
||||
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=(idx, False))
|
||||
def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
|
||||
def r(self, op:Ops, axis:Tuple[int, ...]): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
|
||||
def contiguous(self): return UOp(Ops.CONTIGUOUS, self.dtype, (self,))
|
||||
@property
|
||||
@@ -343,11 +343,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
# *** uop movement ops ***
|
||||
|
||||
@property
|
||||
def base(self): return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 else self
|
||||
def view(self, st:ShapeTracker):
|
||||
def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 else self
|
||||
def view(self, st:ShapeTracker) -> UOp:
|
||||
assert self.op is not Ops.STORE, "VIEW of STORE is invalid, STORE is always base"
|
||||
return self if self.st is None or self.st == st else UOp(Ops.VIEW, self.dtype, (self,), st)
|
||||
def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
|
||||
def reshape(self, arg:Tuple[sint, ...]) -> UOp: return self.view(unwrap(self.st).reshape(arg))
|
||||
|
||||
# *** uop Buffer stuff ***
|
||||
|
||||
|
||||
Reference in New Issue
Block a user