patch to remove hack from stable_diffusion.py (#1814)

* patch to remove hack from stable_diffusion.py

* sorry linter

* realize after assign?

* float16 broken in llvmlite use float64 for now

* int32

* idiot forgot to change test array dtype
This commit is contained in:
segf00lt
2023-09-08 13:26:50 -03:00
committed by GitHub
parent ebcda8a714
commit 9e8c1dbf34
3 changed files with 13 additions and 4 deletions

View File

@@ -579,9 +579,7 @@ if __name__ == "__main__":
if args.fp16:
for l in get_state_dict(model).values():
fp16_buf = l.cast(dtypes.float16).realize()
l.lazydata.realized = None # TODO: why is this needed in order to trigger the free?
l.assign(fp16_buf)
l.assign(l.cast(dtypes.float16).realize())
# run through CLIP to get context
tokenizer = ClipTokenizer()

View File

@@ -5,6 +5,7 @@ from tinygrad.tensor import Tensor
from tinygrad.lazy import LAZY
from tinygrad.ops import GlobalCounters
from tinygrad.graph import nm
from tinygrad.helpers import dtypes
N = 200 # has to be bigger than the cache to fail
@@ -63,5 +64,15 @@ class TestAssign(unittest.TestCase):
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?
def test_cast_assignment(self):
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
a.realize()
oba1 = a.lazydata.output_buffer
a.assign(a.cast(dtypes.int32).realize())
a.realize()
oba2 = a.lazydata.output_buffer
assert oba1 is None and oba2 is None
np.testing.assert_allclose(a.numpy(), np.arange(N*N,dtype=np.int32).reshape((N,N)))
if __name__ == "__main__":
unittest.main()

View File

@@ -96,7 +96,7 @@ class Tensor:
assert self.shape == x.shape and self.device == x.device, f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}"
assert not x.requires_grad # self requires_grad is okay?
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
if self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized
if self.dtype == x.dtype and self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized
self.lazydata = x.lazydata
return self