From c0050fab8ff0bc667e40da11980f4ac4c21affda Mon Sep 17 00:00:00 2001 From: George Hotz Date: Fri, 28 Oct 2022 09:29:12 -0700 Subject: [PATCH] clean up movement_op in cpu and torch --- tinygrad/llops/ops_cpu.py | 18 +++++------------- tinygrad/llops/ops_torch.py | 1 + 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index 6e82ebf568..c82b098eca 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -11,7 +11,8 @@ class CPUBuffer(np.ndarray, GenericExecAST): BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow, BinaryOps.CMPEQ: lambda x,y: (x==y).float(), ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:], - ReduceOps.MAX: lambda x, new_shape: x.amax(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:] + ReduceOps.MAX: lambda x, new_shape: x.amax(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:], + MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)] } def relu(x): return np.maximum(x, 0) @@ -24,8 +25,7 @@ class CPUBuffer(np.ndarray, GenericExecAST): def permute(x, order): return x.transpose(order) def pad(x, padding): return np.pad(x, padding).view(CPUBuffer) def expand(x, new_shape): return np.broadcast_to(x, new_shape).view(CPUBuffer) - def as_strided(x, size, stride): return np.lib.stride_tricks.as_strided(x, shape=size, strides=[y*x.dtype.itemsize for y in stride]).view(CPUBuffer) - def contiguous(x): return x.ravel().reshape(x.shape) + def strided(x, arg): return np.lib.stride_tricks.as_strided(x.ravel().reshape(x.shape), shape=[y[0] for y in arg], strides=[y[1]*x.dtype.itemsize for y in arg]).view(CPUBuffer) @staticmethod def fromCPU(x): return x.view(CPUBuffer) @@ -34,21 +34,13 @@ class CPUBuffer(np.ndarray, GenericExecAST): def unary_op(x, op): return CPUBuffer.fxn_for_op[op](x) def binary_op(x, op, y): return CPUBuffer.fxn_for_op[op](x, y) def reduce_op(x, op, new_shape): return CPUBuffer.fxn_for_op[op](x, new_shape) + def movement_op(x, op, arg=None): return CPUBuffer.fxn_for_op[op](x, arg) if op in CPUBuffer.fxn_for_op else getattr(x, op.name.lower())(arg) - def movement_op(x, op, arg=None): - if op == MovementOps.SHRINK: - return x[tuple(slice(p[0], p[1], None) for p in arg)] - elif op == MovementOps.STRIDED: - return x.contiguous().as_strided([x[0] for x in arg], [x[1] for x in arg]) - else: - return getattr(x, op.name.lower())(arg) - - PREPAD = True def processing_op(x,op,w,C): assert op == ProcessingOps.CONV, f"{op} isn't supported" tx = x.movement_op(MovementOps.STRIDED, ( (C.bs, C.groups*C.cin*x.shape[2]*x.shape[3]), (C.groups, C.cin*x.shape[2]*x.shape[3]), (C.oy, C.sy*x.shape[3]), (C.ox, C.sx), (C.cin, x.shape[2]*x.shape[3]), (C.H, C.dy*x.shape[3]), (C.W, C.dx))) tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W) - out = np.einsum("nGhwCHW, GkCHW -> nGkhw", tx.contiguous(), tw.contiguous()) + out = np.einsum("nGhwCHW, GkCHW -> nGkhw", tx.ravel().reshape(tx.shape), tw.ravel().reshape(tw.shape)) return out.reshape(C.bs, C.groups*C.rcout, C.oy, C.ox).view(CPUBuffer) \ No newline at end of file diff --git a/tinygrad/llops/ops_torch.py b/tinygrad/llops/ops_torch.py index da2602ff6e..4040d29628 100644 --- a/tinygrad/llops/ops_torch.py +++ b/tinygrad/llops/ops_torch.py @@ -5,6 +5,7 @@ from tinygrad.ops import ProcessingOps, GenericExecAST device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") class TorchBuffer(torch.Tensor, GenericExecAST): def pad(x, padding): return torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]) + def strided(x, arg): return x.contiguous().as_strided([y[0] for y in arg], [y[1] for y in arg]) @staticmethod def fromCPU(data): return TorchBuffer(torch.from_numpy(data).requires_grad_(False)).to(device)