cleanups, simple padding in the processing op

This commit is contained in:
George Hotz
2023-01-25 07:37:52 -08:00
parent 3acf62d489
commit baf64c14ac
4 changed files with 10 additions and 21 deletions

View File

@@ -250,37 +250,26 @@ class LazyBuffer:
x = self
if IMAGE >= 1:
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.iy, C.ix))
w = w.movement_op(MovementOps.RESHAPE, (C.groups, C.rcout, C.cin, C.H, C.W))
if C.bs > 1 and C.py > 0:
# explicitly add y-padding for batched inputs
# N C H W
xs = [(0, 0) for _ in x.shape]
xs[2] = (C.py, C.py)
x = x.movement_op(MovementOps.PAD, xs)
C = C._replace(iy=C.iy + C.py*2, py=0)
added_output_channels = 0
# hack for non multiples of 4 on C.cin
if C.cin % 4 != 0 and not (C.cin == 1 and C.groups%4 == 0):
to_add = 4 - (C.cin % 4)
w = w.movement_op(MovementOps.PAD, [(0, to_add) if i == 2 else (0, 0) for i in range(len(w.shape))])
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.iy, C.ix))
x = x.movement_op(MovementOps.PAD, [(0, to_add) if i == 2 else (0, 0) for i in range(len(x.shape))])
C = C._replace(cin = C.cin + to_add)
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups*C.cin, C.iy, C.ix))
# hack for non multiples of 4 on C.rcout
if C.rcout % 4 != 0 and not (C.rcout == 1 and C.groups%4 == 0):
added_output_channels = 4 - (C.rcout % 4)
w = w.movement_op(MovementOps.PAD, [(0, added_output_channels) if i == 1 else (0, 0) for i in range(len(w.shape))])
C = C._replace(rcout = C.rcout + added_output_channels, cout = C.groups * (C.rcout + added_output_channels))
else:
added_output_channels = 0
# packed
assert (C.groups*C.cin) % 4 == 0
x = x.movement_op(MovementOps.PERMUTE, (0,2,3,1))
x = x.movement_op(MovementOps.PERMUTE, (0,3,4,1,2))
x = x.movement_op(MovementOps.RESHAPE, (C.bs*C.iy, C.ix*C.groups*C.cin//4, 4))
assert C.cout % 4 == 0
@@ -348,9 +337,10 @@ class LazyBuffer:
ret = ret.movement_op(MovementOps.PERMUTE, (0,3,1,2))
return ret
# TODO: fixup C?
if NOCONV or not getattr(x.dbuffer, "SUPPORTS_PADDING", False):
# add padding if the backend can't handle it
if NOCONV or (not getattr(x.dbuffer, "SUPPORTS_PADDING", False) and not (getattr(x.dbuffer, "SUPPORTS_SIMPLE_PADDING", False) and C.px == C.px_ and C.py == C.py_ and C.px >= 0 and C.py >= 0)):
x = x.slice(((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
C = C._replace(px=0, px_=0, py=0, py_=0)
if NOCONV or not getattr(x.dbuffer, "processing_op", False):
# universal conv, just mul and reduce

View File

@@ -39,6 +39,7 @@ class CPUBuffer(np.ndarray, GenericExecAST):
def processing_op(x,op,w,C):
assert op == ProcessingOps.CONV, f"{op} isn't supported"
assert C.px == 0 and C.px_ == 0 and C.py == 0 and C.py_ == 0, "padding in conv is not supported"
tx = x.movement_op(MovementOps.STRIDED, (
(C.bs, C.groups*C.cin*x.shape[2]*x.shape[3]), (C.groups, C.cin*x.shape[2]*x.shape[3]),
(C.oy, C.sy*x.shape[3]), (C.ox, C.sx), (C.cin, x.shape[2]*x.shape[3]), (C.H, C.dy*x.shape[3]), (C.W, C.dx)))

View File

@@ -397,10 +397,6 @@ class GPUBuffer(ExplicitExecAST):
@property
def cl(self):
if self._buf is None:
possible_split_shape = [x for x in self._base_shape if x != 1]
# TODO: this is broken, and a hack. I suspect the issue is unaligned float4 accesses, would be caught by the Image valid thing if it worked.
if IMAGE >= 3 and len(possible_split_shape) == 1 and possible_split_shape[0] % 4 == 0 and self._backing is None and possible_split_shape[0] != 6140:
self._base_shape = (1, possible_split_shape[0]//4, 4)
self._buf = CLImage(self._base_shape) if (len(self._base_shape) == 3 and self._base_shape[2] == 4 and IMAGE >= 2) else CLBuffer(4*prod(self._base_shape))
if self._backing is not None:
CL().enqueue_copy(self._buf.cl, self._backing, is_blocking=False)

View File

@@ -13,6 +13,8 @@ class TorchBuffer(torch.Tensor, GenericExecAST):
unary_op, binary_op, reduce_op, movement_op = CPUBuffer.unary_op, CPUBuffer.binary_op, CPUBuffer.reduce_op, CPUBuffer.movement_op
SUPPORTS_SIMPLE_PADDING = True
def processing_op(x,op,w,C):
assert op == ProcessingOps.CONV, f"{op} isn't supported"
return torch.conv2d(x, w, stride=(C.sy, C.sx), groups=C.groups, dilation=(C.dy, C.dx))
assert C.px == C.px_ and C.py == C.py_, "asymmetric padding in conv is not supported"
return torch.conv2d(x, w, stride=(C.sy, C.sx), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px))