use Tensor.replace [pr] (#8455)

This commit is contained in:
George Hotz
2024-12-30 23:20:46 -05:00
committed by GitHub
parent 19a54ae0b4
commit e276b6eecd
3 changed files with 5 additions and 6 deletions

View File

@@ -15,7 +15,7 @@ from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCom
def derandomize_model(model):
for p in get_parameters(model):
p.lazydata = Tensor.empty(p.shape, device=p.device, dtype=p.dtype).lazydata
p.replace(Tensor.empty(p.shape, device=p.device, dtype=p.dtype))
p.realize()
def assert_jit_cache_len(fxn, expected_len):

View File

@@ -1,6 +1,6 @@
# test cases are modified from pytorch test_indexing.py https://github.com/pytorch/pytorch/blob/597d3fb86a2f3b8d6d8ee067e769624dcca31cdb/test/test_indexing.py
import unittest, random, copy, warnings
import unittest, random, warnings
import numpy as np
from tinygrad import Tensor, dtypes, Device, TinyJit
@@ -27,8 +27,8 @@ def set_(reference: Tensor, shape, strides, offset):
assert strided.lazydata.st.real_strides() == strides, "real_strides should equal strides for strided"
return strided
def clone(original:Tensor): return copy.copy(original)
def copy_(src:Tensor, other:Tensor) -> Tensor: return copy.copy(src)
def clone(original:Tensor): return original.clone()
def copy_(src:Tensor, other:Tensor) -> Tensor: return src.clone()
# this is fine for tested usecases since as geohotstan understands,
# data_ptr is used to compare if operations needed between tensors is the same
def data_ptr(tensor:Tensor): return tensor.lazydata

View File

@@ -353,8 +353,7 @@ class Tensor(SimpleMathTrait):
Moves the tensor to the given device in place.
"""
real = self.to(device)
# TODO: is this assign?
if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata
if self.grad is not None and real.grad is not None: self.grad.replace(real.grad)
return self.replace(real)
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None, splits:Optional[tuple[int, ...]]=None) -> Tensor: