mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
add Tensor.replace (#3738)
* add Tensor.replace * fix dtypes in that test * should be replace * and mixtral
This commit is contained in:
@@ -141,7 +141,7 @@ class GPT2:
|
||||
|
||||
if HALF:
|
||||
for l in get_state_dict(model).values():
|
||||
l.assign(l.half().realize())
|
||||
l.replace(l.half().realize())
|
||||
|
||||
return GPT2(model, tokenizer)
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ if __name__ == "__main__":
|
||||
device = Device.DEFAULT
|
||||
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, loading {k} to {device}")
|
||||
# NOTE: we have to copy through CLANG to avoid the HIP hang bug when copying directly from the DISK
|
||||
model_state_dict[k].assign(state[k].to("CLANG").contiguous().to(device).half()).realize()
|
||||
model_state_dict[k].replace(state[k].to("CLANG").contiguous().to(device).half()).realize()
|
||||
if CI: print(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
|
||||
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
@@ -602,7 +602,7 @@ if __name__ == "__main__":
|
||||
|
||||
if args.fp16:
|
||||
for l in get_state_dict(model).values():
|
||||
l.assign(l.cast(dtypes.float16).realize())
|
||||
l.replace(l.cast(dtypes.float16).realize())
|
||||
|
||||
# run through CLIP to get context
|
||||
tokenizer = ClipTokenizer()
|
||||
|
||||
@@ -114,6 +114,7 @@ class TestAssign(unittest.TestCase):
|
||||
|
||||
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?
|
||||
|
||||
@unittest.skip("don't use output buffer, and mismatch dtype no longer supported")
|
||||
def test_cast_assignment(self):
|
||||
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
||||
a.realize()
|
||||
|
||||
@@ -237,9 +237,9 @@ class TestJit(unittest.TestCase):
|
||||
np.testing.assert_equal([0], cache.good_cache.numpy())
|
||||
np.testing.assert_equal([0], cache.bad_cache.numpy())
|
||||
|
||||
zero = Tensor([0])
|
||||
one = Tensor([1])
|
||||
two = Tensor([2])
|
||||
zero = Tensor([0.])
|
||||
one = Tensor([1.])
|
||||
two = Tensor([2.])
|
||||
|
||||
# save [1] in the caches
|
||||
cache.good(zero, one)
|
||||
@@ -248,7 +248,7 @@ class TestJit(unittest.TestCase):
|
||||
np.testing.assert_equal([1], cache.bad_cache.numpy())
|
||||
|
||||
for i in range(5):
|
||||
x = Tensor([i]) # NOTE: if this doesn't change, it just hits the lazybuffer cache
|
||||
x = Tensor([i*1.]) # NOTE: if this doesn't change, it just hits the lazybuffer cache
|
||||
cache.good_jitted(x)
|
||||
cache.bad_jitted(x)
|
||||
|
||||
|
||||
@@ -166,7 +166,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
z = layer(x)
|
||||
|
||||
layer_sharded = nn.Embedding(vocab_size, embed_size)
|
||||
layer_sharded.weight.assign(layer.weight.shard((d0, d1), axis=1)).realize()
|
||||
layer_sharded.weight.replace(layer.weight.shard((d0, d1), axis=1)).realize()
|
||||
x_sharded = x.shard((d0, d1), axis=None)
|
||||
z_shard = layer_sharded(x_sharded)
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ def get_lrs(optim, sched, epochs, steps=1, accs=None):
|
||||
class TestLrScheduler(unittest.TestCase):
|
||||
def _test_lr_scheduler(self, tinygrad_sched, torch_sched, epochs, opts, atol, rtol, adam=True):
|
||||
accs = opts.pop('accs', None)
|
||||
test_tensor = Tensor([0], requires_grad=True) # NOTE: optimizers are broken on 0-dim tensors because it broadcasts to [lr]
|
||||
test_tensor = Tensor([0.], requires_grad=True) # NOTE: optimizers are broken on 0-dim tensors because it broadcasts to [lr]
|
||||
test_tensor.mean().backward()
|
||||
if adam:
|
||||
tinygrad_optim, torch_optim = Adam([test_tensor], lr=0.01), torch.optim.Adam([torch.tensor([0.], requires_grad=True)], lr=0.01)
|
||||
|
||||
@@ -68,7 +68,7 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr
|
||||
if k not in state_dict and not strict:
|
||||
if DEBUG >= 1: print(f"WARNING: not loading {k}")
|
||||
continue
|
||||
v.assign(state_dict[k].shard(mlb.device, mlb.axis) if isinstance((mlb:=v.lazydata), MultiLazyBuffer) else state_dict[k].to(v.device)).realize()
|
||||
v.replace(state_dict[k].shard(mlb.device, mlb.axis) if isinstance((mlb:=v.lazydata), MultiLazyBuffer) else state_dict[k].to(v.device)).realize()
|
||||
if consume: del state_dict[k]
|
||||
|
||||
# torch support!
|
||||
|
||||
@@ -7,7 +7,7 @@ from collections import defaultdict
|
||||
import numpy as np
|
||||
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, Scalar, least_upper_float, least_upper_dtype, cast_scalar
|
||||
from tinygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, WINO, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv
|
||||
from tinygrad.helpers import argfix, make_pair, IMAGE, DEBUG, WINO, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.features.multi import MultiLazyBuffer
|
||||
from tinygrad.ops import LoadOps
|
||||
@@ -137,6 +137,13 @@ class Tensor:
|
||||
Tensor.corealize([self])
|
||||
return self
|
||||
|
||||
def replace(self, x:Tensor) -> Tensor:
|
||||
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
|
||||
assert not x.requires_grad and getattr(self, '_ctx', None) is None
|
||||
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
|
||||
self.lazydata = x.lazydata
|
||||
return self
|
||||
|
||||
def assign(self, x) -> Tensor:
|
||||
# TODO: this is a hack for writing to DISK. remove with working assign
|
||||
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
||||
@@ -148,13 +155,14 @@ class Tensor:
|
||||
if self.lazydata is x.lazydata: return self # a self assign is a NOOP
|
||||
# NOTE: we allow cross device assign
|
||||
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
||||
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
|
||||
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
||||
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
|
||||
assert not x.requires_grad # self requires_grad is okay?
|
||||
if self.dtype == x.dtype and not getenv("DISALLOW_ASSIGN"):
|
||||
if isinstance(self.lazydata, MultiLazyBuffer):
|
||||
for d,s in zip(x.lazydata.lbs, self.lazydata.lbs): d.output_buffer = s.base.realized
|
||||
else:
|
||||
if self.lazydata.base.realized is not None: x.lazydata.output_buffer = self.lazydata.base.realized
|
||||
if isinstance(self.lazydata, MultiLazyBuffer):
|
||||
for d,s in zip(x.lazydata.lbs, self.lazydata.lbs): d.output_buffer = s.base.realized
|
||||
else:
|
||||
if self.lazydata.base.realized is not None: x.lazydata.output_buffer = self.lazydata.base.realized
|
||||
self.lazydata = x.lazydata
|
||||
return self
|
||||
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
||||
|
||||
Reference in New Issue
Block a user