mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-15 01:48:23 -05:00
check spec in test, cleanup vectorize render
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -267,6 +267,8 @@ jobs:
|
||||
run: python -c "from tinygrad import Device; assert Device.DEFAULT == 'CPU', Device.DEFAULT"
|
||||
- name: Run unit tests
|
||||
run: CPU=1 python -m pytest -n=auto test/unit/ --durations=20
|
||||
- name: Check SPEC=1
|
||||
run: SPEC=1 python3 test/test_tiny.py
|
||||
- name: Run targetted tests on NULL backend
|
||||
run: NULL=1 python3 -m unittest test.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step test/device/test_null.py
|
||||
# TODO: too slow
|
||||
|
||||
@@ -568,5 +568,13 @@ class TestUOpChildren(unittest.TestCase):
|
||||
del c
|
||||
self.assertEqual(len(a.children), 0)
|
||||
|
||||
class TestUOpRender(unittest.TestCase):
|
||||
def test_render_vectorize_same(self):
|
||||
u = UOp(Ops.VECTORIZE, src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0)))
|
||||
self.assertEqual(u.render(), "{0, ...}")
|
||||
def test_render_vectorize_different(self):
|
||||
u = UOp(Ops.VECTORIZE, src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)))
|
||||
self.assertEqual(u.render(), "{0,1,2}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -1042,7 +1042,7 @@ if TRACK_MATCH_STATS or PROFILE:
|
||||
|
||||
# *** simple graph rewrite engine ***
|
||||
|
||||
SENTINEL = UOp(Ops.SENTINEL)
|
||||
with Context(SPEC=0): SENTINEL = UOp(Ops.SENTINEL)
|
||||
class RewriteNotReady(Exception): pass
|
||||
class BottomUpGate(Exception): pass
|
||||
class RewriteContext:
|
||||
@@ -1197,7 +1197,7 @@ renderer = PatternMatcher([
|
||||
(UPat((Ops.INDEX, Ops.BUFFERIZE), name="x"), lambda x:
|
||||
UOp(Ops.NOOP, arg=''.join([f"[{strip_parens(y.arg)}]" for y in x.src[1:]])) if all(y.op is Ops.NOOP for y in x.src[1:]) else None),
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.NOOP), name="x"),
|
||||
lambda x: UOp(Ops.NOOP, arg=f"[{','.join([y.arg for y in x.src])}]" if not all_same(x.src) else f"{len(x.src)}x[{x.src[0].arg}]")),
|
||||
lambda x: UOp(Ops.NOOP, arg=f"{{{','.join([y.arg for y in x.src])}}}" if not all_same(x.src) else f"{{{x.src[0].arg}, ...}}")),
|
||||
])
|
||||
renderer_infer = PatternMatcher([
|
||||
(UPat(Ops.MOD, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"cmod({x.src[0].arg}, {x.src[1].arg})")),
|
||||
|
||||
@@ -258,6 +258,9 @@ full_non_rangeify_spec = PatternMatcher([]) if RANGEIFY else PatternMatcher([
|
||||
])
|
||||
|
||||
full_spec = PatternMatcher([
|
||||
# SENTINEL should never be in the graph
|
||||
(UPat(Ops.SENTINEL), lambda: False),
|
||||
|
||||
# Invalid must have type Index
|
||||
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),
|
||||
# where on index in rhs position is fine
|
||||
|
||||
Reference in New Issue
Block a user