1 line less in cpu, fix torch tests

This commit is contained in:
George Hotz
2022-06-26 18:11:53 -07:00
parent dffde3de5a
commit a04813ffe3
2 changed files with 5 additions and 8 deletions

View File

@@ -10,7 +10,7 @@ class CPUBuffer(np.ndarray):
def flip(x, axis): return np.flip(x, axis)
def amax(x, *args, **kwargs): return np.amax(x, *args, **kwargs)
def permute(x, order): return x.transpose(order)
def custompad(x, padding): return np.pad(x, padding).view(CPUBuffer)
def custompad(x, padding): return np.pad(x, padding).view(CPUBuffer) if any(x > 0 or y > 0 for x,y in padding) else x
def expand(x, new_shape): return np.broadcast_to(x, new_shape).view(CPUBuffer)
@staticmethod
@@ -54,15 +54,13 @@ class CPUBuffer(np.ndarray):
elif op == MovementOps.FLIP: return x.flip(arg)
elif op == MovementOps.SLICE:
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
slicee = [(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)]
return x.custompad(padding)[tuple([slice(x[0], x[1], None) for x in slicee])]
return x.custompad(padding)[tuple(slice(p[0] + padding[i][0], p[1] + padding[i][0], None) for i,p in enumerate(arg))]
elif op == MovementOps.EXPAND: return x.expand(arg)
else: raise Exception(f"{op} isn't supported")
def processing_op(x,op,w,C):
assert op == ProcessingOps.CONV, f"{op} isn't supported"
if C.px != 0 or C.py != 0 or C.px_ != 0 or C.py_ != 0:
x = x.movement_op(MovementOps.SLICE, ((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
x = x.movement_op(MovementOps.SLICE, ((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
gx = x.reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3])
tx = np.lib.stride_tricks.as_strided(gx,
shape=(C.bs, C.groups, C.cin, C.oy, C.ox, C.H, C.W),

View File

@@ -23,6 +23,5 @@ class TorchBuffer(torch.Tensor):
def processing_op(x,op,w,C):
assert op == ProcessingOps.CONV, f"{op} isn't supported"
if C.px != C.px_ or C.py != C.py_: padding, x = 0, x.movement_op(MovementOps.SLICE, ((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
else: padding = (C.py, C.px)
return torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx), padding=padding)
x = x.movement_op(MovementOps.SLICE, ((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
return torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx))