mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
write slice for CPU
This commit is contained in:
@@ -113,7 +113,6 @@ Add, Sub, Mul, Pow # binary ops (with broadcasting)
|
||||
Sum, Max # reduce ops (with axis argument)
|
||||
Dot, Conv2D # matrix multiplication and conv
|
||||
Reshape, Transpose, Slice # moving things around ops
|
||||
Pad2D, Unpad2D # stupid (refactor to Slice)
|
||||
```
|
||||
|
||||
## ImageNet inference
|
||||
|
||||
@@ -124,6 +124,11 @@ class TestOps(unittest.TestCase):
|
||||
# NOTE: ANE backwards?
|
||||
helper_test_op(shapes, torch_op, tinygrad_op, device=self.device, forward_only=self.device!=Device.CPU)
|
||||
|
||||
def test_slice(self):
|
||||
helper_test_op([(3,3,3,3)], lambda x: x[1:2], lambda x: x[1:2], device=self.device)
|
||||
helper_test_op([(3,3,3,3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2], device=self.device)
|
||||
helper_test_op([(3,3,3,3)], lambda x: x[1:2, 1:2, 0:-1], lambda x: x[1:2, 1:2, 0:-1], device=self.device)
|
||||
|
||||
def test_pad2d(self):
|
||||
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), device=self.device)
|
||||
|
||||
|
||||
@@ -103,29 +103,26 @@ class Dot(Function):
|
||||
return grad_input, grad_weight
|
||||
register('dot', Dot)
|
||||
|
||||
# ************* simple ops *************
|
||||
# ************* movement ops *************
|
||||
|
||||
# TODO: Combine Pad2D and Unpad2D into something generic
|
||||
class Pad2D(Function):
|
||||
def inner_slice(x, arg):
|
||||
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
|
||||
x = np.pad(x, padding)
|
||||
slicee = [(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)]
|
||||
return x[tuple([slice(x[0], x[1], None) for x in slicee])]
|
||||
|
||||
class Slice(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, padding=None):
|
||||
return np.pad(x, ((0,0), (0,0), tuple(ctx.padding[2:4]), tuple(ctx.padding[0:2])))
|
||||
def forward(ctx, x, arg=None):
|
||||
ctx.save_for_backward(x.shape)
|
||||
return inner_slice(x, arg)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output[...,
|
||||
ctx.padding[2]:(None if ctx.padding[3] == 0 else -ctx.padding[3]),
|
||||
ctx.padding[0]:(None if ctx.padding[1] == 0 else -ctx.padding[1])]
|
||||
register('pad2d', Pad2D)
|
||||
|
||||
class Unpad2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, padding=None):
|
||||
return Pad2D.backward(ctx, x)
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return Pad2D.forward(ctx, grad_output)
|
||||
register('unpad2d', Unpad2D)
|
||||
shape, = ctx.saved_tensors
|
||||
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(ctx.arg)]
|
||||
return inner_slice(grad_output, narg)
|
||||
register('slice', Slice)
|
||||
|
||||
class Reshape(Function):
|
||||
@staticmethod
|
||||
|
||||
@@ -301,7 +301,7 @@ class Dot(Function):
|
||||
return grad_input, grad_weight
|
||||
register('dot', Dot, device=Device.GPU)
|
||||
|
||||
# ************* simple ops *************
|
||||
# ************* movement ops *************
|
||||
|
||||
def get_pad2d_kernel(ctx):
|
||||
return clbuild(ctx.cl_ctx, "pad2d", """
|
||||
@@ -341,17 +341,7 @@ class Pad2D(Function):
|
||||
i32(oy), i32(ox), i32(iy), i32(ix)
|
||||
)
|
||||
return ret
|
||||
register('pad2d', Pad2D, device=Device.GPU)
|
||||
|
||||
# TODO: this is an exact copy from the CPU code
|
||||
class Unpad2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, padding=None):
|
||||
return Pad2D.backward(ctx, x)
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return Pad2D.forward(ctx, grad_output)
|
||||
register('unpad2d', Unpad2D, device=Device.GPU)
|
||||
#register('pad2d', Pad2D, device=Device.GPU)
|
||||
|
||||
class Reshape(Function):
|
||||
@staticmethod
|
||||
|
||||
@@ -196,6 +196,17 @@ class Tensor:
|
||||
|
||||
# ***** non first class ops *****
|
||||
|
||||
def __getitem__(self, val):
|
||||
arg = []
|
||||
for i,s in enumerate(val if type(val) in [list, tuple] else ([] if val is None else [val])):
|
||||
arg.append((s.start if s.start is not None else 0,
|
||||
(s.stop if s.stop >=0 else self.shape[i]+s.stop) if s.stop is not None else self.shape[i]))
|
||||
assert s.step is None or s.step == 1
|
||||
return self.slice(arg = arg+[(0,self.shape[i]) for i in range(len(arg), len(self.shape))])
|
||||
|
||||
def pad2d(self, padding):
|
||||
return self[:, :, -padding[2]:self.shape[2]+padding[3], -padding[0]:self.shape[3]+padding[1]]
|
||||
|
||||
def matmul(self, w):
|
||||
return self.dot(w)
|
||||
|
||||
@@ -247,7 +258,8 @@ class Tensor:
|
||||
return self.relu() + (-1.0*self).relu()
|
||||
|
||||
def _pool2d(self, py, px):
|
||||
xup = self.unpad2d(padding=(0, self.shape[3]%px, 0, self.shape[2]%py))
|
||||
xup = self.slice(arg=[(0,self.shape[0]), (0,self.shape[1]),
|
||||
(0,self.shape[2]-self.shape[2]%py), (0, self.shape[3]-self.shape[3]%px)])
|
||||
return xup.reshape(shape=(xup.shape[0], xup.shape[1], xup.shape[2]//py, py, xup.shape[3]//px, px))
|
||||
|
||||
def avg_pool2d(self, kernel_size=(2,2)):
|
||||
|
||||
Reference in New Issue
Block a user