Files
tinygrad/test/unit/test_flopcounter.py
2023-05-24 23:17:57 -07:00

50 lines
1.7 KiB
Python

#!/usr/bin/env python
import unittest
from typing import NamedTuple, Tuple
from tinygrad.ops import LazyOp, BinaryOps, ReduceOps, get_lazyop_info
from tinygrad.helpers import DType, dtypes
class TestBuffer(NamedTuple):
__test__ = False # To prevent pytest from collecting this as a test
shape: Tuple[int, ...]
dtype: DType
class TestFlopCounter(unittest.TestCase):
def setUp(self):
self.buf0 = TestBuffer(shape=(4,), dtype=dtypes.float32)
self.buf1 = TestBuffer(shape=(4,), dtype=dtypes.float32)
def test_flops_add(self):
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
info = get_lazyop_info(op0)
self.assertEqual(info.flops, 4)
def test_flops_add_twice(self):
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
op1 = LazyOp(BinaryOps.ADD, (op0,self.buf1,), None)
info = get_lazyop_info(op1)
self.assertEqual(info.flops, 8)
def test_flops_add_self(self):
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
op1 = LazyOp(BinaryOps.ADD, (op0,op0,), None)
info = get_lazyop_info(op1)
self.assertEqual(info.flops, 8)
def test_flops_add_roundabout_self(self):
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
op1 = LazyOp(BinaryOps.ADD, (op0,self.buf1,), None)
op2 = LazyOp(BinaryOps.ADD, (op0,op1,), None)
info = get_lazyop_info(op2)
self.assertEqual(info.flops, 12)
def test_flops_red(self):
op0 = LazyOp(BinaryOps.MUL, (self.buf0,self.buf1,), None)
op1 = LazyOp(ReduceOps.SUM, (op0,), (1,))
op2 = LazyOp(BinaryOps.ADD, (op1, op1,), None)
info = get_lazyop_info(op2)
self.assertEqual(info.flops, 9)
if __name__ == '__main__':
unittest.main()