mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
This reverts commit 2594e4db15.
This commit is contained in:
@@ -15,7 +15,7 @@ from tinygrad.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewrite,
|
||||
from tinygrad.codegen.symbolic import symbolic_simple
|
||||
from tinygrad.spec import type_verify, shape_spec
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
|
||||
from tinygrad.engine.grouper import view_left, view_right, sym, get_becomes_map, Kernel, finalize_kernels, merge_views
|
||||
from tinygrad.engine.grouper import view_left, view_right, sym, get_becomes_map, Kernel, create_ast, merge_views
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
@@ -605,7 +605,7 @@ class TestSchedule(unittest.TestCase):
|
||||
b = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().lazydata
|
||||
c = Tensor.arange(4).realize().lazydata
|
||||
kernel = UOp(Ops.KERNEL, src=(a, b, c), arg=Kernel(UOp.sink(c.r(Ops.ADD, (0,))+1, c.r(Ops.ADD, (0,))*2)))
|
||||
kernel = graph_rewrite(kernel, finalize_kernels)
|
||||
kernel = graph_rewrite(kernel, create_ast)
|
||||
run_schedule(check_schedule(UOp.sink(a.assign(kernel), b.assign(kernel)), 1))
|
||||
self.assertEqual(a.buffer.numpy(), [7])
|
||||
self.assertEqual(b.buffer.numpy(), [12])
|
||||
|
||||
@@ -262,6 +262,19 @@ insert_kernels = merge_views+PatternMatcher([
|
||||
(UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"), lambda ctx,x: create_kernel(ctx, x) if x in ctx.realizes else None),
|
||||
])
|
||||
|
||||
def append_to_kernel(ctx:KernelContext, x:UOp):
|
||||
new_srcs: list[UOp] = []
|
||||
metadata = dict.fromkeys(x.arg.metadata)
|
||||
for s in x.src:
|
||||
if s.op in DONT_PLACE_IN_KERNEL or s in ctx.realizes: new_srcs.append(s)
|
||||
else:
|
||||
new_srcs.extend(s.src)
|
||||
if s.base.op not in {Ops.CONST, Ops.DEVICE} and (m:=ctx.metadata.get(s)): metadata[m] = None
|
||||
if (new_src:=tuple(dedup(new_srcs))) != x.src: return x.replace(src=new_src, arg=Kernel(x.arg.ast, tuple(metadata)))
|
||||
|
||||
# walk back the local graph until we reach a realized parent
|
||||
create_kernels = insert_kernels+PatternMatcher([(UPat(Ops.KERNEL, name="x"), append_to_kernel),])
|
||||
|
||||
# **** swizzler
|
||||
|
||||
def reduce_push_add_ones(src:UOp, r:UOp, view:UOp):
|
||||
@@ -376,17 +389,7 @@ fix_kernel_ops = PatternMatcher([
|
||||
(UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st),
|
||||
])
|
||||
|
||||
def finalize(ctx:dict[UOp, Metadata|None], k:UOp) -> UOp|None:
|
||||
# walk back the kernel sources until we reach a realized parent
|
||||
new_srcs: list[UOp] = []
|
||||
metadata = dict.fromkeys(k.arg.metadata)
|
||||
for s in k.src:
|
||||
if s.op in DONT_PLACE_IN_KERNEL: new_srcs.append(s)
|
||||
else:
|
||||
new_srcs.extend(s.src)
|
||||
if s.base.op not in {Ops.CONST, Ops.DEVICE} and (m:=ctx.get(s)): metadata[m] = None
|
||||
if (new_src:=tuple(dedup(new_srcs))) != k.src: return k.replace(src=new_src, arg=Kernel(k.arg.ast, tuple(metadata)))
|
||||
# once we're done fusing, create the local ast with load/stores
|
||||
def fix_kernel_ast(k:UOp) -> UOp|None:
|
||||
if k.arg.ast.op in GroupOp.Meta or all(s.op is Ops.STORE for s in k.arg.ast.src): return None
|
||||
# replace assign sources with a view of the target buffer
|
||||
parents_rep: dict[UOp, UOp] = {}
|
||||
@@ -402,7 +405,7 @@ def finalize(ctx:dict[UOp, Metadata|None], k:UOp) -> UOp|None:
|
||||
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
|
||||
return k.replace(arg=Kernel(ast, k.arg.metadata))
|
||||
|
||||
finalize_kernels = PatternMatcher([(UPat(Ops.KERNEL, name="k"), finalize),])
|
||||
create_ast = PatternMatcher([(UPat(Ops.KERNEL, name="k"), fix_kernel_ast),])
|
||||
|
||||
pm_fuse = PatternMatcher([
|
||||
# FUSE on CONTIGUOUS removes FUSE
|
||||
@@ -488,10 +491,8 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
|
||||
# group into kernels
|
||||
realize_map = group_realizes(tensor_map[big_sink])
|
||||
tensor_map = graph_rewrite_map(tensor_map[big_sink], insert_kernels, ctx=KernelContext(realize_map, {v:k.metadata for k,v in tensor_map.items()}),
|
||||
bottom_up=True, input_map=tensor_map, name="insert_kernels")
|
||||
tensor_map = graph_rewrite_map(tensor_map[big_sink], finalize_kernels, ctx={v:k.metadata for k,v in tensor_map.items()}, bottom_up=True,
|
||||
input_map=tensor_map, name="finalize_kernels")
|
||||
tensor_map = graph_rewrite_map(tensor_map[big_sink], create_kernels, ctx=KernelContext(realize_map, {v:k.metadata for k,v in tensor_map.items()}),
|
||||
bottom_up=True, input_map=tensor_map, name="create_kernels")
|
||||
|
||||
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
||||
kernel_assign: dict[UOp, UOp] = {}
|
||||
@@ -507,6 +508,9 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
if assign_rep:
|
||||
tensor_map = graph_rewrite_map(tensor_map[big_sink], _substitute, ctx=assign_rep, bottom_up=True, input_map=tensor_map, name="fix_assign")
|
||||
|
||||
# finally, create the AST for kernels
|
||||
tensor_map = graph_rewrite_map(tensor_map[big_sink], create_ast, bottom_up=True, input_map=tensor_map, name="create_ast")
|
||||
|
||||
# display the final graph
|
||||
sched_sink = tensor_map[big_sink]
|
||||
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
|
||||
|
||||
Reference in New Issue
Block a user