simpler movement op

This commit is contained in:
George Hotz
2022-09-06 17:27:33 -07:00
parent 896f9f74a9
commit 5a76e652b8
2 changed files with 5 additions and 13 deletions

View File

@@ -18,7 +18,7 @@ class CPUBuffer(np.ndarray):
def flip(x, axis): return np.flip(x, axis)
def amax(x, *args, **kwargs): return np.amax(x, *args, **kwargs)
def permute(x, order): return x.transpose(order)
def custompad(x, padding): return np.pad(x, padding).view(CPUBuffer) if any(x != 0 or y != 0 for x,y in padding) else x
def pad(x, padding): return np.pad(x, padding).view(CPUBuffer) if any(x != 0 or y != 0 for x,y in padding) else x
def expand(x, new_shape): return np.broadcast_to(x, new_shape).view(CPUBuffer)
def as_strided(x, size, stride): return np.lib.stride_tricks.as_strided(x, shape=size, strides=[y*x.dtype.itemsize for y in stride]).view(CPUBuffer)
def contiguous(x): return x.ravel().reshape(x.shape)
@@ -41,20 +41,12 @@ class CPUBuffer(np.ndarray):
return x.amax(axis, keepdims=True)
def movement_op(x, op, arg=None):
if op == MovementOps.RESHAPE:
return x.reshape(arg)
elif op == MovementOps.PERMUTE:
return x.permute(arg)
elif op == MovementOps.FLIP:
return x.flip(arg)
elif op == MovementOps.PAD:
return x.custompad(arg)
elif op == MovementOps.SHRINK:
if op == MovementOps.SHRINK:
return x[tuple(slice(p[0], p[1], None) for p in arg)]
elif op == MovementOps.EXPAND:
return x.expand(arg)
elif op == MovementOps.STRIDED:
return x.contiguous().as_strided([x[0] for x in arg], [x[1] for x in arg])
else:
return getattr(x, op.name.lower())(arg)
PREPAD = True
def processing_op(x,op,w,C):

View File

@@ -4,7 +4,7 @@ from tinygrad.ops import ProcessingOps
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class TorchBuffer(torch.Tensor):
def custompad(x, padding): return torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist])
def pad(x, padding): return torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist])
@staticmethod
def fromCPU(data): return TorchBuffer(torch.from_numpy(data).requires_grad_(False)).to(device)