ban __bool__ on Tensor (#3632)

* ban __bool__ on Tensor

avoid misuse

* test case

* fix tests

* fix more tests
This commit is contained in:
chenyu
2024-03-06 17:12:35 -05:00
committed by GitHub
parent 81baf3eed3
commit 8f10bfa2ff
7 changed files with 19 additions and 7 deletions

View File

@@ -124,7 +124,7 @@ class TestSafetensors(unittest.TestCase):
path = temp(f"ones.{dtype}.safetensors")
ones = Tensor.rand((10,10), dtype=dtype)
safe_save(get_state_dict(ones), path)
assert ones == list(safe_load(path).values())[0]
np.testing.assert_equal(ones.numpy(), list(safe_load(path).values())[0].numpy())
def test_load_supported_types(self):
import torch