fix tests for rewrite [pr] (#10167)

* fix tests for rewrite [pr]

* cleaner

* delete linearize_uop

* clean up the rest
This commit is contained in:
George Hotz
2025-05-05 19:19:49 -07:00
committed by GitHub
parent 10437904cd
commit 603c03bef2
16 changed files with 56 additions and 171 deletions

View File

@@ -9,8 +9,7 @@ from tinygrad.engine.realize import Runner
from tinygrad.dtype import ConstType, DType
from tinygrad.nn.state import get_parameters
from tinygrad.helpers import T, unwrap, CI
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.devectorizer import full_graph_rewrite
from tinygrad.codegen import full_rewrite
from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler, PythonAllocator
def derandomize_model(model):
@@ -59,8 +58,8 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None):
bufs.append(buf:=allocator.alloc(len(data) * buf_dt.itemsize))
allocator._copyin(buf, memoryview(struct.pack(str(len(data)) + buf_dt.fmt, *data)))
g = UOp(Ops.DEFINE_GLOBAL, uop.dtype.ptr(), arg=0, src=())
rw = full_graph_rewrite(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(), PythonRenderer)
prog = PythonProgram("run", PythonCompiler().compile(PythonRenderer().render(linearize_uop(rw))))
lst = full_rewrite(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(), PythonRenderer)
prog = PythonProgram("run", PythonCompiler().compile(PythonRenderer().render(lst)))
prog(out_buf:=allocator.alloc(uop.dtype.itemsize), *bufs)
return out_buf.cast(uop.dtype.fmt).tolist()[0]

View File

@@ -3,7 +3,7 @@ from typing import Any
from tinygrad import Tensor, Device, dtypes
from tinygrad.dtype import DType
from tinygrad.ops import Ops, UOp
from tinygrad.codegen.devectorizer import full_graph_rewrite
from tinygrad.codegen import full_rewrite_to_sink
import numpy as np
from tinygrad.device import is_dtype_supported
from test.helpers import not_support_multi_device
@@ -105,7 +105,7 @@ class TestBitcastConstFolding(unittest.TestCase):
def t(cases: dict[DType, Any]):
for (from_dt, from_v), (to_dt, to_v) in itertools.product(cases.items(), cases.items()):
if not math.isnan(from_v):
r = full_graph_rewrite(UOp.const(from_dt, from_v).bitcast(to_dt).sink()).src[0]
r = full_rewrite_to_sink(UOp.const(from_dt, from_v).bitcast(to_dt).sink()).src[0]
self.assertEqual(r.op, Ops.CONST, msg:=f"{from_dt} -> {to_dt} ({from_v} -> {to_v})")
self.assertEqual(r.dtype, to_dt, msg)
np.testing.assert_equal(r.arg, to_v, msg)
@@ -128,7 +128,7 @@ class TestBitcastConstFolding(unittest.TestCase):
t({dtypes.int64: 4598983288165178391, dtypes.uint64: 4598983288165178391, dtypes.float64: 0.29485681936461233})
def test_vec_bitcast(self):
r = full_graph_rewrite(UOp.const(dtypes.int32.vec(3), (-1, -2**31, 75)).bitcast(dtypes.uint32.vec(3)).sink()).src[0]
r = full_rewrite_to_sink(UOp.const(dtypes.int32.vec(3), (-1, -2**31, 75)).bitcast(dtypes.uint32.vec(3)).sink()).src[0]
self.assertEqual(r.op, Ops.VECTORIZE)
self.assertEqual(r.dtype, dtypes.uint32.vec(3))
self.assertEqual(tuple(x.arg for x in r.src), (2**32-1, 2**31, 75))

View File

@@ -1,8 +1,6 @@
import unittest
from typing import List, cast
import numpy as np
from tinygrad.codegen.devectorizer import full_graph_rewrite
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.device import Buffer, Device, is_dtype_supported
from tinygrad.dtype import dtypes
from tinygrad.engine.realize import CompiledRunner
@@ -13,6 +11,7 @@ from tinygrad.runtime.ops_python import PythonRenderer
from tinygrad.ops import UOp, Ops
from tinygrad.renderer import ProgramSpec
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.codegen import full_rewrite
def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
for x in inputs: x.realize()
@@ -35,7 +34,7 @@ class TestRendererFailures(unittest.TestCase):
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, gate_alu), UOp.const(dtypes.int, 1)))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
np.testing.assert_equal(ret, [0, 1, 1, 1])
@@ -46,7 +45,7 @@ class TestRendererFailures(unittest.TestCase):
gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx1', 2))).ne(0)
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0+lidx1*4, gate_alu_0&gate_alu_1), UOp.const(dtypes.int, 1)))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
ret = _test_uop_result([], uops, local_size=[4, 2, 1])[0]
np.testing.assert_equal(ret, [0, 0, 0, 0, 0, 1, 1, 1])
@@ -60,7 +59,7 @@ class TestCStyleFailures(unittest.TestCase):
alu = ld.alu(Ops.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1))
store = UOp.store(a.index(idx), alu)
sink = UOp(Ops.SINK, dtypes.void, (store,))
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
# CPU doesn't use the max function
ret = _test_uop_result([Tensor([1])], uops)[0]
self.assertEqual(ret[0], 1)
@@ -75,7 +74,7 @@ class TestPTXFailures(unittest.TestCase):
if_uop = UOp(Ops.IF, dtypes.void, (gate_alu,))
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, if_uop), val))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
np.testing.assert_equal(ret, [0, 1, 1, 1])

View File

@@ -10,9 +10,7 @@ from hypothesis import given, settings, strategies as strat
from tinygrad.device import is_dtype_supported
from tinygrad.ops import Ops, UOp
from tinygrad.runtime.support.compiler_cuda import PTX
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.devectorizer import full_graph_rewrite
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
from tinygrad.codegen import full_rewrite
from tinygrad.dtype import DType
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
@@ -847,7 +845,7 @@ class TestIdxUpcast(unittest.TestCase):
for s in schedule:
if s.ast.op is Ops.SINK:
renderer = Device[s.bufs[0].device].renderer
uops = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(s.ast, renderer), renderer))
uops = full_rewrite(s.ast, renderer)
renderer.render(uops)
return uops

View File

@@ -1,14 +1,11 @@
from typing import List
import unittest, time, pytest
from tinygrad import dtypes, Device, Variable
import unittest, pytest
from tinygrad import dtypes, Variable
from tinygrad.helpers import DEBUG, Context
from tinygrad.ops import Ops, UOp, KernelInfo, UPat, PatternMatcher, track_rewrites
from tinygrad.renderer import Renderer
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
from tinygrad.codegen.devectorizer import full_graph_rewrite, graph_rewrite, sym
from tinygrad.codegen.expander import expander, expand_rewrite
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.ops import Ops, UOp, UPat, PatternMatcher, track_rewrites, graph_rewrite
from tinygrad.codegen.symbolic import sym
from tinygrad.codegen import full_rewrite, full_rewrite_to_sink
from tinygrad.codegen.expander import expander
simple_pm = PatternMatcher([
(UPat.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
@@ -19,54 +16,10 @@ simple_pm = PatternMatcher([
def to_uops_list(u:List[UOp]) -> List[UOp]:
# we strip the SINK here for legacy reasons
ret = linearize_uop(full_graph_rewrite(UOp.sink(*u)))
ret = full_rewrite(UOp.sink(*u))
assert ret[-1].op is Ops.SINK
return ret[:-1]
class TestGraphRewriteEfficiency(unittest.TestCase):
def test_create_many_uops(self):
c1 = UOp.const(dtypes.int, 1)
c2 = UOp.const(dtypes.int, 2)
st = time.perf_counter()
uops = [UOp(Ops.ADD, dtypes.int, (c1, c2)) for _ in range(10000)]
et = time.perf_counter() - st
print(f"created {len(uops)} uops in {et*1000:.2f} ms")
def test_expand_rewrite(self):
sink = UOp(Ops.SINK, dtypes.void, arg=KernelInfo(local_dims=2, upcasted=4, dont_use_locals=False), src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 4, 64, 8, 16, 1, 1, 3, 3, 4, 1),
strides=(1179648, 9216, 1, 147456, 576, 0, 0, 64, 192, 36864, 0),
offset=0, mask=None, contiguous=False),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 10)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(
View(shape=(1, 1024, 1, 64, 4, 17, 4, 17), strides=(0, 14400, 0, 225, 0, 15, 0, 1), offset=-16,
mask=((0, 1), (0, 1024), (0, 1), (0, 64), (0, 4), (1, 16), (0, 4), (1, 16)), contiguous=False),
View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(0, 73984, 4734976, 0, 4624, 295936, 68, 18, 1224, 0, 1), offset=0,
mask=None, contiguous=False))), src=()),)),
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(
View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(7200, 0, 230400, 900, 0, 14400, 15, 0, 0, 225, 1), offset=0,
mask=None, contiguous=False),)), src=()),)),)),)),)),)),))
lower_sink = rewrite_shapetracker_with_index(sink, Device[Device.DEFAULT].renderer)
cnt = [0]
old_init = UOp.__init__
def uop_hook(self, *args, **kwargs):
cnt[0] += 1
old_init(self, *args, **kwargs)
UOp.__init__ = uop_hook
st = time.perf_counter()
new_sink = full_graph_rewrite(lower_sink)
et = time.perf_counter() - st
UOp.__init__ = old_init
print(f"rewrote in {et*1000:.2f} ms, from {len(lower_sink.toposort())} -> {len(new_sink.toposort())}, creating {cnt[0]} uops")
class TestGraphRewriteConst(unittest.TestCase):
def test_gep_const(self):
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
@@ -572,8 +525,6 @@ class TestUOpGraph(unittest.TestCase):
@track_rewrites()
def expander_rewrite(sink): return graph_rewrite(sink, sym + expander)
@track_rewrites()
def float4_rewrite(sink): return full_graph_rewrite(sink, Renderer())
class TestExpander(unittest.TestCase):
def test_expand_add_broadcast(self):
@@ -735,7 +686,7 @@ class TestIFUOps(unittest.TestCase):
lbuf = UOp(Ops.LOAD, dtypes.float, (sbuf.index(UOp.const(dtypes.int, 0)), barrier))
store = UOp(Ops.STORE, dtypes.void, (gbuf.index(UOp.const(dtypes.int, 0), gate), lbuf))
sink = UOp(Ops.SINK, dtypes.void, (store,))
sink = full_graph_rewrite(expand_rewrite(sink))
sink = full_rewrite_to_sink(sink)
if_uops = [u for u in sink.toposort() if u.op is Ops.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
@@ -753,7 +704,7 @@ class TestIFUOps(unittest.TestCase):
lbufs = [UOp(Ops.LOAD, dtypes.float, (sbuf.index(UOp.const(dtypes.int, i)), barrier)) for i in range(4)]
stores = [UOp(Ops.STORE, dtypes.void, (gbuf.index(UOp.const(dtypes.int, i), gate), lbufs[i])) for i in range(4)]
sink = UOp(Ops.SINK, dtypes.void, tuple(stores))
sink = full_graph_rewrite(expand_rewrite(sink))
sink = full_rewrite_to_sink(sink)
if_uops = [u for u in sink.toposort() if u.op is Ops.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
@@ -769,7 +720,7 @@ class TestIFUOps(unittest.TestCase):
gate = valid&(lidx.ne(2))
stores = [UOp(Ops.STORE, dtypes.void, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
sink = UOp(Ops.SINK, dtypes.void, tuple(stores))
sink = full_graph_rewrite(sink)
sink = full_rewrite_to_sink(sink)
if_uops = [u for u in sink.toposort() if u.op is Ops.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)

View File

@@ -12,16 +12,15 @@ from tinygrad.spec import spec
from tinygrad.renderer import ProgramSpec
from tinygrad.engine.grouper import fix_kernel_ops
from tinygrad.engine.realize import CompiledRunner, get_kernel
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.devectorizer import full_graph_rewrite
from tinygrad.codegen import full_rewrite
from tinygrad.codegen.symbolic import sym
from tinygrad.device import is_dtype_supported
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
def to_uops_list(u:list[UOp], opts=None, skip_check=False) -> list[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check)
def to_uops_list(u:list[UOp], opts=None, skip_check=False) -> list[UOp]: return full_rewrite(UOp.sink(*u), opts)
def _uops_to_prg(uops_list):
uops = linearize_uop(full_graph_rewrite(ast:=UOp.sink(*uops_list), opts=Device[Device.DEFAULT].renderer))
uops = full_rewrite(ast:=UOp.sink(*uops_list), opts=Device[Device.DEFAULT].renderer)
src = Device[Device.DEFAULT].renderer.render(uops)
has_local = Device[Device.DEFAULT].renderer.has_local
return CompiledRunner(ProgramSpec("test", src, Device.DEFAULT, ast, uops=uops,
@@ -503,7 +502,7 @@ class TestIndexingOrdering(unittest.TestCase):
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
st0 = UOp(Ops.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
st1 = UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
uops = linearize_uop(UOp.sink(st1, st0), skip_check=True)
uops = full_rewrite(UOp.sink(st1, st0))
stores = [st for st in uops if st.op is Ops.STORE]
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"

View File

@@ -3,7 +3,7 @@ from tinygrad import Tensor
from tinygrad.helpers import getenv, GlobalCounters
from tinygrad.engine.realize import lower_schedule_item, ProgramSpec
from tinygrad.renderer import Estimates
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen import full_rewrite
from tinygrad.ops import Ops, UOp
from tinygrad.dtype import dtypes
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
@@ -144,7 +144,7 @@ class TestUOpsStats(unittest.TestCase):
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
u4 = UOp(Ops.MUL, dtypes.int, (u1,u2))
u5 = UOp(Ops.ADD, dtypes.int, (u4,u3))
uops = linearize_uop(u5.sink())
uops = full_rewrite(u5.sink())
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
@@ -153,7 +153,7 @@ class TestUOpsStats(unittest.TestCase):
u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),))
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
u4 = UOp(Ops.MULACC, dtypes.int, (u1,u2,u3))
uops_fma = linearize_uop(u4.sink())
uops_fma = full_rewrite(u4.sink())
self.assertEqual(flops_mem(uops), flops_mem(uops_fma))

View File

@@ -2,11 +2,11 @@ import unittest, math
from tinygrad import dtypes
from tinygrad.helpers import all_same
from tinygrad.ops import GroupOp, UOp, Ops, exec_alu
from tinygrad.codegen.devectorizer import full_graph_rewrite
from tinygrad.codegen import full_rewrite_to_sink
# Helper function to apply the graph rewrite
def apply_rewrite(expr):
return full_graph_rewrite(expr.sink()).src[0]
return full_rewrite_to_sink(expr.sink()).src[0]
def evaluate_uop(uop, variables):
if uop.op == Ops.CONST:
@@ -145,7 +145,7 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
class TestEdgeCasesAndSpecialOperations(unittest.TestCase):
def test_full_graph_rewrite_transcendental_edge_cases(self):
optimized_sink = full_graph_rewrite(UOp.const(dtypes.float32, -1.0).log2().sink(UOp.const(dtypes.float32, 0.0).reciprocal()))
optimized_sink = full_rewrite_to_sink(UOp.const(dtypes.float32, -1.0).log2().sink(UOp.const(dtypes.float32, 0.0).reciprocal()))
optimized_log2_neg, optimized_recip_zero = optimized_sink.src
self.assertTrue(math.isnan(optimized_log2_neg.arg), f"Expected NaN for log2(-1.0), got {optimized_log2_neg.arg}")
self.assertTrue(math.isinf(optimized_recip_zero.arg) and optimized_recip_zero.arg > 0,
@@ -154,14 +154,14 @@ class TestEdgeCasesAndSpecialOperations(unittest.TestCase):
@unittest.skip("broken")
def test_full_graph_rewrite_modulo_negative_dividend(self):
x_var_uop = UOp.variable('x', -5, -1)
optimized_sink = full_graph_rewrite((x_var_uop % 3).sink())
optimized_sink = full_rewrite_to_sink((x_var_uop % 3).sink())
for x_value in range(-5, 0):
self.assertEqual(x_value % 3, evaluate_uop(optimized_sink.src[0], {'x': x_value}))
@unittest.skip("broken")
def test_full_graph_rewrite_division_negative_divisor(self):
x_var_uop = UOp.variable('x', 1, 5)
optimized_sink = full_graph_rewrite((x_var_uop // -2).sink())
optimized_sink = full_rewrite_to_sink((x_var_uop // -2).sink())
for x_value in range(1, 6):
self.assertEqual(x_value // -2, evaluate_uop(optimized_sink.src[0], {'x': x_value}))

View File

@@ -1,6 +1,6 @@
import unittest, itertools
from tinygrad.codegen.devectorizer import full_graph_rewrite
from tinygrad.codegen import full_rewrite_to_sink
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, Ops
from tinygrad.codegen.symbolic import simplify_valid
@@ -45,7 +45,7 @@ class TestHelpers(unittest.TestCase):
class TestValidIdxSimplification(unittest.TestCase):
def check(self, load, sidx, svalid):
load = full_graph_rewrite(load.sink()).src[0]
load = full_rewrite_to_sink(load.sink()).src[0]
idx, valid = load.src[0].src[1], load.src[0].src[2]
self.assertEqual(idx.render(simplify=False), sidx)
self.assertEqual(valid.render(simplify=False), svalid)
@@ -167,7 +167,7 @@ class TestValidIdxSimplification(unittest.TestCase):
class TestImageSimplification(unittest.TestCase):
def check(self, load, svalid, sidx0, sidx1):
load = full_graph_rewrite(load.sink()).src[0]
load = full_rewrite_to_sink(load.sink()).src[0]
idx = load.src[0].src[1]
self.assertEqual(idx.op, Ops.VECTORIZE)
self.assertEqual(len(idx.src), 2)
@@ -233,7 +233,7 @@ class TestImageSimplification(unittest.TestCase):
# empty -> invalid
load = get_load_image_uop(shape, (gidx0<8) & (gidx0<8).ne(True), idx)
load = full_graph_rewrite(load.sink()).src[0]
load = full_rewrite_to_sink(load.sink()).src[0]
self.assertEqual(load.op, Ops.VECTORIZE)
self.assertEqual(load.dtype.count, 4)

View File

@@ -2,8 +2,8 @@
import unittest, pickle
from tinygrad.dtype import dtypes, ConstType
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.devectorizer import full_graph_rewrite, sym
from tinygrad.codegen import full_rewrite
from tinygrad.codegen.devectorizer import sym
from tinygrad.ops import UOp, Ops, graph_rewrite, sym_infer
from tinygrad import Variable
import functools
@@ -11,7 +11,7 @@ import functools
def render(self) -> tuple[str, ConstType, ConstType]:
# NOTE: we need STORE so the ALU op has children
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
uops = linearize_uop(full_graph_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), self)).sink()))
uops = full_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), self)).sink())
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[-1]
return rewritten_uop.render(simplify=False), rewritten_uop.vmin, rewritten_uop.vmax
@@ -569,7 +569,7 @@ class TestSymbolic(unittest.TestCase):
# TODO: copied from render, render does not support cast
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
uops = linearize_uop(full_graph_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink()))
uops = full_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink())
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[-1]
self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half)))

View File

@@ -72,3 +72,7 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
ret.append(RewriteStep(block_merge, name="Linearizer: Merge Blocks"))
ret.append(RewriteStep(pm_finalize, name="Linearizer: Finalize"))
return ret
def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, linearizer:bool=False) -> UOp:
return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), linearizer))
def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]: return list(full_rewrite_to_sink(sink, opts, linearizer=True).arg.lst)

View File

@@ -1,12 +1,12 @@
from typing import Optional, Any, Callable, cast
from typing import Any, Callable, cast
import functools, operator, itertools
from collections import defaultdict
from dataclasses import dataclass
from tinygrad.device import is_dtype_supported
from tinygrad.dtype import dtypes, ImageDType, PtrDType, promo_lattice, DType
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve, graph_rewrite, GroupOp, identity_element
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat, gep_pushing
from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE, partition
from tinygrad.codegen.symbolic import split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
from tinygrad.helpers import getenv, flatten, AMX, prod, partition
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
from tinygrad.renderer import Renderer
@@ -429,26 +429,3 @@ pm_reduce = PatternMatcher([
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
])+sym
# *** uop graph ***
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
# remove reduce
sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce")
# devectorize is optional
if DEVECTORIZE >= 2: sink = graph_rewrite(sink, sym+load_store_folding+load_store_indexing, ctx=opts)
elif DEVECTORIZE: sink = graph_rewrite(sink, sym+devectorize+load_store_folding+correct_load_store+load_store_indexing, ctx=opts)
else: sink = graph_rewrite(sink, sym+load_store_folding+correct_load_store+load_store_indexing, ctx=opts)
# optional pre matcher
if opts is not None and opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher)
# final rules for the renderer (without sym)
sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher,
ctx=opts, name="final rewrite")
return sink

View File

@@ -2,8 +2,7 @@
import functools, itertools, operator
from tinygrad.helpers import AMX, dedup, flatten, all_same, prod
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, graph_rewrite
from tinygrad.codegen.symbolic import sym
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, GroupOp
def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) -> int:
idx, mul = 0, 1
@@ -137,16 +136,3 @@ pm_delete_ignore = PatternMatcher([
# IGNORE on SELF is nothing
(UPat(Ops.IGNORE, src=(UPat(name="x"), UPat())), lambda x: x),
])
def expand_rewrite(sink:UOp) -> UOp:
# initial symbolic + migrate indexing (remove this)
sink = graph_rewrite(sink, sym+migrate_indexing)
# store IGNORE
sink = graph_rewrite(sink, pm_store_ignore, name="store_ignore")
# move IGNORE
sink = graph_rewrite(sink, pm_move_ignore, name="move_ignore")
# expand + remove surviving ignores
return graph_rewrite(sink, pm_delete_ignore+sym+expander)

View File

@@ -16,7 +16,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import strides_for_shape
from tinygrad.codegen.lowerer import get_contraction
from tinygrad.engine.grouper import view_left
from tinygrad.codegen import get_rewrites_for_renderer, apply_rewrites
from tinygrad.codegen import full_rewrite
class KernelOptError(Exception): pass
@@ -553,8 +553,7 @@ class Kernel:
#if __debug__: type_verify(list(modified_ast.toposort()), shape_spec)
try:
rewrite_list = get_rewrites_for_renderer(self.opts)
self.uops:list[UOp] = list(apply_rewrites(modified_ast, rewrite_list).arg.lst)
self.uops:list[UOp] = full_rewrite(modified_ast, self.opts)
except RuntimeError:
print("***** LINEARIZE FAILURE *****")
print(f"ast = {self.ast}")

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import heapq
from collections import defaultdict
from dataclasses import dataclass, replace
from tinygrad.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat, GroupOp
from tinygrad.ops import UOp, Ops, PatternMatcher, UPat, GroupOp
from tinygrad.helpers import dedup, partition, all_same, flatten
from tinygrad.spec import type_verify
@@ -243,23 +243,3 @@ def finalize(sink:UOp) -> UOp:
return UOp(Ops.BLOCKFINAL, arg=BasicBlock2(tuple(lst)))
pm_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), finalize)])
def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]:
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
# get block context
ctx = BlockContext.from_sink(sink)
# wrap all uops in blocks, already reordered
sink = graph_rewrite(sink, block_create, ctx=ctx, name="Linearizer: Create Blocks", bottom_up=True)
# merge blockends
sink = graph_rewrite(sink, pm_blockend_merge, name="Linearizer: Merge Blockends")
# merge blocks
sink = graph_rewrite(sink, block_merge, name="Linearizer: Merge Blocks")
# finalize
sink = graph_rewrite(sink, pm_finalize, name="Linearizer: Finalize")
return list(sink.arg.lst)

View File

@@ -3,10 +3,9 @@ import itertools, operator, math
from dataclasses import dataclass
from typing import cast
from tinygrad.dtype import dtypes, PtrDType, least_upper_dtype
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, sint_to_uop
from tinygrad.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint, sint_to_uop
from tinygrad.renderer import Renderer
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap, QUANTIZE
from tinygrad.codegen.expander import expand_rewrite
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap
from tinygrad.codegen.symbolic import symbolic
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
@@ -233,9 +232,3 @@ pm_quant = symbolic+PatternMatcher([
(UPat(Ops.REDUCE_AXIS, src=((UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")),), name="r"),
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,)) + r.replace(src=(c1*c2,))),
])
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
if QUANTIZE and opts.device in {"CPU", "DSP"}: ast = graph_rewrite(ast, pm_quant, name="quantize")
sink = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
# expand_rewrite turns this into a vectorized program
return expand_rewrite(sink)