mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add threefry const folding (#15787)
* prim threefry * test fix * clean test * cleanup * cleanup 2 * cleanup 3 * fix conflict markers in test_const_folding.py * update test * fix lint * use const instead of value for test
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import unittest, math
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.dtype import DTYPES_DICT
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.device import is_dtype_supported
|
||||
import numpy as np
|
||||
from test.helpers import not_support_multi_device
|
||||
@@ -163,6 +163,11 @@ class TestMultiConstFolding(unittest.TestCase):
|
||||
np.testing.assert_equal((t ** one).numpy(), np.arange(16))
|
||||
np.testing.assert_equal((one ** t).numpy(), [1] * 16)
|
||||
|
||||
class TestThreefryConstFolding(unittest.TestCase):
|
||||
def test_threefry(self):
|
||||
x = UOp.const(dtypes.uint64, 5, Device.DEFAULT, ()).threefry(UOp.const(dtypes.uint64, 10, Device.DEFAULT, ()))
|
||||
self.assertIs(x.simplify().op, Ops.CONST)
|
||||
|
||||
class TestTautologicalCompare(unittest.TestCase):
|
||||
# without const folding, these would have triggered -Wtautological-compare in clang
|
||||
def test_lt_false(self):
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections import defaultdict
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_lossless_cast, Invalid
|
||||
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, IMAGE, dedup
|
||||
from tinygrad.uop.decompositions import xpow
|
||||
from tinygrad.uop.decompositions import threefry2x32, xpow
|
||||
from tinygrad.uop.divandmod import div_and_mod_symbolic
|
||||
|
||||
# ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ********
|
||||
@@ -104,10 +104,11 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
||||
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) != UPat.var("x"),
|
||||
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
|
||||
# ** constant folding **
|
||||
# TODO: add const folding for Ops.THREEFRY
|
||||
(UPat(GroupOp.Unary, src=(UPat((Ops.VCONST, Ops.CONST)),), name="a"), lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg], False))),
|
||||
(UPat(GroupOp.Binary-{Ops.THREEFRY}, src=(UPat((Ops.VCONST, Ops.CONST)),)*2, name="a"),
|
||||
lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg, a.src[1].arg], False))),
|
||||
(UPat(Ops.THREEFRY, src=(UPat.cvar("x"), UPat.cvar("key")), name="a"),
|
||||
lambda a, x, key: a.const_like(threefry2x32(x, key).simplify().arg)),
|
||||
(UPat(GroupOp.Ternary, src=(UPat((Ops.VCONST, Ops.CONST)),)*3, name="a"),
|
||||
lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg, a.src[1].arg, a.src[2].arg], False))),
|
||||
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
|
||||
Reference in New Issue
Block a user