diff --git a/test/backend/test_symbolic_jit.py b/test/backend/test_symbolic_jit.py index 85c11ac5b9..0b8e5a716b 100644 --- a/test/backend/test_symbolic_jit.py +++ b/test/backend/test_symbolic_jit.py @@ -80,7 +80,7 @@ class TestSymbolicJit(unittest.TestCase): symbolic = jf(q, k[:, :vi], v[:, :vi])[:2, :4, :1, :8].numpy() expected = f(q, k[:, :i], v[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - assert_jit_cache_len(jf, 4) + assert_jit_cache_len(jf, 5) def test_cat_dim0(self): def f(a, b): return a.cat(b, dim=0).realize() diff --git a/test/helpers.py b/test/helpers.py index 9ec5bbe008..deda500a1b 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -39,7 +39,7 @@ def assert_jit_cache_len(fxn, expected_len): assert len(fxn.jit_cache) == 1, len(fxn.jit_cache) # until we have a better way of typing the prg in ExecItem assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph') - assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len + assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len, f"expected {expected_len}, got {len(fxn.jit_cache[0].prg.jit_cache)}" def rand_for_dtype(dt:DType, size:int, allow_subnormal=True): if dtypes.is_unsigned(dt): diff --git a/test/null/test_real_world.py b/test/null/test_real_world.py index b63789cf9f..9cfbfdb1de 100644 --- a/test/null/test_real_world.py +++ b/test/null/test_real_world.py @@ -98,7 +98,7 @@ class TestRealWorld(unittest.TestCase): @TinyJit def test(t, v): with Context(JIT=0): return model(t, v).realize() - helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.23, 160, all_jitted=True) + helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.23, 168, all_jitted=True) @slow def test_train_mnist(self): diff --git a/tinygrad/engine/allocations.py b/tinygrad/engine/allocations.py index b425f4bde6..4c73e323e3 100644 --- a/tinygrad/engine/allocations.py +++ b/tinygrad/engine/allocations.py @@ -107,7 +107,8 @@ def append_after(ctx:AllocCtx, x:UOp): def replace_input_buffer(ctx:AllocCtx, b:UOp): ctx.replacements.append(b) - return UOp.param(len(ctx.replacements)-1, b.dtype, b.shape, b._device, b._min_max if b.op is Ops.BIND else None) + return UOp.param(len(ctx.replacements)-1, b.dtype, b.shape, b._device, + b._min_max if b.op is Ops.BIND else None, b.src[0].arg[0] if b.op is Ops.BIND else None) pm_finalize_call = PatternMatcher([ (UPat(Ops.ASSIGN, name="x"), untag_and_append), diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index e24264ffe6..1f19e99eae 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,10 +1,10 @@ import time, inspect from typing import cast from collections import deque -from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, gate_kernel_sink +from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink from tinygrad.uop.spec import type_verify, tensor_spec from tinygrad.device import Buffer, MultiBuffer -from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR, unwrap +from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR from tinygrad.engine.realize import ExecItem from tinygrad.engine.allocations import allocate_global_buffers @@ -63,40 +63,18 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]: return pre_schedule, UOp.sink(*buf_uops_list) from tinygrad.engine.memory import memory_planner -from tinygrad.schedule.rangeify import get_rangeify, resolve_call +from tinygrad.schedule.rangeify import get_rangeify from tinygrad.schedule.multi import multi_pm +from tinygrad.uop.ops import PatternMatcher, UPat -def replace_input_buffer(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]], b:UOp): - if (ret:=ctx[0].get(b, None)) is None: - # replace BUFFER with PARAM for cache key normalization (same as CALL) - ctx[0][b] = ret = UOp.param(ctx[2][0], b.dtype, b.shape, b.device) - ctx[2][0] += 1 - return ret - -def strip_bind(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]], b:UOp): - var, val = b.src[0], b.src[1].arg - assert var.expr not in ctx[1] or ctx[1][var.expr] == val, f"bind mismatch on {var}, {ctx[1][var.expr]} != {val}" - ctx[1][var.expr] = val - return ctx[0].setdefault(b, b.replace(src=(b.src[0],))) - -pm_pre_sched_cache = PatternMatcher([ - # replace BUFFER with PARAM for cache key normalization - (UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer), - # strip value from BIND for cache key normalization, so different values hit same cache - (UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), strip_bind), -]) - -def create_new_buffer(ctx:dict[UOp, UOp], b:UOp): - if (ret:=ctx.get(b, None)) is None: ctx[b] = ret = UOp.new_buffer(b.device, b.arg, b.dtype) +def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp): + if (ret:=ctx[0].get(b, None)) is None: ctx[0][b] = ret = UOp.new_buffer(b.device, b.arg, b.dtype) return ret pm_post_sched_cache = PatternMatcher([ + (UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg]), # create new BUFFERs for LUNIQUE BUFFERs from rangeify (UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer), - # restore PARAM back to original BUFFER - (UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE)), name="b"), lambda ctx,b: ctx.get(b)), - # restore BIND value stripped in pm_pre_sched_cache - (UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR),), name="b"), lambda ctx,b: ctx.get(b)), ]) schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {} @@ -104,20 +82,20 @@ schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {} def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]: # big_sink srcs are all the Tensors st = time.perf_counter() - big_sink, buffer_map = allocate_global_buffers(big_sink) - # HACK: apply the call for now - big_sink = unwrap(resolve_call(big_sink)) - - # replace BUFFERs with PARAMs, CONSTs UNIQUE with LUNIQUE, strip BIND values for cache key, extract var_vals - input_buffers: dict[UOp, UOp] = {} + # get var_vals var_vals: dict[str, int] = {} - big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=(input_buffers, var_vals, [0], [0]), name="rewrite for sched cache") - sched_cache_key = big_sink_cache.key + for i,b in enumerate(big_sink.src[1:]): + if b.op is Ops.BIND: + nm = b.src[0].expr + val = b.src[1].arg + assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}" + var_vals[nm] = val + big_sink_cache = big_sink.src[0] + sched_cache_key = big_sink_cache.key if not SCACHE or (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None: - # verify Tensors match the spec (on big_sink, we only need to do this if cache misses) if SPEC: type_verify(big_sink, tensor_spec) big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm", rewrite_into_calls=True) pre_schedule, buf_uops_sink = create_schedule(get_rangeify(big_sink_cache)) @@ -125,11 +103,8 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li else: # schedule cache hit pre_schedule, buf_uops_sink = sc_ret - del big_sink, big_sink_cache - - # replace all the PARAMs/LUNIQUEs back (single graph_rewrite for everything) - input_buffers_inverse = {v:k for k,v in input_buffers.items()} - buf_uops_sink = graph_rewrite(buf_uops_sink, pm_post_sched_cache, ctx=input_buffers_inverse, name="unrewrite combined") + # it's a call that we late apply + buf_uops_sink = graph_rewrite(buf_uops_sink, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="apply buffers") # add bufs to pre_schedule schedule: list[ExecItem] = [] diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index bdddd32306..8275ca6ce3 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -75,7 +75,7 @@ mop_cleanup = PatternMatcher([ ]) pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ]) -def resolve_call(c:UOp) -> UOp|None: +def resolve_call(c:UOp, allow_param_mismatch=False) -> UOp|None: # don't resolve real kernel calls, sink or program if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None if c.src[0].op is Ops.PROGRAM: return None @@ -84,8 +84,9 @@ def resolve_call(c:UOp) -> UOp|None: params = sorted(params, key=lambda x: x.arg) args = c.src[1:] # TODO: this check belongs in spec, not here - if [x.arg for x in params] != list(range(len(params))): raise RuntimeError(f"params not in order: {[x.arg for x in params]}") - if len(params) != len(args): raise TypeError(f"expected {len(params)} args, got {len(args)}") + if not allow_param_mismatch: + if [x.arg for x in params] != list(range(len(params))): raise RuntimeError(f"params not in order: {[x.arg for x in params]}") + if len(params) != len(args): raise TypeError(f"expected {len(params)} args, got {len(args)}") for i, (p, a) in enumerate(zip(params, args)): if p.shape != a.shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}") if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}") @@ -410,6 +411,10 @@ to_define_global = PatternMatcher([ (UPat(Ops.STORE, name="x"), find_bufs), (UPat(Ops.BUFFER, name="buf"), debuf), (UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE)), name="buf"), debuf), + (UPat(Ops.PARAM, src=(UPat(), UPat(), UPat.cvar('vmin'), UPat.cvar('vmax'), UPat.var("nm")), name="v"), + lambda v, vmin, vmax, nm: UOp.variable(nm.arg, vmin.arg, vmax.arg, v.dtype)), + (UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_VAR, name="v"),)), lambda v: v), + (UPat(Ops.BIND, name="b"), unbind_kernel), (UPat((Ops.MSTACK, Ops.MSELECT, Ops.AFTER), name="after"), handle_after), diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 8f32d43084..ff84d51dc1 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -859,10 +859,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass): # TODO: this should replace placeholder @staticmethod - def param(slot:int, dtype:DType, shape:tuple[sint, ...]|None=None, device=None, vmin_vmax:tuple[PyConst, PyConst]|None=None): + def param(slot:int, dtype:DType, shape:tuple[sint, ...]|None=None, device=None, vmin_vmax:tuple[PyConst, PyConst]|None=None, name="None"): src: tuple[UOp, ...] = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + \ (UOp(Ops.NOOP) if device is None else UOp(Ops.DEVICE, arg=device),) if vmin_vmax is not None: src += (UOp.const(dtype, vmin_vmax[0]), UOp.const(dtype.scalar(), vmin_vmax[1])) + if name is not None: src += (UOp(Ops.NOOP, arg=name),) return UOp(Ops.PARAM, dtype, src, arg=slot) def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp: