mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
add PAD movementop
This commit is contained in:
@@ -41,6 +41,7 @@ class CPUBuffer(np.ndarray):
|
||||
if op == MovementOps.RESHAPE: return x.reshape(arg)
|
||||
elif op == MovementOps.PERMUTE: return x.permute(arg)
|
||||
elif op == MovementOps.FLIP: return x.flip(arg)
|
||||
elif op == MovementOps.PAD: return x.custompad(arg)
|
||||
elif op == MovementOps.SLICE:
|
||||
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
|
||||
return x.custompad(padding)[tuple(slice(p[0] + padding[i][0], p[1] + padding[i][0], None) for i,p in enumerate(arg))]
|
||||
|
||||
@@ -88,7 +88,6 @@ class GPUBuffer:
|
||||
return self._buf.cl
|
||||
|
||||
def __repr__(self): return f"<GPUBuffer with shape {self.shape!r}>"
|
||||
def shapeTrackerView(x, st:ShapeTracker): return type(x)(ShapeTracker(st), hostbuf=x)
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x): return GPUBuffer(x.shape, backing=x.view(np.ndarray).astype(np.float32).ravel())
|
||||
|
||||
@@ -170,7 +170,7 @@ class Conv2D(Function):
|
||||
xt = grad_output
|
||||
if C.sx > 1 or C.sy > 1: # unstride. NOTE: this is really memory intensive for big strides. (but only when we contiguous it)
|
||||
xt = xt.movement_op(MovementOps.RESHAPE, (grad_output.shape[0], grad_output.shape[1], grad_output.shape[2], 1, grad_output.shape[3], 1))
|
||||
xt = xt.movement_op(MovementOps.SLICE, ((0,xt.shape[0]), (0,xt.shape[1]), (0,xt.shape[2]), (0,C.sy), (0,xt.shape[4]), (0,C.sx)))
|
||||
xt = xt.movement_op(MovementOps.PAD, ((0,0), (0,0), (0,0), (0,C.sy-1), (0,0), (0,C.sx-1)))
|
||||
xt = xt.movement_op(MovementOps.RESHAPE, (xt.shape[0], xt.shape[1], xt.shape[2]*C.sy, xt.shape[4]*C.sx))
|
||||
wt = w.movement_op(MovementOps.RESHAPE, (C.groups, C.rcout, C.cin, C.H, C.W)).movement_op(MovementOps.PERMUTE, (0, 2, 1, 3, 4))
|
||||
wt = wt.movement_op(MovementOps.RESHAPE, (C.groups*C.cin, C.rcout, C.H, C.W)).movement_op(MovementOps.FLIP, (2, 3))
|
||||
|
||||
@@ -13,7 +13,7 @@ sys.setrecursionlimit(10000)
|
||||
UnaryOps = Enum("UnaryOps", ["NOOP", "NEG", "RELU", "EXP", "LOG", "SIGN"])
|
||||
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"])
|
||||
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
|
||||
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLIP", "STRIDED"])
|
||||
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLIP", "STRIDED", "PAD"])
|
||||
ProcessingOps = Enum("ProcessingOps", ["CONV"])
|
||||
LoadOps = Enum("LoadOps", ["FROMCPU"])
|
||||
|
||||
@@ -120,12 +120,8 @@ def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
|
||||
return real_src.reduce_op(self.op.op, self.op.arg), [real_src], ReduceOps
|
||||
|
||||
def _realize_movementops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
|
||||
real_src = get_lazybuffers(self.op)[0].realize(self.device)
|
||||
if getattr(real_src, "shapeTrackerView", None) is not None:
|
||||
return real_src.shapeTrackerView(self.st), [real_src], MovementOps
|
||||
else:
|
||||
# slow path, creates middle buffers
|
||||
return functools.reduce(lambda x,o: x.movement_op(o.op, o.arg), get_lazyops(self.op)[::-1], real_src), [real_src], MovementOps
|
||||
real_src = self.op.src[0].realize(self.device)
|
||||
return real_src.movement_op(self.op.op, self.op.arg), [real_src], MovementOps
|
||||
|
||||
def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
|
||||
real_srcs : Dict[LazyBuffer, DeviceBuffer] = {x:None for x in get_lazybuffers(self.op)}
|
||||
|
||||
@@ -123,6 +123,11 @@ class ShapeTracker:
|
||||
|
||||
# *** under this line are not invertible ***
|
||||
|
||||
# TODO: take this functionality out of slice
|
||||
def pad(self, *arg):
|
||||
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
|
||||
return self.slice(*[(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
|
||||
|
||||
def slice(self, *arg):
|
||||
assert len(arg) == len(self.shape)
|
||||
offset = sum([self.strides[i]*x for i,(x,_) in enumerate(arg)])
|
||||
|
||||
Reference in New Issue
Block a user