add Tensor.replace (#3738)

* add Tensor.replace

* fix dtypes in that test

* should be replace

* and mixtral
This commit is contained in:
George Hotz
2024-03-14 13:34:14 -07:00
committed by GitHub
parent 0ead0bdb65
commit 3527c5a9d2
9 changed files with 25 additions and 16 deletions

View File

@@ -114,6 +114,7 @@ class TestAssign(unittest.TestCase):
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?
@unittest.skip("don't use output buffer, and mismatch dtype no longer supported")
def test_cast_assignment(self):
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
a.realize()