fix up eye and fix gc test

This commit is contained in:
George Hotz
2023-02-27 06:53:18 -08:00
parent 686a74de92
commit f10ccf7ec1
2 changed files with 9 additions and 8 deletions

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python
import gc
import unittest
import numpy as np
from tinygrad.tensor import Tensor
def tensors_allocated():
@@ -17,14 +18,14 @@ class TestGC(unittest.TestCase):
assert(tensors_allocated() == 0)
def test_gc_complex(self):
a = Tensor.zeros(4, 4, requires_grad=True)
b = Tensor.zeros(4, 4, requires_grad=True)
a = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
b = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
assert(tensors_allocated() == 2)
(a*b).mean().backward()
assert(tensors_allocated() == 4)
del b
assert(tensors_allocated() == 2)
b = Tensor.zeros(4, 4, requires_grad=True)
b = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
print(tensors_allocated())
(a*b).mean().backward()
print(tensors_allocated())

View File

@@ -126,7 +126,7 @@ class Tensor:
def empty(cls, *shape, **kwargs): return cls.zeros(*shape, **kwargs)
@classmethod
def eye(cls, dim, **kwargs): return cls(np.eye(dim, dtype=np.float32), **kwargs)
def eye(cls, dim, **kwargs): return cls([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim)
# TODO: requires cumsum to remove numpy
@classmethod
@@ -222,7 +222,7 @@ class Tensor:
new_shape.append(1 if s is None else slcfix(s.stop, sz, sz) - slcfix(s.start, sz, 0))
new_shape += [self.shape[i] for i in range(len(new_slice), len(self.shape))]
new_slice += [(0,self.shape[i]) for i in range(len(new_slice), len(self.shape))]
return self.slice(arg = new_slice).reshape(new_shape if len(new_shape) else (1,))
return self.slice(new_slice).reshape(new_shape if len(new_shape) else (1,))
def cat(self, *args, dim=0):
dim = (dim + len(self.shape)) if dim < 0 else dim
@@ -233,21 +233,21 @@ class Tensor:
slc = [[(0, s) for s in self.shape] for _ in catargs]
for s,k in zip(slc, shape_cumsum):
s[dim] = (-k, shape_cumsum[-1]-k)
return functools.reduce(Tensor.__add__, [arg.slice(arg=s) for arg,s in zip(catargs, slc)])
return functools.reduce(Tensor.__add__, [arg.slice(s) for arg,s in zip(catargs, slc)])
# TODO: make this nicer with syntactic sugar in slice
def chunk(self, num, dim):
slice_params = [[(0, s) for s in self.shape] for _ in range(num)]
for i,k in enumerate(range(0, self.shape[dim], self.shape[dim]//num)):
slice_params[i][dim] = (k, min(self.shape[dim], k+self.shape[dim]//num))
return [self.slice(arg=p) for p in slice_params]
return [self.slice(p) for p in slice_params]
def unsqueeze(self, dim):
if dim < 0: dim = len(self.shape) + dim + 1
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
# (padding_left, padding_right, padding_top, padding_bottom)
def pad2d(self, padding:Tuple[int, ...]): return self.slice(arg = [(0,self.shape[0]), (0,self.shape[1]), (-padding[2],self.shape[2]+padding[3]), (-padding[0],self.shape[3]+padding[1])])
def pad2d(self, padding:Tuple[int, ...]): return self.slice(((0,self.shape[0]), (0,self.shape[1]), (-padding[2],self.shape[2]+padding[3]), (-padding[0],self.shape[3]+padding[1])))
# TODO: this is totally not transpose
def transpose(self, order=(1,0)): return self.permute(order=order)
def flatten(self, start_dim=0): return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1]))