mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
prereq for scheduler contiguous_child [pr] (#8163)
* the whole context is fine here [pr] * fix that
This commit is contained in:
@@ -16,7 +16,7 @@ from tinygrad.shape.view import View
|
||||
from tinygrad.ops import UOp, Ops, graph_rewrite, track_rewrites, view_supported_devices
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule, view_right, view_left, do_realize
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleContext, ScheduleItem, create_schedule, view_right, view_left, do_realize
|
||||
from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
@@ -1938,7 +1938,7 @@ class TestView(unittest.TestCase):
|
||||
np.testing.assert_allclose(b.numpy(), np.pad(a.numpy(), ((0, 5), (0, 0)))[5:])
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def big_graph_rewrite(big_graph:UOp, realizes={}) -> UOp: return graph_rewrite(big_graph, do_realize, realizes)
|
||||
def big_graph_rewrite(big_graph:UOp, realizes={}) -> UOp: return graph_rewrite(big_graph, do_realize, ScheduleContext(realizes=realizes))
|
||||
class TestBigGraph(unittest.TestCase):
|
||||
def test_sink_childless_const(self):
|
||||
x = UOp.const(dtypes.int, 0)
|
||||
|
||||
@@ -358,23 +358,23 @@ ops_folding = PatternMatcher([
|
||||
|
||||
# ** this decides which ops get realized
|
||||
|
||||
def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, **kwargs) -> None:
|
||||
if to_store.op not in {Ops.CONST, Ops.BIND}: ctx.update([(b, to_store)])
|
||||
def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None:
|
||||
if to_store.op not in {Ops.CONST, Ops.BIND}: ctx.realizes.update([(b, to_store)])
|
||||
|
||||
def realize_view(ctx:Dict[UOp, UOp], view:UOp, src:UOp, b:UOp, **kwargs) -> None:
|
||||
def realize_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None:
|
||||
if src.st is None: return None
|
||||
st = unwrap(view.st)
|
||||
# fold simple pads
|
||||
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
|
||||
return None if can_pad(src, ctx, set()) else realize(ctx, b, src)
|
||||
return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src)
|
||||
# early realize before expand
|
||||
if resolve(prod(src.shape) < prod(st.shape)): return realize(ctx, b, src)
|
||||
# otherwise safety check pads
|
||||
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx, set())) else realize(ctx, b, src)
|
||||
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src)
|
||||
|
||||
def fold_img_cast(ctx:Dict[UOp, UOp], xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kwargs) -> Optional[UOp]:
|
||||
if not isinstance(xb.dtype, ImageDType) or b not in ctx or xb not in ctx or uval(to_cast).op in GroupOp.Meta: return None
|
||||
del ctx[b]
|
||||
def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kwargs) -> Optional[UOp]:
|
||||
if not isinstance(xb.dtype, ImageDType) or b not in ctx.realizes or xb not in ctx.realizes or uval(to_cast).op in GroupOp.Meta: return None
|
||||
del ctx.realizes[b]
|
||||
return to_cast.view(unwrap(view.st))
|
||||
|
||||
def init_big_graph(ctx:ScheduleContext, sink:UOp) -> Optional[UOp]:
|
||||
@@ -436,7 +436,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
cache: Dict[LazyBuffer, UOp] = {}
|
||||
buffers: Dict[UOp, Buffer] = {}
|
||||
for u in (big_graph:=UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs))).src: ctx.realizes[u.buf_uop] = u
|
||||
big_graph = graph_rewrite(big_graph, ops_folding+do_realize, ctx.realizes)
|
||||
big_graph = graph_rewrite(big_graph, ops_folding+do_realize, ctx)
|
||||
# create the scheduler context
|
||||
graph_rewrite(big_graph, create_ctx, ctx)
|
||||
# group realizes into kernels
|
||||
|
||||
Reference in New Issue
Block a user