mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
@@ -54,11 +54,12 @@ confidence=
|
||||
# --enable=similarities". If you want to run only the classes checker, but have
|
||||
# no Warning level messages displayed, use"--disable=all --enable=classes
|
||||
# --disable=W"
|
||||
disable=C,R,W0613,W0511,W0212,W0201,W0106,W0603,W0621,W0703,W1201,W1203,E1136,W1514,E1101,W0221,W0105,E0401,abstract-method
|
||||
disable=C,R,W0613,W0511,W0212,W0201,W0106,W0603,W0621,W0703,W1201,W1203,E1136,W1514,E1101,W0221,W0105,E0401,abstract-method,W0707
|
||||
# E1101 for function binding
|
||||
# W0221 for Function class
|
||||
# W0105 for comment strings
|
||||
# E0401 for missing imports
|
||||
# W0707 for not reraising
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
|
||||
22
test/test_opts.py
Normal file
22
test/test_opts.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.helpers import RANGEIFY
|
||||
from tinygrad.codegen.opt.kernel import Opt, OptOps
|
||||
from tinygrad.engine.realize import get_program
|
||||
|
||||
@unittest.skipIf(RANGEIFY>0, "arg is partial contig in rangeify")
|
||||
class TestOpts(unittest.TestCase):
|
||||
def test_opt_upcast(self):
|
||||
opts = (Opt(OptOps.UPCAST, 0, 4),)
|
||||
a = Tensor.empty(16)
|
||||
b = Tensor.empty(16)
|
||||
out = (a+b).contiguous(arg=opts)
|
||||
s = out.schedule()
|
||||
self.assertEqual(s[-1].ast.arg.opts_to_apply, opts)
|
||||
if Device.DEFAULT in {"CPU", "GPU", "METAL"}:
|
||||
prg = get_program(s[-1].ast)
|
||||
self.assertIn('float4', prg.src)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# basic self-contained tests of the external functionality of tinygrad
|
||||
import unittest, random
|
||||
from tinygrad import Tensor, Context, Variable, TinyJit, dtypes, Device, nn
|
||||
from tinygrad.helpers import IMAGE, CI
|
||||
from tinygrad.helpers import IMAGE, CI, getenv
|
||||
|
||||
class TestTiny(unittest.TestCase):
|
||||
|
||||
@@ -27,7 +27,7 @@ class TestTiny(unittest.TestCase):
|
||||
out = Tensor.ones(256).contiguous().sum()
|
||||
self.assertEqual(out.item(), 256)
|
||||
|
||||
def test_gemm(self, N=64, out_dtype=dtypes.float):
|
||||
def test_gemm(self, N=getenv("GEMM_N", 64), out_dtype=dtypes.float):
|
||||
a = Tensor.ones(N,N).contiguous()
|
||||
b = Tensor.eye(N).contiguous()
|
||||
lst = (out:=a@b).tolist()
|
||||
@@ -36,6 +36,14 @@ class TestTiny(unittest.TestCase):
|
||||
self.assertEqual(lst[y][x], 1.0, msg=f"mismatch at ({y},{x})")
|
||||
if IMAGE < 2: self.assertEqual(out.dtype, out_dtype)
|
||||
|
||||
def test_gemv(self, N=getenv("GEMV_N", 64), out_dtype=dtypes.float):
|
||||
a = Tensor.ones(1,N).contiguous()
|
||||
b = Tensor.eye(N).contiguous()
|
||||
lst = (out:=a@b).tolist()
|
||||
for x in range(N):
|
||||
self.assertEqual(lst[0][x], 1.0, msg=f"mismatch at {x}")
|
||||
if IMAGE < 2: self.assertEqual(out.dtype, out_dtype)
|
||||
|
||||
# *** randomness ***
|
||||
|
||||
def test_random(self):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve
|
||||
from tinygrad.uop.ops import track_rewrites, _substitute
|
||||
from tinygrad.uop.ops import track_rewrites, _substitute, KernelInfo
|
||||
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
||||
from tinygrad.uop.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
|
||||
@@ -8,6 +8,7 @@ from tinygrad.dtype import ImageDType
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
|
||||
from tinygrad.codegen.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop
|
||||
from tinygrad.codegen.opt.kernel import Opt
|
||||
|
||||
# creation can recurse a lot
|
||||
import sys
|
||||
@@ -154,6 +155,10 @@ def unbind_view(x:UOp):
|
||||
return None
|
||||
|
||||
replace_buffers = PatternMatcher([
|
||||
# sink on contig creates a KernelInfo
|
||||
(UPat(Ops.CONTIGUOUS, name="c").sink(name="s"),
|
||||
lambda s,c: s.replace(src=(c.replace(arg=None),), arg=KernelInfo(opts_to_apply=c.arg)) \
|
||||
if s.arg is None and c.arg is not None and isinstance(c.arg[0], Opt) else None),
|
||||
# replace ASSIGN with the target BUFFER
|
||||
(UPat(Ops.ASSIGN, src=(UPat((Ops.BUFFER, Ops.LOAD)), UPat(Ops.KERNEL)), name="assign", allow_any_len=True), lambda assign: assign.src[0]),
|
||||
# HACK: select the 0 branch of MSTACK (the device is wrong after this, is that okay?)
|
||||
|
||||
@@ -970,7 +970,7 @@ class RewriteContext:
|
||||
for x in reversed(new_n.src): stack.append((x, 0, x))
|
||||
elif stage == 1:
|
||||
try: new_src = tuple([self.replace[x] for x in new_n.src])
|
||||
except KeyError: raise RewriteNotReady # pylint: disable=raise-missing-from
|
||||
except KeyError: raise RewriteNotReady
|
||||
if new_src == new_n.src:
|
||||
# if top down, do the rewrite. if no rewrite or bottom up, we are done rewriting this node so we add it to the dict
|
||||
if self.pm is None or (new_src_n:=self.cached_pm_rewrite(new_n)) is None:
|
||||
@@ -985,7 +985,7 @@ class RewriteContext:
|
||||
else:
|
||||
# in stage 2, we link the result of new_n to the result of n
|
||||
try: self.replace[n] = self.replace[new_n]
|
||||
except KeyError: raise RewriteNotReady # pylint: disable=raise-missing-from
|
||||
except KeyError: raise RewriteNotReady
|
||||
except RewriteNotReady:
|
||||
# retry this later
|
||||
stack.insert(0, (n, stage, new_n))
|
||||
|
||||
Reference in New Issue
Block a user