add test case for uop cast constant fold (#3746)

and a expected failed bitcast fold test case. Will fix with UOps.BITCAST refactor
This commit is contained in:
chenyu
2024-03-14 20:00:27 -04:00
committed by GitHub
parent 11c61ae044
commit 9a00a453c7

View File

@@ -1,10 +1,11 @@
from typing import Optional, Tuple, Any, List
import unittest, math
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.device import Buffer, Device
from tinygrad.device import Buffer, Device, CompiledASTRunner
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.device import CompiledASTRunner
from tinygrad.realize import create_schedule
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.codegen.uops import exec_alu, UOpGraph
from test.test_dtype import is_dtype_supported
@@ -176,5 +177,24 @@ class TestExecALU(TestUOps):
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (1, 1)), 2)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-128, 0)), -128)
class TestConstantFolding(unittest.TestCase):
def test_cast_const(self):
t = Tensor(1, dtype=dtypes.float).cast(dtypes.int)
si = create_schedule([t.lazydata])
assert len(si) == 1
si = si[0]
lin = Device[Device.DEFAULT].get_linearizer(si.ast[0]).linearize()
assert all(uop.uop is not UOps.CAST for uop in lin.uops.uops), f"{[uop.uop for uop in lin.uops.uops]} contains non-folded constant cast"
@unittest.expectedFailure
def test_bitcast_const(self):
# TODO: fix bitcast const should not fold
t = Tensor(1, dtype=dtypes.float).bitcast(dtypes.int)
si = create_schedule([t.lazydata])
assert len(si) == 1
si = si[0]
lin = Device[Device.DEFAULT].get_linearizer(si.ast[0]).linearize()
assert any(uop.uop is UOps.CAST for uop in lin.uops.uops), f"{[uop.uop for uop in lin.uops.uops]} does not contain bitcast"
if __name__ == '__main__':
unittest.main(verbosity=2)