mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-10 14:45:35 -05:00
* new uops is an actual graph
* it's way slower
* simpler
* fix define acc
* render_loop unique
* ops test pass
* add pattern matcher back, there's bugs
* rewrite
* use priority queue
* recursive children
* fix tests
* fix tests with SINK
* fix abstractions
* fix assembly
* simpler
* link define_acc
* fix DEFINE_ACC placement
* type verify
* full cmp
* fix cmp
* ACCESS_ACC
* insert DEFINE_ACC
* fix PHI
* recursive rewrite
* fix many tests
* sum collapse
* more patterns
* correct change
* fold arange
* fix that lin test
* space
* big folding rule works
* close
* has more maxes, meh
* cached node replace
* set changed
* simplest folding yet
* works
* works
* DIV
* all tests pass
* del
* fuzz linearizer fails
* sum_collapse
* test depth 2 cf
* fix lin test 14
* fix clang depth
* disable that
* failure 14 is fixed
* fix ptx
* failure 27 is fixed
* fix llama
* run_cnt
* Revert "Optimize PTX gated loads index calculation (#4304)"
This reverts commit d97d5a7689.
* fix uops loop
* fix ptx bugs
* add barrier
* print
* mem_type in ptx direct
* bypass tests that fail in CI but pass locally
* ptx remove ptr_ar
* more ptx passing
* fix ptx tests
* assert compile support
* remove model inference benchmark from red
70 lines
2.5 KiB
Python
70 lines
2.5 KiB
Python
import unittest
|
|
from tinygrad import dtypes, Variable
|
|
from tinygrad.ops import BinaryOps, TernaryOps
|
|
from tinygrad.codegen.uops import UOpGraph, UOps
|
|
|
|
class TestUOpGraph(unittest.TestCase):
|
|
def test_add_constant_fold(self):
|
|
g = UOpGraph()
|
|
c1 = g.add(UOps.CONST, dtypes.float, arg=1.0)
|
|
c2 = g.add(UOps.CONST, dtypes.float, arg=2.0)
|
|
out = g.add(UOps.ALU, dtypes.float, (c1, c2), BinaryOps.ADD)
|
|
g.add(UOps.SINK, None, (out,))
|
|
self.assertEqual(len(g.uops), 1)
|
|
out = g.uops[-1]
|
|
self.assertEqual(out.uop, UOps.CONST)
|
|
self.assertEqual(out.arg, 3.0)
|
|
|
|
def test_where_same_fold(self):
|
|
g = UOpGraph()
|
|
v = g.add(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1))
|
|
c0 = g.add(UOps.CONST, dtypes.int, arg=0)
|
|
vc = g.add(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPEQ)
|
|
c1 = g.add(UOps.CONST, dtypes.float, arg=1.0)
|
|
out = g.add(UOps.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE)
|
|
g.add(UOps.SINK, None, (out,))
|
|
self.assertEqual(len(g.uops), 1)
|
|
out = g.uops[-1]
|
|
self.assertEqual(out.uop, UOps.CONST)
|
|
self.assertEqual(out.arg, 1.0)
|
|
|
|
def test_where_const_fold(self):
|
|
g = UOpGraph()
|
|
bf = g.add(UOps.CONST, dtypes.bool, arg=False)
|
|
c1 = g.add(UOps.CONST, dtypes.float, arg=1.0)
|
|
c2 = g.add(UOps.CONST, dtypes.float, arg=2.0)
|
|
out = g.add(UOps.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE)
|
|
g.add(UOps.SINK, None, (out,))
|
|
self.assertEqual(len(g.uops), 1)
|
|
out = g.uops[-1]
|
|
self.assertEqual(out.uop, UOps.CONST)
|
|
self.assertEqual(out.arg, 2.0)
|
|
|
|
def test_const_cast(self):
|
|
g = UOpGraph()
|
|
bf = g.add(UOps.CONST, dtypes.bool, arg=False)
|
|
out = g.add(UOps.CAST, dtypes.int, (bf,))
|
|
g.add(UOps.SINK, None, (out,))
|
|
self.assertEqual(len(g.uops), 1)
|
|
out = g.uops[-1]
|
|
self.assertEqual(out.uop, UOps.CONST)
|
|
self.assertEqual(out.arg, 0)
|
|
|
|
def test_depth_2_const_fold(self):
|
|
g = UOpGraph()
|
|
v = g.add(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1))
|
|
c2 = g.add(UOps.CONST, dtypes.int, arg=2)
|
|
c4 = g.add(UOps.CONST, dtypes.int, arg=4)
|
|
vc = g.add(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
|
|
out = g.add(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
|
|
g.add(UOps.SINK, None, (out,))
|
|
self.assertEqual(len(g.uops), 3)
|
|
out = g.uops[-1]
|
|
self.assertEqual(out.uop, UOps.ALU)
|
|
self.assertEqual(out.arg, BinaryOps.ADD)
|
|
self.assertEqual(out.vin[1].uop, UOps.CONST)
|
|
self.assertEqual(out.vin[1].arg, 6)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main(verbosity=2)
|