check spec in test, cleanup vectorize render

This commit is contained in:
George Hotz
2025-10-07 16:36:24 +08:00
parent 75ce11593c
commit 7e4cf39929
4 changed files with 15 additions and 2 deletions

View File

@@ -267,6 +267,8 @@ jobs:
run: python -c "from tinygrad import Device; assert Device.DEFAULT == 'CPU', Device.DEFAULT"
- name: Run unit tests
run: CPU=1 python -m pytest -n=auto test/unit/ --durations=20
- name: Check SPEC=1
run: SPEC=1 python3 test/test_tiny.py
- name: Run targetted tests on NULL backend
run: NULL=1 python3 -m unittest test.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step test/device/test_null.py
# TODO: too slow

View File

@@ -568,5 +568,13 @@ class TestUOpChildren(unittest.TestCase):
del c
self.assertEqual(len(a.children), 0)
class TestUOpRender(unittest.TestCase):
def test_render_vectorize_same(self):
u = UOp(Ops.VECTORIZE, src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0)))
self.assertEqual(u.render(), "{0, ...}")
def test_render_vectorize_different(self):
u = UOp(Ops.VECTORIZE, src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)))
self.assertEqual(u.render(), "{0,1,2}")
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -1042,7 +1042,7 @@ if TRACK_MATCH_STATS or PROFILE:
# *** simple graph rewrite engine ***
SENTINEL = UOp(Ops.SENTINEL)
with Context(SPEC=0): SENTINEL = UOp(Ops.SENTINEL)
class RewriteNotReady(Exception): pass
class BottomUpGate(Exception): pass
class RewriteContext:
@@ -1197,7 +1197,7 @@ renderer = PatternMatcher([
(UPat((Ops.INDEX, Ops.BUFFERIZE), name="x"), lambda x:
UOp(Ops.NOOP, arg=''.join([f"[{strip_parens(y.arg)}]" for y in x.src[1:]])) if all(y.op is Ops.NOOP for y in x.src[1:]) else None),
(UPat(Ops.VECTORIZE, src=UPat(Ops.NOOP), name="x"),
lambda x: UOp(Ops.NOOP, arg=f"[{','.join([y.arg for y in x.src])}]" if not all_same(x.src) else f"{len(x.src)}x[{x.src[0].arg}]")),
lambda x: UOp(Ops.NOOP, arg=f"{{{','.join([y.arg for y in x.src])}}}" if not all_same(x.src) else f"{{{x.src[0].arg}, ...}}")),
])
renderer_infer = PatternMatcher([
(UPat(Ops.MOD, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"cmod({x.src[0].arg}, {x.src[1].arg})")),

View File

@@ -258,6 +258,9 @@ full_non_rangeify_spec = PatternMatcher([]) if RANGEIFY else PatternMatcher([
])
full_spec = PatternMatcher([
# SENTINEL should never be in the graph
(UPat(Ops.SENTINEL), lambda: False),
# Invalid must have type Index
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),
# where on index in rhs position is fine