tests from postopt (#11964)

* tests from postopt

* reraise is fine
This commit is contained in:
George Hotz
2025-09-02 13:34:17 -07:00
committed by GitHub
parent b977ec0813
commit 550cf2ca7f
5 changed files with 42 additions and 6 deletions

View File

@@ -54,11 +54,12 @@ confidence=
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=C,R,W0613,W0511,W0212,W0201,W0106,W0603,W0621,W0703,W1201,W1203,E1136,W1514,E1101,W0221,W0105,E0401,abstract-method
disable=C,R,W0613,W0511,W0212,W0201,W0106,W0603,W0621,W0703,W1201,W1203,E1136,W1514,E1101,W0221,W0105,E0401,abstract-method,W0707
# E1101 for function binding
# W0221 for Function class
# W0105 for comment strings
# E0401 for missing imports
# W0707 for not reraising
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option

22
test/test_opts.py Normal file
View File

@@ -0,0 +1,22 @@
import unittest
from tinygrad import Tensor, Device
from tinygrad.helpers import RANGEIFY
from tinygrad.codegen.opt.kernel import Opt, OptOps
from tinygrad.engine.realize import get_program
@unittest.skipIf(RANGEIFY>0, "arg is partial contig in rangeify")
class TestOpts(unittest.TestCase):
def test_opt_upcast(self):
opts = (Opt(OptOps.UPCAST, 0, 4),)
a = Tensor.empty(16)
b = Tensor.empty(16)
out = (a+b).contiguous(arg=opts)
s = out.schedule()
self.assertEqual(s[-1].ast.arg.opts_to_apply, opts)
if Device.DEFAULT in {"CPU", "GPU", "METAL"}:
prg = get_program(s[-1].ast)
self.assertIn('float4', prg.src)
if __name__ == '__main__':
unittest.main()

View File

@@ -1,7 +1,7 @@
# basic self-contained tests of the external functionality of tinygrad
import unittest, random
from tinygrad import Tensor, Context, Variable, TinyJit, dtypes, Device, nn
from tinygrad.helpers import IMAGE, CI
from tinygrad.helpers import IMAGE, CI, getenv
class TestTiny(unittest.TestCase):
@@ -27,7 +27,7 @@ class TestTiny(unittest.TestCase):
out = Tensor.ones(256).contiguous().sum()
self.assertEqual(out.item(), 256)
def test_gemm(self, N=64, out_dtype=dtypes.float):
def test_gemm(self, N=getenv("GEMM_N", 64), out_dtype=dtypes.float):
a = Tensor.ones(N,N).contiguous()
b = Tensor.eye(N).contiguous()
lst = (out:=a@b).tolist()
@@ -36,6 +36,14 @@ class TestTiny(unittest.TestCase):
self.assertEqual(lst[y][x], 1.0, msg=f"mismatch at ({y},{x})")
if IMAGE < 2: self.assertEqual(out.dtype, out_dtype)
def test_gemv(self, N=getenv("GEMV_N", 64), out_dtype=dtypes.float):
a = Tensor.ones(1,N).contiguous()
b = Tensor.eye(N).contiguous()
lst = (out:=a@b).tolist()
for x in range(N):
self.assertEqual(lst[0][x], 1.0, msg=f"mismatch at {x}")
if IMAGE < 2: self.assertEqual(out.dtype, out_dtype)
# *** randomness ***
def test_random(self):

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve
from tinygrad.uop.ops import track_rewrites, _substitute
from tinygrad.uop.ops import track_rewrites, _substitute, KernelInfo
from tinygrad.uop.spec import type_verify, tensor_uop_spec
from tinygrad.uop.symbolic import symbolic_simple
from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
@@ -8,6 +8,7 @@ from tinygrad.dtype import ImageDType
from tinygrad.schedule.multi import multi_pm
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
from tinygrad.codegen.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop
from tinygrad.codegen.opt.kernel import Opt
# creation can recurse a lot
import sys
@@ -154,6 +155,10 @@ def unbind_view(x:UOp):
return None
replace_buffers = PatternMatcher([
# sink on contig creates a KernelInfo
(UPat(Ops.CONTIGUOUS, name="c").sink(name="s"),
lambda s,c: s.replace(src=(c.replace(arg=None),), arg=KernelInfo(opts_to_apply=c.arg)) \
if s.arg is None and c.arg is not None and isinstance(c.arg[0], Opt) else None),
# replace ASSIGN with the target BUFFER
(UPat(Ops.ASSIGN, src=(UPat((Ops.BUFFER, Ops.LOAD)), UPat(Ops.KERNEL)), name="assign", allow_any_len=True), lambda assign: assign.src[0]),
# HACK: select the 0 branch of MSTACK (the device is wrong after this, is that okay?)

View File

@@ -970,7 +970,7 @@ class RewriteContext:
for x in reversed(new_n.src): stack.append((x, 0, x))
elif stage == 1:
try: new_src = tuple([self.replace[x] for x in new_n.src])
except KeyError: raise RewriteNotReady # pylint: disable=raise-missing-from
except KeyError: raise RewriteNotReady
if new_src == new_n.src:
# if top down, do the rewrite. if no rewrite or bottom up, we are done rewriting this node so we add it to the dict
if self.pm is None or (new_src_n:=self.cached_pm_rewrite(new_n)) is None:
@@ -985,7 +985,7 @@ class RewriteContext:
else:
# in stage 2, we link the result of new_n to the result of n
try: self.replace[n] = self.replace[new_n]
except KeyError: raise RewriteNotReady # pylint: disable=raise-missing-from
except KeyError: raise RewriteNotReady
except RewriteNotReady:
# retry this later
stack.insert(0, (n, stage, new_n))