allocate generates a call (#14958)

* allocate generates a call

* symbolic works too

* DEFINE_VAR is param

* replace param later

* apply buffers

* name

* upd

* this was a bug...
This commit is contained in:
George Hotz
2026-02-23 15:59:20 +08:00
committed by GitHub
parent dd8302a6d0
commit b824490e3f
7 changed files with 33 additions and 51 deletions

View File

@@ -80,7 +80,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(q, k[:, :vi], v[:, :vi])[:2, :4, :1, :8].numpy()
expected = f(q, k[:, :i], v[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 4)
assert_jit_cache_len(jf, 5)
def test_cat_dim0(self):
def f(a, b): return a.cat(b, dim=0).realize()

View File

@@ -39,7 +39,7 @@ def assert_jit_cache_len(fxn, expected_len):
assert len(fxn.jit_cache) == 1, len(fxn.jit_cache)
# until we have a better way of typing the prg in ExecItem
assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph')
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len, f"expected {expected_len}, got {len(fxn.jit_cache[0].prg.jit_cache)}"
def rand_for_dtype(dt:DType, size:int, allow_subnormal=True):
if dtypes.is_unsigned(dt):

View File

@@ -98,7 +98,7 @@ class TestRealWorld(unittest.TestCase):
@TinyJit
def test(t, v):
with Context(JIT=0): return model(t, v).realize()
helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.23, 160, all_jitted=True)
helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.23, 168, all_jitted=True)
@slow
def test_train_mnist(self):

View File

@@ -107,7 +107,8 @@ def append_after(ctx:AllocCtx, x:UOp):
def replace_input_buffer(ctx:AllocCtx, b:UOp):
ctx.replacements.append(b)
return UOp.param(len(ctx.replacements)-1, b.dtype, b.shape, b._device, b._min_max if b.op is Ops.BIND else None)
return UOp.param(len(ctx.replacements)-1, b.dtype, b.shape, b._device,
b._min_max if b.op is Ops.BIND else None, b.src[0].arg[0] if b.op is Ops.BIND else None)
pm_finalize_call = PatternMatcher([
(UPat(Ops.ASSIGN, name="x"), untag_and_append),

View File

@@ -1,10 +1,10 @@
import time, inspect
from typing import cast
from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, gate_kernel_sink
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR, unwrap
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR
from tinygrad.engine.realize import ExecItem
from tinygrad.engine.allocations import allocate_global_buffers
@@ -63,40 +63,18 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
return pre_schedule, UOp.sink(*buf_uops_list)
from tinygrad.engine.memory import memory_planner
from tinygrad.schedule.rangeify import get_rangeify, resolve_call
from tinygrad.schedule.rangeify import get_rangeify
from tinygrad.schedule.multi import multi_pm
from tinygrad.uop.ops import PatternMatcher, UPat
def replace_input_buffer(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]], b:UOp):
if (ret:=ctx[0].get(b, None)) is None:
# replace BUFFER with PARAM for cache key normalization (same as CALL)
ctx[0][b] = ret = UOp.param(ctx[2][0], b.dtype, b.shape, b.device)
ctx[2][0] += 1
return ret
def strip_bind(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]], b:UOp):
var, val = b.src[0], b.src[1].arg
assert var.expr not in ctx[1] or ctx[1][var.expr] == val, f"bind mismatch on {var}, {ctx[1][var.expr]} != {val}"
ctx[1][var.expr] = val
return ctx[0].setdefault(b, b.replace(src=(b.src[0],)))
pm_pre_sched_cache = PatternMatcher([
# replace BUFFER with PARAM for cache key normalization
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
# strip value from BIND for cache key normalization, so different values hit same cache
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), strip_bind),
])
def create_new_buffer(ctx:dict[UOp, UOp], b:UOp):
if (ret:=ctx.get(b, None)) is None: ctx[b] = ret = UOp.new_buffer(b.device, b.arg, b.dtype)
def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp):
if (ret:=ctx[0].get(b, None)) is None: ctx[0][b] = ret = UOp.new_buffer(b.device, b.arg, b.dtype)
return ret
pm_post_sched_cache = PatternMatcher([
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg]),
# create new BUFFERs for LUNIQUE BUFFERs from rangeify
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer),
# restore PARAM back to original BUFFER
(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE)), name="b"), lambda ctx,b: ctx.get(b)),
# restore BIND value stripped in pm_pre_sched_cache
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR),), name="b"), lambda ctx,b: ctx.get(b)),
])
schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {}
@@ -104,20 +82,20 @@ schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {}
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]:
# big_sink srcs are all the Tensors
st = time.perf_counter()
big_sink, buffer_map = allocate_global_buffers(big_sink)
# HACK: apply the call for now
big_sink = unwrap(resolve_call(big_sink))
# replace BUFFERs with PARAMs, CONSTs UNIQUE with LUNIQUE, strip BIND values for cache key, extract var_vals
input_buffers: dict[UOp, UOp] = {}
# get var_vals
var_vals: dict[str, int] = {}
big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=(input_buffers, var_vals, [0], [0]), name="rewrite for sched cache")
sched_cache_key = big_sink_cache.key
for i,b in enumerate(big_sink.src[1:]):
if b.op is Ops.BIND:
nm = b.src[0].expr
val = b.src[1].arg
assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}"
var_vals[nm] = val
big_sink_cache = big_sink.src[0]
sched_cache_key = big_sink_cache.key
if not SCACHE or (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
# verify Tensors match the spec (on big_sink, we only need to do this if cache misses)
if SPEC: type_verify(big_sink, tensor_spec)
big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm", rewrite_into_calls=True)
pre_schedule, buf_uops_sink = create_schedule(get_rangeify(big_sink_cache))
@@ -125,11 +103,8 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
else:
# schedule cache hit
pre_schedule, buf_uops_sink = sc_ret
del big_sink, big_sink_cache
# replace all the PARAMs/LUNIQUEs back (single graph_rewrite for everything)
input_buffers_inverse = {v:k for k,v in input_buffers.items()}
buf_uops_sink = graph_rewrite(buf_uops_sink, pm_post_sched_cache, ctx=input_buffers_inverse, name="unrewrite combined")
# it's a call that we late apply
buf_uops_sink = graph_rewrite(buf_uops_sink, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="apply buffers")
# add bufs to pre_schedule
schedule: list[ExecItem] = []

View File

@@ -75,7 +75,7 @@ mop_cleanup = PatternMatcher([
])
pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ])
def resolve_call(c:UOp) -> UOp|None:
def resolve_call(c:UOp, allow_param_mismatch=False) -> UOp|None:
# don't resolve real kernel calls, sink or program
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None
if c.src[0].op is Ops.PROGRAM: return None
@@ -84,8 +84,9 @@ def resolve_call(c:UOp) -> UOp|None:
params = sorted(params, key=lambda x: x.arg)
args = c.src[1:]
# TODO: this check belongs in spec, not here
if [x.arg for x in params] != list(range(len(params))): raise RuntimeError(f"params not in order: {[x.arg for x in params]}")
if len(params) != len(args): raise TypeError(f"expected {len(params)} args, got {len(args)}")
if not allow_param_mismatch:
if [x.arg for x in params] != list(range(len(params))): raise RuntimeError(f"params not in order: {[x.arg for x in params]}")
if len(params) != len(args): raise TypeError(f"expected {len(params)} args, got {len(args)}")
for i, (p, a) in enumerate(zip(params, args)):
if p.shape != a.shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}")
if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}")
@@ -410,6 +411,10 @@ to_define_global = PatternMatcher([
(UPat(Ops.STORE, name="x"), find_bufs),
(UPat(Ops.BUFFER, name="buf"), debuf),
(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE)), name="buf"), debuf),
(UPat(Ops.PARAM, src=(UPat(), UPat(), UPat.cvar('vmin'), UPat.cvar('vmax'), UPat.var("nm")), name="v"),
lambda v, vmin, vmax, nm: UOp.variable(nm.arg, vmin.arg, vmax.arg, v.dtype)),
(UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_VAR, name="v"),)), lambda v: v),
(UPat(Ops.BIND, name="b"), unbind_kernel),
(UPat((Ops.MSTACK, Ops.MSELECT, Ops.AFTER), name="after"), handle_after),

View File

@@ -859,10 +859,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# TODO: this should replace placeholder
@staticmethod
def param(slot:int, dtype:DType, shape:tuple[sint, ...]|None=None, device=None, vmin_vmax:tuple[PyConst, PyConst]|None=None):
def param(slot:int, dtype:DType, shape:tuple[sint, ...]|None=None, device=None, vmin_vmax:tuple[PyConst, PyConst]|None=None, name="None"):
src: tuple[UOp, ...] = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + \
(UOp(Ops.NOOP) if device is None else UOp(Ops.DEVICE, arg=device),)
if vmin_vmax is not None: src += (UOp.const(dtype, vmin_vmax[0]), UOp.const(dtype.scalar(), vmin_vmax[1]))
if name is not None: src += (UOp(Ops.NOOP, arg=name),)
return UOp(Ops.PARAM, dtype, src, arg=slot)
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp: