mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-11 07:05:04 -05:00
cannot really catch a spec change error without testing the new spec explicitly, but we don't intended to change the lazy spec lightly another possible way to catch reduce flopcounter shape would be type checking InterpretedFlopCounter and throw error if `in` results in `Never`
68 lines
2.4 KiB
Python
68 lines
2.4 KiB
Python
#!/usr/bin/env python
|
|
import unittest
|
|
from tinygrad import dtypes
|
|
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, get_lazyop_info, BufferOps, MemBuffer
|
|
from tinygrad.shape.shapetracker import ShapeTracker
|
|
|
|
class TestFlopCounter(unittest.TestCase):
|
|
def setUp(self):
|
|
self.buf0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float32, ShapeTracker.from_shape((4,))))
|
|
self.buf1 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,))))
|
|
self.buf2 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,4))))
|
|
|
|
def test_flops_sin(self):
|
|
op0 = LazyOp(UnaryOps.SIN, (self.buf0,), None)
|
|
info = get_lazyop_info(op0)
|
|
self.assertEqual(info.flops, 4)
|
|
|
|
def test_flops_add(self):
|
|
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
|
info = get_lazyop_info(op0)
|
|
self.assertEqual(info.flops, 4)
|
|
|
|
def test_flops_add_twice(self):
|
|
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
|
op1 = LazyOp(BinaryOps.ADD, (op0,self.buf1,), None)
|
|
info = get_lazyop_info(op1)
|
|
self.assertEqual(info.flops, 8)
|
|
|
|
def test_flops_add_self(self):
|
|
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
|
op1 = LazyOp(BinaryOps.ADD, (op0,op0,), None)
|
|
info = get_lazyop_info(op1)
|
|
self.assertEqual(info.flops, 8)
|
|
|
|
def test_flops_add_roundabout_self(self):
|
|
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
|
op1 = LazyOp(BinaryOps.ADD, (op0,self.buf1,), None)
|
|
op2 = LazyOp(BinaryOps.ADD, (op0,op1,), None)
|
|
info = get_lazyop_info(op2)
|
|
self.assertEqual(info.flops, 12)
|
|
|
|
def test_flops_red(self):
|
|
op0 = LazyOp(BinaryOps.MUL, (self.buf0,self.buf1,), None)
|
|
op1 = LazyOp(ReduceOps.SUM, (op0,), (0,))
|
|
op2 = LazyOp(BinaryOps.ADD, (op1, op1,), None)
|
|
info = get_lazyop_info(op2)
|
|
self.assertEqual(info.flops, 9)
|
|
|
|
def test_flops_sum1d(self):
|
|
op0 = LazyOp(ReduceOps.SUM, (self.buf0,), (0,))
|
|
info = get_lazyop_info(op0)
|
|
self.assertEqual(info.flops, 4)
|
|
self.assertEqual(info.shape, (1,))
|
|
|
|
def test_flops_sum2d(self):
|
|
op0 = LazyOp(ReduceOps.SUM, (self.buf2,), (0,))
|
|
info = get_lazyop_info(op0)
|
|
self.assertEqual(info.flops, 16)
|
|
self.assertEqual(info.shape, (1,4))
|
|
|
|
op1 = LazyOp(ReduceOps.SUM, (op0,), (1,))
|
|
info = get_lazyop_info(op1)
|
|
self.assertEqual(info.flops, 16+4)
|
|
self.assertEqual(info.shape, (1,1))
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|