mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
add Tensor.replace (#3738)
* add Tensor.replace * fix dtypes in that test * should be replace * and mixtral
This commit is contained in:
@@ -141,7 +141,7 @@ class GPT2:
|
|||||||
|
|
||||||
if HALF:
|
if HALF:
|
||||||
for l in get_state_dict(model).values():
|
for l in get_state_dict(model).values():
|
||||||
l.assign(l.half().realize())
|
l.replace(l.half().realize())
|
||||||
|
|
||||||
return GPT2(model, tokenizer)
|
return GPT2(model, tokenizer)
|
||||||
|
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ if __name__ == "__main__":
|
|||||||
device = Device.DEFAULT
|
device = Device.DEFAULT
|
||||||
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, loading {k} to {device}")
|
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, loading {k} to {device}")
|
||||||
# NOTE: we have to copy through CLANG to avoid the HIP hang bug when copying directly from the DISK
|
# NOTE: we have to copy through CLANG to avoid the HIP hang bug when copying directly from the DISK
|
||||||
model_state_dict[k].assign(state[k].to("CLANG").contiguous().to(device).half()).realize()
|
model_state_dict[k].replace(state[k].to("CLANG").contiguous().to(device).half()).realize()
|
||||||
if CI: print(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
|
if CI: print(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
|
||||||
|
|
||||||
from sentencepiece import SentencePieceProcessor
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
|||||||
@@ -602,7 +602,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
for l in get_state_dict(model).values():
|
for l in get_state_dict(model).values():
|
||||||
l.assign(l.cast(dtypes.float16).realize())
|
l.replace(l.cast(dtypes.float16).realize())
|
||||||
|
|
||||||
# run through CLIP to get context
|
# run through CLIP to get context
|
||||||
tokenizer = ClipTokenizer()
|
tokenizer = ClipTokenizer()
|
||||||
|
|||||||
@@ -114,6 +114,7 @@ class TestAssign(unittest.TestCase):
|
|||||||
|
|
||||||
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?
|
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?
|
||||||
|
|
||||||
|
@unittest.skip("don't use output buffer, and mismatch dtype no longer supported")
|
||||||
def test_cast_assignment(self):
|
def test_cast_assignment(self):
|
||||||
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
||||||
a.realize()
|
a.realize()
|
||||||
|
|||||||
@@ -237,9 +237,9 @@ class TestJit(unittest.TestCase):
|
|||||||
np.testing.assert_equal([0], cache.good_cache.numpy())
|
np.testing.assert_equal([0], cache.good_cache.numpy())
|
||||||
np.testing.assert_equal([0], cache.bad_cache.numpy())
|
np.testing.assert_equal([0], cache.bad_cache.numpy())
|
||||||
|
|
||||||
zero = Tensor([0])
|
zero = Tensor([0.])
|
||||||
one = Tensor([1])
|
one = Tensor([1.])
|
||||||
two = Tensor([2])
|
two = Tensor([2.])
|
||||||
|
|
||||||
# save [1] in the caches
|
# save [1] in the caches
|
||||||
cache.good(zero, one)
|
cache.good(zero, one)
|
||||||
@@ -248,7 +248,7 @@ class TestJit(unittest.TestCase):
|
|||||||
np.testing.assert_equal([1], cache.bad_cache.numpy())
|
np.testing.assert_equal([1], cache.bad_cache.numpy())
|
||||||
|
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
x = Tensor([i]) # NOTE: if this doesn't change, it just hits the lazybuffer cache
|
x = Tensor([i*1.]) # NOTE: if this doesn't change, it just hits the lazybuffer cache
|
||||||
cache.good_jitted(x)
|
cache.good_jitted(x)
|
||||||
cache.bad_jitted(x)
|
cache.bad_jitted(x)
|
||||||
|
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ class TestMultiTensor(unittest.TestCase):
|
|||||||
z = layer(x)
|
z = layer(x)
|
||||||
|
|
||||||
layer_sharded = nn.Embedding(vocab_size, embed_size)
|
layer_sharded = nn.Embedding(vocab_size, embed_size)
|
||||||
layer_sharded.weight.assign(layer.weight.shard((d0, d1), axis=1)).realize()
|
layer_sharded.weight.replace(layer.weight.shard((d0, d1), axis=1)).realize()
|
||||||
x_sharded = x.shard((d0, d1), axis=None)
|
x_sharded = x.shard((d0, d1), axis=None)
|
||||||
z_shard = layer_sharded(x_sharded)
|
z_shard = layer_sharded(x_sharded)
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ def get_lrs(optim, sched, epochs, steps=1, accs=None):
|
|||||||
class TestLrScheduler(unittest.TestCase):
|
class TestLrScheduler(unittest.TestCase):
|
||||||
def _test_lr_scheduler(self, tinygrad_sched, torch_sched, epochs, opts, atol, rtol, adam=True):
|
def _test_lr_scheduler(self, tinygrad_sched, torch_sched, epochs, opts, atol, rtol, adam=True):
|
||||||
accs = opts.pop('accs', None)
|
accs = opts.pop('accs', None)
|
||||||
test_tensor = Tensor([0], requires_grad=True) # NOTE: optimizers are broken on 0-dim tensors because it broadcasts to [lr]
|
test_tensor = Tensor([0.], requires_grad=True) # NOTE: optimizers are broken on 0-dim tensors because it broadcasts to [lr]
|
||||||
test_tensor.mean().backward()
|
test_tensor.mean().backward()
|
||||||
if adam:
|
if adam:
|
||||||
tinygrad_optim, torch_optim = Adam([test_tensor], lr=0.01), torch.optim.Adam([torch.tensor([0.], requires_grad=True)], lr=0.01)
|
tinygrad_optim, torch_optim = Adam([test_tensor], lr=0.01), torch.optim.Adam([torch.tensor([0.], requires_grad=True)], lr=0.01)
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr
|
|||||||
if k not in state_dict and not strict:
|
if k not in state_dict and not strict:
|
||||||
if DEBUG >= 1: print(f"WARNING: not loading {k}")
|
if DEBUG >= 1: print(f"WARNING: not loading {k}")
|
||||||
continue
|
continue
|
||||||
v.assign(state_dict[k].shard(mlb.device, mlb.axis) if isinstance((mlb:=v.lazydata), MultiLazyBuffer) else state_dict[k].to(v.device)).realize()
|
v.replace(state_dict[k].shard(mlb.device, mlb.axis) if isinstance((mlb:=v.lazydata), MultiLazyBuffer) else state_dict[k].to(v.device)).realize()
|
||||||
if consume: del state_dict[k]
|
if consume: del state_dict[k]
|
||||||
|
|
||||||
# torch support!
|
# torch support!
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from collections import defaultdict
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tinygrad.dtype import DType, dtypes, ImageDType, Scalar, least_upper_float, least_upper_dtype, cast_scalar
|
from tinygrad.dtype import DType, dtypes, ImageDType, Scalar, least_upper_float, least_upper_dtype, cast_scalar
|
||||||
from tinygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, WINO, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv
|
from tinygrad.helpers import argfix, make_pair, IMAGE, DEBUG, WINO, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv
|
||||||
from tinygrad.lazy import LazyBuffer
|
from tinygrad.lazy import LazyBuffer
|
||||||
from tinygrad.features.multi import MultiLazyBuffer
|
from tinygrad.features.multi import MultiLazyBuffer
|
||||||
from tinygrad.ops import LoadOps
|
from tinygrad.ops import LoadOps
|
||||||
@@ -137,6 +137,13 @@ class Tensor:
|
|||||||
Tensor.corealize([self])
|
Tensor.corealize([self])
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def replace(self, x:Tensor) -> Tensor:
|
||||||
|
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
|
||||||
|
assert not x.requires_grad and getattr(self, '_ctx', None) is None
|
||||||
|
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
|
||||||
|
self.lazydata = x.lazydata
|
||||||
|
return self
|
||||||
|
|
||||||
def assign(self, x) -> Tensor:
|
def assign(self, x) -> Tensor:
|
||||||
# TODO: this is a hack for writing to DISK. remove with working assign
|
# TODO: this is a hack for writing to DISK. remove with working assign
|
||||||
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
||||||
@@ -148,13 +155,14 @@ class Tensor:
|
|||||||
if self.lazydata is x.lazydata: return self # a self assign is a NOOP
|
if self.lazydata is x.lazydata: return self # a self assign is a NOOP
|
||||||
# NOTE: we allow cross device assign
|
# NOTE: we allow cross device assign
|
||||||
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
||||||
|
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
|
||||||
|
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
||||||
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
|
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
|
||||||
assert not x.requires_grad # self requires_grad is okay?
|
assert not x.requires_grad # self requires_grad is okay?
|
||||||
if self.dtype == x.dtype and not getenv("DISALLOW_ASSIGN"):
|
if isinstance(self.lazydata, MultiLazyBuffer):
|
||||||
if isinstance(self.lazydata, MultiLazyBuffer):
|
for d,s in zip(x.lazydata.lbs, self.lazydata.lbs): d.output_buffer = s.base.realized
|
||||||
for d,s in zip(x.lazydata.lbs, self.lazydata.lbs): d.output_buffer = s.base.realized
|
else:
|
||||||
else:
|
if self.lazydata.base.realized is not None: x.lazydata.output_buffer = self.lazydata.base.realized
|
||||||
if self.lazydata.base.realized is not None: x.lazydata.output_buffer = self.lazydata.base.realized
|
|
||||||
self.lazydata = x.lazydata
|
self.lazydata = x.lazydata
|
||||||
return self
|
return self
|
||||||
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user