prereq for scheduler contiguous_child [pr] (#8163)

* the whole context is fine here [pr]

* fix that
This commit is contained in:
qazal
2024-12-11 20:02:22 +02:00
committed by GitHub
parent 3a8e8ac6c2
commit 047a6dabc3
2 changed files with 11 additions and 11 deletions

View File

@@ -16,7 +16,7 @@ from tinygrad.shape.view import View
from tinygrad.ops import UOp, Ops, graph_rewrite, track_rewrites, view_supported_devices
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule, view_right, view_left, do_realize
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleContext, ScheduleItem, create_schedule, view_right, view_left, do_realize
from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule
from tinygrad.engine.lazy import LazyBuffer
from extra.models.llama import precompute_freqs_cis
@@ -1938,7 +1938,7 @@ class TestView(unittest.TestCase):
np.testing.assert_allclose(b.numpy(), np.pad(a.numpy(), ((0, 5), (0, 0)))[5:])
@track_rewrites(named=True)
def big_graph_rewrite(big_graph:UOp, realizes={}) -> UOp: return graph_rewrite(big_graph, do_realize, realizes)
def big_graph_rewrite(big_graph:UOp, realizes={}) -> UOp: return graph_rewrite(big_graph, do_realize, ScheduleContext(realizes=realizes))
class TestBigGraph(unittest.TestCase):
def test_sink_childless_const(self):
x = UOp.const(dtypes.int, 0)

View File

@@ -358,23 +358,23 @@ ops_folding = PatternMatcher([
# ** this decides which ops get realized
def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, **kwargs) -> None:
if to_store.op not in {Ops.CONST, Ops.BIND}: ctx.update([(b, to_store)])
def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None:
if to_store.op not in {Ops.CONST, Ops.BIND}: ctx.realizes.update([(b, to_store)])
def realize_view(ctx:Dict[UOp, UOp], view:UOp, src:UOp, b:UOp, **kwargs) -> None:
def realize_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None:
if src.st is None: return None
st = unwrap(view.st)
# fold simple pads
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
return None if can_pad(src, ctx, set()) else realize(ctx, b, src)
return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src)
# early realize before expand
if resolve(prod(src.shape) < prod(st.shape)): return realize(ctx, b, src)
# otherwise safety check pads
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx, set())) else realize(ctx, b, src)
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src)
def fold_img_cast(ctx:Dict[UOp, UOp], xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kwargs) -> Optional[UOp]:
if not isinstance(xb.dtype, ImageDType) or b not in ctx or xb not in ctx or uval(to_cast).op in GroupOp.Meta: return None
del ctx[b]
def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kwargs) -> Optional[UOp]:
if not isinstance(xb.dtype, ImageDType) or b not in ctx.realizes or xb not in ctx.realizes or uval(to_cast).op in GroupOp.Meta: return None
del ctx.realizes[b]
return to_cast.view(unwrap(view.st))
def init_big_graph(ctx:ScheduleContext, sink:UOp) -> Optional[UOp]:
@@ -436,7 +436,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
cache: Dict[LazyBuffer, UOp] = {}
buffers: Dict[UOp, Buffer] = {}
for u in (big_graph:=UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs))).src: ctx.realizes[u.buf_uop] = u
big_graph = graph_rewrite(big_graph, ops_folding+do_realize, ctx.realizes)
big_graph = graph_rewrite(big_graph, ops_folding+do_realize, ctx)
# create the scheduler context
graph_rewrite(big_graph, create_ctx, ctx)
# group realizes into kernels