mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
* initial moving bound to src * arg to src * remove import * fixup linearizer * arg to src * fix test_uop_graph * fix more tests * fix python renderer * get const value from const uop * ssimplify uop estimates * fix webgpu locals * fix old test * gate Ops.SPECIAL in linearizer * use ssimplify() for local/global_size * remove toposort gate_parents_instead_of_self * fix rendering in comment * cleanup * rename and add comments * add BottomUpGate with test
29 lines
913 B
Python
29 lines
913 B
Python
import unittest
|
|
from tinygrad.uop.ops import PatternMatcher, UOp, graph_rewrite, Ops, UPat, BottomUpGate
|
|
|
|
def assert_not_reached(): assert False, "This function should not be reached"
|
|
def gate(): raise BottomUpGate
|
|
|
|
class TestBottomUpGate(unittest.TestCase):
|
|
def test_basic_bottom_up_gate(self):
|
|
"""Test that BottomUpGate stops bottom-up"""
|
|
pm = PatternMatcher([
|
|
(UPat(Ops.ADD), gate),
|
|
(UPat(Ops.MUL), assert_not_reached)
|
|
])
|
|
|
|
a,b,c = UOp.variable("a",0,10), UOp.variable("b",0,10), UOp.variable("c",0,10)
|
|
graph_rewrite((a*a)+(b*c), pm, bottom_up=True)
|
|
|
|
def test_bottom_up_gate_with_rewriting(self):
|
|
pm = PatternMatcher([
|
|
(UPat.var("a")+UPat.var("a"), lambda a: 2*a),
|
|
(UPat(Ops.MUL), gate),
|
|
(UPat(Ops.CONST), assert_not_reached)
|
|
])
|
|
a = UOp.variable("a",0,10)
|
|
graph_rewrite(a+a, pm, bottom_up=True)
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|