diff --git a/test/backend/test_const_folding.py b/test/backend/test_const_folding.py index e6267d1b72..2e2095f75a 100644 --- a/test/backend/test_const_folding.py +++ b/test/backend/test_const_folding.py @@ -1,7 +1,7 @@ import unittest, math from tinygrad import Tensor, Device, dtypes from tinygrad.dtype import DTYPES_DICT -from tinygrad.uop.ops import Ops +from tinygrad.uop.ops import Ops, UOp from tinygrad.device import is_dtype_supported import numpy as np from test.helpers import not_support_multi_device @@ -163,6 +163,11 @@ class TestMultiConstFolding(unittest.TestCase): np.testing.assert_equal((t ** one).numpy(), np.arange(16)) np.testing.assert_equal((one ** t).numpy(), [1] * 16) +class TestThreefryConstFolding(unittest.TestCase): + def test_threefry(self): + x = UOp.const(dtypes.uint64, 5, Device.DEFAULT, ()).threefry(UOp.const(dtypes.uint64, 10, Device.DEFAULT, ())) + self.assertIs(x.simplify().op, Ops.CONST) + class TestTautologicalCompare(unittest.TestCase): # without const folding, these would have triggered -Wtautological-compare in clang def test_lt_false(self): diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index f1b7d0e66e..4f8aeb358f 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -4,7 +4,7 @@ from collections import defaultdict from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu from tinygrad.dtype import ConstType, dtypes, PtrDType, can_lossless_cast, Invalid from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, IMAGE, dedup -from tinygrad.uop.decompositions import xpow +from tinygrad.uop.decompositions import threefry2x32, xpow from tinygrad.uop.divandmod import div_and_mod_symbolic # ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ******** @@ -104,10 +104,11 @@ symbolic_simple = propagate_invalid + PatternMatcher([ (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) != UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints) # ** constant folding ** - # TODO: add const folding for Ops.THREEFRY (UPat(GroupOp.Unary, src=(UPat((Ops.VCONST, Ops.CONST)),), name="a"), lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg], False))), (UPat(GroupOp.Binary-{Ops.THREEFRY}, src=(UPat((Ops.VCONST, Ops.CONST)),)*2, name="a"), lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg, a.src[1].arg], False))), + (UPat(Ops.THREEFRY, src=(UPat.cvar("x"), UPat.cvar("key")), name="a"), + lambda a, x, key: a.const_like(threefry2x32(x, key).simplify().arg)), (UPat(GroupOp.Ternary, src=(UPat((Ops.VCONST, Ops.CONST)),)*3, name="a"), lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg, a.src[1].arg, a.src[2].arg], False))), # bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly