mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
start ripping out old scheduler -- no maps (#14909)
* start ripping out old scheduler -- no maps * no more metadata
This commit is contained in:
@@ -10,7 +10,7 @@ Directories are listed in order of how they are processed.
|
||||
|
||||
Group UOps into kernels.
|
||||
|
||||
::: tinygrad.schedule.rangeify.get_rangeify_map
|
||||
::: tinygrad.schedule.rangeify.get_rangeify
|
||||
options:
|
||||
members: false
|
||||
show_labels: false
|
||||
|
||||
@@ -266,7 +266,7 @@ class TestCustomKernel(unittest.TestCase):
|
||||
The custom_addmul kernel should be at index 3.
|
||||
"""
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||
from tinygrad.schedule.rangeify import get_rangeify
|
||||
|
||||
A, B = Tensor.empty(4, 4), Tensor.empty(4, 4)
|
||||
A2 = (A + 1).contiguous() # kernel 0: depends on A
|
||||
@@ -277,8 +277,7 @@ class TestCustomKernel(unittest.TestCase):
|
||||
result = (C + D + E).sum() # kernel 3: custom_addmul, then kernel 4: sum
|
||||
|
||||
big_sink = result.uop.sink()
|
||||
tensor_map = get_rangeify_map(big_sink)
|
||||
sched_sink = big_sink.substitute(tensor_map)
|
||||
sched_sink = get_rangeify(big_sink)
|
||||
schedule, _ = create_schedule(sched_sink)
|
||||
|
||||
# Find the custom_addmul kernel position
|
||||
|
||||
13
test/external/process_replay/process_replay.py
vendored
13
test/external/process_replay/process_replay.py
vendored
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# compare kernels created by HEAD against master
|
||||
import os, multiprocessing, logging, pickle, sqlite3, difflib, warnings, itertools, functools, base64, codecs
|
||||
import os, multiprocessing, logging, pickle, sqlite3, difflib, warnings, functools, base64, codecs
|
||||
from dataclasses import replace
|
||||
from typing import Callable, Any
|
||||
|
||||
@@ -8,7 +8,6 @@ ASSERT_DIFF = int((flag:="[pr]") in os.getenv("COMMIT_MESSAGE", flag) or flag in
|
||||
if not int(os.getenv("ASSERT_PROCESS_REPLAY", "1")): ASSERT_DIFF = 0
|
||||
|
||||
try:
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||
from tinygrad.renderer import Renderer, ProgramSpec
|
||||
from tinygrad.engine.realize import get_program
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
@@ -43,14 +42,6 @@ class ProcessReplayWarning(Warning): pass
|
||||
|
||||
# *** replay the function and convert return values to string
|
||||
|
||||
def replay_get_rangeify_map(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str, tuple[Any, ...]]:
|
||||
UOp.unique_num = itertools.count(max([u.arg for u in big_sink.toposort() if u.op is Ops.UNIQUE], default=0)+1)
|
||||
new_sink = big_sink.substitute(get_rangeify_map(big_sink))
|
||||
def to_str(ret:UOp) -> str:
|
||||
asts = [repr(u.arg.ast) for u in ret.toposort() if u.op is Ops.CALL]
|
||||
return "\n".join([f"{len(asts)} kernels", *asts])
|
||||
return to_str(new_sink), to_str(big_sink.substitute(ret)), (big_sink,)
|
||||
|
||||
def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> tuple[str, str, tuple[Any, ...]]:
|
||||
# the ast.arg is non None if we are inside of search.py
|
||||
sink_arg = ast.arg or KernelInfo()
|
||||
@@ -68,8 +59,6 @@ def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer, opts:list[Opt]
|
||||
|
||||
replayers: dict[str, Callable[..., tuple[str, str, tuple[Any, ...]]]] = {}
|
||||
replayers["get_program"] = replay_get_program
|
||||
# disable this for speed, does it ever find things?
|
||||
#replayers["get_rangeify_map"] = replay_get_rangeify_map
|
||||
|
||||
# *** run replayers on captured rows and print diffs
|
||||
|
||||
|
||||
@@ -1,202 +0,0 @@
|
||||
import unittest
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.uop.ops import UOp, graph_rewrite_map, _substitute
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
|
||||
class TestRewriteMap(unittest.TestCase):
|
||||
def test_substitute(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
b = UOp.variable('b', 0, 10)
|
||||
c = UOp.variable('c', 0, 10)
|
||||
e = UOp.variable('e', 0, 10)
|
||||
ret = (a+b)*c
|
||||
sub = {a+b: e}
|
||||
sub_map = graph_rewrite_map(ret, _substitute, sub, bottom_up=True)
|
||||
self.assertIs(sub_map[a+b], e)
|
||||
self.assertIs(sub_map[(a+b)*c], e*c)
|
||||
|
||||
def test_substitute_depth_2(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
b = UOp.variable('b', 0, 10)
|
||||
c = UOp.variable('c', 0, 10)
|
||||
d = UOp.variable('d', 0, 10)
|
||||
e = UOp.variable('e', 0, 10)
|
||||
f = UOp.variable('f', 0, 10)
|
||||
ret = (a+b)*c+d
|
||||
sub = {a+b: e, (a+b)*c: f}
|
||||
sub_map = graph_rewrite_map(ret, _substitute, sub, bottom_up=True)
|
||||
self.assertIs(sub_map[a+b], e)
|
||||
self.assertIs(sub_map[(a+b)*c], f)
|
||||
|
||||
def test_multistage_substitute(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
b = UOp.variable('b', 0, 10)
|
||||
c = UOp.variable('c', 0, 10)
|
||||
d = UOp.variable('d', 0, 10)
|
||||
sub1 = {a+b:c}
|
||||
start = (a+b)*c
|
||||
# stage 1: (a+b)*c -> c*c
|
||||
sub_map1 = graph_rewrite_map(start, _substitute, sub1, bottom_up=True)
|
||||
self.assertIs(sub_map1[(a+b)*c], c*c)
|
||||
# stage 2: c*c -> d
|
||||
sub2 = {c*c:d}
|
||||
sub_map2 = graph_rewrite_map(sub_map1[start], _substitute, sub2, input_map=sub_map1, bottom_up=True)
|
||||
# (a+b)*c -> c*c -> d
|
||||
self.assertIs(sub_map2[(a+b)*c], d)
|
||||
|
||||
def test_add_zero(self):
|
||||
# Build a small graph: add(0, add(const=0, const=5))
|
||||
zero_node = UOp.const(dtypes.index, 0)
|
||||
five_node = UOp.const(dtypes.index, 5)
|
||||
inner_add = zero_node + five_node
|
||||
root_add = zero_node + inner_add
|
||||
|
||||
# Perform top-down rewrite
|
||||
node_map = graph_rewrite_map(root_add, symbolic)
|
||||
|
||||
# We expect that add(0, add(0, 5)) -> add(0, 5) -> 5
|
||||
# Check the mapping
|
||||
assert node_map[root_add] == five_node
|
||||
assert node_map[inner_add] == five_node
|
||||
# zero_node and five_node map to themselves
|
||||
assert node_map[zero_node] == zero_node
|
||||
assert node_map[five_node] == five_node
|
||||
|
||||
def test_double_neg(self):
|
||||
"""
|
||||
Test rewriting neg(neg(5)) => 5 using symbolic.
|
||||
"""
|
||||
# In some versions of TinyGrad, you might do: (-(-five_node))
|
||||
five_node = UOp.const(dtypes.index, 5)
|
||||
# If your code allows UOp(...), do that; else you might do something like:
|
||||
# double_neg_five = -(-five_node)
|
||||
# But let's be explicit:
|
||||
neg_five = -five_node
|
||||
double_neg_five = -neg_five
|
||||
|
||||
node_map = graph_rewrite_map(double_neg_five, symbolic)
|
||||
|
||||
# node_map should map double_neg_five -> five_node
|
||||
self.assertEqual(node_map[double_neg_five], five_node)
|
||||
# five_node maps to itself
|
||||
self.assertEqual(node_map[five_node], five_node)
|
||||
|
||||
def test_add_zero_and_double_neg(self):
|
||||
"""
|
||||
Combine both rewrites: add(0, neg(neg(5))) => add(0, 5) => 5
|
||||
"""
|
||||
zero_node = UOp.const(dtypes.index, 0)
|
||||
five_node = UOp.const(dtypes.index, 5)
|
||||
neg_five = -five_node
|
||||
double_neg_five = -neg_five
|
||||
root_add = zero_node + double_neg_five
|
||||
|
||||
node_map = graph_rewrite_map(root_add, symbolic)
|
||||
|
||||
# node_map: root_add -> five_node, double_neg_five -> five_node
|
||||
self.assertEqual(node_map[root_add], five_node)
|
||||
self.assertEqual(node_map[double_neg_five], five_node)
|
||||
# zero_node, five_node map to themselves
|
||||
self.assertEqual(node_map[zero_node], zero_node)
|
||||
self.assertEqual(node_map[five_node], five_node)
|
||||
|
||||
def test_multi_var_rewrites(self):
|
||||
x_var = UOp.variable('x', 0, 10)
|
||||
y_var = UOp.variable('y', -5, 5)
|
||||
zero_node = UOp.const(dtypes.index, 0)
|
||||
|
||||
sum_with_zero = y_var + zero_node # (y + 0)
|
||||
combined = x_var + sum_with_zero # x + (y + 0)
|
||||
double_neg = -(-combined) # neg(neg(x + y))
|
||||
final_expr = zero_node + double_neg # 0 + (x + y)
|
||||
|
||||
node_map = graph_rewrite_map(final_expr, symbolic)
|
||||
|
||||
# The final root should be (x_var + y_var).
|
||||
expected = x_var + y_var
|
||||
|
||||
# Each sub-expression has its own "final" result.
|
||||
# (y + 0) -> y_var
|
||||
self.assertEqual(node_map[sum_with_zero], y_var)
|
||||
# (x + (y+0)) -> (x + y)
|
||||
self.assertEqual(node_map[combined], expected)
|
||||
# neg(neg(x+y)) -> (x + y)
|
||||
self.assertEqual(node_map[double_neg], expected)
|
||||
# 0 + (x+y) -> (x + y)
|
||||
self.assertEqual(node_map[final_expr], expected)
|
||||
|
||||
# x_var, y_var, zero_node remain unchanged
|
||||
self.assertEqual(node_map[x_var], x_var)
|
||||
self.assertEqual(node_map[y_var], y_var)
|
||||
self.assertEqual(node_map[zero_node], zero_node)
|
||||
|
||||
def test_complex_multi_var_edges(self):
|
||||
"""
|
||||
Build a multi-variable expression with multiple intermediates:
|
||||
|
||||
x_var = UOp.variable('x', 1, 10)
|
||||
y_var = UOp.variable('y', -5, 5)
|
||||
z_var = UOp.variable('z', 0, 5)
|
||||
zero_node = UOp.const(dtypes.int, 0)
|
||||
one_node = UOp.const(dtypes.int, 1)
|
||||
|
||||
yz_sum = y_var + z_var
|
||||
yz_sum_zero = yz_sum + zero_node -> rewrites to yz_sum
|
||||
yz_neg = -yz_sum_zero -> -(y+z)
|
||||
yz_dneg = -yz_neg -> y+z (double neg gone)
|
||||
x_plus_yz = x_var + yz_dneg -> x + (y+z)
|
||||
double_neg_x = -(-x_plus_yz) -> x + (y+z)
|
||||
final_expr = double_neg_x * one_node -> x + (y+z)
|
||||
|
||||
We expect the final result to be (x + (y+z)).
|
||||
Each original node should map to the final node that replaces it,
|
||||
which might be structurally equivalent but not the same reference.
|
||||
"""
|
||||
x_var = UOp.variable('x', 1, 10)
|
||||
y_var = UOp.variable('y', -5, 5)
|
||||
z_var = UOp.variable('z', 0, 5)
|
||||
zero_node = UOp.const(dtypes.index, 0)
|
||||
one_node = UOp.const(dtypes.index, 1)
|
||||
|
||||
# Build sub-expressions
|
||||
yz_sum = y_var + z_var # (y + z)
|
||||
yz_sum_zero = yz_sum + zero_node # (y + z) + 0
|
||||
yz_neg = -yz_sum_zero # -(y+z)
|
||||
yz_dneg = -yz_neg # -(-(y+z)) -> (y+z)
|
||||
x_plus_yz = x_var + yz_dneg # x + (y+z)
|
||||
double_neg_x = -(-x_plus_yz) # neg(neg(x+(y+z))) -> x+(y+z)
|
||||
final_expr = double_neg_x * one_node # (x+(y+z)) * 1 -> x+(y+z)
|
||||
|
||||
node_map = graph_rewrite_map(final_expr, symbolic)
|
||||
|
||||
# (y + z) is unchanged
|
||||
self.assertEqual(node_map[yz_sum], yz_sum)
|
||||
|
||||
# (y+z) + 0 => (y+z)
|
||||
self.assertEqual(node_map[yz_sum_zero], yz_sum)
|
||||
|
||||
# -(y+z) remains -(y+z), but might be a new UOp with updated children
|
||||
# Compare structurally to -(y_var + z_var).
|
||||
self.assertEqual(node_map[yz_neg], -yz_sum)
|
||||
|
||||
# -(-(y+z)) => (y+z)
|
||||
self.assertEqual(node_map[yz_dneg], yz_sum)
|
||||
|
||||
# x + (y+z) => might get recreated if yz_dneg was changed, so compare to x + yz_sum
|
||||
self.assertEqual(node_map[x_plus_yz], x_var + yz_sum)
|
||||
|
||||
# -(-(x+(y+z))) => x + (y+z)
|
||||
self.assertEqual(node_map[double_neg_x], x_var + yz_sum)
|
||||
|
||||
# (x+(y+z)) * 1 => x+(y+z)
|
||||
self.assertEqual(node_map[final_expr], x_var + yz_sum)
|
||||
|
||||
# Unchanged atomic nodes map to themselves
|
||||
self.assertEqual(node_map[x_var], x_var)
|
||||
self.assertEqual(node_map[y_var], y_var)
|
||||
self.assertEqual(node_map[z_var], z_var)
|
||||
self.assertEqual(node_map[zero_node], zero_node)
|
||||
self.assertEqual(node_map[one_node], one_node)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
from typing import cast
|
||||
from collections import deque
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, gate_kernel_sink
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, gate_kernel_sink
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE
|
||||
@@ -63,8 +63,8 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
||||
return pre_schedule, UOp.sink(*buf_uops_list)
|
||||
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||
from tinygrad.schedule.multi import get_multi_map
|
||||
from tinygrad.schedule.rangeify import get_rangeify
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
|
||||
def replace_input_buffer(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]], b:UOp):
|
||||
if (ret:=ctx[0].get(b, None)) is None:
|
||||
@@ -128,20 +128,11 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
# verify Tensors match the spec (on big_sink, we only need to do this if cache misses)
|
||||
if SPEC: type_verify(big_sink, tensor_spec)
|
||||
|
||||
# hack to preserve metadata
|
||||
graph_rewrite_map(big_sink, pm_pre_sched_cache, ctx=({}, {}, [0], [0]), name="preserve metadata")
|
||||
|
||||
# tensor map is what we return
|
||||
tensor_map: dict[UOp, UOp] = {}
|
||||
|
||||
if any(isinstance(x._device, tuple) for x in big_sink_cache.toposort()):
|
||||
tensor_map |= get_multi_map(big_sink_cache)
|
||||
big_sink_cache = big_sink_cache.substitute(tensor_map, name="Apply Multi Map")
|
||||
big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm")
|
||||
big_sink_cache = UOp.sink(*flatten([x.src if x.op is Ops.MULTI else [x] for x in big_sink_cache.src]))
|
||||
|
||||
tensor_map |= get_rangeify_map(big_sink_cache)
|
||||
big_sink = big_sink_cache.substitute(tensor_map, name="Apply Kernelize Map")
|
||||
|
||||
big_sink = get_rangeify(big_sink_cache)
|
||||
pre_schedule, buf_uops_sink = create_schedule(big_sink)
|
||||
if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, buf_uops_sink)
|
||||
else:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import functools, itertools
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, VIZ, getenv
|
||||
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, getenv
|
||||
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
# *** allreduce implementation ***
|
||||
@@ -187,9 +187,3 @@ multi_pm = PatternMatcher([
|
||||
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CALL)), name="a"),
|
||||
lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis)),
|
||||
])+replace_allreduce
|
||||
|
||||
def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
if VIZ: graph_rewrite(big_sink, PatternMatcher([]), name="View Multi AST")
|
||||
ret = graph_rewrite_map(big_sink, multi_pm, name="multi_pm")
|
||||
if VIZ: graph_rewrite(ret[big_sink], PatternMatcher([]), name="View Post Multi AST")
|
||||
return ret
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, pm_gate_kernel_sink
|
||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
|
||||
from tinygrad.helpers import argsort, prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
|
||||
from tinygrad.helpers import PCONTIG, partition, get_single_element
|
||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
|
||||
from tinygrad.codegen.opt import Opt
|
||||
@@ -503,7 +503,7 @@ pm_add_range_tags = PatternMatcher([
|
||||
(UPat(Ops.RANGE, name="x"), lambda x: x.rtag(())),
|
||||
])
|
||||
|
||||
def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
||||
def split_store(x:UOp) -> UOp|None:
|
||||
# if we have any open ranges here, we don't split
|
||||
if x.ranges: return None
|
||||
|
||||
@@ -511,9 +511,6 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
||||
lctx = LocalAddBufferContext()
|
||||
ret = graph_rewrite(x, to_define_global+pm_flatten_range+rangeify_codegen+pm_remove_tags, ctx=lctx, name="kernel split", bottom_up=True)
|
||||
|
||||
# gather the metadata
|
||||
metadatas = [ctx[y].metadata for y in lctx.parent_tags]
|
||||
|
||||
# SINK requires all buffers on the same device, but COPY/BUFFER_VIEW/ENCDEC are cross-device or special hardware ops
|
||||
if ret.op is Ops.STORE: stored = ret.src[1]
|
||||
elif ret.op is Ops.END and ret.src[0].op is Ops.STORE: stored = ret.src[0].src[1]
|
||||
@@ -521,8 +518,7 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
||||
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}: ret = stored
|
||||
else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts))
|
||||
|
||||
metadata = tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]
|
||||
kernel = ret.call(*lctx.map.values(), *lctx.vars.keys(), metadata=metadata)
|
||||
kernel = ret.call(*lctx.map.values(), *lctx.vars.keys())
|
||||
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src[1:] if x.op is not Ops.BIND]):
|
||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src[1:])}")
|
||||
return kernel
|
||||
@@ -561,7 +557,7 @@ replace_contiguous = PatternMatcher([
|
||||
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
||||
])
|
||||
|
||||
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
def get_rangeify(sink:UOp) -> UOp:
|
||||
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
|
||||
uop_list: list[UOp] = []
|
||||
tsink = graph_rewrite(sink, add_tags, ctx=(uop_list, set()), bottom_up=True, name="number the uops")
|
||||
@@ -586,7 +582,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
|
||||
tsink = graph_rewrite(tsink, pm_gate_kernel_sink+pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True,
|
||||
name="bufferize to store")
|
||||
tsink = graph_rewrite(tsink, pm_gate_kernel_sink+split_kernels, ctx=uop_list, bottom_up=True, name="split kernels")
|
||||
tsink = graph_rewrite(tsink, pm_gate_kernel_sink+split_kernels, bottom_up=True, name="split kernels")
|
||||
|
||||
# WAR deps: if kernel U reads buffer S, and S is also written by another kernel, S's write must wait for U to finish
|
||||
afters = [u for u in tsink.toposort() if u.op is Ops.AFTER]
|
||||
@@ -603,15 +599,6 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
if assign_rep: tsink = graph_rewrite(tsink, _substitute, ctx=assign_rep, bottom_up=True, name="fix_assign")
|
||||
|
||||
# TODO: we can probably get this earlier
|
||||
sink_tags = [s.tag for s in tsink.src]
|
||||
tsink = graph_rewrite(tsink, _remove_all_tags, name="remove all tags")
|
||||
|
||||
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
|
||||
|
||||
becomes_map: dict[UOp, UOp] = {}
|
||||
for tag, s in zip(sink_tags, tsink.src):
|
||||
assert tag is not None
|
||||
for a in tag:
|
||||
if a is None: continue
|
||||
becomes_map[uop_list[int(a)]] = s
|
||||
return becomes_map
|
||||
return tsink
|
||||
@@ -1311,18 +1311,6 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=N
|
||||
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx)
|
||||
return rewrite_ctx.unified_rewrite(sink)
|
||||
|
||||
@profile_matches
|
||||
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None,
|
||||
input_map:dict[UOp, UOp]|None=None, ) -> dict[UOp, UOp]:
|
||||
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx)
|
||||
new_map: dict[UOp, UOp] = {}
|
||||
for k in (list(sink.toposort())[::-1] if bottom_up else sink.toposort()):
|
||||
new_map[k] = v = rewrite_ctx.unified_rewrite(k)
|
||||
if k is not v and k.metadata is not None: all_metadata[v] = tuple(dedup(all_metadata.get(v, ())))+k.metadata
|
||||
if input_map is not None:
|
||||
for k,v in input_map.items(): new_map[k] = new_map.get(v,v)
|
||||
return new_map
|
||||
|
||||
def sint_to_uop(x:sint, dtype=dtypes.index) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x.cast(dtype)
|
||||
|
||||
def select_dtype(u): return (dtypes.long if u.overflows(dtypes.int32) else dtypes.int).vec(u.dtype.count)
|
||||
|
||||
Reference in New Issue
Block a user