SPEC=1 passes all tests

This commit is contained in:
George Hotz
2025-10-25 10:31:38 +08:00
parent a5b0f57067
commit 9cdd284008
7 changed files with 23 additions and 16 deletions

View File

@@ -547,10 +547,10 @@ class TestUopsObject(unittest.TestCase):
class TestUOpRender(unittest.TestCase):
def test_render_vectorize_same(self):
u = UOp(Ops.VECTORIZE, src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0)))
u = UOp(Ops.VECTORIZE, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0)))
self.assertEqual(u.render(), "{0, ...}")
def test_render_vectorize_different(self):
u = UOp(Ops.VECTORIZE, src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)))
u = UOp(Ops.VECTORIZE, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)))
self.assertEqual(u.render(), "{0,1,2}")
if __name__ == '__main__':

View File

@@ -305,19 +305,19 @@ class TestRecurse(unittest.TestCase):
graph_rewrite(a, pm, bottom_up=True)
def test_inf_loop(self):
a = UOp.variable('a', 0, 10)
a = UOp.const(dtypes.int, 3)
pm = PatternMatcher([
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
(UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
(UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)),
(UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)),
])
with self.assertRaises(RuntimeError):
graph_rewrite(a, pm)
def test_inf_loop_bottom_up(self):
a = UOp.variable('a', 0, 10)
a = UOp.const(dtypes.int, 3)
pm = PatternMatcher([
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
(UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
(UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)),
(UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)),
])
with self.assertRaises(RuntimeError):
graph_rewrite(a, pm, bottom_up=True)

View File

@@ -50,7 +50,7 @@ class TestPatternMatcher(unittest.TestCase):
def fxn(ctx, x):
ctx.append(True)
assert len(x.src) == 0
return UOp(Ops.CONST, src=(UOp(Ops.CONST),))
return x.replace(src=(UOp(Ops.DEVICE, arg="blah"),))
matcher = PatternMatcher([(UPat(Ops.CONST, src=(), name="x"), fxn)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
# second rewrite shouldn't match anything

View File

@@ -41,7 +41,7 @@ class TestHelpers(unittest.TestCase):
self.assertTrue(f2.is_increasing())
self.assertTrue(f3.is_increasing())
rng = UOp(Ops.RANGE, dtypes.int, arg=(2, True), src=(UOp(Ops.CONST, dtypes.int, arg=5, src=()),))
rng = UOp.range(5, 2)
self.assertTrue(rng.is_increasing())
self.assertTrue((rng+2).is_increasing())

View File

@@ -1,5 +1,5 @@
import unittest
from tinygrad.helpers import DEBUG
from tinygrad.helpers import DEBUG, Context
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UPat, track_rewrites, GroupOp, Ops
from tinygrad.uop.upat import _get_code, upat_compile
@@ -14,6 +14,7 @@ def do_compile(up):
if DEBUG >= 2: dis.dis(match)
return match_code[0]
@Context(SPEC=0)
class TestUPatCompile(unittest.TestCase):
def test_double(self):
up = UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1")

View File

@@ -157,11 +157,11 @@ class TestViz(BaseTestViz):
self.assertEqual(ansistrip(a2["label"]), "CUSTOM\nx\nyzww\nw")
def test_inf_loop(self):
a = UOp.variable('a', 0, 10, dtype=dtypes.int)
b = a.replace(op=Ops.CONST)
a = UOp.const(dtypes.int, 3)
b = UOp.const(dtypes.int, 4)
pm = PatternMatcher([
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
(UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
(UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)),
(UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)),
])
with self.assertRaises(RuntimeError): exec_rewrite(a, [pm])
graphs = flatten(x["graph"].values() for x in get_viz_details(0, 0))

View File

@@ -173,7 +173,7 @@ full_spec = PatternMatcher([
# copy on index
(UPat(Ops.COPY, src=(UPat(Ops.INDEX), UPat())), lambda: True),
# assign on index. the third op is the shape
(UPat(Ops.ASSIGN, src=(UPat(), UPat(), UPat(GroupOp.Movement))), lambda: True),
(UPat(Ops.ASSIGN, src=(UPat(), UPat(), UPat())), lambda: True),
# expander: unroll/contract/gep/ptrcat/cat
#(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
@@ -195,6 +195,12 @@ full_spec = PatternMatcher([
(UPat((Ops.ADD, Ops.MUL, Ops.MOD, Ops.IDIV, Ops.MAX, Ops.WHERE,
Ops.SPECIAL, Ops.CAST, Ops.RANGE, Ops.VCONST, Ops.VECTORIZE), dtype=dtypes.index), lambda: True),
# while BIND is being casted
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(), UPat()), arg=None), lambda: True),
# in progress MSTACK may lose device
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
# all loads/stores
(UPat((Ops.LOAD, Ops.STORE)), lambda: True),
# all ifs