mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
1 line less in cpu, fix torch tests
This commit is contained in:
@@ -10,7 +10,7 @@ class CPUBuffer(np.ndarray):
|
||||
def flip(x, axis): return np.flip(x, axis)
|
||||
def amax(x, *args, **kwargs): return np.amax(x, *args, **kwargs)
|
||||
def permute(x, order): return x.transpose(order)
|
||||
def custompad(x, padding): return np.pad(x, padding).view(CPUBuffer)
|
||||
def custompad(x, padding): return np.pad(x, padding).view(CPUBuffer) if any(x > 0 or y > 0 for x,y in padding) else x
|
||||
def expand(x, new_shape): return np.broadcast_to(x, new_shape).view(CPUBuffer)
|
||||
|
||||
@staticmethod
|
||||
@@ -54,15 +54,13 @@ class CPUBuffer(np.ndarray):
|
||||
elif op == MovementOps.FLIP: return x.flip(arg)
|
||||
elif op == MovementOps.SLICE:
|
||||
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
|
||||
slicee = [(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)]
|
||||
return x.custompad(padding)[tuple([slice(x[0], x[1], None) for x in slicee])]
|
||||
return x.custompad(padding)[tuple(slice(p[0] + padding[i][0], p[1] + padding[i][0], None) for i,p in enumerate(arg))]
|
||||
elif op == MovementOps.EXPAND: return x.expand(arg)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
def processing_op(x,op,w,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
if C.px != 0 or C.py != 0 or C.px_ != 0 or C.py_ != 0:
|
||||
x = x.movement_op(MovementOps.SLICE, ((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
|
||||
x = x.movement_op(MovementOps.SLICE, ((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
|
||||
gx = x.reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3])
|
||||
tx = np.lib.stride_tricks.as_strided(gx,
|
||||
shape=(C.bs, C.groups, C.cin, C.oy, C.ox, C.H, C.W),
|
||||
|
||||
@@ -23,6 +23,5 @@ class TorchBuffer(torch.Tensor):
|
||||
|
||||
def processing_op(x,op,w,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
if C.px != C.px_ or C.py != C.py_: padding, x = 0, x.movement_op(MovementOps.SLICE, ((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
|
||||
else: padding = (C.py, C.px)
|
||||
return torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx), padding=padding)
|
||||
x = x.movement_op(MovementOps.SLICE, ((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
|
||||
return torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx))
|
||||
|
||||
Reference in New Issue
Block a user