mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
allocate generates a call (#14958)
* allocate generates a call * symbolic works too * DEFINE_VAR is param * replace param later * apply buffers * name * upd * this was a bug...
This commit is contained in:
@@ -80,7 +80,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
symbolic = jf(q, k[:, :vi], v[:, :vi])[:2, :4, :1, :8].numpy()
|
||||
expected = f(q, k[:, :i], v[:, :i]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 4)
|
||||
assert_jit_cache_len(jf, 5)
|
||||
|
||||
def test_cat_dim0(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
|
||||
@@ -39,7 +39,7 @@ def assert_jit_cache_len(fxn, expected_len):
|
||||
assert len(fxn.jit_cache) == 1, len(fxn.jit_cache)
|
||||
# until we have a better way of typing the prg in ExecItem
|
||||
assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph')
|
||||
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len
|
||||
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len, f"expected {expected_len}, got {len(fxn.jit_cache[0].prg.jit_cache)}"
|
||||
|
||||
def rand_for_dtype(dt:DType, size:int, allow_subnormal=True):
|
||||
if dtypes.is_unsigned(dt):
|
||||
|
||||
@@ -98,7 +98,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
@TinyJit
|
||||
def test(t, v):
|
||||
with Context(JIT=0): return model(t, v).realize()
|
||||
helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.23, 160, all_jitted=True)
|
||||
helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.23, 168, all_jitted=True)
|
||||
|
||||
@slow
|
||||
def test_train_mnist(self):
|
||||
|
||||
@@ -107,7 +107,8 @@ def append_after(ctx:AllocCtx, x:UOp):
|
||||
|
||||
def replace_input_buffer(ctx:AllocCtx, b:UOp):
|
||||
ctx.replacements.append(b)
|
||||
return UOp.param(len(ctx.replacements)-1, b.dtype, b.shape, b._device, b._min_max if b.op is Ops.BIND else None)
|
||||
return UOp.param(len(ctx.replacements)-1, b.dtype, b.shape, b._device,
|
||||
b._min_max if b.op is Ops.BIND else None, b.src[0].arg[0] if b.op is Ops.BIND else None)
|
||||
|
||||
pm_finalize_call = PatternMatcher([
|
||||
(UPat(Ops.ASSIGN, name="x"), untag_and_append),
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import time, inspect
|
||||
from typing import cast
|
||||
from collections import deque
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, gate_kernel_sink
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR, unwrap
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.engine.allocations import allocate_global_buffers
|
||||
|
||||
@@ -63,40 +63,18 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
||||
return pre_schedule, UOp.sink(*buf_uops_list)
|
||||
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.schedule.rangeify import get_rangeify, resolve_call
|
||||
from tinygrad.schedule.rangeify import get_rangeify
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat
|
||||
|
||||
def replace_input_buffer(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]], b:UOp):
|
||||
if (ret:=ctx[0].get(b, None)) is None:
|
||||
# replace BUFFER with PARAM for cache key normalization (same as CALL)
|
||||
ctx[0][b] = ret = UOp.param(ctx[2][0], b.dtype, b.shape, b.device)
|
||||
ctx[2][0] += 1
|
||||
return ret
|
||||
|
||||
def strip_bind(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]], b:UOp):
|
||||
var, val = b.src[0], b.src[1].arg
|
||||
assert var.expr not in ctx[1] or ctx[1][var.expr] == val, f"bind mismatch on {var}, {ctx[1][var.expr]} != {val}"
|
||||
ctx[1][var.expr] = val
|
||||
return ctx[0].setdefault(b, b.replace(src=(b.src[0],)))
|
||||
|
||||
pm_pre_sched_cache = PatternMatcher([
|
||||
# replace BUFFER with PARAM for cache key normalization
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
|
||||
# strip value from BIND for cache key normalization, so different values hit same cache
|
||||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), strip_bind),
|
||||
])
|
||||
|
||||
def create_new_buffer(ctx:dict[UOp, UOp], b:UOp):
|
||||
if (ret:=ctx.get(b, None)) is None: ctx[b] = ret = UOp.new_buffer(b.device, b.arg, b.dtype)
|
||||
def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp):
|
||||
if (ret:=ctx[0].get(b, None)) is None: ctx[0][b] = ret = UOp.new_buffer(b.device, b.arg, b.dtype)
|
||||
return ret
|
||||
|
||||
pm_post_sched_cache = PatternMatcher([
|
||||
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg]),
|
||||
# create new BUFFERs for LUNIQUE BUFFERs from rangeify
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer),
|
||||
# restore PARAM back to original BUFFER
|
||||
(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE)), name="b"), lambda ctx,b: ctx.get(b)),
|
||||
# restore BIND value stripped in pm_pre_sched_cache
|
||||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR),), name="b"), lambda ctx,b: ctx.get(b)),
|
||||
])
|
||||
|
||||
schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {}
|
||||
@@ -104,20 +82,20 @@ schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {}
|
||||
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]:
|
||||
# big_sink srcs are all the Tensors
|
||||
st = time.perf_counter()
|
||||
|
||||
big_sink, buffer_map = allocate_global_buffers(big_sink)
|
||||
|
||||
# HACK: apply the call for now
|
||||
big_sink = unwrap(resolve_call(big_sink))
|
||||
|
||||
# replace BUFFERs with PARAMs, CONSTs UNIQUE with LUNIQUE, strip BIND values for cache key, extract var_vals
|
||||
input_buffers: dict[UOp, UOp] = {}
|
||||
# get var_vals
|
||||
var_vals: dict[str, int] = {}
|
||||
big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=(input_buffers, var_vals, [0], [0]), name="rewrite for sched cache")
|
||||
sched_cache_key = big_sink_cache.key
|
||||
for i,b in enumerate(big_sink.src[1:]):
|
||||
if b.op is Ops.BIND:
|
||||
nm = b.src[0].expr
|
||||
val = b.src[1].arg
|
||||
assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}"
|
||||
var_vals[nm] = val
|
||||
|
||||
big_sink_cache = big_sink.src[0]
|
||||
sched_cache_key = big_sink_cache.key
|
||||
if not SCACHE or (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
|
||||
# verify Tensors match the spec (on big_sink, we only need to do this if cache misses)
|
||||
if SPEC: type_verify(big_sink, tensor_spec)
|
||||
big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm", rewrite_into_calls=True)
|
||||
pre_schedule, buf_uops_sink = create_schedule(get_rangeify(big_sink_cache))
|
||||
@@ -125,11 +103,8 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
else:
|
||||
# schedule cache hit
|
||||
pre_schedule, buf_uops_sink = sc_ret
|
||||
del big_sink, big_sink_cache
|
||||
|
||||
# replace all the PARAMs/LUNIQUEs back (single graph_rewrite for everything)
|
||||
input_buffers_inverse = {v:k for k,v in input_buffers.items()}
|
||||
buf_uops_sink = graph_rewrite(buf_uops_sink, pm_post_sched_cache, ctx=input_buffers_inverse, name="unrewrite combined")
|
||||
# it's a call that we late apply
|
||||
buf_uops_sink = graph_rewrite(buf_uops_sink, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="apply buffers")
|
||||
|
||||
# add bufs to pre_schedule
|
||||
schedule: list[ExecItem] = []
|
||||
|
||||
@@ -75,7 +75,7 @@ mop_cleanup = PatternMatcher([
|
||||
])
|
||||
|
||||
pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ])
|
||||
def resolve_call(c:UOp) -> UOp|None:
|
||||
def resolve_call(c:UOp, allow_param_mismatch=False) -> UOp|None:
|
||||
# don't resolve real kernel calls, sink or program
|
||||
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None
|
||||
if c.src[0].op is Ops.PROGRAM: return None
|
||||
@@ -84,8 +84,9 @@ def resolve_call(c:UOp) -> UOp|None:
|
||||
params = sorted(params, key=lambda x: x.arg)
|
||||
args = c.src[1:]
|
||||
# TODO: this check belongs in spec, not here
|
||||
if [x.arg for x in params] != list(range(len(params))): raise RuntimeError(f"params not in order: {[x.arg for x in params]}")
|
||||
if len(params) != len(args): raise TypeError(f"expected {len(params)} args, got {len(args)}")
|
||||
if not allow_param_mismatch:
|
||||
if [x.arg for x in params] != list(range(len(params))): raise RuntimeError(f"params not in order: {[x.arg for x in params]}")
|
||||
if len(params) != len(args): raise TypeError(f"expected {len(params)} args, got {len(args)}")
|
||||
for i, (p, a) in enumerate(zip(params, args)):
|
||||
if p.shape != a.shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}")
|
||||
if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}")
|
||||
@@ -410,6 +411,10 @@ to_define_global = PatternMatcher([
|
||||
(UPat(Ops.STORE, name="x"), find_bufs),
|
||||
(UPat(Ops.BUFFER, name="buf"), debuf),
|
||||
(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE)), name="buf"), debuf),
|
||||
(UPat(Ops.PARAM, src=(UPat(), UPat(), UPat.cvar('vmin'), UPat.cvar('vmax'), UPat.var("nm")), name="v"),
|
||||
lambda v, vmin, vmax, nm: UOp.variable(nm.arg, vmin.arg, vmax.arg, v.dtype)),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_VAR, name="v"),)), lambda v: v),
|
||||
|
||||
(UPat(Ops.BIND, name="b"), unbind_kernel),
|
||||
(UPat((Ops.MSTACK, Ops.MSELECT, Ops.AFTER), name="after"), handle_after),
|
||||
|
||||
|
||||
@@ -859,10 +859,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
|
||||
# TODO: this should replace placeholder
|
||||
@staticmethod
|
||||
def param(slot:int, dtype:DType, shape:tuple[sint, ...]|None=None, device=None, vmin_vmax:tuple[PyConst, PyConst]|None=None):
|
||||
def param(slot:int, dtype:DType, shape:tuple[sint, ...]|None=None, device=None, vmin_vmax:tuple[PyConst, PyConst]|None=None, name="None"):
|
||||
src: tuple[UOp, ...] = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + \
|
||||
(UOp(Ops.NOOP) if device is None else UOp(Ops.DEVICE, arg=device),)
|
||||
if vmin_vmax is not None: src += (UOp.const(dtype, vmin_vmax[0]), UOp.const(dtype.scalar(), vmin_vmax[1]))
|
||||
if name is not None: src += (UOp(Ops.NOOP, arg=name),)
|
||||
return UOp(Ops.PARAM, dtype, src, arg=slot)
|
||||
|
||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp:
|
||||
|
||||
Reference in New Issue
Block a user