mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Tensor.callify will be the JIT (#14983)
* close * simple callify, support linear in the scheduler * all tests pass * everyone is happy * dumb test * Remove unnecessary blank line in rangeify.py
This commit is contained in:
111
test/unit/test_callify.py
Normal file
111
test/unit/test_callify.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, dtypes
|
||||
|
||||
class TestCallify(unittest.TestCase):
|
||||
def test_basic(self):
|
||||
a = Tensor([1.,2,3])
|
||||
b = Tensor([4.,5,6])
|
||||
out = a + b
|
||||
out.callify()
|
||||
self.assertListEqual(out.tolist(), [5.0, 7.0, 9.0])
|
||||
|
||||
def test_const(self):
|
||||
out = Tensor(2.0) + Tensor(3.0)
|
||||
out.callify()
|
||||
self.assertEqual(out.item(), 5.0)
|
||||
|
||||
def test_sum(self):
|
||||
out = Tensor.ones(16).contiguous().sum()
|
||||
out.callify()
|
||||
self.assertEqual(out.item(), 16.0)
|
||||
|
||||
def test_multi_output(self):
|
||||
a = Tensor([1.,2,3])
|
||||
b = Tensor([4.,5,6])
|
||||
c = a + b
|
||||
d = a * b
|
||||
c.callify(d)
|
||||
self.assertListEqual(c.tolist(), [5.0, 7.0, 9.0])
|
||||
self.assertListEqual(d.tolist(), [4.0, 10.0, 18.0])
|
||||
|
||||
def test_two_callify_independent(self):
|
||||
a = Tensor([1.,2,3])
|
||||
b = Tensor([4.,5,6])
|
||||
c = a + b
|
||||
c.callify()
|
||||
|
||||
d = Tensor([10.,20,30])
|
||||
e = Tensor([1.,1,1])
|
||||
f = d - e
|
||||
f.callify()
|
||||
|
||||
self.assertListEqual(c.tolist(), [5.0, 7.0, 9.0])
|
||||
self.assertListEqual(f.tolist(), [9.0, 19.0, 29.0])
|
||||
|
||||
def test_two_callify_shared_input(self):
|
||||
a = Tensor([1.,2,3]).contiguous().realize()
|
||||
b = a + 1
|
||||
b.callify()
|
||||
c = a * 2
|
||||
c.callify()
|
||||
self.assertListEqual(b.tolist(), [2.0, 3.0, 4.0])
|
||||
self.assertListEqual(c.tolist(), [2.0, 4.0, 6.0])
|
||||
|
||||
def test_chained_callify(self):
|
||||
a = Tensor([1.,2,3])
|
||||
b = a + 1
|
||||
b.callify()
|
||||
b.realize()
|
||||
c = b + 1
|
||||
c.callify()
|
||||
self.assertListEqual(c.tolist(), [3.0, 4.0, 5.0])
|
||||
|
||||
def test_gemm(self):
|
||||
a = Tensor.ones(8, 8).contiguous()
|
||||
b = Tensor.eye(8).contiguous()
|
||||
out = a @ b
|
||||
out.callify()
|
||||
lst = out.tolist()
|
||||
for y in range(8):
|
||||
for x in range(8):
|
||||
self.assertEqual(lst[y][x], 1.0)
|
||||
|
||||
def test_int_dtype(self):
|
||||
a = Tensor([1,2,3], dtype=dtypes.int)
|
||||
b = Tensor([4,5,6], dtype=dtypes.int)
|
||||
out = a + b
|
||||
out.callify()
|
||||
self.assertListEqual(out.tolist(), [5, 7, 9])
|
||||
|
||||
def test_reduce(self):
|
||||
out = Tensor([1.,2,3,4]).sum()
|
||||
out.callify()
|
||||
self.assertEqual(out.item(), 10.0)
|
||||
|
||||
def test_multiple_ops(self):
|
||||
a = Tensor([1.,2,3])
|
||||
b = Tensor([4.,5,6])
|
||||
out = (a + b) * (a - b)
|
||||
out.callify()
|
||||
self.assertListEqual(out.tolist(), [-15.0, -21.0, -27.0])
|
||||
|
||||
def test_double_callify(self):
|
||||
a = Tensor([1.,2,3])
|
||||
b = Tensor([4.,5,6])
|
||||
out = a + b
|
||||
out.callify()
|
||||
out.callify()
|
||||
self.assertListEqual(out.tolist(), [5.0, 7.0, 9.0])
|
||||
|
||||
def test_double_callify_multi_output(self):
|
||||
a = Tensor([1.,2,3])
|
||||
b = Tensor([4.,5,6])
|
||||
c = a + b
|
||||
d = a * b
|
||||
c.callify(d)
|
||||
c.callify(d)
|
||||
self.assertListEqual(c.tolist(), [5.0, 7.0, 9.0])
|
||||
self.assertListEqual(d.tolist(), [4.0, 10.0, 18.0])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -2,6 +2,7 @@ import time, inspect
|
||||
from typing import cast
|
||||
from collections import deque
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink, KernelInfo
|
||||
from tinygrad.uop.ops import _remove_all_tags
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR
|
||||
@@ -22,7 +23,7 @@ def create_schedule(sched_sink:UOp) -> UOp:
|
||||
for u in sched_sink.toposort(gate_kernel_sink):
|
||||
if u.op is not Ops.AFTER: continue
|
||||
k = u.src[1]
|
||||
assert k.op in {Ops.CALL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}"
|
||||
assert k.op in {Ops.CALL, Ops.END, Ops.LINEAR}, f"AFTER src[1] should be CALL or END, not {k.op}"
|
||||
in_degree.setdefault(k, 0)
|
||||
if k.op is Ops.END: assert k.src[0].op is Ops.CALL, f"END src[0] should be KERNEL, not {k.src[0].op}"
|
||||
# WAR deps from rangeify are stored in AFTER src[2:]
|
||||
@@ -49,10 +50,13 @@ def create_schedule(sched_sink:UOp) -> UOp:
|
||||
linearized: list[UOp] = []
|
||||
while len(queue):
|
||||
rk = queue.popleft()
|
||||
k = rk.src[0] if rk.op is Ops.END else rk
|
||||
assert k.op is Ops.CALL, f"unexpected op in queue: {k.op}"
|
||||
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
|
||||
linearized.append(k.src[0].call(*buf_uops, metadata=k.arg.metadata))
|
||||
if rk.op is Ops.LINEAR:
|
||||
linearized.extend(rk.src)
|
||||
else:
|
||||
k = rk.src[0] if rk.op is Ops.END else rk
|
||||
assert k.op is Ops.CALL, f"unexpected op in queue: {k.op}"
|
||||
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
|
||||
linearized.append(k.src[0].call(*buf_uops, metadata=k.arg.metadata))
|
||||
for x in children.get(rk, []):
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
@@ -88,7 +92,7 @@ def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp):
|
||||
return ret
|
||||
|
||||
pm_post_sched_cache = PatternMatcher([
|
||||
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg]),
|
||||
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg].rtag() if x.tag is None else None),
|
||||
# create new BUFFERs for LUNIQUE BUFFERs from rangeify
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer),
|
||||
])
|
||||
@@ -100,6 +104,8 @@ def lower_schedule_to_linear(big_sink:UOp) -> UOp|None:
|
||||
if isinstance(function.arg, KernelInfo): return None
|
||||
if not SCACHE or (sc_ret:=schedule_cache.get(function.key, None)) is None:
|
||||
if SPEC: type_verify(big_sink, tensor_spec)
|
||||
# support recursive CALLs
|
||||
function = graph_rewrite(function, pm_schedule, name="schedule to linear")
|
||||
linear = create_schedule(get_kernel_graph(function))
|
||||
if SCACHE: schedule_cache[function.key] = linear
|
||||
else:
|
||||
@@ -115,7 +121,8 @@ def lower_schedule_to_linear(big_sink:UOp) -> UOp|None:
|
||||
print(f"scheduled {len(linear.src):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
|
||||
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {function.key.hex()[:8]}"+\
|
||||
f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}"))
|
||||
return graph_rewrite(linear, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers")
|
||||
linear = graph_rewrite(linear, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers")
|
||||
return graph_rewrite(linear, _remove_all_tags, name="remove tags")
|
||||
|
||||
pm_schedule = PatternMatcher([
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.SINK),), allow_any_len=True, name="big_sink"), lower_schedule_to_linear),
|
||||
|
||||
@@ -181,7 +181,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue
|
||||
|
||||
# no ranges on kernels, they are internal
|
||||
if x.op is Ops.CALL: continue
|
||||
if x.op in {Ops.CALL, Ops.LINEAR}: continue
|
||||
|
||||
if x.dtype.scalar() == dtypes.index: continue # TODO: why do I need this?
|
||||
ending_ranges[x] = sum([ending_ranges.get(u, []) for u in consumer_map[x]], [])
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import functools, itertools
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, getenv
|
||||
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
|
||||
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite, should_resolve_call
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
# *** allreduce implementation ***
|
||||
@@ -163,6 +163,8 @@ def assign_multi(dest:UOp, src:UOp):
|
||||
def passthrough_multi(root:UOp, multi:UOp):
|
||||
return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis)
|
||||
|
||||
def rewrite_into_call(call:UOp): return call.replace(src=(graph_rewrite(call.src[0], multi_pm),)+call.src[1:]) if should_resolve_call(call) else None
|
||||
|
||||
# NOTE: this is the same pattern as Ops.UNROLL
|
||||
multi_pm = PatternMatcher([
|
||||
(UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi),
|
||||
@@ -177,6 +179,8 @@ multi_pm = PatternMatcher([
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi),
|
||||
(UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device")), name="red"),
|
||||
lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)),
|
||||
# rewrite into calls explicitly for MULTI
|
||||
(UPat(Ops.CALL, name="call"), rewrite_into_call),
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
|
||||
# we just remove the MULTI from CALLs with dtypes.void and assume they are handled by the user for custom kernels
|
||||
(UPat(Ops.CALL, dtype=dtypes.void, name="root", custom_early_reject=set([Ops.MULTI])), lambda root:
|
||||
|
||||
@@ -2,7 +2,7 @@ from dataclasses import dataclass, field, replace
|
||||
import itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
||||
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches
|
||||
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, should_resolve_call
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
|
||||
from tinygrad.helpers import PCONTIG, partition, get_single_element
|
||||
@@ -77,9 +77,7 @@ mop_cleanup = PatternMatcher([
|
||||
|
||||
pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ])
|
||||
def resolve_call(c:UOp, allow_param_mismatch=False) -> UOp|None:
|
||||
# don't resolve real kernel calls, sink or program
|
||||
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None
|
||||
if c.src[0].op is Ops.PROGRAM: return None
|
||||
if not should_resolve_call(c): return None
|
||||
params: list[UOp] = []
|
||||
graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params)
|
||||
params = sorted(params, key=lambda x: x.arg)
|
||||
@@ -481,7 +479,7 @@ split_kernels = PatternMatcher([
|
||||
|
||||
@profile_matches
|
||||
def get_kernel_graph(sink:UOp) -> UOp:
|
||||
tsink = graph_rewrite(sink, multi_pm, name="multi_pm", rewrite_into_calls=True)
|
||||
tsink = graph_rewrite(sink, multi_pm, name="multi_pm")
|
||||
tsink = graph_rewrite(tsink, pm_syntactic_sugar+pm_mops+earliest_rewrites, bottom_up=True, name="earliest rewrites")
|
||||
|
||||
# convert movement ops to ranges
|
||||
@@ -511,4 +509,4 @@ def get_kernel_graph(sink:UOp) -> UOp:
|
||||
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
|
||||
if assign_rep: tsink = graph_rewrite(tsink, _substitute, ctx=assign_rep, bottom_up=True, name="fix_assign")
|
||||
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
|
||||
return tsink
|
||||
return tsink
|
||||
|
||||
@@ -250,6 +250,12 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
return [Tensor(u, device=u.device) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
|
||||
|
||||
def callify(self, *lst:Tensor) -> Tensor:
|
||||
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
||||
big_sink, buffer_map = transform_to_call(big_sink)
|
||||
_apply_map_to_tensors({x:y.after(big_sink) for x,y in buffer_map.items()}, name="callify")
|
||||
return self
|
||||
|
||||
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[str, int]]:
|
||||
"""
|
||||
Creates the schedule needed to realize these Tensor(s), with Variables.
|
||||
|
||||
@@ -895,6 +895,13 @@ class CallInfo:
|
||||
def __reduce__(self): return (CallInfo, (None, self.metadata))
|
||||
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata})"
|
||||
|
||||
def should_resolve_call(c:UOp) -> bool:
|
||||
# don't resolve real kernel calls, sink or program
|
||||
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return False
|
||||
if c.src[0].op is Ops.PROGRAM: return False
|
||||
if c.src[0].op is Ops.COPY: return False
|
||||
return True
|
||||
|
||||
# ******** ops in python ********
|
||||
|
||||
def safe_exp2(x):
|
||||
@@ -1239,13 +1246,12 @@ if TRACK_MATCH_STATS or PROFILE:
|
||||
SENTINEL: Final[UOp] = cast(UOp, object())
|
||||
class BottomUpGate(Exception): pass
|
||||
class RewriteContext:
|
||||
def __init__(self, pm, bpm, ctx=None, rewrite_into_calls=False):
|
||||
def __init__(self, pm, bpm, ctx=None):
|
||||
self.pm: PatternMatcher|None = pm
|
||||
self.bpm: PatternMatcher|None = bpm
|
||||
self.bpm_cache: dict[UOp, UOp|None] = {}
|
||||
self.ctx = ctx
|
||||
self.replace: dict[UOp, UOp] = {}
|
||||
self.rewrite_into_calls = rewrite_into_calls
|
||||
|
||||
# no cache needed: pm_rewrite is called at most once per UOp due to the replace dict check in unified_rewrite
|
||||
def pm_rewrite(self, x:UOp) -> UOp|None: return unwrap(self.pm).rewrite(x, self.ctx)
|
||||
@@ -1283,7 +1289,7 @@ class RewriteContext:
|
||||
# NOTE: CALL is handled as a special case.
|
||||
# The function that is called is not included in the graph_rewrite.
|
||||
# If you want to graph_rewrite a call, you can
|
||||
if new_n.op is Ops.CALL and not self.rewrite_into_calls: self.replace[new_n.src[0]] = new_n.src[0]
|
||||
if new_n.op is Ops.CALL: self.replace[new_n.src[0]] = new_n.src[0]
|
||||
for x in reversed(new_n.src):
|
||||
if x in on_stack: continue
|
||||
stack.append((x, 0, x))
|
||||
@@ -1322,8 +1328,8 @@ class RewriteContext:
|
||||
return self.replace[root]
|
||||
|
||||
@profile_matches
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None, rewrite_into_calls=False) -> UOp:
|
||||
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx, rewrite_into_calls=rewrite_into_calls)
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None) -> UOp:
|
||||
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx)
|
||||
return rewrite_ctx.unified_rewrite(sink)
|
||||
|
||||
def sint_to_uop(x:sint, dtype=dtypes.index) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x.cast(dtype)
|
||||
|
||||
@@ -258,7 +258,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
|
||||
# only RANGE/IF/STORE/KERNEL have side effects
|
||||
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
|
||||
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))),
|
||||
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.BARRIER, Ops.END, Ops.UNROLL, Ops.LINEAR} else y.src for y in x.src[1:]])))),
|
||||
# after with 1 src is just src[0]
|
||||
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
|
||||
# VECTORIZE/CONST
|
||||
|
||||
Reference in New Issue
Block a user