add threefry const folding (#15787)

* prim threefry

* test fix

* clean test

* cleanup

* cleanup 2

* cleanup 3

* fix conflict markers in test_const_folding.py

* update test

* fix lint

* use const instead of value for test
This commit is contained in:
oxrinz
2026-04-20 03:30:03 +02:00
committed by GitHub
parent b05b1010bf
commit f551a4bded
2 changed files with 9 additions and 3 deletions

View File

@@ -1,7 +1,7 @@
import unittest, math
from tinygrad import Tensor, Device, dtypes
from tinygrad.dtype import DTYPES_DICT
from tinygrad.uop.ops import Ops
from tinygrad.uop.ops import Ops, UOp
from tinygrad.device import is_dtype_supported
import numpy as np
from test.helpers import not_support_multi_device
@@ -163,6 +163,11 @@ class TestMultiConstFolding(unittest.TestCase):
np.testing.assert_equal((t ** one).numpy(), np.arange(16))
np.testing.assert_equal((one ** t).numpy(), [1] * 16)
class TestThreefryConstFolding(unittest.TestCase):
def test_threefry(self):
x = UOp.const(dtypes.uint64, 5, Device.DEFAULT, ()).threefry(UOp.const(dtypes.uint64, 10, Device.DEFAULT, ()))
self.assertIs(x.simplify().op, Ops.CONST)
class TestTautologicalCompare(unittest.TestCase):
# without const folding, these would have triggered -Wtautological-compare in clang
def test_lt_false(self):

View File

@@ -4,7 +4,7 @@ from collections import defaultdict
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_lossless_cast, Invalid
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, IMAGE, dedup
from tinygrad.uop.decompositions import xpow
from tinygrad.uop.decompositions import threefry2x32, xpow
from tinygrad.uop.divandmod import div_and_mod_symbolic
# ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ********
@@ -104,10 +104,11 @@ symbolic_simple = propagate_invalid + PatternMatcher([
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) != UPat.var("x"),
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
# ** constant folding **
# TODO: add const folding for Ops.THREEFRY
(UPat(GroupOp.Unary, src=(UPat((Ops.VCONST, Ops.CONST)),), name="a"), lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg], False))),
(UPat(GroupOp.Binary-{Ops.THREEFRY}, src=(UPat((Ops.VCONST, Ops.CONST)),)*2, name="a"),
lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg, a.src[1].arg], False))),
(UPat(Ops.THREEFRY, src=(UPat.cvar("x"), UPat.cvar("key")), name="a"),
lambda a, x, key: a.const_like(threefry2x32(x, key).simplify().arg)),
(UPat(GroupOp.Ternary, src=(UPat((Ops.VCONST, Ops.CONST)),)*3, name="a"),
lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg, a.src[1].arg, a.src[2].arg], False))),
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly