diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index 0ba429e6b8..b9db70ff38 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -18,7 +18,7 @@ class CPUBuffer(np.ndarray): def flip(x, axis): return np.flip(x, axis) def amax(x, *args, **kwargs): return np.amax(x, *args, **kwargs) def permute(x, order): return x.transpose(order) - def custompad(x, padding): return np.pad(x, padding).view(CPUBuffer) if any(x != 0 or y != 0 for x,y in padding) else x + def pad(x, padding): return np.pad(x, padding).view(CPUBuffer) if any(x != 0 or y != 0 for x,y in padding) else x def expand(x, new_shape): return np.broadcast_to(x, new_shape).view(CPUBuffer) def as_strided(x, size, stride): return np.lib.stride_tricks.as_strided(x, shape=size, strides=[y*x.dtype.itemsize for y in stride]).view(CPUBuffer) def contiguous(x): return x.ravel().reshape(x.shape) @@ -41,20 +41,12 @@ class CPUBuffer(np.ndarray): return x.amax(axis, keepdims=True) def movement_op(x, op, arg=None): - if op == MovementOps.RESHAPE: - return x.reshape(arg) - elif op == MovementOps.PERMUTE: - return x.permute(arg) - elif op == MovementOps.FLIP: - return x.flip(arg) - elif op == MovementOps.PAD: - return x.custompad(arg) - elif op == MovementOps.SHRINK: + if op == MovementOps.SHRINK: return x[tuple(slice(p[0], p[1], None) for p in arg)] - elif op == MovementOps.EXPAND: - return x.expand(arg) elif op == MovementOps.STRIDED: return x.contiguous().as_strided([x[0] for x in arg], [x[1] for x in arg]) + else: + return getattr(x, op.name.lower())(arg) PREPAD = True def processing_op(x,op,w,C): diff --git a/tinygrad/llops/ops_torch.py b/tinygrad/llops/ops_torch.py index 3c80ff97ea..c6319a5a55 100644 --- a/tinygrad/llops/ops_torch.py +++ b/tinygrad/llops/ops_torch.py @@ -4,7 +4,7 @@ from tinygrad.ops import ProcessingOps device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") class TorchBuffer(torch.Tensor): - def custompad(x, padding): return torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]) + def pad(x, padding): return torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]) @staticmethod def fromCPU(data): return TorchBuffer(torch.from_numpy(data).requires_grad_(False)).to(device)