A conv is a reduce op (#356)

* universal strided conv

* more correct

* hmm, CPU works

* cleaner cl code output

* make noconv a flag

* cleanup __getitem__

* refactor broadcasting

* put that back

* unneeded reshape in getitem

* fix strided for torch
This commit is contained in:
George Hotz
2022-07-10 19:58:50 -07:00
committed by GitHub
parent 057e4f5aa5
commit 817b64f5e5
7 changed files with 60 additions and 59 deletions

View File

@@ -72,10 +72,3 @@ def postprocessing_op(ret, C, C_initial):
ret = ret.movement_op(MovementOps.RESHAPE, (C.bs, C.oy, C.ox, C.cout))
ret = ret.movement_op(MovementOps.PERMUTE, (0,3,1,2))
return ret
def processed_conv(x, w, C):
x,w,Cn = preprocessing_op(x,w,C)
# precompute the weight
w.realize().image
ret = x.processing_op(ProcessingOps.CONV, w, Cn)
return postprocessing_op(ret, Cn, C)

View File

@@ -19,6 +19,8 @@ class CPUBuffer(np.ndarray):
def permute(x, order): return x.transpose(order)
def custompad(x, padding): return np.pad(x, padding).view(CPUBuffer) if any(x != 0 or y != 0 for x,y in padding) else x
def expand(x, new_shape): return np.broadcast_to(x, new_shape).view(CPUBuffer)
def as_strided(x, size, stride): return np.lib.stride_tricks.as_strided(x, shape=size, strides=[x*4 for x in stride]).view(CPUBuffer)
def contiguous(x): return x.ravel().reshape(x.shape)
@staticmethod
def fromCPU(x): return x.view(CPUBuffer)
@@ -42,6 +44,7 @@ class CPUBuffer(np.ndarray):
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
return x.custompad(padding)[tuple(slice(p[0] + padding[i][0], p[1] + padding[i][0], None) for i,p in enumerate(arg))]
elif op == MovementOps.EXPAND: return x.expand(arg)
elif op == MovementOps.STRIDED: return x.contiguous().as_strided([x[0] for x in arg], [x[1] for x in arg])
def processing_op(x,op,w,C):
assert op == ProcessingOps.CONV, f"{op} isn't supported"

View File

@@ -127,7 +127,7 @@ class GPUBuffer:
if C is not None: # this is a conv
ints = ''.join(f"int {x} = {getattr(C, x)};" for x in ["H", "W", "sy", "sx", "dx", "dy", "px", "py", "groups", "rcout", "cin"])
params = [(f"int {x}", getattr(C, x)) for x in ["oy", "ox", "iy", "ix"]]
global_size = [C.bs*C.cout, C.oy, C.ox]
global_size = [C.bs*C.cout, C.oy, C.ox] # [nGk, h, w]
assert ret.shape == C.out_shape, "output shape is wrong (NOTE: you can't reduce and conv together)"
# now input and weight can be anywhere in bufs
@@ -157,13 +157,13 @@ class GPUBuffer:
kernel_name = "conv" if C is not None else ("reduce" if len(loop) > 0 else "elementwise")
views = {name:buf.contiguous_view_constant_fold(name) for name, buf in ewbufs}
buf_types = [f"__global const float *{name}_g" for name, _ in bufs if name not in views or views[name][1]]
conv_prg = CLProgram(kernel_name, f"""{''.join([x[0] for x in views.values()])}
conv_prg = CLProgram(kernel_name, f"""{chr(13).join([x[0] for x in views.values()])}
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + [x[0] for x in params])}) {{ {ints}
float acc = {start}; int gid = get_global_id(0); {conv_src} int idx = gid; {view.expr.replace('//', '/')};
{''.join([ls for ls, _ in loop[::-1]])}
{''.join([f'float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in ewbufs])}
{' '.join([ls for ls, _ in loop[::-1]])}
{chr(13).join([f' float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in ewbufs])}
acc = {code};
{''.join([le for _, le in loop])}
{' '.join([le for _, le in loop])}
output[gid] = acc;
}}""", argdtypes=tuple(None if i < 1+len(buf_types) else np.int32 for i in range(1+len(buf_types)+len(params))))
conv_prg(global_size, None, ret.cl, *[buf.cl for name, buf in bufs if name not in views or views[name][1]], *[x[1] for x in params])

View File

@@ -156,18 +156,10 @@ class Flip(Function):
# ************* processing ops *************
class Conv2D(Function):
def _conv(ctx, x, w, C):
# TODO: this does NOT belong here
if x.device == "OPENCL":
from accel.opencl.preprocessing import processed_conv # type: ignore
return processed_conv(x, w, C)
else:
return x.processing_op(ProcessingOps.CONV, w, C)
def forward(ctx, x, w, stride=1, groups=1, dilation=1, padding=0):
ctx.C = get_conv_args(x.shape, w.shape, stride, groups, dilation=dilation, padding=padding)
ctx.save_for_backward(x,w)
return ctx._conv(x, w, ctx.C)
return x.processing_op(ProcessingOps.CONV, w, ctx.C)
def backward(ctx, grad_output):
x, w = ctx.saved_tensors
@@ -176,7 +168,7 @@ class Conv2D(Function):
if ctx.needs_input_grad[0]: # compute derivative of inputs using ProcessingOps.CONV (this is a transposed conv)
xt = grad_output
if C.sx > 1 or C.sy > 1: # unstride. NOTE: this is really memory intensive for big strides.
if C.sx > 1 or C.sy > 1: # unstride. NOTE: this is really memory intensive for big strides. (but only when we contiguous it)
xt = xt.movement_op(MovementOps.RESHAPE, (grad_output.shape[0], grad_output.shape[1], grad_output.shape[2], 1, grad_output.shape[3], 1))
xt = xt.movement_op(MovementOps.SLICE, ((0,xt.shape[0]), (0,xt.shape[1]), (0,xt.shape[2]), (0,C.sy), (0,xt.shape[4]), (0,C.sx)))
xt = xt.movement_op(MovementOps.RESHAPE, (xt.shape[0], xt.shape[1], xt.shape[2]*C.sy, xt.shape[4]*C.sx))
@@ -184,13 +176,13 @@ class Conv2D(Function):
wt = wt.movement_op(MovementOps.RESHAPE, (C.groups*C.cin, C.rcout, C.H, C.W)).movement_op(MovementOps.FLIP, (2, 3))
py, px = (C.H-1)*C.dy - C.py, (C.W-1)*C.dx - C.px
Cdx = get_conv_args(xt.shape, wt.shape, out_shape=x.shape, dilation=(C.dy, C.dx), padding=(py, px), groups=C.groups)
dx = ctx._conv(xt, wt, Cdx)
dx = xt.processing_op(ProcessingOps.CONV, wt, Cdx)
if ctx.needs_input_grad[1]: # compute derivative of weights using ProcessingOps.CONV
xdw = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.iy, C.ix)).movement_op(MovementOps.PERMUTE, (2, 1, 0, 3, 4))
xdw = xdw.movement_op(MovementOps.RESHAPE, (C.cin, C.groups*C.bs, C.iy, C.ix))
grad_output_dw = grad_output.movement_op(MovementOps.PERMUTE, (1,0,2,3))
Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, out_shape=(w.shape[1], w.shape[0], w.shape[2], w.shape[3]), padding=(C.py, C.px), stride=(C.dy, C.dx), dilation=(C.sy, C.sx), groups=C.groups)
dw = ctx._conv(xdw, grad_output_dw, Cdw).movement_op(MovementOps.PERMUTE, (1,0,2,3))
dw = xdw.processing_op(ProcessingOps.CONV, grad_output_dw, Cdw).movement_op(MovementOps.PERMUTE, (1,0,2,3))
return dx, dw

View File

@@ -13,7 +13,7 @@ sys.setrecursionlimit(10000)
UnaryOps = Enum("UnaryOps", ["NOOP", "NEG", "RELU", "EXP", "LOG", "SIGN"])
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"])
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLIP"])
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLIP", "STRIDED"])
ProcessingOps = Enum("ProcessingOps", ["CONV"])
LoadOps = Enum("LoadOps", ["FROMCPU"])
@@ -23,6 +23,7 @@ OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOp
DEBUG = int(os.getenv("DEBUG", "0"))
GRAPH = int(os.getenv("GRAPH", "0"))
OPT = int(os.getenv("OPT", "1"))
NOCONV = int(os.getenv("NOCONV", "0"))
# TODO: movement ops that only change shape are really nops. treat them as such
MERGE_MOVEMENT_OPS, REMOVE_MOVEMENT_NOPS, MERGE_UNARY_OPS = OPT>=1, OPT>=1, OPT>=1
@@ -79,7 +80,8 @@ def log_op(optype : OpType, op : List[Op], ret : DeviceBuffer, inp : List[Device
G.add_edge(nm(x), nm(ret), label=sop)
if 'label' not in G.nodes[nm(x)]: G.nodes[nm(x)]['label'] = str(x.shape)
if nm(ret) not in G.nodes: G.add_node(nm(ret))
if getattr(ret, "st", None) is not None and not ret.st.contiguous: # checked twice to make type checker happy
if getattr(ret, "st", None) is not None and not ret.st.contiguous:
G.nodes[nm(ret)]['label'] = str(ret.shape)+"\n"+str(tuple(x[0] if x[1]!=0 else 0 for x in ret.st.views[-1].shape_strides))
dashed = True
else:
@@ -245,7 +247,8 @@ class LazyBuffer:
ret = LazyBuffer(x.device, ShapeTracker(x.st).movement_op(op, arg), MovementOps,
LazyOp(op, (x.op if MERGE_MOVEMENT_OPS and x.optype == MovementOps and x.realized is None else x,), arg))
if REMOVE_MOVEMENT_NOPS and x.realized is None and ret.st.contiguous:
# NOTE: if ret is in the cache, it can already be realized
if REMOVE_MOVEMENT_NOPS and ret.realized is None and x.realized is None and ret.st.contiguous:
root = get_lazybuffers(ret.op)[0]
if ret.st.shape == root.shape and root.st.contiguous:
return root
@@ -253,7 +256,27 @@ class LazyBuffer:
return ret
def processing_op(x:LazyBuffer, op:ProcessingOps, w:LazyBuffer, C:ConvArgs) -> LazyBuffer:
return LazyBuffer(x.device, C.out_shape, ProcessingOps, LazyOp(op, (x, w), C))
if NOCONV:
# universal conv, just mul and reduce
x = x.movement_op(MovementOps.SLICE, ((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
# TODO: is there any way to replace strided with other movement ops?
x = x.movement_op(MovementOps.STRIDED, (
(C.bs, C.groups*C.cin*x.shape[2]*x.shape[3]), (C.groups, C.cin*x.shape[2]*x.shape[3]),
(C.rcout, 0), (C.oy, C.sy*x.shape[3]), (C.ox, C.sx),
(C.cin, x.shape[2]*x.shape[3]), (C.H, C.dy*x.shape[3]), (C.W, C.dx)))
w = w.movement_op(MovementOps.RESHAPE, (1, C.groups, C.rcout, 1, 1, C.cin, C.H, C.W)) \
.movement_op(MovementOps.EXPAND, (C.bs, C.groups, C.rcout, C.oy, C.ox, C.cin, C.H, C.W))
#print(x.st.views, w.st.views)
return x.binary_op(BinaryOps.MUL, w).reduce_op(ReduceOps.SUM, (C.bs, C.groups, C.rcout, C.oy, C.ox, 1, 1, 1)) \
.movement_op(MovementOps.RESHAPE, (C.bs, C.cout, C.oy, C.ox))
elif x.device == "OPENCL":
# TODO: these can be properties on the device buffer
from accel.opencl.preprocessing import preprocessing_op, postprocessing_op # type: ignore
x,w,Cn = preprocessing_op(x, w, C)
ret = LazyBuffer(x.device, C.out_shape, ProcessingOps, LazyOp(op, (x, w), C))
return postprocessing_op(ret, Cn, C)
else:
return LazyBuffer(x.device, C.out_shape, ProcessingOps, LazyOp(op, (x, w), C))
def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer) -> LazyBuffer:
out_device, out_shape = srcs[0].device, srcs[0].shape

View File

@@ -95,6 +95,11 @@ class ShapeTracker:
exec(self.expr(), None, locals)
return locals["idx"] if locals["valid"] else -1
def strided(self, *arg):
view = View([x[0] for x in arg], [x[1] for x in arg])
if self.contiguous: self.views[-1] = view
else: self.views.append(view)
def reshape(self, *new_shape):
assert all([isinstance(x, int) for x in new_shape])
assert prod(self.shape) == prod(new_shape)

View File

@@ -120,8 +120,7 @@ class Tensor:
self.grad = Tensor.ones(*self.shape, device=self.device, requires_grad=False)
for t0 in reversed(self.deepwalk()):
if not any(x.requires_grad for x in t0._ctx.parents):
continue
if not any(x.requires_grad for x in t0._ctx.parents): continue
assert (t0.grad is not None)
grads = t0._ctx.backward(t0.grad.lazydata)
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
@@ -136,20 +135,12 @@ class Tensor:
def __getitem__(self, val):
arg = []
new_shape = []
if val is not None:
for i, s in enumerate(val if isinstance(val, (list, tuple)) else [val]):
if isinstance(s, int):
arg.append((s, s + 1))
else:
arg.append((s.start if s.start is not None else 0,
(s.stop if s.stop >=0 else self.shape[i]+s.stop) if s.stop is not None else self.shape[i]))
new_shape.append(arg[-1][1] - arg[-1][0])
assert s.step is None or s.step == 1
new_shape += self.shape[len(arg):]
if len(new_shape) == 0: new_shape = (1,) # tinygrad doesn't support len 0 shapes
ret = self.slice(arg = arg + [(0,self.shape[i]) for i in range(len(arg), len(self.shape))])
return ret.reshape(shape=new_shape) if tuple(ret.shape) != tuple(new_shape) else ret
for i, s in enumerate(val if isinstance(val, (list, tuple)) else [val]) if val is not None else []:
if isinstance(s, int): s = slice(s, s+1, None)
arg.append((s.start if s.start is not None else 0,
(s.stop if s.stop >=0 else self.shape[i]+s.stop) if s.stop is not None else self.shape[i]))
assert s.step is None or s.step == 1
return self.slice(arg = arg + [(0,self.shape[i]) for i in range(len(arg), len(self.shape))])
# TODO: there has to be a cleaner way to write this
def cat(self, *args, dim=0):
@@ -209,12 +200,12 @@ class Tensor:
def sum(self, axis=None, keepdim=False):
axis, out_shape = self._canonicalize_reduce_axis(axis)
ret = self._sum(axis=axis)
return ret if keepdim or ret.shape == out_shape else ret.reshape(shape=out_shape)
return ret if keepdim else ret.reshape(shape=out_shape)
def max(self, axis=None, keepdim=False):
axis, out_shape = self._canonicalize_reduce_axis(axis)
ret = self._max(axis=axis)
return ret if keepdim or ret.shape == out_shape else ret.reshape(shape=out_shape)
return ret if keepdim else ret.reshape(shape=out_shape)
def mean(self, axis=None, keepdim=False):
out = self.sum(axis=axis, keepdim=keepdim)
@@ -276,17 +267,10 @@ class Tensor:
@staticmethod
def broadcasted(fxn, x, y):
tt = [arg for arg in [x,y] if isinstance(arg, Tensor)][0] # this is the prototype tensor
if not isinstance(x, Tensor): x = Tensor([x], device=tt.device, requires_grad=False)
if not isinstance(y, Tensor): y = Tensor([y], device=tt.device, requires_grad=False)
n_dims = max(len(x.shape), len(y.shape))
if len(x.shape) != n_dims: x = x.reshape(list(x.shape) + [1]*(n_dims-len(x.shape)))
if len(y.shape) != n_dims: y = y.reshape(list(y.shape) + [1]*(n_dims-len(y.shape)))
shape_ret = tuple([max(sx, sy) for sx,sy in zip(x.shape, y.shape)])
if x.shape != shape_ret: x = x.expand(shape_ret)
if y.shape != shape_ret: y = y.expand(shape_ret)
return fxn(x, y)
x,y = [Tensor([t], device=tt.device, requires_grad=False) if not isinstance(t, Tensor) else t for t in [x,y]]
x,y = [t.reshape(list(t.shape) + [1]*(max(len(x.shape), len(y.shape))-len(t.shape))) for t in [x,y]]
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x.shape, y.shape))
return fxn(x.expand(shape_ret), y.expand(shape_ret))
# TODO: are these the only ones that can take number arguments?
def add(self, x): return Tensor.broadcasted(Tensor._add, self, x)
@@ -301,8 +285,9 @@ class Tensor:
# ***** functional nn ops *****
# TODO: fix the kwargs problem, then remove these
def reshape(self, shape): return self._reshape(shape=shape)
def expand(self, shape): return self._expand(shape=shape)
# NOTE: perhaps don't, since they create NOOPs if the shape already matches
def reshape(self, shape): return self._reshape(shape=shape) if tuple(self.shape) != tuple(shape) else self
def expand(self, shape): return self._expand(shape=shape) if tuple(self.shape) != tuple(shape) else self
def linear(self, weight:Tensor, bias:Tensor):
shp = [1] * (len(self.shape)-1) + [-1]