mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
patch to remove hack from stable_diffusion.py (#1814)
* patch to remove hack from stable_diffusion.py * sorry linter * realize after assign? * float16 broken in llvmlite use float64 for now * int32 * idiot forgot to change test array dtype
This commit is contained in:
@@ -5,6 +5,7 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import LAZY
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.graph import nm
|
||||
from tinygrad.helpers import dtypes
|
||||
|
||||
N = 200 # has to be bigger than the cache to fail
|
||||
|
||||
@@ -63,5 +64,15 @@ class TestAssign(unittest.TestCase):
|
||||
|
||||
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?
|
||||
|
||||
def test_cast_assignment(self):
|
||||
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
||||
a.realize()
|
||||
oba1 = a.lazydata.output_buffer
|
||||
a.assign(a.cast(dtypes.int32).realize())
|
||||
a.realize()
|
||||
oba2 = a.lazydata.output_buffer
|
||||
assert oba1 is None and oba2 is None
|
||||
np.testing.assert_allclose(a.numpy(), np.arange(N*N,dtype=np.int32).reshape((N,N)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user