mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
add test case for uop cast constant fold (#3746)
and a expected failed bitcast fold test case. Will fix with UOps.BITCAST refactor
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
from typing import Optional, Tuple, Any, List
|
||||
import unittest, math
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.device import Buffer, Device, CompiledASTRunner
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.device import CompiledASTRunner
|
||||
from tinygrad.realize import create_schedule
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.codegen.uops import exec_alu, UOpGraph
|
||||
from test.test_dtype import is_dtype_supported
|
||||
@@ -176,5 +177,24 @@ class TestExecALU(TestUOps):
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (1, 1)), 2)
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-128, 0)), -128)
|
||||
|
||||
class TestConstantFolding(unittest.TestCase):
|
||||
def test_cast_const(self):
|
||||
t = Tensor(1, dtype=dtypes.float).cast(dtypes.int)
|
||||
si = create_schedule([t.lazydata])
|
||||
assert len(si) == 1
|
||||
si = si[0]
|
||||
lin = Device[Device.DEFAULT].get_linearizer(si.ast[0]).linearize()
|
||||
assert all(uop.uop is not UOps.CAST for uop in lin.uops.uops), f"{[uop.uop for uop in lin.uops.uops]} contains non-folded constant cast"
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_bitcast_const(self):
|
||||
# TODO: fix bitcast const should not fold
|
||||
t = Tensor(1, dtype=dtypes.float).bitcast(dtypes.int)
|
||||
si = create_schedule([t.lazydata])
|
||||
assert len(si) == 1
|
||||
si = si[0]
|
||||
lin = Device[Device.DEFAULT].get_linearizer(si.ast[0]).linearize()
|
||||
assert any(uop.uop is UOps.CAST for uop in lin.uops.uops), f"{[uop.uop for uop in lin.uops.uops]} does not contain bitcast"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user