diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index a963fe009f..59b66959e0 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -35,11 +35,7 @@ class ScheduleItem: @functools.cached_property def output_idxs(self) -> Tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,) -# **** small wrapper for LazyBuffer -> UOp - -def UPatSrc(*args, **kwargs): return UPat(Ops.VIEW, src=(UPat.var("b"), UPat(*args, **{**kwargs, "name":"to_store"})), name="base") -@functools.lru_cache(None) -def is_scheduled(u:UOp): return u.op is Ops.VIEW and len(u.src) == 2 +# **** Schedule context and big graph @dataclass(frozen=True) class ScheduleContext: @@ -50,6 +46,11 @@ class ScheduleContext: allbufs: Dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op children: DefaultDict[UOp, Dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict)) +class UPatSrc(UPat): + def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, src=(UPat.var("b"), UPat(*args, **{**kwargs, "name": "to_store"})), name="base") +@functools.lru_cache(None) +def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 + def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache:Dict[LazyBuffer, UOp]) -> UOp: if (r:=cache.get(buf)) is not None: return r if buf is not buf.base: @@ -193,7 +194,7 @@ to_si = PatternMatcher([ # ** fusion def fuse_src(ctx:ScheduleItemContext, b:UOp, to_store:UOp, base:UOp) -> UOp: - if (metadata:=ctx.lazybufs[b].metadata) is not None: ctx.metadata.add(metadata) + if (lbuf:=ctx.lazybufs.get(b)) is not None and (metadata:=lbuf.metadata) is not None: ctx.metadata.add(metadata) return to_store lazy = PatternMatcher([ @@ -206,9 +207,9 @@ multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: c def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemContext]: si_ctx = ScheduleItemContext(ctx.lazybufs, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src}, - metadata={mx for x in pre.src if (mx:=ctx.lazybufs[x.buf_uop].metadata) is not None}) + metadata={l.metadata for x in pre.src if (l:=ctx.lazybufs.get(x.buf_uop)) is not None and l.metadata is not None}) # fuse and fold store -> loads - sink = graph_rewrite(pre, lazy+multioutput if len(pre.src)>1 else lazy, si_ctx) + sink = graph_rewrite(pre, lazy+multioutput if len(pre.src) > 1 else lazy, si_ctx) # assert cyclic dependency for b,ops in itertools.groupby((x for x in sink.sparents if x.op in {Ops.PRELOAD,Ops.LOAD} and x.buf_uop in ctx.assigns), key=lambda x:x.buf_uop): if not all_same([x.op for x in ops]): @@ -232,7 +233,7 @@ def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemCon PROCESS_REPLAY_CAPTURE: List[Tuple[UOp, UOp]] = [] if getenv("RUN_PROCESS_REPLAY"): @atexit.register - def save_process_replay(): + def save_process_replay() -> None: for x,ret in PROCESS_REPLAY_CAPTURE: diskcache_put("schedule_process_replay", str(x.key), (x, {}, ret)) # **** Schedule grouping @@ -281,7 +282,7 @@ def group_realizes(ctx:ScheduleContext, realizes:Dict[UOp, UOp]) -> List[List[UO double_reduces: List[UOp] = [] for r, r_uop in ctx.allbufs.items(): if (r_uop:=uval(r_uop)).op is not Ops.REDUCE_AXIS: continue - if FUSE_CONV_BW and r_uop.op is Ops.REDUCE_AXIS and uval((x:=r_uop.src[0]).base).op is r_uop.op and x.base is not x: double_reduces.append(r) + if FUSE_CONV_BW and uval((x:=r_uop.src[0]).base).op is r_uop.op and x.base is not x: double_reduces.append(r) if r in realizes: continue group: Dict[UOp, None] = {} recursive_group(r, unwrap(r_uop.st), r, ctx.children, ctx.allbufs, realizes, reduce_for_op, group, cache={}) @@ -355,7 +356,7 @@ def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> do_realize = PatternMatcher([ # always realize meta ops - (UPatSrc((Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta)), realize), + (UPatSrc({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}), realize), # don't realize image to image casts (UPatSrc(Ops.CAST, src=(UPat(Ops.LOAD, name="x"),), dtype=dtypes.float).view(name="v"), lambda ctx,x,v,**kwargs: r.src[2].view(v.st) if (r:=ctx.get(b:=x.buf_uop)) is not None and r.op is Ops.STORE and isinstance(b.dtype, ImageDType) and r.src[2].op not in GroupOp.Meta else None), diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 356bc3e1af..166264c358 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -334,7 +334,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start, UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=(idx, False)) - def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis)) + def r(self, op:Ops, axis:Tuple[int, ...]): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis)) def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x)) def contiguous(self): return UOp(Ops.CONTIGUOUS, self.dtype, (self,)) @property @@ -343,11 +343,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # *** uop movement ops *** @property - def base(self): return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 else self - def view(self, st:ShapeTracker): + def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 else self + def view(self, st:ShapeTracker) -> UOp: assert self.op is not Ops.STORE, "VIEW of STORE is invalid, STORE is always base" return self if self.st is None or self.st == st else UOp(Ops.VIEW, self.dtype, (self,), st) - def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg)) + def reshape(self, arg:Tuple[sint, ...]) -> UOp: return self.view(unwrap(self.st).reshape(arg)) # *** uop Buffer stuff ***