From 550cf2ca7fe1c0fdfebd658dba5c9e479b22d8c0 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 2 Sep 2025 13:34:17 -0700 Subject: [PATCH] tests from postopt (#11964) * tests from postopt * reraise is fine --- .pylintrc | 3 ++- test/test_opts.py | 22 ++++++++++++++++++++++ test/test_tiny.py | 12 ++++++++++-- tinygrad/schedule/kernelize.py | 7 ++++++- tinygrad/uop/ops.py | 4 ++-- 5 files changed, 42 insertions(+), 6 deletions(-) create mode 100644 test/test_opts.py diff --git a/.pylintrc b/.pylintrc index 57e05e9386..2f1de51927 100644 --- a/.pylintrc +++ b/.pylintrc @@ -54,11 +54,12 @@ confidence= # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use"--disable=all --enable=classes # --disable=W" -disable=C,R,W0613,W0511,W0212,W0201,W0106,W0603,W0621,W0703,W1201,W1203,E1136,W1514,E1101,W0221,W0105,E0401,abstract-method +disable=C,R,W0613,W0511,W0212,W0201,W0106,W0603,W0621,W0703,W1201,W1203,E1136,W1514,E1101,W0221,W0105,E0401,abstract-method,W0707 # E1101 for function binding # W0221 for Function class # W0105 for comment strings # E0401 for missing imports +# W0707 for not reraising # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/test/test_opts.py b/test/test_opts.py new file mode 100644 index 0000000000..eceb73a3c6 --- /dev/null +++ b/test/test_opts.py @@ -0,0 +1,22 @@ +import unittest +from tinygrad import Tensor, Device +from tinygrad.helpers import RANGEIFY +from tinygrad.codegen.opt.kernel import Opt, OptOps +from tinygrad.engine.realize import get_program + +@unittest.skipIf(RANGEIFY>0, "arg is partial contig in rangeify") +class TestOpts(unittest.TestCase): + def test_opt_upcast(self): + opts = (Opt(OptOps.UPCAST, 0, 4),) + a = Tensor.empty(16) + b = Tensor.empty(16) + out = (a+b).contiguous(arg=opts) + s = out.schedule() + self.assertEqual(s[-1].ast.arg.opts_to_apply, opts) + if Device.DEFAULT in {"CPU", "GPU", "METAL"}: + prg = get_program(s[-1].ast) + self.assertIn('float4', prg.src) + +if __name__ == '__main__': + unittest.main() + diff --git a/test/test_tiny.py b/test/test_tiny.py index bc133a0dcf..b6e77d05c9 100644 --- a/test/test_tiny.py +++ b/test/test_tiny.py @@ -1,7 +1,7 @@ # basic self-contained tests of the external functionality of tinygrad import unittest, random from tinygrad import Tensor, Context, Variable, TinyJit, dtypes, Device, nn -from tinygrad.helpers import IMAGE, CI +from tinygrad.helpers import IMAGE, CI, getenv class TestTiny(unittest.TestCase): @@ -27,7 +27,7 @@ class TestTiny(unittest.TestCase): out = Tensor.ones(256).contiguous().sum() self.assertEqual(out.item(), 256) - def test_gemm(self, N=64, out_dtype=dtypes.float): + def test_gemm(self, N=getenv("GEMM_N", 64), out_dtype=dtypes.float): a = Tensor.ones(N,N).contiguous() b = Tensor.eye(N).contiguous() lst = (out:=a@b).tolist() @@ -36,6 +36,14 @@ class TestTiny(unittest.TestCase): self.assertEqual(lst[y][x], 1.0, msg=f"mismatch at ({y},{x})") if IMAGE < 2: self.assertEqual(out.dtype, out_dtype) + def test_gemv(self, N=getenv("GEMV_N", 64), out_dtype=dtypes.float): + a = Tensor.ones(1,N).contiguous() + b = Tensor.eye(N).contiguous() + lst = (out:=a@b).tolist() + for x in range(N): + self.assertEqual(lst[0][x], 1.0, msg=f"mismatch at {x}") + if IMAGE < 2: self.assertEqual(out.dtype, out_dtype) + # *** randomness *** def test_random(self): diff --git a/tinygrad/schedule/kernelize.py b/tinygrad/schedule/kernelize.py index b1e90bc60d..75ea81fe2f 100644 --- a/tinygrad/schedule/kernelize.py +++ b/tinygrad/schedule/kernelize.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve -from tinygrad.uop.ops import track_rewrites, _substitute +from tinygrad.uop.ops import track_rewrites, _substitute, KernelInfo from tinygrad.uop.spec import type_verify, tensor_uop_spec from tinygrad.uop.symbolic import symbolic_simple from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP @@ -8,6 +8,7 @@ from tinygrad.dtype import ImageDType from tinygrad.schedule.multi import multi_pm from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS from tinygrad.codegen.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop +from tinygrad.codegen.opt.kernel import Opt # creation can recurse a lot import sys @@ -154,6 +155,10 @@ def unbind_view(x:UOp): return None replace_buffers = PatternMatcher([ + # sink on contig creates a KernelInfo + (UPat(Ops.CONTIGUOUS, name="c").sink(name="s"), + lambda s,c: s.replace(src=(c.replace(arg=None),), arg=KernelInfo(opts_to_apply=c.arg)) \ + if s.arg is None and c.arg is not None and isinstance(c.arg[0], Opt) else None), # replace ASSIGN with the target BUFFER (UPat(Ops.ASSIGN, src=(UPat((Ops.BUFFER, Ops.LOAD)), UPat(Ops.KERNEL)), name="assign", allow_any_len=True), lambda assign: assign.src[0]), # HACK: select the 0 branch of MSTACK (the device is wrong after this, is that okay?) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index ad86abfde3..cf667f5831 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -970,7 +970,7 @@ class RewriteContext: for x in reversed(new_n.src): stack.append((x, 0, x)) elif stage == 1: try: new_src = tuple([self.replace[x] for x in new_n.src]) - except KeyError: raise RewriteNotReady # pylint: disable=raise-missing-from + except KeyError: raise RewriteNotReady if new_src == new_n.src: # if top down, do the rewrite. if no rewrite or bottom up, we are done rewriting this node so we add it to the dict if self.pm is None or (new_src_n:=self.cached_pm_rewrite(new_n)) is None: @@ -985,7 +985,7 @@ class RewriteContext: else: # in stage 2, we link the result of new_n to the result of n try: self.replace[n] = self.replace[new_n] - except KeyError: raise RewriteNotReady # pylint: disable=raise-missing-from + except KeyError: raise RewriteNotReady except RewriteNotReady: # retry this later stack.insert(0, (n, stage, new_n))