From 8a941d95a4caa696021bbbcfd1c91f4c70d00068 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 25 Oct 2025 11:10:43 +0800 Subject: [PATCH] SPEC=2 is full spec, SPEC=1 is default (#12910) * SPEC=1 passes all tests * just use SPEC, not __debug__ --- .github/workflows/test.yml | 4 ++-- test/test_uops.py | 4 ++-- test/unit/test_graph_rewrite.py | 12 ++++++------ test/unit/test_pattern_matcher.py | 2 +- test/unit/test_simplify_valid_idx.py | 2 +- test/unit/test_upat_compile.py | 3 ++- test/unit/test_viz.py | 8 ++++---- tinygrad/codegen/__init__.py | 4 ++-- tinygrad/helpers.py | 2 +- tinygrad/tensor.py | 4 ++-- tinygrad/uop/ops.py | 2 +- tinygrad/uop/spec.py | 8 +++++++- 12 files changed, 31 insertions(+), 24 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3f5191bd8b..72f5e2421f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -264,8 +264,8 @@ jobs: run: python -c "from tinygrad import Device; assert Device.DEFAULT == 'CPU', Device.DEFAULT" - name: Run unit tests run: CPU=1 python -m pytest -n=auto test/unit/ --durations=20 - - name: Check SPEC=1 - run: SPEC=1 python3 test/test_tiny.py + - name: Check SPEC=2 + run: SPEC=2 python3 test/test_tiny.py - name: Run targetted tests on NULL backend run: NULL=1 python3 -m unittest test.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step test/device/test_null.py # TODO: too slow diff --git a/test/test_uops.py b/test/test_uops.py index 7a6c5bc6cb..6749115624 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -547,10 +547,10 @@ class TestUopsObject(unittest.TestCase): class TestUOpRender(unittest.TestCase): def test_render_vectorize_same(self): - u = UOp(Ops.VECTORIZE, src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0))) + u = UOp(Ops.VECTORIZE, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0))) self.assertEqual(u.render(), "{0, ...}") def test_render_vectorize_different(self): - u = UOp(Ops.VECTORIZE, src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2))) + u = UOp(Ops.VECTORIZE, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2))) self.assertEqual(u.render(), "{0,1,2}") if __name__ == '__main__': diff --git a/test/unit/test_graph_rewrite.py b/test/unit/test_graph_rewrite.py index 46c7c760c8..ea9c271aec 100644 --- a/test/unit/test_graph_rewrite.py +++ b/test/unit/test_graph_rewrite.py @@ -305,19 +305,19 @@ class TestRecurse(unittest.TestCase): graph_rewrite(a, pm, bottom_up=True) def test_inf_loop(self): - a = UOp.variable('a', 0, 10) + a = UOp.const(dtypes.int, 3) pm = PatternMatcher([ - (UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)), - (UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)), + (UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)), + (UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)), ]) with self.assertRaises(RuntimeError): graph_rewrite(a, pm) def test_inf_loop_bottom_up(self): - a = UOp.variable('a', 0, 10) + a = UOp.const(dtypes.int, 3) pm = PatternMatcher([ - (UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)), - (UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)), + (UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)), + (UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)), ]) with self.assertRaises(RuntimeError): graph_rewrite(a, pm, bottom_up=True) diff --git a/test/unit/test_pattern_matcher.py b/test/unit/test_pattern_matcher.py index 156d4d0cb9..1d8beeaf3a 100644 --- a/test/unit/test_pattern_matcher.py +++ b/test/unit/test_pattern_matcher.py @@ -50,7 +50,7 @@ class TestPatternMatcher(unittest.TestCase): def fxn(ctx, x): ctx.append(True) assert len(x.src) == 0 - return UOp(Ops.CONST, src=(UOp(Ops.CONST),)) + return x.replace(src=(UOp(Ops.DEVICE, arg="blah"),)) matcher = PatternMatcher([(UPat(Ops.CONST, src=(), name="x"), fxn)]) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) # second rewrite shouldn't match anything diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index 619d10e5ca..dfaee9e58d 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -41,7 +41,7 @@ class TestHelpers(unittest.TestCase): self.assertTrue(f2.is_increasing()) self.assertTrue(f3.is_increasing()) - rng = UOp(Ops.RANGE, dtypes.int, arg=(2, True), src=(UOp(Ops.CONST, dtypes.int, arg=5, src=()),)) + rng = UOp.range(5, 2) self.assertTrue(rng.is_increasing()) self.assertTrue((rng+2).is_increasing()) diff --git a/test/unit/test_upat_compile.py b/test/unit/test_upat_compile.py index 1f68a8830b..c1366b006b 100644 --- a/test/unit/test_upat_compile.py +++ b/test/unit/test_upat_compile.py @@ -1,5 +1,5 @@ import unittest -from tinygrad.helpers import DEBUG +from tinygrad.helpers import DEBUG, Context from tinygrad.dtype import dtypes from tinygrad.uop.ops import UPat, track_rewrites, GroupOp, Ops from tinygrad.uop.upat import _get_code, upat_compile @@ -14,6 +14,7 @@ def do_compile(up): if DEBUG >= 2: dis.dis(match) return match_code[0] +@Context(SPEC=0) class TestUPatCompile(unittest.TestCase): def test_double(self): up = UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1") diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index 4edee1323b..38b7e1f44b 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -157,11 +157,11 @@ class TestViz(BaseTestViz): self.assertEqual(ansistrip(a2["label"]), "CUSTOM\nx\nyzww\nw") def test_inf_loop(self): - a = UOp.variable('a', 0, 10, dtype=dtypes.int) - b = a.replace(op=Ops.CONST) + a = UOp.const(dtypes.int, 3) + b = UOp.const(dtypes.int, 4) pm = PatternMatcher([ - (UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)), - (UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)), + (UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)), + (UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)), ]) with self.assertRaises(RuntimeError): exec_rewrite(a, [pm]) graphs = flatten(x["graph"].values() for x in get_viz_details(0, 0)) diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index b2b6b9aa94..99416e44c7 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -1,4 +1,4 @@ -from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL +from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL, SPEC from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype from tinygrad.uop.spec import type_verify, program_spec from tinygrad.renderer import Renderer @@ -102,5 +102,5 @@ def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]: full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None) assert len(full_sink.ranges) == 0, "all ranges must end by the sink" lst = linearize(full_sink) - if __debug__: type_verify(lst, program_spec) + if SPEC: type_verify(lst, program_spec) return lst diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 6c0fb1cb13..fe5af45099 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -167,7 +167,7 @@ EMULATE = ContextVar("EMULATE", "") CPU_COUNT = ContextVar("CPU_COUNT", max(1, len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else (os.cpu_count() or 1))) CPU_LLVM, CPU_LVP, AMD_LLVM = ContextVar("CPU_LLVM", 0), ContextVar("CPU_LVP", 0), ContextVar("AMD_LLVM", 1) VIZ = PROFILE = ContextVar("VIZ", 0) -SPEC = ContextVar("SPEC", 0) +SPEC = ContextVar("SPEC", 1) # TODO: disable by default due to speed IGNORE_OOB = ContextVar("IGNORE_OOB", 1) PCONTIG = ContextVar("PCONTIG", 0) # partial contiguous in rangeify diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index c26360852d..57bd08ab2e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -6,7 +6,7 @@ from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, Suppor from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.dtype import _from_np_dtype, _to_np_dtype from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup -from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, DEBUG, is_numpy_ndarray, FUSE_ATTENTION +from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, DEBUG, is_numpy_ndarray, FUSE_ATTENTION, SPEC from tinygrad.helpers import suppress_finalizing from tinygrad.gradient import compute_gradient from tinygrad.uop.mathtraits import MathTrait @@ -229,7 +229,7 @@ class Tensor(MathTrait): big_sink = UOp.sink(*[x.uop for x in (self,)+lst]) # verify Tensors match the spec - if __debug__: type_verify(list(big_sink.toposort()), tensor_spec) + if SPEC: type_verify(list(big_sink.toposort()), tensor_spec) if any(isinstance(x._device, tuple) for x in big_sink.toposort()): _apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map") diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 18502fbfd8..f6328ffea1 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -64,7 +64,7 @@ class UOpMetaClass(type): if _buffer is not None: assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}" buffers[created] = _buffer - if SPEC: + if SPEC > 1: from tinygrad.uop.spec import full_spec with Context(IGNORE_OOB=1): ret = full_spec.rewrite(created) if cast(bool|None, ret) is not True: raise RuntimeError(f"SPEC ISSUE {ret}: {created}") diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 81ea7a8c3c..6af4aa943d 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -173,7 +173,7 @@ full_spec = PatternMatcher([ # copy on index (UPat(Ops.COPY, src=(UPat(Ops.INDEX), UPat())), lambda: True), # assign on index. the third op is the shape - (UPat(Ops.ASSIGN, src=(UPat(), UPat(), UPat(GroupOp.Movement))), lambda: True), + (UPat(Ops.ASSIGN, src=(UPat(), UPat(), UPat())), lambda: True), # expander: unroll/contract/gep/ptrcat/cat #(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)), @@ -195,6 +195,12 @@ full_spec = PatternMatcher([ (UPat((Ops.ADD, Ops.MUL, Ops.MOD, Ops.IDIV, Ops.MAX, Ops.WHERE, Ops.SPECIAL, Ops.CAST, Ops.RANGE, Ops.VCONST, Ops.VECTORIZE), dtype=dtypes.index), lambda: True), + # while BIND is being casted + (UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(), UPat()), arg=None), lambda: True), + + # in progress MSTACK may lose device + (UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True), + # all loads/stores (UPat((Ops.LOAD, Ops.STORE)), lambda: True), # all ifs