mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
simpler movement op
This commit is contained in:
@@ -18,7 +18,7 @@ class CPUBuffer(np.ndarray):
|
||||
def flip(x, axis): return np.flip(x, axis)
|
||||
def amax(x, *args, **kwargs): return np.amax(x, *args, **kwargs)
|
||||
def permute(x, order): return x.transpose(order)
|
||||
def custompad(x, padding): return np.pad(x, padding).view(CPUBuffer) if any(x != 0 or y != 0 for x,y in padding) else x
|
||||
def pad(x, padding): return np.pad(x, padding).view(CPUBuffer) if any(x != 0 or y != 0 for x,y in padding) else x
|
||||
def expand(x, new_shape): return np.broadcast_to(x, new_shape).view(CPUBuffer)
|
||||
def as_strided(x, size, stride): return np.lib.stride_tricks.as_strided(x, shape=size, strides=[y*x.dtype.itemsize for y in stride]).view(CPUBuffer)
|
||||
def contiguous(x): return x.ravel().reshape(x.shape)
|
||||
@@ -41,20 +41,12 @@ class CPUBuffer(np.ndarray):
|
||||
return x.amax(axis, keepdims=True)
|
||||
|
||||
def movement_op(x, op, arg=None):
|
||||
if op == MovementOps.RESHAPE:
|
||||
return x.reshape(arg)
|
||||
elif op == MovementOps.PERMUTE:
|
||||
return x.permute(arg)
|
||||
elif op == MovementOps.FLIP:
|
||||
return x.flip(arg)
|
||||
elif op == MovementOps.PAD:
|
||||
return x.custompad(arg)
|
||||
elif op == MovementOps.SHRINK:
|
||||
if op == MovementOps.SHRINK:
|
||||
return x[tuple(slice(p[0], p[1], None) for p in arg)]
|
||||
elif op == MovementOps.EXPAND:
|
||||
return x.expand(arg)
|
||||
elif op == MovementOps.STRIDED:
|
||||
return x.contiguous().as_strided([x[0] for x in arg], [x[1] for x in arg])
|
||||
else:
|
||||
return getattr(x, op.name.lower())(arg)
|
||||
|
||||
PREPAD = True
|
||||
def processing_op(x,op,w,C):
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad.ops import ProcessingOps
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
class TorchBuffer(torch.Tensor):
|
||||
def custompad(x, padding): return torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist])
|
||||
def pad(x, padding): return torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist])
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(data): return TorchBuffer(torch.from_numpy(data).requires_grad_(False)).to(device)
|
||||
|
||||
Reference in New Issue
Block a user