mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fix tests for rewrite [pr] (#10167)
* fix tests for rewrite [pr] * cleaner * delete linearize_uop * clean up the rest
This commit is contained in:
@@ -9,8 +9,7 @@ from tinygrad.engine.realize import Runner
|
||||
from tinygrad.dtype import ConstType, DType
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.helpers import T, unwrap, CI
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.codegen.devectorizer import full_graph_rewrite
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler, PythonAllocator
|
||||
|
||||
def derandomize_model(model):
|
||||
@@ -59,8 +58,8 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None):
|
||||
bufs.append(buf:=allocator.alloc(len(data) * buf_dt.itemsize))
|
||||
allocator._copyin(buf, memoryview(struct.pack(str(len(data)) + buf_dt.fmt, *data)))
|
||||
g = UOp(Ops.DEFINE_GLOBAL, uop.dtype.ptr(), arg=0, src=())
|
||||
rw = full_graph_rewrite(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(), PythonRenderer)
|
||||
prog = PythonProgram("run", PythonCompiler().compile(PythonRenderer().render(linearize_uop(rw))))
|
||||
lst = full_rewrite(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(), PythonRenderer)
|
||||
prog = PythonProgram("run", PythonCompiler().compile(PythonRenderer().render(lst)))
|
||||
prog(out_buf:=allocator.alloc(uop.dtype.itemsize), *bufs)
|
||||
return out_buf.cast(uop.dtype.fmt).tolist()[0]
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.ops import Ops, UOp
|
||||
from tinygrad.codegen.devectorizer import full_graph_rewrite
|
||||
from tinygrad.codegen import full_rewrite_to_sink
|
||||
import numpy as np
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from test.helpers import not_support_multi_device
|
||||
@@ -105,7 +105,7 @@ class TestBitcastConstFolding(unittest.TestCase):
|
||||
def t(cases: dict[DType, Any]):
|
||||
for (from_dt, from_v), (to_dt, to_v) in itertools.product(cases.items(), cases.items()):
|
||||
if not math.isnan(from_v):
|
||||
r = full_graph_rewrite(UOp.const(from_dt, from_v).bitcast(to_dt).sink()).src[0]
|
||||
r = full_rewrite_to_sink(UOp.const(from_dt, from_v).bitcast(to_dt).sink()).src[0]
|
||||
self.assertEqual(r.op, Ops.CONST, msg:=f"{from_dt} -> {to_dt} ({from_v} -> {to_v})")
|
||||
self.assertEqual(r.dtype, to_dt, msg)
|
||||
np.testing.assert_equal(r.arg, to_v, msg)
|
||||
@@ -128,7 +128,7 @@ class TestBitcastConstFolding(unittest.TestCase):
|
||||
t({dtypes.int64: 4598983288165178391, dtypes.uint64: 4598983288165178391, dtypes.float64: 0.29485681936461233})
|
||||
|
||||
def test_vec_bitcast(self):
|
||||
r = full_graph_rewrite(UOp.const(dtypes.int32.vec(3), (-1, -2**31, 75)).bitcast(dtypes.uint32.vec(3)).sink()).src[0]
|
||||
r = full_rewrite_to_sink(UOp.const(dtypes.int32.vec(3), (-1, -2**31, 75)).bitcast(dtypes.uint32.vec(3)).sink()).src[0]
|
||||
self.assertEqual(r.op, Ops.VECTORIZE)
|
||||
self.assertEqual(r.dtype, dtypes.uint32.vec(3))
|
||||
self.assertEqual(tuple(x.arg for x in r.src), (2**32-1, 2**31, 75))
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import unittest
|
||||
from typing import List, cast
|
||||
import numpy as np
|
||||
from tinygrad.codegen.devectorizer import full_graph_rewrite
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.device import Buffer, Device, is_dtype_supported
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
@@ -13,6 +11,7 @@ from tinygrad.runtime.ops_python import PythonRenderer
|
||||
from tinygrad.ops import UOp, Ops
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.codegen import full_rewrite
|
||||
|
||||
def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
|
||||
for x in inputs: x.realize()
|
||||
@@ -35,7 +34,7 @@ class TestRendererFailures(unittest.TestCase):
|
||||
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, gate_alu), UOp.const(dtypes.int, 1)))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
|
||||
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
|
||||
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
|
||||
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
|
||||
np.testing.assert_equal(ret, [0, 1, 1, 1])
|
||||
|
||||
@@ -46,7 +45,7 @@ class TestRendererFailures(unittest.TestCase):
|
||||
gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx1', 2))).ne(0)
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0+lidx1*4, gate_alu_0&gate_alu_1), UOp.const(dtypes.int, 1)))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
|
||||
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
|
||||
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
|
||||
ret = _test_uop_result([], uops, local_size=[4, 2, 1])[0]
|
||||
np.testing.assert_equal(ret, [0, 0, 0, 0, 0, 1, 1, 1])
|
||||
|
||||
@@ -60,7 +59,7 @@ class TestCStyleFailures(unittest.TestCase):
|
||||
alu = ld.alu(Ops.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1))
|
||||
store = UOp.store(a.index(idx), alu)
|
||||
sink = UOp(Ops.SINK, dtypes.void, (store,))
|
||||
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
|
||||
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
|
||||
# CPU doesn't use the max function
|
||||
ret = _test_uop_result([Tensor([1])], uops)[0]
|
||||
self.assertEqual(ret[0], 1)
|
||||
@@ -75,7 +74,7 @@ class TestPTXFailures(unittest.TestCase):
|
||||
if_uop = UOp(Ops.IF, dtypes.void, (gate_alu,))
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, if_uop), val))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
|
||||
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
|
||||
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
|
||||
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
|
||||
np.testing.assert_equal(ret, [0, 1, 1, 1])
|
||||
|
||||
|
||||
@@ -10,9 +10,7 @@ from hypothesis import given, settings, strategies as strat
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.ops import Ops, UOp
|
||||
from tinygrad.runtime.support.compiler_cuda import PTX
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.codegen.devectorizer import full_graph_rewrite
|
||||
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.dtype import DType
|
||||
|
||||
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
||||
@@ -847,7 +845,7 @@ class TestIdxUpcast(unittest.TestCase):
|
||||
for s in schedule:
|
||||
if s.ast.op is Ops.SINK:
|
||||
renderer = Device[s.bufs[0].device].renderer
|
||||
uops = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(s.ast, renderer), renderer))
|
||||
uops = full_rewrite(s.ast, renderer)
|
||||
renderer.render(uops)
|
||||
return uops
|
||||
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
from typing import List
|
||||
import unittest, time, pytest
|
||||
from tinygrad import dtypes, Device, Variable
|
||||
import unittest, pytest
|
||||
from tinygrad import dtypes, Variable
|
||||
from tinygrad.helpers import DEBUG, Context
|
||||
from tinygrad.ops import Ops, UOp, KernelInfo, UPat, PatternMatcher, track_rewrites
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
|
||||
from tinygrad.codegen.devectorizer import full_graph_rewrite, graph_rewrite, sym
|
||||
from tinygrad.codegen.expander import expander, expand_rewrite
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.ops import Ops, UOp, UPat, PatternMatcher, track_rewrites, graph_rewrite
|
||||
from tinygrad.codegen.symbolic import sym
|
||||
from tinygrad.codegen import full_rewrite, full_rewrite_to_sink
|
||||
from tinygrad.codegen.expander import expander
|
||||
|
||||
simple_pm = PatternMatcher([
|
||||
(UPat.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
|
||||
@@ -19,54 +16,10 @@ simple_pm = PatternMatcher([
|
||||
|
||||
def to_uops_list(u:List[UOp]) -> List[UOp]:
|
||||
# we strip the SINK here for legacy reasons
|
||||
ret = linearize_uop(full_graph_rewrite(UOp.sink(*u)))
|
||||
ret = full_rewrite(UOp.sink(*u))
|
||||
assert ret[-1].op is Ops.SINK
|
||||
return ret[:-1]
|
||||
|
||||
class TestGraphRewriteEfficiency(unittest.TestCase):
|
||||
def test_create_many_uops(self):
|
||||
c1 = UOp.const(dtypes.int, 1)
|
||||
c2 = UOp.const(dtypes.int, 2)
|
||||
st = time.perf_counter()
|
||||
uops = [UOp(Ops.ADD, dtypes.int, (c1, c2)) for _ in range(10000)]
|
||||
et = time.perf_counter() - st
|
||||
print(f"created {len(uops)} uops in {et*1000:.2f} ms")
|
||||
|
||||
def test_expand_rewrite(self):
|
||||
sink = UOp(Ops.SINK, dtypes.void, arg=KernelInfo(local_dims=2, upcasted=4, dont_use_locals=False), src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 4, 64, 8, 16, 1, 1, 3, 3, 4, 1),
|
||||
strides=(1179648, 9216, 1, 147456, 576, 0, 0, 64, 192, 36864, 0),
|
||||
offset=0, mask=None, contiguous=False),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 10)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(
|
||||
View(shape=(1, 1024, 1, 64, 4, 17, 4, 17), strides=(0, 14400, 0, 225, 0, 15, 0, 1), offset=-16,
|
||||
mask=((0, 1), (0, 1024), (0, 1), (0, 64), (0, 4), (1, 16), (0, 4), (1, 16)), contiguous=False),
|
||||
View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(0, 73984, 4734976, 0, 4624, 295936, 68, 18, 1224, 0, 1), offset=0,
|
||||
mask=None, contiguous=False))), src=()),)),
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(
|
||||
View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(7200, 0, 230400, 900, 0, 14400, 15, 0, 0, 225, 1), offset=0,
|
||||
mask=None, contiguous=False),)), src=()),)),)),)),)),)),))
|
||||
lower_sink = rewrite_shapetracker_with_index(sink, Device[Device.DEFAULT].renderer)
|
||||
cnt = [0]
|
||||
old_init = UOp.__init__
|
||||
def uop_hook(self, *args, **kwargs):
|
||||
cnt[0] += 1
|
||||
old_init(self, *args, **kwargs)
|
||||
UOp.__init__ = uop_hook
|
||||
st = time.perf_counter()
|
||||
new_sink = full_graph_rewrite(lower_sink)
|
||||
et = time.perf_counter() - st
|
||||
UOp.__init__ = old_init
|
||||
print(f"rewrote in {et*1000:.2f} ms, from {len(lower_sink.toposort())} -> {len(new_sink.toposort())}, creating {cnt[0]} uops")
|
||||
|
||||
class TestGraphRewriteConst(unittest.TestCase):
|
||||
def test_gep_const(self):
|
||||
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
|
||||
@@ -572,8 +525,6 @@ class TestUOpGraph(unittest.TestCase):
|
||||
|
||||
@track_rewrites()
|
||||
def expander_rewrite(sink): return graph_rewrite(sink, sym + expander)
|
||||
@track_rewrites()
|
||||
def float4_rewrite(sink): return full_graph_rewrite(sink, Renderer())
|
||||
|
||||
class TestExpander(unittest.TestCase):
|
||||
def test_expand_add_broadcast(self):
|
||||
@@ -735,7 +686,7 @@ class TestIFUOps(unittest.TestCase):
|
||||
lbuf = UOp(Ops.LOAD, dtypes.float, (sbuf.index(UOp.const(dtypes.int, 0)), barrier))
|
||||
store = UOp(Ops.STORE, dtypes.void, (gbuf.index(UOp.const(dtypes.int, 0), gate), lbuf))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (store,))
|
||||
sink = full_graph_rewrite(expand_rewrite(sink))
|
||||
sink = full_rewrite_to_sink(sink)
|
||||
if_uops = [u for u in sink.toposort() if u.op is Ops.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
@@ -753,7 +704,7 @@ class TestIFUOps(unittest.TestCase):
|
||||
lbufs = [UOp(Ops.LOAD, dtypes.float, (sbuf.index(UOp.const(dtypes.int, i)), barrier)) for i in range(4)]
|
||||
stores = [UOp(Ops.STORE, dtypes.void, (gbuf.index(UOp.const(dtypes.int, i), gate), lbufs[i])) for i in range(4)]
|
||||
sink = UOp(Ops.SINK, dtypes.void, tuple(stores))
|
||||
sink = full_graph_rewrite(expand_rewrite(sink))
|
||||
sink = full_rewrite_to_sink(sink)
|
||||
if_uops = [u for u in sink.toposort() if u.op is Ops.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
@@ -769,7 +720,7 @@ class TestIFUOps(unittest.TestCase):
|
||||
gate = valid&(lidx.ne(2))
|
||||
stores = [UOp(Ops.STORE, dtypes.void, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
|
||||
sink = UOp(Ops.SINK, dtypes.void, tuple(stores))
|
||||
sink = full_graph_rewrite(sink)
|
||||
sink = full_rewrite_to_sink(sink)
|
||||
if_uops = [u for u in sink.toposort() if u.op is Ops.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
|
||||
@@ -12,16 +12,15 @@ from tinygrad.spec import spec
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.engine.grouper import fix_kernel_ops
|
||||
from tinygrad.engine.realize import CompiledRunner, get_kernel
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.codegen.devectorizer import full_graph_rewrite
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.codegen.symbolic import sym
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||
|
||||
def to_uops_list(u:list[UOp], opts=None, skip_check=False) -> list[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check)
|
||||
def to_uops_list(u:list[UOp], opts=None, skip_check=False) -> list[UOp]: return full_rewrite(UOp.sink(*u), opts)
|
||||
|
||||
def _uops_to_prg(uops_list):
|
||||
uops = linearize_uop(full_graph_rewrite(ast:=UOp.sink(*uops_list), opts=Device[Device.DEFAULT].renderer))
|
||||
uops = full_rewrite(ast:=UOp.sink(*uops_list), opts=Device[Device.DEFAULT].renderer)
|
||||
src = Device[Device.DEFAULT].renderer.render(uops)
|
||||
has_local = Device[Device.DEFAULT].renderer.has_local
|
||||
return CompiledRunner(ProgramSpec("test", src, Device.DEFAULT, ast, uops=uops,
|
||||
@@ -503,7 +502,7 @@ class TestIndexingOrdering(unittest.TestCase):
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||
st0 = UOp(Ops.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||
st1 = UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||
uops = linearize_uop(UOp.sink(st1, st0), skip_check=True)
|
||||
uops = full_rewrite(UOp.sink(st1, st0))
|
||||
stores = [st for st in uops if st.op is Ops.STORE]
|
||||
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from tinygrad import Tensor
|
||||
from tinygrad.helpers import getenv, GlobalCounters
|
||||
from tinygrad.engine.realize import lower_schedule_item, ProgramSpec
|
||||
from tinygrad.renderer import Estimates
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.ops import Ops, UOp
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
|
||||
@@ -144,7 +144,7 @@ class TestUOpsStats(unittest.TestCase):
|
||||
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = UOp(Ops.MUL, dtypes.int, (u1,u2))
|
||||
u5 = UOp(Ops.ADD, dtypes.int, (u4,u3))
|
||||
uops = linearize_uop(u5.sink())
|
||||
uops = full_rewrite(u5.sink())
|
||||
|
||||
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
|
||||
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
|
||||
@@ -153,7 +153,7 @@ class TestUOpsStats(unittest.TestCase):
|
||||
u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),))
|
||||
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = UOp(Ops.MULACC, dtypes.int, (u1,u2,u3))
|
||||
uops_fma = linearize_uop(u4.sink())
|
||||
uops_fma = full_rewrite(u4.sink())
|
||||
|
||||
self.assertEqual(flops_mem(uops), flops_mem(uops_fma))
|
||||
|
||||
|
||||
@@ -2,11 +2,11 @@ import unittest, math
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.helpers import all_same
|
||||
from tinygrad.ops import GroupOp, UOp, Ops, exec_alu
|
||||
from tinygrad.codegen.devectorizer import full_graph_rewrite
|
||||
from tinygrad.codegen import full_rewrite_to_sink
|
||||
|
||||
# Helper function to apply the graph rewrite
|
||||
def apply_rewrite(expr):
|
||||
return full_graph_rewrite(expr.sink()).src[0]
|
||||
return full_rewrite_to_sink(expr.sink()).src[0]
|
||||
|
||||
def evaluate_uop(uop, variables):
|
||||
if uop.op == Ops.CONST:
|
||||
@@ -145,7 +145,7 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
|
||||
|
||||
class TestEdgeCasesAndSpecialOperations(unittest.TestCase):
|
||||
def test_full_graph_rewrite_transcendental_edge_cases(self):
|
||||
optimized_sink = full_graph_rewrite(UOp.const(dtypes.float32, -1.0).log2().sink(UOp.const(dtypes.float32, 0.0).reciprocal()))
|
||||
optimized_sink = full_rewrite_to_sink(UOp.const(dtypes.float32, -1.0).log2().sink(UOp.const(dtypes.float32, 0.0).reciprocal()))
|
||||
optimized_log2_neg, optimized_recip_zero = optimized_sink.src
|
||||
self.assertTrue(math.isnan(optimized_log2_neg.arg), f"Expected NaN for log2(-1.0), got {optimized_log2_neg.arg}")
|
||||
self.assertTrue(math.isinf(optimized_recip_zero.arg) and optimized_recip_zero.arg > 0,
|
||||
@@ -154,14 +154,14 @@ class TestEdgeCasesAndSpecialOperations(unittest.TestCase):
|
||||
@unittest.skip("broken")
|
||||
def test_full_graph_rewrite_modulo_negative_dividend(self):
|
||||
x_var_uop = UOp.variable('x', -5, -1)
|
||||
optimized_sink = full_graph_rewrite((x_var_uop % 3).sink())
|
||||
optimized_sink = full_rewrite_to_sink((x_var_uop % 3).sink())
|
||||
for x_value in range(-5, 0):
|
||||
self.assertEqual(x_value % 3, evaluate_uop(optimized_sink.src[0], {'x': x_value}))
|
||||
|
||||
@unittest.skip("broken")
|
||||
def test_full_graph_rewrite_division_negative_divisor(self):
|
||||
x_var_uop = UOp.variable('x', 1, 5)
|
||||
optimized_sink = full_graph_rewrite((x_var_uop // -2).sink())
|
||||
optimized_sink = full_rewrite_to_sink((x_var_uop // -2).sink())
|
||||
for x_value in range(1, 6):
|
||||
self.assertEqual(x_value // -2, evaluate_uop(optimized_sink.src[0], {'x': x_value}))
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import unittest, itertools
|
||||
|
||||
from tinygrad.codegen.devectorizer import full_graph_rewrite
|
||||
from tinygrad.codegen import full_rewrite_to_sink
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, Ops
|
||||
from tinygrad.codegen.symbolic import simplify_valid
|
||||
@@ -45,7 +45,7 @@ class TestHelpers(unittest.TestCase):
|
||||
|
||||
class TestValidIdxSimplification(unittest.TestCase):
|
||||
def check(self, load, sidx, svalid):
|
||||
load = full_graph_rewrite(load.sink()).src[0]
|
||||
load = full_rewrite_to_sink(load.sink()).src[0]
|
||||
idx, valid = load.src[0].src[1], load.src[0].src[2]
|
||||
self.assertEqual(idx.render(simplify=False), sidx)
|
||||
self.assertEqual(valid.render(simplify=False), svalid)
|
||||
@@ -167,7 +167,7 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
|
||||
class TestImageSimplification(unittest.TestCase):
|
||||
def check(self, load, svalid, sidx0, sidx1):
|
||||
load = full_graph_rewrite(load.sink()).src[0]
|
||||
load = full_rewrite_to_sink(load.sink()).src[0]
|
||||
idx = load.src[0].src[1]
|
||||
self.assertEqual(idx.op, Ops.VECTORIZE)
|
||||
self.assertEqual(len(idx.src), 2)
|
||||
@@ -233,7 +233,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
|
||||
# empty -> invalid
|
||||
load = get_load_image_uop(shape, (gidx0<8) & (gidx0<8).ne(True), idx)
|
||||
load = full_graph_rewrite(load.sink()).src[0]
|
||||
load = full_rewrite_to_sink(load.sink()).src[0]
|
||||
self.assertEqual(load.op, Ops.VECTORIZE)
|
||||
self.assertEqual(load.dtype.count, 4)
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
import unittest, pickle
|
||||
|
||||
from tinygrad.dtype import dtypes, ConstType
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.codegen.devectorizer import full_graph_rewrite, sym
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.codegen.devectorizer import sym
|
||||
from tinygrad.ops import UOp, Ops, graph_rewrite, sym_infer
|
||||
from tinygrad import Variable
|
||||
import functools
|
||||
@@ -11,7 +11,7 @@ import functools
|
||||
def render(self) -> tuple[str, ConstType, ConstType]:
|
||||
# NOTE: we need STORE so the ALU op has children
|
||||
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
|
||||
uops = linearize_uop(full_graph_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), self)).sink()))
|
||||
uops = full_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), self)).sink())
|
||||
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[-1]
|
||||
return rewritten_uop.render(simplify=False), rewritten_uop.vmin, rewritten_uop.vmax
|
||||
|
||||
@@ -569,7 +569,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
|
||||
# TODO: copied from render, render does not support cast
|
||||
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
|
||||
uops = linearize_uop(full_graph_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink()))
|
||||
uops = full_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink())
|
||||
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[-1]
|
||||
|
||||
self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half)))
|
||||
|
||||
@@ -72,3 +72,7 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
|
||||
ret.append(RewriteStep(block_merge, name="Linearizer: Merge Blocks"))
|
||||
ret.append(RewriteStep(pm_finalize, name="Linearizer: Finalize"))
|
||||
return ret
|
||||
|
||||
def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, linearizer:bool=False) -> UOp:
|
||||
return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), linearizer))
|
||||
def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]: return list(full_rewrite_to_sink(sink, opts, linearizer=True).arg.lst)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import Optional, Any, Callable, cast
|
||||
from typing import Any, Callable, cast
|
||||
import functools, operator, itertools
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import dtypes, ImageDType, PtrDType, promo_lattice, DType
|
||||
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve, graph_rewrite, GroupOp, identity_element
|
||||
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat, gep_pushing
|
||||
from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE, partition
|
||||
from tinygrad.codegen.symbolic import split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
|
||||
from tinygrad.helpers import getenv, flatten, AMX, prod, partition
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
@@ -429,26 +429,3 @@ pm_reduce = PatternMatcher([
|
||||
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
|
||||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||
])+sym
|
||||
|
||||
# *** uop graph ***
|
||||
|
||||
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
|
||||
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
|
||||
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
|
||||
|
||||
# remove reduce
|
||||
sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce")
|
||||
|
||||
# devectorize is optional
|
||||
if DEVECTORIZE >= 2: sink = graph_rewrite(sink, sym+load_store_folding+load_store_indexing, ctx=opts)
|
||||
elif DEVECTORIZE: sink = graph_rewrite(sink, sym+devectorize+load_store_folding+correct_load_store+load_store_indexing, ctx=opts)
|
||||
else: sink = graph_rewrite(sink, sym+load_store_folding+correct_load_store+load_store_indexing, ctx=opts)
|
||||
|
||||
# optional pre matcher
|
||||
if opts is not None and opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher)
|
||||
|
||||
# final rules for the renderer (without sym)
|
||||
sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher,
|
||||
ctx=opts, name="final rewrite")
|
||||
return sink
|
||||
|
||||
@@ -2,8 +2,7 @@
|
||||
|
||||
import functools, itertools, operator
|
||||
from tinygrad.helpers import AMX, dedup, flatten, all_same, prod
|
||||
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, graph_rewrite
|
||||
from tinygrad.codegen.symbolic import sym
|
||||
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, GroupOp
|
||||
|
||||
def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) -> int:
|
||||
idx, mul = 0, 1
|
||||
@@ -137,16 +136,3 @@ pm_delete_ignore = PatternMatcher([
|
||||
# IGNORE on SELF is nothing
|
||||
(UPat(Ops.IGNORE, src=(UPat(name="x"), UPat())), lambda x: x),
|
||||
])
|
||||
|
||||
def expand_rewrite(sink:UOp) -> UOp:
|
||||
# initial symbolic + migrate indexing (remove this)
|
||||
sink = graph_rewrite(sink, sym+migrate_indexing)
|
||||
|
||||
# store IGNORE
|
||||
sink = graph_rewrite(sink, pm_store_ignore, name="store_ignore")
|
||||
|
||||
# move IGNORE
|
||||
sink = graph_rewrite(sink, pm_move_ignore, name="move_ignore")
|
||||
|
||||
# expand + remove surviving ignores
|
||||
return graph_rewrite(sink, pm_delete_ignore+sym+expander)
|
||||
|
||||
@@ -16,7 +16,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
from tinygrad.codegen.lowerer import get_contraction
|
||||
from tinygrad.engine.grouper import view_left
|
||||
from tinygrad.codegen import get_rewrites_for_renderer, apply_rewrites
|
||||
from tinygrad.codegen import full_rewrite
|
||||
|
||||
class KernelOptError(Exception): pass
|
||||
|
||||
@@ -553,8 +553,7 @@ class Kernel:
|
||||
#if __debug__: type_verify(list(modified_ast.toposort()), shape_spec)
|
||||
|
||||
try:
|
||||
rewrite_list = get_rewrites_for_renderer(self.opts)
|
||||
self.uops:list[UOp] = list(apply_rewrites(modified_ast, rewrite_list).arg.lst)
|
||||
self.uops:list[UOp] = full_rewrite(modified_ast, self.opts)
|
||||
except RuntimeError:
|
||||
print("***** LINEARIZE FAILURE *****")
|
||||
print(f"ast = {self.ast}")
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import heapq
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, replace
|
||||
from tinygrad.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat, GroupOp
|
||||
from tinygrad.ops import UOp, Ops, PatternMatcher, UPat, GroupOp
|
||||
from tinygrad.helpers import dedup, partition, all_same, flatten
|
||||
from tinygrad.spec import type_verify
|
||||
|
||||
@@ -243,23 +243,3 @@ def finalize(sink:UOp) -> UOp:
|
||||
return UOp(Ops.BLOCKFINAL, arg=BasicBlock2(tuple(lst)))
|
||||
|
||||
pm_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), finalize)])
|
||||
|
||||
def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]:
|
||||
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
|
||||
|
||||
# get block context
|
||||
ctx = BlockContext.from_sink(sink)
|
||||
|
||||
# wrap all uops in blocks, already reordered
|
||||
sink = graph_rewrite(sink, block_create, ctx=ctx, name="Linearizer: Create Blocks", bottom_up=True)
|
||||
|
||||
# merge blockends
|
||||
sink = graph_rewrite(sink, pm_blockend_merge, name="Linearizer: Merge Blockends")
|
||||
|
||||
# merge blocks
|
||||
sink = graph_rewrite(sink, block_merge, name="Linearizer: Merge Blocks")
|
||||
|
||||
# finalize
|
||||
sink = graph_rewrite(sink, pm_finalize, name="Linearizer: Finalize")
|
||||
|
||||
return list(sink.arg.lst)
|
||||
@@ -3,10 +3,9 @@ import itertools, operator, math
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
from tinygrad.dtype import dtypes, PtrDType, least_upper_dtype
|
||||
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, sint_to_uop
|
||||
from tinygrad.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint, sint_to_uop
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap, QUANTIZE
|
||||
from tinygrad.codegen.expander import expand_rewrite
|
||||
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap
|
||||
from tinygrad.codegen.symbolic import symbolic
|
||||
|
||||
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
||||
@@ -233,9 +232,3 @@ pm_quant = symbolic+PatternMatcher([
|
||||
(UPat(Ops.REDUCE_AXIS, src=((UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")),), name="r"),
|
||||
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,)) + r.replace(src=(c1*c2,))),
|
||||
])
|
||||
|
||||
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
|
||||
if QUANTIZE and opts.device in {"CPU", "DSP"}: ast = graph_rewrite(ast, pm_quant, name="quantize")
|
||||
sink = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
|
||||
# expand_rewrite turns this into a vectorized program
|
||||
return expand_rewrite(sink)
|
||||
|
||||
Reference in New Issue
Block a user