mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
SPEC=1 passes all tests
This commit is contained in:
@@ -547,10 +547,10 @@ class TestUopsObject(unittest.TestCase):
|
||||
|
||||
class TestUOpRender(unittest.TestCase):
|
||||
def test_render_vectorize_same(self):
|
||||
u = UOp(Ops.VECTORIZE, src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0)))
|
||||
u = UOp(Ops.VECTORIZE, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0)))
|
||||
self.assertEqual(u.render(), "{0, ...}")
|
||||
def test_render_vectorize_different(self):
|
||||
u = UOp(Ops.VECTORIZE, src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)))
|
||||
u = UOp(Ops.VECTORIZE, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)))
|
||||
self.assertEqual(u.render(), "{0,1,2}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -305,19 +305,19 @@ class TestRecurse(unittest.TestCase):
|
||||
graph_rewrite(a, pm, bottom_up=True)
|
||||
|
||||
def test_inf_loop(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
a = UOp.const(dtypes.int, 3)
|
||||
pm = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
|
||||
(UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)),
|
||||
(UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)),
|
||||
])
|
||||
with self.assertRaises(RuntimeError):
|
||||
graph_rewrite(a, pm)
|
||||
|
||||
def test_inf_loop_bottom_up(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
a = UOp.const(dtypes.int, 3)
|
||||
pm = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
|
||||
(UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)),
|
||||
(UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)),
|
||||
])
|
||||
with self.assertRaises(RuntimeError):
|
||||
graph_rewrite(a, pm, bottom_up=True)
|
||||
|
||||
@@ -50,7 +50,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
def fxn(ctx, x):
|
||||
ctx.append(True)
|
||||
assert len(x.src) == 0
|
||||
return UOp(Ops.CONST, src=(UOp(Ops.CONST),))
|
||||
return x.replace(src=(UOp(Ops.DEVICE, arg="blah"),))
|
||||
matcher = PatternMatcher([(UPat(Ops.CONST, src=(), name="x"), fxn)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
# second rewrite shouldn't match anything
|
||||
|
||||
@@ -41,7 +41,7 @@ class TestHelpers(unittest.TestCase):
|
||||
self.assertTrue(f2.is_increasing())
|
||||
self.assertTrue(f3.is_increasing())
|
||||
|
||||
rng = UOp(Ops.RANGE, dtypes.int, arg=(2, True), src=(UOp(Ops.CONST, dtypes.int, arg=5, src=()),))
|
||||
rng = UOp.range(5, 2)
|
||||
self.assertTrue(rng.is_increasing())
|
||||
self.assertTrue((rng+2).is_increasing())
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import unittest
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.helpers import DEBUG, Context
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import UPat, track_rewrites, GroupOp, Ops
|
||||
from tinygrad.uop.upat import _get_code, upat_compile
|
||||
@@ -14,6 +14,7 @@ def do_compile(up):
|
||||
if DEBUG >= 2: dis.dis(match)
|
||||
return match_code[0]
|
||||
|
||||
@Context(SPEC=0)
|
||||
class TestUPatCompile(unittest.TestCase):
|
||||
def test_double(self):
|
||||
up = UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1")
|
||||
|
||||
@@ -157,11 +157,11 @@ class TestViz(BaseTestViz):
|
||||
self.assertEqual(ansistrip(a2["label"]), "CUSTOM\nx\nyzww\nw")
|
||||
|
||||
def test_inf_loop(self):
|
||||
a = UOp.variable('a', 0, 10, dtype=dtypes.int)
|
||||
b = a.replace(op=Ops.CONST)
|
||||
a = UOp.const(dtypes.int, 3)
|
||||
b = UOp.const(dtypes.int, 4)
|
||||
pm = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
|
||||
(UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)),
|
||||
(UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)),
|
||||
])
|
||||
with self.assertRaises(RuntimeError): exec_rewrite(a, [pm])
|
||||
graphs = flatten(x["graph"].values() for x in get_viz_details(0, 0))
|
||||
|
||||
@@ -173,7 +173,7 @@ full_spec = PatternMatcher([
|
||||
# copy on index
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.INDEX), UPat())), lambda: True),
|
||||
# assign on index. the third op is the shape
|
||||
(UPat(Ops.ASSIGN, src=(UPat(), UPat(), UPat(GroupOp.Movement))), lambda: True),
|
||||
(UPat(Ops.ASSIGN, src=(UPat(), UPat(), UPat())), lambda: True),
|
||||
|
||||
# expander: unroll/contract/gep/ptrcat/cat
|
||||
#(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
||||
@@ -195,6 +195,12 @@ full_spec = PatternMatcher([
|
||||
(UPat((Ops.ADD, Ops.MUL, Ops.MOD, Ops.IDIV, Ops.MAX, Ops.WHERE,
|
||||
Ops.SPECIAL, Ops.CAST, Ops.RANGE, Ops.VCONST, Ops.VECTORIZE), dtype=dtypes.index), lambda: True),
|
||||
|
||||
# while BIND is being casted
|
||||
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(), UPat()), arg=None), lambda: True),
|
||||
|
||||
# in progress MSTACK may lose device
|
||||
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
|
||||
|
||||
# all loads/stores
|
||||
(UPat((Ops.LOAD, Ops.STORE)), lambda: True),
|
||||
# all ifs
|
||||
|
||||
Reference in New Issue
Block a user