mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
A conv is a reduce op (#356)
* universal strided conv * more correct * hmm, CPU works * cleaner cl code output * make noconv a flag * cleanup __getitem__ * refactor broadcasting * put that back * unneeded reshape in getitem * fix strided for torch
This commit is contained in:
@@ -72,10 +72,3 @@ def postprocessing_op(ret, C, C_initial):
|
||||
ret = ret.movement_op(MovementOps.RESHAPE, (C.bs, C.oy, C.ox, C.cout))
|
||||
ret = ret.movement_op(MovementOps.PERMUTE, (0,3,1,2))
|
||||
return ret
|
||||
|
||||
def processed_conv(x, w, C):
|
||||
x,w,Cn = preprocessing_op(x,w,C)
|
||||
# precompute the weight
|
||||
w.realize().image
|
||||
ret = x.processing_op(ProcessingOps.CONV, w, Cn)
|
||||
return postprocessing_op(ret, Cn, C)
|
||||
|
||||
@@ -19,6 +19,8 @@ class CPUBuffer(np.ndarray):
|
||||
def permute(x, order): return x.transpose(order)
|
||||
def custompad(x, padding): return np.pad(x, padding).view(CPUBuffer) if any(x != 0 or y != 0 for x,y in padding) else x
|
||||
def expand(x, new_shape): return np.broadcast_to(x, new_shape).view(CPUBuffer)
|
||||
def as_strided(x, size, stride): return np.lib.stride_tricks.as_strided(x, shape=size, strides=[x*4 for x in stride]).view(CPUBuffer)
|
||||
def contiguous(x): return x.ravel().reshape(x.shape)
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x): return x.view(CPUBuffer)
|
||||
@@ -42,6 +44,7 @@ class CPUBuffer(np.ndarray):
|
||||
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
|
||||
return x.custompad(padding)[tuple(slice(p[0] + padding[i][0], p[1] + padding[i][0], None) for i,p in enumerate(arg))]
|
||||
elif op == MovementOps.EXPAND: return x.expand(arg)
|
||||
elif op == MovementOps.STRIDED: return x.contiguous().as_strided([x[0] for x in arg], [x[1] for x in arg])
|
||||
|
||||
def processing_op(x,op,w,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
|
||||
@@ -127,7 +127,7 @@ class GPUBuffer:
|
||||
if C is not None: # this is a conv
|
||||
ints = ''.join(f"int {x} = {getattr(C, x)};" for x in ["H", "W", "sy", "sx", "dx", "dy", "px", "py", "groups", "rcout", "cin"])
|
||||
params = [(f"int {x}", getattr(C, x)) for x in ["oy", "ox", "iy", "ix"]]
|
||||
global_size = [C.bs*C.cout, C.oy, C.ox]
|
||||
global_size = [C.bs*C.cout, C.oy, C.ox] # [nGk, h, w]
|
||||
assert ret.shape == C.out_shape, "output shape is wrong (NOTE: you can't reduce and conv together)"
|
||||
|
||||
# now input and weight can be anywhere in bufs
|
||||
@@ -157,13 +157,13 @@ class GPUBuffer:
|
||||
kernel_name = "conv" if C is not None else ("reduce" if len(loop) > 0 else "elementwise")
|
||||
views = {name:buf.contiguous_view_constant_fold(name) for name, buf in ewbufs}
|
||||
buf_types = [f"__global const float *{name}_g" for name, _ in bufs if name not in views or views[name][1]]
|
||||
conv_prg = CLProgram(kernel_name, f"""{''.join([x[0] for x in views.values()])}
|
||||
conv_prg = CLProgram(kernel_name, f"""{chr(13).join([x[0] for x in views.values()])}
|
||||
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + [x[0] for x in params])}) {{ {ints}
|
||||
float acc = {start}; int gid = get_global_id(0); {conv_src} int idx = gid; {view.expr.replace('//', '/')};
|
||||
{''.join([ls for ls, _ in loop[::-1]])}
|
||||
{''.join([f'float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in ewbufs])}
|
||||
{' '.join([ls for ls, _ in loop[::-1]])}
|
||||
{chr(13).join([f' float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in ewbufs])}
|
||||
acc = {code};
|
||||
{''.join([le for _, le in loop])}
|
||||
{' '.join([le for _, le in loop])}
|
||||
output[gid] = acc;
|
||||
}}""", argdtypes=tuple(None if i < 1+len(buf_types) else np.int32 for i in range(1+len(buf_types)+len(params))))
|
||||
conv_prg(global_size, None, ret.cl, *[buf.cl for name, buf in bufs if name not in views or views[name][1]], *[x[1] for x in params])
|
||||
|
||||
@@ -156,18 +156,10 @@ class Flip(Function):
|
||||
# ************* processing ops *************
|
||||
|
||||
class Conv2D(Function):
|
||||
def _conv(ctx, x, w, C):
|
||||
# TODO: this does NOT belong here
|
||||
if x.device == "OPENCL":
|
||||
from accel.opencl.preprocessing import processed_conv # type: ignore
|
||||
return processed_conv(x, w, C)
|
||||
else:
|
||||
return x.processing_op(ProcessingOps.CONV, w, C)
|
||||
|
||||
def forward(ctx, x, w, stride=1, groups=1, dilation=1, padding=0):
|
||||
ctx.C = get_conv_args(x.shape, w.shape, stride, groups, dilation=dilation, padding=padding)
|
||||
ctx.save_for_backward(x,w)
|
||||
return ctx._conv(x, w, ctx.C)
|
||||
return x.processing_op(ProcessingOps.CONV, w, ctx.C)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
x, w = ctx.saved_tensors
|
||||
@@ -176,7 +168,7 @@ class Conv2D(Function):
|
||||
|
||||
if ctx.needs_input_grad[0]: # compute derivative of inputs using ProcessingOps.CONV (this is a transposed conv)
|
||||
xt = grad_output
|
||||
if C.sx > 1 or C.sy > 1: # unstride. NOTE: this is really memory intensive for big strides.
|
||||
if C.sx > 1 or C.sy > 1: # unstride. NOTE: this is really memory intensive for big strides. (but only when we contiguous it)
|
||||
xt = xt.movement_op(MovementOps.RESHAPE, (grad_output.shape[0], grad_output.shape[1], grad_output.shape[2], 1, grad_output.shape[3], 1))
|
||||
xt = xt.movement_op(MovementOps.SLICE, ((0,xt.shape[0]), (0,xt.shape[1]), (0,xt.shape[2]), (0,C.sy), (0,xt.shape[4]), (0,C.sx)))
|
||||
xt = xt.movement_op(MovementOps.RESHAPE, (xt.shape[0], xt.shape[1], xt.shape[2]*C.sy, xt.shape[4]*C.sx))
|
||||
@@ -184,13 +176,13 @@ class Conv2D(Function):
|
||||
wt = wt.movement_op(MovementOps.RESHAPE, (C.groups*C.cin, C.rcout, C.H, C.W)).movement_op(MovementOps.FLIP, (2, 3))
|
||||
py, px = (C.H-1)*C.dy - C.py, (C.W-1)*C.dx - C.px
|
||||
Cdx = get_conv_args(xt.shape, wt.shape, out_shape=x.shape, dilation=(C.dy, C.dx), padding=(py, px), groups=C.groups)
|
||||
dx = ctx._conv(xt, wt, Cdx)
|
||||
dx = xt.processing_op(ProcessingOps.CONV, wt, Cdx)
|
||||
|
||||
if ctx.needs_input_grad[1]: # compute derivative of weights using ProcessingOps.CONV
|
||||
xdw = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.iy, C.ix)).movement_op(MovementOps.PERMUTE, (2, 1, 0, 3, 4))
|
||||
xdw = xdw.movement_op(MovementOps.RESHAPE, (C.cin, C.groups*C.bs, C.iy, C.ix))
|
||||
grad_output_dw = grad_output.movement_op(MovementOps.PERMUTE, (1,0,2,3))
|
||||
Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, out_shape=(w.shape[1], w.shape[0], w.shape[2], w.shape[3]), padding=(C.py, C.px), stride=(C.dy, C.dx), dilation=(C.sy, C.sx), groups=C.groups)
|
||||
dw = ctx._conv(xdw, grad_output_dw, Cdw).movement_op(MovementOps.PERMUTE, (1,0,2,3))
|
||||
dw = xdw.processing_op(ProcessingOps.CONV, grad_output_dw, Cdw).movement_op(MovementOps.PERMUTE, (1,0,2,3))
|
||||
|
||||
return dx, dw
|
||||
|
||||
@@ -13,7 +13,7 @@ sys.setrecursionlimit(10000)
|
||||
UnaryOps = Enum("UnaryOps", ["NOOP", "NEG", "RELU", "EXP", "LOG", "SIGN"])
|
||||
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"])
|
||||
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
|
||||
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLIP"])
|
||||
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLIP", "STRIDED"])
|
||||
ProcessingOps = Enum("ProcessingOps", ["CONV"])
|
||||
LoadOps = Enum("LoadOps", ["FROMCPU"])
|
||||
|
||||
@@ -23,6 +23,7 @@ OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOp
|
||||
DEBUG = int(os.getenv("DEBUG", "0"))
|
||||
GRAPH = int(os.getenv("GRAPH", "0"))
|
||||
OPT = int(os.getenv("OPT", "1"))
|
||||
NOCONV = int(os.getenv("NOCONV", "0"))
|
||||
|
||||
# TODO: movement ops that only change shape are really nops. treat them as such
|
||||
MERGE_MOVEMENT_OPS, REMOVE_MOVEMENT_NOPS, MERGE_UNARY_OPS = OPT>=1, OPT>=1, OPT>=1
|
||||
@@ -79,7 +80,8 @@ def log_op(optype : OpType, op : List[Op], ret : DeviceBuffer, inp : List[Device
|
||||
G.add_edge(nm(x), nm(ret), label=sop)
|
||||
if 'label' not in G.nodes[nm(x)]: G.nodes[nm(x)]['label'] = str(x.shape)
|
||||
if nm(ret) not in G.nodes: G.add_node(nm(ret))
|
||||
if getattr(ret, "st", None) is not None and not ret.st.contiguous: # checked twice to make type checker happy
|
||||
|
||||
if getattr(ret, "st", None) is not None and not ret.st.contiguous:
|
||||
G.nodes[nm(ret)]['label'] = str(ret.shape)+"\n"+str(tuple(x[0] if x[1]!=0 else 0 for x in ret.st.views[-1].shape_strides))
|
||||
dashed = True
|
||||
else:
|
||||
@@ -245,7 +247,8 @@ class LazyBuffer:
|
||||
ret = LazyBuffer(x.device, ShapeTracker(x.st).movement_op(op, arg), MovementOps,
|
||||
LazyOp(op, (x.op if MERGE_MOVEMENT_OPS and x.optype == MovementOps and x.realized is None else x,), arg))
|
||||
|
||||
if REMOVE_MOVEMENT_NOPS and x.realized is None and ret.st.contiguous:
|
||||
# NOTE: if ret is in the cache, it can already be realized
|
||||
if REMOVE_MOVEMENT_NOPS and ret.realized is None and x.realized is None and ret.st.contiguous:
|
||||
root = get_lazybuffers(ret.op)[0]
|
||||
if ret.st.shape == root.shape and root.st.contiguous:
|
||||
return root
|
||||
@@ -253,7 +256,27 @@ class LazyBuffer:
|
||||
return ret
|
||||
|
||||
def processing_op(x:LazyBuffer, op:ProcessingOps, w:LazyBuffer, C:ConvArgs) -> LazyBuffer:
|
||||
return LazyBuffer(x.device, C.out_shape, ProcessingOps, LazyOp(op, (x, w), C))
|
||||
if NOCONV:
|
||||
# universal conv, just mul and reduce
|
||||
x = x.movement_op(MovementOps.SLICE, ((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
|
||||
# TODO: is there any way to replace strided with other movement ops?
|
||||
x = x.movement_op(MovementOps.STRIDED, (
|
||||
(C.bs, C.groups*C.cin*x.shape[2]*x.shape[3]), (C.groups, C.cin*x.shape[2]*x.shape[3]),
|
||||
(C.rcout, 0), (C.oy, C.sy*x.shape[3]), (C.ox, C.sx),
|
||||
(C.cin, x.shape[2]*x.shape[3]), (C.H, C.dy*x.shape[3]), (C.W, C.dx)))
|
||||
w = w.movement_op(MovementOps.RESHAPE, (1, C.groups, C.rcout, 1, 1, C.cin, C.H, C.W)) \
|
||||
.movement_op(MovementOps.EXPAND, (C.bs, C.groups, C.rcout, C.oy, C.ox, C.cin, C.H, C.W))
|
||||
#print(x.st.views, w.st.views)
|
||||
return x.binary_op(BinaryOps.MUL, w).reduce_op(ReduceOps.SUM, (C.bs, C.groups, C.rcout, C.oy, C.ox, 1, 1, 1)) \
|
||||
.movement_op(MovementOps.RESHAPE, (C.bs, C.cout, C.oy, C.ox))
|
||||
elif x.device == "OPENCL":
|
||||
# TODO: these can be properties on the device buffer
|
||||
from accel.opencl.preprocessing import preprocessing_op, postprocessing_op # type: ignore
|
||||
x,w,Cn = preprocessing_op(x, w, C)
|
||||
ret = LazyBuffer(x.device, C.out_shape, ProcessingOps, LazyOp(op, (x, w), C))
|
||||
return postprocessing_op(ret, Cn, C)
|
||||
else:
|
||||
return LazyBuffer(x.device, C.out_shape, ProcessingOps, LazyOp(op, (x, w), C))
|
||||
|
||||
def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer) -> LazyBuffer:
|
||||
out_device, out_shape = srcs[0].device, srcs[0].shape
|
||||
|
||||
@@ -95,6 +95,11 @@ class ShapeTracker:
|
||||
exec(self.expr(), None, locals)
|
||||
return locals["idx"] if locals["valid"] else -1
|
||||
|
||||
def strided(self, *arg):
|
||||
view = View([x[0] for x in arg], [x[1] for x in arg])
|
||||
if self.contiguous: self.views[-1] = view
|
||||
else: self.views.append(view)
|
||||
|
||||
def reshape(self, *new_shape):
|
||||
assert all([isinstance(x, int) for x in new_shape])
|
||||
assert prod(self.shape) == prod(new_shape)
|
||||
|
||||
@@ -120,8 +120,7 @@ class Tensor:
|
||||
self.grad = Tensor.ones(*self.shape, device=self.device, requires_grad=False)
|
||||
|
||||
for t0 in reversed(self.deepwalk()):
|
||||
if not any(x.requires_grad for x in t0._ctx.parents):
|
||||
continue
|
||||
if not any(x.requires_grad for x in t0._ctx.parents): continue
|
||||
assert (t0.grad is not None)
|
||||
grads = t0._ctx.backward(t0.grad.lazydata)
|
||||
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
||||
@@ -136,20 +135,12 @@ class Tensor:
|
||||
|
||||
def __getitem__(self, val):
|
||||
arg = []
|
||||
new_shape = []
|
||||
if val is not None:
|
||||
for i, s in enumerate(val if isinstance(val, (list, tuple)) else [val]):
|
||||
if isinstance(s, int):
|
||||
arg.append((s, s + 1))
|
||||
else:
|
||||
arg.append((s.start if s.start is not None else 0,
|
||||
(s.stop if s.stop >=0 else self.shape[i]+s.stop) if s.stop is not None else self.shape[i]))
|
||||
new_shape.append(arg[-1][1] - arg[-1][0])
|
||||
assert s.step is None or s.step == 1
|
||||
new_shape += self.shape[len(arg):]
|
||||
if len(new_shape) == 0: new_shape = (1,) # tinygrad doesn't support len 0 shapes
|
||||
ret = self.slice(arg = arg + [(0,self.shape[i]) for i in range(len(arg), len(self.shape))])
|
||||
return ret.reshape(shape=new_shape) if tuple(ret.shape) != tuple(new_shape) else ret
|
||||
for i, s in enumerate(val if isinstance(val, (list, tuple)) else [val]) if val is not None else []:
|
||||
if isinstance(s, int): s = slice(s, s+1, None)
|
||||
arg.append((s.start if s.start is not None else 0,
|
||||
(s.stop if s.stop >=0 else self.shape[i]+s.stop) if s.stop is not None else self.shape[i]))
|
||||
assert s.step is None or s.step == 1
|
||||
return self.slice(arg = arg + [(0,self.shape[i]) for i in range(len(arg), len(self.shape))])
|
||||
|
||||
# TODO: there has to be a cleaner way to write this
|
||||
def cat(self, *args, dim=0):
|
||||
@@ -209,12 +200,12 @@ class Tensor:
|
||||
def sum(self, axis=None, keepdim=False):
|
||||
axis, out_shape = self._canonicalize_reduce_axis(axis)
|
||||
ret = self._sum(axis=axis)
|
||||
return ret if keepdim or ret.shape == out_shape else ret.reshape(shape=out_shape)
|
||||
return ret if keepdim else ret.reshape(shape=out_shape)
|
||||
|
||||
def max(self, axis=None, keepdim=False):
|
||||
axis, out_shape = self._canonicalize_reduce_axis(axis)
|
||||
ret = self._max(axis=axis)
|
||||
return ret if keepdim or ret.shape == out_shape else ret.reshape(shape=out_shape)
|
||||
return ret if keepdim else ret.reshape(shape=out_shape)
|
||||
|
||||
def mean(self, axis=None, keepdim=False):
|
||||
out = self.sum(axis=axis, keepdim=keepdim)
|
||||
@@ -276,17 +267,10 @@ class Tensor:
|
||||
@staticmethod
|
||||
def broadcasted(fxn, x, y):
|
||||
tt = [arg for arg in [x,y] if isinstance(arg, Tensor)][0] # this is the prototype tensor
|
||||
if not isinstance(x, Tensor): x = Tensor([x], device=tt.device, requires_grad=False)
|
||||
if not isinstance(y, Tensor): y = Tensor([y], device=tt.device, requires_grad=False)
|
||||
|
||||
n_dims = max(len(x.shape), len(y.shape))
|
||||
if len(x.shape) != n_dims: x = x.reshape(list(x.shape) + [1]*(n_dims-len(x.shape)))
|
||||
if len(y.shape) != n_dims: y = y.reshape(list(y.shape) + [1]*(n_dims-len(y.shape)))
|
||||
|
||||
shape_ret = tuple([max(sx, sy) for sx,sy in zip(x.shape, y.shape)])
|
||||
if x.shape != shape_ret: x = x.expand(shape_ret)
|
||||
if y.shape != shape_ret: y = y.expand(shape_ret)
|
||||
return fxn(x, y)
|
||||
x,y = [Tensor([t], device=tt.device, requires_grad=False) if not isinstance(t, Tensor) else t for t in [x,y]]
|
||||
x,y = [t.reshape(list(t.shape) + [1]*(max(len(x.shape), len(y.shape))-len(t.shape))) for t in [x,y]]
|
||||
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x.shape, y.shape))
|
||||
return fxn(x.expand(shape_ret), y.expand(shape_ret))
|
||||
|
||||
# TODO: are these the only ones that can take number arguments?
|
||||
def add(self, x): return Tensor.broadcasted(Tensor._add, self, x)
|
||||
@@ -301,8 +285,9 @@ class Tensor:
|
||||
# ***** functional nn ops *****
|
||||
|
||||
# TODO: fix the kwargs problem, then remove these
|
||||
def reshape(self, shape): return self._reshape(shape=shape)
|
||||
def expand(self, shape): return self._expand(shape=shape)
|
||||
# NOTE: perhaps don't, since they create NOOPs if the shape already matches
|
||||
def reshape(self, shape): return self._reshape(shape=shape) if tuple(self.shape) != tuple(shape) else self
|
||||
def expand(self, shape): return self._expand(shape=shape) if tuple(self.shape) != tuple(shape) else self
|
||||
|
||||
def linear(self, weight:Tensor, bias:Tensor):
|
||||
shp = [1] * (len(self.shape)-1) + [-1]
|
||||
|
||||
Reference in New Issue
Block a user