From 3527c5a9d2e6de99ccba6d954472d5204ab2998a Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 14 Mar 2024 13:34:14 -0700 Subject: [PATCH] add Tensor.replace (#3738) * add Tensor.replace * fix dtypes in that test * should be replace * and mixtral --- examples/gpt2.py | 2 +- examples/mixtral.py | 2 +- examples/stable_diffusion.py | 2 +- test/test_assign.py | 1 + test/test_jit.py | 8 ++++---- test/test_multitensor.py | 2 +- test/testextra/test_lr_scheduler.py | 2 +- tinygrad/nn/state.py | 2 +- tinygrad/tensor.py | 20 ++++++++++++++------ 9 files changed, 25 insertions(+), 16 deletions(-) diff --git a/examples/gpt2.py b/examples/gpt2.py index 9c6c9b7f62..7255bcf535 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -141,7 +141,7 @@ class GPT2: if HALF: for l in get_state_dict(model).values(): - l.assign(l.half().realize()) + l.replace(l.half().realize()) return GPT2(model, tokenizer) diff --git a/examples/mixtral.py b/examples/mixtral.py index be1d203fe6..7bb5314abf 100644 --- a/examples/mixtral.py +++ b/examples/mixtral.py @@ -43,7 +43,7 @@ if __name__ == "__main__": device = Device.DEFAULT t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, loading {k} to {device}") # NOTE: we have to copy through CLANG to avoid the HIP hang bug when copying directly from the DISK - model_state_dict[k].assign(state[k].to("CLANG").contiguous().to(device).half()).realize() + model_state_dict[k].replace(state[k].to("CLANG").contiguous().to(device).half()).realize() if CI: print(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB") from sentencepiece import SentencePieceProcessor diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 0edf43489d..96b9c744ed 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -602,7 +602,7 @@ if __name__ == "__main__": if args.fp16: for l in get_state_dict(model).values(): - l.assign(l.cast(dtypes.float16).realize()) + l.replace(l.cast(dtypes.float16).realize()) # run through CLIP to get context tokenizer = ClipTokenizer() diff --git a/test/test_assign.py b/test/test_assign.py index 24db950a13..a84ff289b9 100644 --- a/test/test_assign.py +++ b/test/test_assign.py @@ -114,6 +114,7 @@ class TestAssign(unittest.TestCase): # TODO: is there a way to sneak in a permute such that it returns the wrong answer? + @unittest.skip("don't use output buffer, and mismatch dtype no longer supported") def test_cast_assignment(self): a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N) a.realize() diff --git a/test/test_jit.py b/test/test_jit.py index 1c8d2de15c..1aaa711f70 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -237,9 +237,9 @@ class TestJit(unittest.TestCase): np.testing.assert_equal([0], cache.good_cache.numpy()) np.testing.assert_equal([0], cache.bad_cache.numpy()) - zero = Tensor([0]) - one = Tensor([1]) - two = Tensor([2]) + zero = Tensor([0.]) + one = Tensor([1.]) + two = Tensor([2.]) # save [1] in the caches cache.good(zero, one) @@ -248,7 +248,7 @@ class TestJit(unittest.TestCase): np.testing.assert_equal([1], cache.bad_cache.numpy()) for i in range(5): - x = Tensor([i]) # NOTE: if this doesn't change, it just hits the lazybuffer cache + x = Tensor([i*1.]) # NOTE: if this doesn't change, it just hits the lazybuffer cache cache.good_jitted(x) cache.bad_jitted(x) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 754eba92ee..093d3c8237 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -166,7 +166,7 @@ class TestMultiTensor(unittest.TestCase): z = layer(x) layer_sharded = nn.Embedding(vocab_size, embed_size) - layer_sharded.weight.assign(layer.weight.shard((d0, d1), axis=1)).realize() + layer_sharded.weight.replace(layer.weight.shard((d0, d1), axis=1)).realize() x_sharded = x.shard((d0, d1), axis=None) z_shard = layer_sharded(x_sharded) diff --git a/test/testextra/test_lr_scheduler.py b/test/testextra/test_lr_scheduler.py index 24d09989e1..9ff38c10c3 100644 --- a/test/testextra/test_lr_scheduler.py +++ b/test/testextra/test_lr_scheduler.py @@ -55,7 +55,7 @@ def get_lrs(optim, sched, epochs, steps=1, accs=None): class TestLrScheduler(unittest.TestCase): def _test_lr_scheduler(self, tinygrad_sched, torch_sched, epochs, opts, atol, rtol, adam=True): accs = opts.pop('accs', None) - test_tensor = Tensor([0], requires_grad=True) # NOTE: optimizers are broken on 0-dim tensors because it broadcasts to [lr] + test_tensor = Tensor([0.], requires_grad=True) # NOTE: optimizers are broken on 0-dim tensors because it broadcasts to [lr] test_tensor.mean().backward() if adam: tinygrad_optim, torch_optim = Adam([test_tensor], lr=0.01), torch.optim.Adam([torch.tensor([0.], requires_grad=True)], lr=0.01) diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 604848135c..c8856a3496 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -68,7 +68,7 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr if k not in state_dict and not strict: if DEBUG >= 1: print(f"WARNING: not loading {k}") continue - v.assign(state_dict[k].shard(mlb.device, mlb.axis) if isinstance((mlb:=v.lazydata), MultiLazyBuffer) else state_dict[k].to(v.device)).realize() + v.replace(state_dict[k].shard(mlb.device, mlb.axis) if isinstance((mlb:=v.lazydata), MultiLazyBuffer) else state_dict[k].to(v.device)).realize() if consume: del state_dict[k] # torch support! diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e94a898a3e..35106b96b0 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -7,7 +7,7 @@ from collections import defaultdict import numpy as np from tinygrad.dtype import DType, dtypes, ImageDType, Scalar, least_upper_float, least_upper_dtype, cast_scalar -from tinygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, WINO, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv +from tinygrad.helpers import argfix, make_pair, IMAGE, DEBUG, WINO, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv from tinygrad.lazy import LazyBuffer from tinygrad.features.multi import MultiLazyBuffer from tinygrad.ops import LoadOps @@ -137,6 +137,13 @@ class Tensor: Tensor.corealize([self]) return self + def replace(self, x:Tensor) -> Tensor: + # used for replacing a Tensor with a new version of it (potentially with a different device and dtype) + assert not x.requires_grad and getattr(self, '_ctx', None) is None + assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}" + self.lazydata = x.lazydata + return self + def assign(self, x) -> Tensor: # TODO: this is a hack for writing to DISK. remove with working assign if isinstance(self.device, str) and self.device.startswith("DISK"): @@ -148,13 +155,14 @@ class Tensor: if self.lazydata is x.lazydata: return self # a self assign is a NOOP # NOTE: we allow cross device assign assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}" + assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}" + assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}" assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer" assert not x.requires_grad # self requires_grad is okay? - if self.dtype == x.dtype and not getenv("DISALLOW_ASSIGN"): - if isinstance(self.lazydata, MultiLazyBuffer): - for d,s in zip(x.lazydata.lbs, self.lazydata.lbs): d.output_buffer = s.base.realized - else: - if self.lazydata.base.realized is not None: x.lazydata.output_buffer = self.lazydata.base.realized + if isinstance(self.lazydata, MultiLazyBuffer): + for d,s in zip(x.lazydata.lbs, self.lazydata.lbs): d.output_buffer = s.base.realized + else: + if self.lazydata.base.realized is not None: x.lazydata.output_buffer = self.lazydata.base.realized self.lazydata = x.lazydata return self def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)