diff --git a/test/unit/test_callify.py b/test/unit/test_callify.py new file mode 100644 index 0000000000..82b852ee5c --- /dev/null +++ b/test/unit/test_callify.py @@ -0,0 +1,111 @@ +import unittest +from tinygrad import Tensor, dtypes + +class TestCallify(unittest.TestCase): + def test_basic(self): + a = Tensor([1.,2,3]) + b = Tensor([4.,5,6]) + out = a + b + out.callify() + self.assertListEqual(out.tolist(), [5.0, 7.0, 9.0]) + + def test_const(self): + out = Tensor(2.0) + Tensor(3.0) + out.callify() + self.assertEqual(out.item(), 5.0) + + def test_sum(self): + out = Tensor.ones(16).contiguous().sum() + out.callify() + self.assertEqual(out.item(), 16.0) + + def test_multi_output(self): + a = Tensor([1.,2,3]) + b = Tensor([4.,5,6]) + c = a + b + d = a * b + c.callify(d) + self.assertListEqual(c.tolist(), [5.0, 7.0, 9.0]) + self.assertListEqual(d.tolist(), [4.0, 10.0, 18.0]) + + def test_two_callify_independent(self): + a = Tensor([1.,2,3]) + b = Tensor([4.,5,6]) + c = a + b + c.callify() + + d = Tensor([10.,20,30]) + e = Tensor([1.,1,1]) + f = d - e + f.callify() + + self.assertListEqual(c.tolist(), [5.0, 7.0, 9.0]) + self.assertListEqual(f.tolist(), [9.0, 19.0, 29.0]) + + def test_two_callify_shared_input(self): + a = Tensor([1.,2,3]).contiguous().realize() + b = a + 1 + b.callify() + c = a * 2 + c.callify() + self.assertListEqual(b.tolist(), [2.0, 3.0, 4.0]) + self.assertListEqual(c.tolist(), [2.0, 4.0, 6.0]) + + def test_chained_callify(self): + a = Tensor([1.,2,3]) + b = a + 1 + b.callify() + b.realize() + c = b + 1 + c.callify() + self.assertListEqual(c.tolist(), [3.0, 4.0, 5.0]) + + def test_gemm(self): + a = Tensor.ones(8, 8).contiguous() + b = Tensor.eye(8).contiguous() + out = a @ b + out.callify() + lst = out.tolist() + for y in range(8): + for x in range(8): + self.assertEqual(lst[y][x], 1.0) + + def test_int_dtype(self): + a = Tensor([1,2,3], dtype=dtypes.int) + b = Tensor([4,5,6], dtype=dtypes.int) + out = a + b + out.callify() + self.assertListEqual(out.tolist(), [5, 7, 9]) + + def test_reduce(self): + out = Tensor([1.,2,3,4]).sum() + out.callify() + self.assertEqual(out.item(), 10.0) + + def test_multiple_ops(self): + a = Tensor([1.,2,3]) + b = Tensor([4.,5,6]) + out = (a + b) * (a - b) + out.callify() + self.assertListEqual(out.tolist(), [-15.0, -21.0, -27.0]) + + def test_double_callify(self): + a = Tensor([1.,2,3]) + b = Tensor([4.,5,6]) + out = a + b + out.callify() + out.callify() + self.assertListEqual(out.tolist(), [5.0, 7.0, 9.0]) + + def test_double_callify_multi_output(self): + a = Tensor([1.,2,3]) + b = Tensor([4.,5,6]) + c = a + b + d = a * b + c.callify(d) + c.callify(d) + self.assertListEqual(c.tolist(), [5.0, 7.0, 9.0]) + self.assertListEqual(d.tolist(), [4.0, 10.0, 18.0]) + +if __name__ == "__main__": + unittest.main() diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index ff2f1bc7db..f6fcc29a76 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -2,6 +2,7 @@ import time, inspect from typing import cast from collections import deque from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink, KernelInfo +from tinygrad.uop.ops import _remove_all_tags from tinygrad.uop.spec import type_verify, tensor_spec from tinygrad.device import Buffer, MultiBuffer from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR @@ -22,7 +23,7 @@ def create_schedule(sched_sink:UOp) -> UOp: for u in sched_sink.toposort(gate_kernel_sink): if u.op is not Ops.AFTER: continue k = u.src[1] - assert k.op in {Ops.CALL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}" + assert k.op in {Ops.CALL, Ops.END, Ops.LINEAR}, f"AFTER src[1] should be CALL or END, not {k.op}" in_degree.setdefault(k, 0) if k.op is Ops.END: assert k.src[0].op is Ops.CALL, f"END src[0] should be KERNEL, not {k.src[0].op}" # WAR deps from rangeify are stored in AFTER src[2:] @@ -49,10 +50,13 @@ def create_schedule(sched_sink:UOp) -> UOp: linearized: list[UOp] = [] while len(queue): rk = queue.popleft() - k = rk.src[0] if rk.op is Ops.END else rk - assert k.op is Ops.CALL, f"unexpected op in queue: {k.op}" - buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND) - linearized.append(k.src[0].call(*buf_uops, metadata=k.arg.metadata)) + if rk.op is Ops.LINEAR: + linearized.extend(rk.src) + else: + k = rk.src[0] if rk.op is Ops.END else rk + assert k.op is Ops.CALL, f"unexpected op in queue: {k.op}" + buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND) + linearized.append(k.src[0].call(*buf_uops, metadata=k.arg.metadata)) for x in children.get(rk, []): in_degree[x] -= 1 if in_degree[x] == 0: queue.append(x) @@ -88,7 +92,7 @@ def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp): return ret pm_post_sched_cache = PatternMatcher([ - (UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg]), + (UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg].rtag() if x.tag is None else None), # create new BUFFERs for LUNIQUE BUFFERs from rangeify (UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer), ]) @@ -100,6 +104,8 @@ def lower_schedule_to_linear(big_sink:UOp) -> UOp|None: if isinstance(function.arg, KernelInfo): return None if not SCACHE or (sc_ret:=schedule_cache.get(function.key, None)) is None: if SPEC: type_verify(big_sink, tensor_spec) + # support recursive CALLs + function = graph_rewrite(function, pm_schedule, name="schedule to linear") linear = create_schedule(get_kernel_graph(function)) if SCACHE: schedule_cache[function.key] = linear else: @@ -115,7 +121,8 @@ def lower_schedule_to_linear(big_sink:UOp) -> UOp|None: print(f"scheduled {len(linear.src):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\ f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {function.key.hex()[:8]}"+\ f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}")) - return graph_rewrite(linear, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers") + linear = graph_rewrite(linear, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers") + return graph_rewrite(linear, _remove_all_tags, name="remove tags") pm_schedule = PatternMatcher([ (UPat(Ops.CALL, src=(UPat(Ops.SINK),), allow_any_len=True, name="big_sink"), lower_schedule_to_linear), diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 60ee6882ce..76ed69d3db 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -181,7 +181,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue # no ranges on kernels, they are internal - if x.op is Ops.CALL: continue + if x.op in {Ops.CALL, Ops.LINEAR}: continue if x.dtype.scalar() == dtypes.index: continue # TODO: why do I need this? ending_ranges[x] = sum([ending_ranges.get(u, []) for u in consumer_map[x]], []) diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 077ac47a15..487e562139 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -1,6 +1,6 @@ import functools, itertools from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, getenv -from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp +from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite, should_resolve_call from tinygrad.dtype import dtypes # *** allreduce implementation *** @@ -163,6 +163,8 @@ def assign_multi(dest:UOp, src:UOp): def passthrough_multi(root:UOp, multi:UOp): return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis) +def rewrite_into_call(call:UOp): return call.replace(src=(graph_rewrite(call.src[0], multi_pm),)+call.src[1:]) if should_resolve_call(call) else None + # NOTE: this is the same pattern as Ops.UNROLL multi_pm = PatternMatcher([ (UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi), @@ -177,6 +179,8 @@ multi_pm = PatternMatcher([ (UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi), (UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device")), name="red"), lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)), + # rewrite into calls explicitly for MULTI + (UPat(Ops.CALL, name="call"), rewrite_into_call), (UPat(Ops.CALL, src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi), # we just remove the MULTI from CALLs with dtypes.void and assume they are handled by the user for custom kernels (UPat(Ops.CALL, dtype=dtypes.void, name="root", custom_early_reject=set([Ops.MULTI])), lambda root: diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index c739116489..4661447558 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field, replace import itertools from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo -from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches +from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, should_resolve_call from tinygrad.uop.symbolic import symbolic from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS from tinygrad.helpers import PCONTIG, partition, get_single_element @@ -77,9 +77,7 @@ mop_cleanup = PatternMatcher([ pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ]) def resolve_call(c:UOp, allow_param_mismatch=False) -> UOp|None: - # don't resolve real kernel calls, sink or program - if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None - if c.src[0].op is Ops.PROGRAM: return None + if not should_resolve_call(c): return None params: list[UOp] = [] graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params) params = sorted(params, key=lambda x: x.arg) @@ -481,7 +479,7 @@ split_kernels = PatternMatcher([ @profile_matches def get_kernel_graph(sink:UOp) -> UOp: - tsink = graph_rewrite(sink, multi_pm, name="multi_pm", rewrite_into_calls=True) + tsink = graph_rewrite(sink, multi_pm, name="multi_pm") tsink = graph_rewrite(tsink, pm_syntactic_sugar+pm_mops+earliest_rewrites, bottom_up=True, name="earliest rewrites") # convert movement ops to ranges @@ -511,4 +509,4 @@ def get_kernel_graph(sink:UOp) -> UOp: assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,)) if assign_rep: tsink = graph_rewrite(tsink, _substitute, ctx=assign_rep, bottom_up=True, name="fix_assign") if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph") - return tsink \ No newline at end of file + return tsink diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index de4d8c4355..924769b655 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -250,6 +250,12 @@ class Tensor(OpMixin): """ return [Tensor(u, device=u.device) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)] + def callify(self, *lst:Tensor) -> Tensor: + big_sink = UOp.sink(*[x.uop for x in (self,)+lst]) + big_sink, buffer_map = transform_to_call(big_sink) + _apply_map_to_tensors({x:y.after(big_sink) for x,y in buffer_map.items()}, name="callify") + return self + def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[str, int]]: """ Creates the schedule needed to realize these Tensor(s), with Variables. diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 1aed20480f..cb81df22fa 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -895,6 +895,13 @@ class CallInfo: def __reduce__(self): return (CallInfo, (None, self.metadata)) def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata})" +def should_resolve_call(c:UOp) -> bool: + # don't resolve real kernel calls, sink or program + if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return False + if c.src[0].op is Ops.PROGRAM: return False + if c.src[0].op is Ops.COPY: return False + return True + # ******** ops in python ******** def safe_exp2(x): @@ -1239,13 +1246,12 @@ if TRACK_MATCH_STATS or PROFILE: SENTINEL: Final[UOp] = cast(UOp, object()) class BottomUpGate(Exception): pass class RewriteContext: - def __init__(self, pm, bpm, ctx=None, rewrite_into_calls=False): + def __init__(self, pm, bpm, ctx=None): self.pm: PatternMatcher|None = pm self.bpm: PatternMatcher|None = bpm self.bpm_cache: dict[UOp, UOp|None] = {} self.ctx = ctx self.replace: dict[UOp, UOp] = {} - self.rewrite_into_calls = rewrite_into_calls # no cache needed: pm_rewrite is called at most once per UOp due to the replace dict check in unified_rewrite def pm_rewrite(self, x:UOp) -> UOp|None: return unwrap(self.pm).rewrite(x, self.ctx) @@ -1283,7 +1289,7 @@ class RewriteContext: # NOTE: CALL is handled as a special case. # The function that is called is not included in the graph_rewrite. # If you want to graph_rewrite a call, you can - if new_n.op is Ops.CALL and not self.rewrite_into_calls: self.replace[new_n.src[0]] = new_n.src[0] + if new_n.op is Ops.CALL: self.replace[new_n.src[0]] = new_n.src[0] for x in reversed(new_n.src): if x in on_stack: continue stack.append((x, 0, x)) @@ -1322,8 +1328,8 @@ class RewriteContext: return self.replace[root] @profile_matches -def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None, rewrite_into_calls=False) -> UOp: - rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx, rewrite_into_calls=rewrite_into_calls) +def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None) -> UOp: + rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx) return rewrite_ctx.unified_rewrite(sink) def sint_to_uop(x:sint, dtype=dtypes.index) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x.cast(dtype) diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 1e8f6bd31c..4b36226681 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -258,7 +258,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ ((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)), # only RANGE/IF/STORE/KERNEL have side effects (UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+ - tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))), + tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.BARRIER, Ops.END, Ops.UNROLL, Ops.LINEAR} else y.src for y in x.src[1:]])))), # after with 1 src is just src[0] (UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s), # VECTORIZE/CONST