unify torch Slice

This commit is contained in:
George Hotz
2021-11-30 16:25:37 -05:00
parent e59381d0da
commit 785fe8ead7
2 changed files with 7 additions and 23 deletions

View File

@@ -16,6 +16,8 @@ class CPUBuffer(np.ndarray):
return x.transpose(order)
def type(x, tt):
return x.astype(tt)
def custompad(x, padding):
return np.pad(x, padding)
def toCPU(x):
return x
@staticmethod
@@ -143,7 +145,7 @@ class Transpose(Function):
def inner_slice(x, arg):
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
x = np.pad(x, padding)
x = x.custompad(padding)
slicee = [(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)]
return x[tuple([slice(x[0], x[1], None) for x in slicee])]

View File

@@ -3,6 +3,8 @@ import numpy as np
from ..tensor import Function
class TorchBuffer(torch.Tensor):
def custompad(x, padding):
return torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist])
@staticmethod
def fromCPU(data):
return TorchBuffer(torch.from_numpy(data).requires_grad_(False))
@@ -11,29 +13,9 @@ class TorchBuffer(torch.Tensor):
def getdtype(self):
return np.float32
# ************* unary+binary+reduce ops *************
# ************* unary+binary+reduce+movement ops *************
from tinygrad.ops.ops_cpu import ReLU, Log, Exp, Add, Sub, Mul, Pow, Sum, Max
# ************* movement ops *************
from tinygrad.ops.ops_cpu import Reshape, Transpose
def inner_slice(x, arg):
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
x = torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist])
slicee = [(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)]
return x[tuple([slice(x[0], x[1], None) for x in slicee])]
class Slice(Function):
def forward(ctx, x, arg=None):
ctx.save_for_backward(x.shape)
return inner_slice(x, arg)
def backward(ctx, grad_output):
shape, = ctx.saved_tensors
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(ctx.arg)]
return inner_slice(grad_output, narg)
from tinygrad.ops.ops_cpu import ReLU, Log, Exp, Add, Sub, Mul, Pow, Sum, Max, Reshape, Transpose, Slice
# ************* processing ops *************