SPEC=2 is full spec, SPEC=1 is default (#12910)

* SPEC=1 passes all tests

* just use SPEC, not __debug__
This commit is contained in:
George Hotz
2025-10-25 11:10:43 +08:00
committed by GitHub
parent 456560c1ff
commit 8a941d95a4
12 changed files with 31 additions and 24 deletions

View File

@@ -264,8 +264,8 @@ jobs:
run: python -c "from tinygrad import Device; assert Device.DEFAULT == 'CPU', Device.DEFAULT"
- name: Run unit tests
run: CPU=1 python -m pytest -n=auto test/unit/ --durations=20
- name: Check SPEC=1
run: SPEC=1 python3 test/test_tiny.py
- name: Check SPEC=2
run: SPEC=2 python3 test/test_tiny.py
- name: Run targetted tests on NULL backend
run: NULL=1 python3 -m unittest test.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step test/device/test_null.py
# TODO: too slow

View File

@@ -547,10 +547,10 @@ class TestUopsObject(unittest.TestCase):
class TestUOpRender(unittest.TestCase):
def test_render_vectorize_same(self):
u = UOp(Ops.VECTORIZE, src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0)))
u = UOp(Ops.VECTORIZE, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0)))
self.assertEqual(u.render(), "{0, ...}")
def test_render_vectorize_different(self):
u = UOp(Ops.VECTORIZE, src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)))
u = UOp(Ops.VECTORIZE, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)))
self.assertEqual(u.render(), "{0,1,2}")
if __name__ == '__main__':

View File

@@ -305,19 +305,19 @@ class TestRecurse(unittest.TestCase):
graph_rewrite(a, pm, bottom_up=True)
def test_inf_loop(self):
a = UOp.variable('a', 0, 10)
a = UOp.const(dtypes.int, 3)
pm = PatternMatcher([
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
(UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
(UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)),
(UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)),
])
with self.assertRaises(RuntimeError):
graph_rewrite(a, pm)
def test_inf_loop_bottom_up(self):
a = UOp.variable('a', 0, 10)
a = UOp.const(dtypes.int, 3)
pm = PatternMatcher([
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
(UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
(UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)),
(UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)),
])
with self.assertRaises(RuntimeError):
graph_rewrite(a, pm, bottom_up=True)

View File

@@ -50,7 +50,7 @@ class TestPatternMatcher(unittest.TestCase):
def fxn(ctx, x):
ctx.append(True)
assert len(x.src) == 0
return UOp(Ops.CONST, src=(UOp(Ops.CONST),))
return x.replace(src=(UOp(Ops.DEVICE, arg="blah"),))
matcher = PatternMatcher([(UPat(Ops.CONST, src=(), name="x"), fxn)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
# second rewrite shouldn't match anything

View File

@@ -41,7 +41,7 @@ class TestHelpers(unittest.TestCase):
self.assertTrue(f2.is_increasing())
self.assertTrue(f3.is_increasing())
rng = UOp(Ops.RANGE, dtypes.int, arg=(2, True), src=(UOp(Ops.CONST, dtypes.int, arg=5, src=()),))
rng = UOp.range(5, 2)
self.assertTrue(rng.is_increasing())
self.assertTrue((rng+2).is_increasing())

View File

@@ -1,5 +1,5 @@
import unittest
from tinygrad.helpers import DEBUG
from tinygrad.helpers import DEBUG, Context
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UPat, track_rewrites, GroupOp, Ops
from tinygrad.uop.upat import _get_code, upat_compile
@@ -14,6 +14,7 @@ def do_compile(up):
if DEBUG >= 2: dis.dis(match)
return match_code[0]
@Context(SPEC=0)
class TestUPatCompile(unittest.TestCase):
def test_double(self):
up = UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1")

View File

@@ -157,11 +157,11 @@ class TestViz(BaseTestViz):
self.assertEqual(ansistrip(a2["label"]), "CUSTOM\nx\nyzww\nw")
def test_inf_loop(self):
a = UOp.variable('a', 0, 10, dtype=dtypes.int)
b = a.replace(op=Ops.CONST)
a = UOp.const(dtypes.int, 3)
b = UOp.const(dtypes.int, 4)
pm = PatternMatcher([
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
(UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
(UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)),
(UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)),
])
with self.assertRaises(RuntimeError): exec_rewrite(a, [pm])
graphs = flatten(x["graph"].values() for x in get_viz_details(0, 0))

View File

@@ -1,4 +1,4 @@
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL, SPEC
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype
from tinygrad.uop.spec import type_verify, program_spec
from tinygrad.renderer import Renderer
@@ -102,5 +102,5 @@ def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]:
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
assert len(full_sink.ranges) == 0, "all ranges must end by the sink"
lst = linearize(full_sink)
if __debug__: type_verify(lst, program_spec)
if SPEC: type_verify(lst, program_spec)
return lst

View File

@@ -167,7 +167,7 @@ EMULATE = ContextVar("EMULATE", "")
CPU_COUNT = ContextVar("CPU_COUNT", max(1, len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else (os.cpu_count() or 1)))
CPU_LLVM, CPU_LVP, AMD_LLVM = ContextVar("CPU_LLVM", 0), ContextVar("CPU_LVP", 0), ContextVar("AMD_LLVM", 1)
VIZ = PROFILE = ContextVar("VIZ", 0)
SPEC = ContextVar("SPEC", 0)
SPEC = ContextVar("SPEC", 1)
# TODO: disable by default due to speed
IGNORE_OOB = ContextVar("IGNORE_OOB", 1)
PCONTIG = ContextVar("PCONTIG", 0) # partial contiguous in rangeify

View File

@@ -6,7 +6,7 @@ from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, Suppor
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, DEBUG, is_numpy_ndarray, FUSE_ATTENTION
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, DEBUG, is_numpy_ndarray, FUSE_ATTENTION, SPEC
from tinygrad.helpers import suppress_finalizing
from tinygrad.gradient import compute_gradient
from tinygrad.uop.mathtraits import MathTrait
@@ -229,7 +229,7 @@ class Tensor(MathTrait):
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
# verify Tensors match the spec
if __debug__: type_verify(list(big_sink.toposort()), tensor_spec)
if SPEC: type_verify(list(big_sink.toposort()), tensor_spec)
if any(isinstance(x._device, tuple) for x in big_sink.toposort()):
_apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map")

View File

@@ -64,7 +64,7 @@ class UOpMetaClass(type):
if _buffer is not None:
assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}"
buffers[created] = _buffer
if SPEC:
if SPEC > 1:
from tinygrad.uop.spec import full_spec
with Context(IGNORE_OOB=1): ret = full_spec.rewrite(created)
if cast(bool|None, ret) is not True: raise RuntimeError(f"SPEC ISSUE {ret}: {created}")

View File

@@ -173,7 +173,7 @@ full_spec = PatternMatcher([
# copy on index
(UPat(Ops.COPY, src=(UPat(Ops.INDEX), UPat())), lambda: True),
# assign on index. the third op is the shape
(UPat(Ops.ASSIGN, src=(UPat(), UPat(), UPat(GroupOp.Movement))), lambda: True),
(UPat(Ops.ASSIGN, src=(UPat(), UPat(), UPat())), lambda: True),
# expander: unroll/contract/gep/ptrcat/cat
#(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
@@ -195,6 +195,12 @@ full_spec = PatternMatcher([
(UPat((Ops.ADD, Ops.MUL, Ops.MOD, Ops.IDIV, Ops.MAX, Ops.WHERE,
Ops.SPECIAL, Ops.CAST, Ops.RANGE, Ops.VCONST, Ops.VECTORIZE), dtype=dtypes.index), lambda: True),
# while BIND is being casted
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(), UPat()), arg=None), lambda: True),
# in progress MSTACK may lose device
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
# all loads/stores
(UPat((Ops.LOAD, Ops.STORE)), lambda: True),
# all ifs