mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix const bitcast should not be constant folded (#3743)
* fix const bitcast should not be constant folded * fixed const bf16 creation * LLVM still broken
This commit is contained in:
@@ -137,16 +137,14 @@ class TestBFloat16(unittest.TestCase):
|
||||
assert tnp.dtype == np.float32
|
||||
np.testing.assert_allclose(tnp, np.array(data))
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipIf(Device.DEFAULT=="LLVM", "broken for LLVM")
|
||||
def test_bf16_ones(self):
|
||||
# TODO: fix this with correct bfloat16 cast
|
||||
t = Tensor.ones(3, 5, dtype=dtypes.bfloat16)
|
||||
assert t.dtype == dtypes.bfloat16
|
||||
np.testing.assert_allclose(t.numpy(), np.ones((3, 5)))
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipIf(Device.DEFAULT=="LLVM", "broken for LLVM")
|
||||
def test_bf16_eye(self):
|
||||
# TODO: fix this with correct bfloat16 cast
|
||||
t = Tensor.eye(3, dtype=dtypes.bfloat16)
|
||||
assert t.dtype == dtypes.bfloat16
|
||||
np.testing.assert_allclose(t.numpy(), np.eye(3))
|
||||
@@ -256,6 +254,10 @@ class TestBitCast(unittest.TestCase):
|
||||
b = a.bitcast(dtypes.int32)
|
||||
assert b.numpy()[0] == 0x3f800000
|
||||
|
||||
def test_bitcast_const(self):
|
||||
a = Tensor(1, dtype=dtypes.float32).bitcast(dtypes.uint32)
|
||||
np.testing.assert_equal(a.numpy(), np.array(1, dtype=np.float32).view(np.uint32))
|
||||
|
||||
def test_bitcast_upcasted(self):
|
||||
a = Tensor.zeros(100, 4, dtype=dtypes.int32).contiguous() + 0x3f800000
|
||||
b = a.bitcast(dtypes.float32)
|
||||
|
||||
Reference in New Issue
Block a user