add Tensor.replace (#3738)

* add Tensor.replace

* fix dtypes in that test

* should be replace

* and mixtral
This commit is contained in:
George Hotz
2024-03-14 13:34:14 -07:00
committed by GitHub
parent 0ead0bdb65
commit 3527c5a9d2
9 changed files with 25 additions and 16 deletions

View File

@@ -141,7 +141,7 @@ class GPT2:
if HALF:
for l in get_state_dict(model).values():
l.assign(l.half().realize())
l.replace(l.half().realize())
return GPT2(model, tokenizer)

View File

@@ -43,7 +43,7 @@ if __name__ == "__main__":
device = Device.DEFAULT
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, loading {k} to {device}")
# NOTE: we have to copy through CLANG to avoid the HIP hang bug when copying directly from the DISK
model_state_dict[k].assign(state[k].to("CLANG").contiguous().to(device).half()).realize()
model_state_dict[k].replace(state[k].to("CLANG").contiguous().to(device).half()).realize()
if CI: print(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
from sentencepiece import SentencePieceProcessor

View File

@@ -602,7 +602,7 @@ if __name__ == "__main__":
if args.fp16:
for l in get_state_dict(model).values():
l.assign(l.cast(dtypes.float16).realize())
l.replace(l.cast(dtypes.float16).realize())
# run through CLIP to get context
tokenizer = ClipTokenizer()

View File

@@ -114,6 +114,7 @@ class TestAssign(unittest.TestCase):
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?
@unittest.skip("don't use output buffer, and mismatch dtype no longer supported")
def test_cast_assignment(self):
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
a.realize()

View File

@@ -237,9 +237,9 @@ class TestJit(unittest.TestCase):
np.testing.assert_equal([0], cache.good_cache.numpy())
np.testing.assert_equal([0], cache.bad_cache.numpy())
zero = Tensor([0])
one = Tensor([1])
two = Tensor([2])
zero = Tensor([0.])
one = Tensor([1.])
two = Tensor([2.])
# save [1] in the caches
cache.good(zero, one)
@@ -248,7 +248,7 @@ class TestJit(unittest.TestCase):
np.testing.assert_equal([1], cache.bad_cache.numpy())
for i in range(5):
x = Tensor([i]) # NOTE: if this doesn't change, it just hits the lazybuffer cache
x = Tensor([i*1.]) # NOTE: if this doesn't change, it just hits the lazybuffer cache
cache.good_jitted(x)
cache.bad_jitted(x)

View File

@@ -166,7 +166,7 @@ class TestMultiTensor(unittest.TestCase):
z = layer(x)
layer_sharded = nn.Embedding(vocab_size, embed_size)
layer_sharded.weight.assign(layer.weight.shard((d0, d1), axis=1)).realize()
layer_sharded.weight.replace(layer.weight.shard((d0, d1), axis=1)).realize()
x_sharded = x.shard((d0, d1), axis=None)
z_shard = layer_sharded(x_sharded)

View File

@@ -55,7 +55,7 @@ def get_lrs(optim, sched, epochs, steps=1, accs=None):
class TestLrScheduler(unittest.TestCase):
def _test_lr_scheduler(self, tinygrad_sched, torch_sched, epochs, opts, atol, rtol, adam=True):
accs = opts.pop('accs', None)
test_tensor = Tensor([0], requires_grad=True) # NOTE: optimizers are broken on 0-dim tensors because it broadcasts to [lr]
test_tensor = Tensor([0.], requires_grad=True) # NOTE: optimizers are broken on 0-dim tensors because it broadcasts to [lr]
test_tensor.mean().backward()
if adam:
tinygrad_optim, torch_optim = Adam([test_tensor], lr=0.01), torch.optim.Adam([torch.tensor([0.], requires_grad=True)], lr=0.01)

View File

@@ -68,7 +68,7 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr
if k not in state_dict and not strict:
if DEBUG >= 1: print(f"WARNING: not loading {k}")
continue
v.assign(state_dict[k].shard(mlb.device, mlb.axis) if isinstance((mlb:=v.lazydata), MultiLazyBuffer) else state_dict[k].to(v.device)).realize()
v.replace(state_dict[k].shard(mlb.device, mlb.axis) if isinstance((mlb:=v.lazydata), MultiLazyBuffer) else state_dict[k].to(v.device)).realize()
if consume: del state_dict[k]
# torch support!

View File

@@ -7,7 +7,7 @@ from collections import defaultdict
import numpy as np
from tinygrad.dtype import DType, dtypes, ImageDType, Scalar, least_upper_float, least_upper_dtype, cast_scalar
from tinygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, WINO, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv
from tinygrad.helpers import argfix, make_pair, IMAGE, DEBUG, WINO, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv
from tinygrad.lazy import LazyBuffer
from tinygrad.features.multi import MultiLazyBuffer
from tinygrad.ops import LoadOps
@@ -137,6 +137,13 @@ class Tensor:
Tensor.corealize([self])
return self
def replace(self, x:Tensor) -> Tensor:
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
assert not x.requires_grad and getattr(self, '_ctx', None) is None
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
self.lazydata = x.lazydata
return self
def assign(self, x) -> Tensor:
# TODO: this is a hack for writing to DISK. remove with working assign
if isinstance(self.device, str) and self.device.startswith("DISK"):
@@ -148,13 +155,14 @@ class Tensor:
if self.lazydata is x.lazydata: return self # a self assign is a NOOP
# NOTE: we allow cross device assign
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
assert not x.requires_grad # self requires_grad is okay?
if self.dtype == x.dtype and not getenv("DISALLOW_ASSIGN"):
if isinstance(self.lazydata, MultiLazyBuffer):
for d,s in zip(x.lazydata.lbs, self.lazydata.lbs): d.output_buffer = s.base.realized
else:
if self.lazydata.base.realized is not None: x.lazydata.output_buffer = self.lazydata.base.realized
if isinstance(self.lazydata, MultiLazyBuffer):
for d,s in zip(x.lazydata.lbs, self.lazydata.lbs): d.output_buffer = s.base.realized
else:
if self.lazydata.base.realized is not None: x.lazydata.output_buffer = self.lazydata.base.realized
self.lazydata = x.lazydata
return self
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)