mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
use Tensor.replace [pr] (#8455)
This commit is contained in:
@@ -15,7 +15,7 @@ from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCom
|
||||
|
||||
def derandomize_model(model):
|
||||
for p in get_parameters(model):
|
||||
p.lazydata = Tensor.empty(p.shape, device=p.device, dtype=p.dtype).lazydata
|
||||
p.replace(Tensor.empty(p.shape, device=p.device, dtype=p.dtype))
|
||||
p.realize()
|
||||
|
||||
def assert_jit_cache_len(fxn, expected_len):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# test cases are modified from pytorch test_indexing.py https://github.com/pytorch/pytorch/blob/597d3fb86a2f3b8d6d8ee067e769624dcca31cdb/test/test_indexing.py
|
||||
|
||||
import unittest, random, copy, warnings
|
||||
import unittest, random, warnings
|
||||
import numpy as np
|
||||
|
||||
from tinygrad import Tensor, dtypes, Device, TinyJit
|
||||
@@ -27,8 +27,8 @@ def set_(reference: Tensor, shape, strides, offset):
|
||||
assert strided.lazydata.st.real_strides() == strides, "real_strides should equal strides for strided"
|
||||
return strided
|
||||
|
||||
def clone(original:Tensor): return copy.copy(original)
|
||||
def copy_(src:Tensor, other:Tensor) -> Tensor: return copy.copy(src)
|
||||
def clone(original:Tensor): return original.clone()
|
||||
def copy_(src:Tensor, other:Tensor) -> Tensor: return src.clone()
|
||||
# this is fine for tested usecases since as geohotstan understands,
|
||||
# data_ptr is used to compare if operations needed between tensors is the same
|
||||
def data_ptr(tensor:Tensor): return tensor.lazydata
|
||||
|
||||
@@ -353,8 +353,7 @@ class Tensor(SimpleMathTrait):
|
||||
Moves the tensor to the given device in place.
|
||||
"""
|
||||
real = self.to(device)
|
||||
# TODO: is this assign?
|
||||
if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata
|
||||
if self.grad is not None and real.grad is not None: self.grad.replace(real.grad)
|
||||
return self.replace(real)
|
||||
|
||||
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None, splits:Optional[tuple[int, ...]]=None) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user