diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index f2d83f5743..a8aefd23cf 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -10,7 +10,7 @@ class CPUBuffer(np.ndarray): def flip(x, axis): return np.flip(x, axis) def amax(x, *args, **kwargs): return np.amax(x, *args, **kwargs) def permute(x, order): return x.transpose(order) - def custompad(x, padding): return np.pad(x, padding).view(CPUBuffer) + def custompad(x, padding): return np.pad(x, padding).view(CPUBuffer) if any(x > 0 or y > 0 for x,y in padding) else x def expand(x, new_shape): return np.broadcast_to(x, new_shape).view(CPUBuffer) @staticmethod @@ -54,15 +54,13 @@ class CPUBuffer(np.ndarray): elif op == MovementOps.FLIP: return x.flip(arg) elif op == MovementOps.SLICE: padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)] - slicee = [(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)] - return x.custompad(padding)[tuple([slice(x[0], x[1], None) for x in slicee])] + return x.custompad(padding)[tuple(slice(p[0] + padding[i][0], p[1] + padding[i][0], None) for i,p in enumerate(arg))] elif op == MovementOps.EXPAND: return x.expand(arg) else: raise Exception(f"{op} isn't supported") def processing_op(x,op,w,C): assert op == ProcessingOps.CONV, f"{op} isn't supported" - if C.px != 0 or C.py != 0 or C.px_ != 0 or C.py_ != 0: - x = x.movement_op(MovementOps.SLICE, ((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_))) + x = x.movement_op(MovementOps.SLICE, ((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_))) gx = x.reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3]) tx = np.lib.stride_tricks.as_strided(gx, shape=(C.bs, C.groups, C.cin, C.oy, C.ox, C.H, C.W), diff --git a/tinygrad/llops/ops_torch.py b/tinygrad/llops/ops_torch.py index 21571cc3ce..49338ad97d 100644 --- a/tinygrad/llops/ops_torch.py +++ b/tinygrad/llops/ops_torch.py @@ -23,6 +23,5 @@ class TorchBuffer(torch.Tensor): def processing_op(x,op,w,C): assert op == ProcessingOps.CONV, f"{op} isn't supported" - if C.px != C.px_ or C.py != C.py_: padding, x = 0, x.movement_op(MovementOps.SLICE, ((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_))) - else: padding = (C.py, C.px) - return torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx), padding=padding) + x = x.movement_op(MovementOps.SLICE, ((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_))) + return torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx))