mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
unify torch Slice
This commit is contained in:
@@ -16,6 +16,8 @@ class CPUBuffer(np.ndarray):
|
||||
return x.transpose(order)
|
||||
def type(x, tt):
|
||||
return x.astype(tt)
|
||||
def custompad(x, padding):
|
||||
return np.pad(x, padding)
|
||||
def toCPU(x):
|
||||
return x
|
||||
@staticmethod
|
||||
@@ -143,7 +145,7 @@ class Transpose(Function):
|
||||
|
||||
def inner_slice(x, arg):
|
||||
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
|
||||
x = np.pad(x, padding)
|
||||
x = x.custompad(padding)
|
||||
slicee = [(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)]
|
||||
return x[tuple([slice(x[0], x[1], None) for x in slicee])]
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ import numpy as np
|
||||
from ..tensor import Function
|
||||
|
||||
class TorchBuffer(torch.Tensor):
|
||||
def custompad(x, padding):
|
||||
return torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist])
|
||||
@staticmethod
|
||||
def fromCPU(data):
|
||||
return TorchBuffer(torch.from_numpy(data).requires_grad_(False))
|
||||
@@ -11,29 +13,9 @@ class TorchBuffer(torch.Tensor):
|
||||
def getdtype(self):
|
||||
return np.float32
|
||||
|
||||
# ************* unary+binary+reduce ops *************
|
||||
# ************* unary+binary+reduce+movement ops *************
|
||||
|
||||
from tinygrad.ops.ops_cpu import ReLU, Log, Exp, Add, Sub, Mul, Pow, Sum, Max
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
from tinygrad.ops.ops_cpu import Reshape, Transpose
|
||||
|
||||
def inner_slice(x, arg):
|
||||
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
|
||||
x = torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist])
|
||||
slicee = [(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)]
|
||||
return x[tuple([slice(x[0], x[1], None) for x in slicee])]
|
||||
|
||||
class Slice(Function):
|
||||
def forward(ctx, x, arg=None):
|
||||
ctx.save_for_backward(x.shape)
|
||||
return inner_slice(x, arg)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
shape, = ctx.saved_tensors
|
||||
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(ctx.arg)]
|
||||
return inner_slice(grad_output, narg)
|
||||
from tinygrad.ops.ops_cpu import ReLU, Log, Exp, Add, Sub, Mul, Pow, Sum, Max, Reshape, Transpose, Slice
|
||||
|
||||
# ************* processing ops *************
|
||||
|
||||
|
||||
Reference in New Issue
Block a user