Files
tinygrad/test/unit/test_rewrite_bottom_up_gate.py
Sieds Lykles 572a3c15c6 Move Ops.SPECIAL arg to src (#11918)
* initial moving bound to src

* arg to src

* remove import

* fixup linearizer

* arg to src

* fix test_uop_graph

* fix more tests

* fix python renderer

* get const value from const uop

* ssimplify uop estimates

* fix webgpu locals

* fix old test

* gate Ops.SPECIAL in linearizer

* use ssimplify() for local/global_size

* remove toposort gate_parents_instead_of_self

* fix rendering in comment

* cleanup

* rename and add comments

* add BottomUpGate with test
2025-09-04 09:31:44 +02:00

29 lines
913 B
Python

import unittest
from tinygrad.uop.ops import PatternMatcher, UOp, graph_rewrite, Ops, UPat, BottomUpGate
def assert_not_reached(): assert False, "This function should not be reached"
def gate(): raise BottomUpGate
class TestBottomUpGate(unittest.TestCase):
def test_basic_bottom_up_gate(self):
"""Test that BottomUpGate stops bottom-up"""
pm = PatternMatcher([
(UPat(Ops.ADD), gate),
(UPat(Ops.MUL), assert_not_reached)
])
a,b,c = UOp.variable("a",0,10), UOp.variable("b",0,10), UOp.variable("c",0,10)
graph_rewrite((a*a)+(b*c), pm, bottom_up=True)
def test_bottom_up_gate_with_rewriting(self):
pm = PatternMatcher([
(UPat.var("a")+UPat.var("a"), lambda a: 2*a),
(UPat(Ops.MUL), gate),
(UPat(Ops.CONST), assert_not_reached)
])
a = UOp.variable("a",0,10)
graph_rewrite(a+a, pm, bottom_up=True)
if __name__ == "__main__":
unittest.main()