mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
patch to remove hack from stable_diffusion.py (#1814)
* patch to remove hack from stable_diffusion.py * sorry linter * realize after assign? * float16 broken in llvmlite use float64 for now * int32 * idiot forgot to change test array dtype
This commit is contained in:
@@ -579,9 +579,7 @@ if __name__ == "__main__":
|
||||
|
||||
if args.fp16:
|
||||
for l in get_state_dict(model).values():
|
||||
fp16_buf = l.cast(dtypes.float16).realize()
|
||||
l.lazydata.realized = None # TODO: why is this needed in order to trigger the free?
|
||||
l.assign(fp16_buf)
|
||||
l.assign(l.cast(dtypes.float16).realize())
|
||||
|
||||
# run through CLIP to get context
|
||||
tokenizer = ClipTokenizer()
|
||||
|
||||
@@ -5,6 +5,7 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import LAZY
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.graph import nm
|
||||
from tinygrad.helpers import dtypes
|
||||
|
||||
N = 200 # has to be bigger than the cache to fail
|
||||
|
||||
@@ -63,5 +64,15 @@ class TestAssign(unittest.TestCase):
|
||||
|
||||
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?
|
||||
|
||||
def test_cast_assignment(self):
|
||||
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
||||
a.realize()
|
||||
oba1 = a.lazydata.output_buffer
|
||||
a.assign(a.cast(dtypes.int32).realize())
|
||||
a.realize()
|
||||
oba2 = a.lazydata.output_buffer
|
||||
assert oba1 is None and oba2 is None
|
||||
np.testing.assert_allclose(a.numpy(), np.arange(N*N,dtype=np.int32).reshape((N,N)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -96,7 +96,7 @@ class Tensor:
|
||||
assert self.shape == x.shape and self.device == x.device, f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}"
|
||||
assert not x.requires_grad # self requires_grad is okay?
|
||||
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
|
||||
if self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized
|
||||
if self.dtype == x.dtype and self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized
|
||||
self.lazydata = x.lazydata
|
||||
return self
|
||||
|
||||
|
||||
Reference in New Issue
Block a user