mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
SPEC=3 tests pyrender
This commit is contained in:
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -264,8 +264,8 @@ jobs:
|
||||
run: python -c "from tinygrad import Device; assert Device.DEFAULT == 'CPU', Device.DEFAULT"
|
||||
- name: Run unit tests
|
||||
run: CPU=1 python -m pytest -n=auto test/unit/ --durations=20
|
||||
- name: Check SPEC=2
|
||||
run: SPEC=2 python3 test/test_tiny.py
|
||||
- name: Check SPEC=3
|
||||
run: SPEC=3 python3 test/test_tiny.py
|
||||
- name: Run targetted tests on NULL backend
|
||||
run: NULL=1 python3 -m unittest test.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step test/device/test_null.py
|
||||
# TODO: too slow
|
||||
|
||||
@@ -65,6 +65,12 @@ class UOpMetaClass(type):
|
||||
assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}"
|
||||
buffers[created] = _buffer
|
||||
if SPEC > 1:
|
||||
if SPEC > 2:
|
||||
with Context(SPEC=0):
|
||||
code = '\n'.join(pyrender(created))
|
||||
lcls:dict[str, UOp] = {}
|
||||
exec(code, None, lcls)
|
||||
if lcls['ast'] is not created: raise RuntimeError(f"PYRENDER ISSUE:\nCODE:\n{code}\nUOP:\n{created}\nPRODUCED:\n{lcls['ast']}")
|
||||
from tinygrad.uop.spec import full_spec
|
||||
with Context(IGNORE_OOB=1): ret = full_spec.rewrite(created)
|
||||
if cast(bool|None, ret) is not True: raise RuntimeError(f"SPEC ISSUE {ret}: {created}")
|
||||
@@ -1254,7 +1260,7 @@ pm_pyrender = PatternMatcher([
|
||||
@Context(SPEC=0)
|
||||
def pyrender(ast:UOp) -> list[str]:
|
||||
cmap = ast.get_consumer_map()
|
||||
to_render = set()
|
||||
to_render = set({ast})
|
||||
for u in ast.toposort():
|
||||
if u.op is Ops.STORE: to_render.add(u.src[1])
|
||||
if len(cmap[u]) == 1 and u.op not in {Ops.DEFINE_GLOBAL, Ops.LOAD} or u.op in {Ops.CONST}: continue
|
||||
|
||||
Reference in New Issue
Block a user