diff --git a/accel/opencl/preprocessing.py b/accel/opencl/preprocessing.py index a8ffdc9af6..331f00d372 100644 --- a/accel/opencl/preprocessing.py +++ b/accel/opencl/preprocessing.py @@ -72,10 +72,3 @@ def postprocessing_op(ret, C, C_initial): ret = ret.movement_op(MovementOps.RESHAPE, (C.bs, C.oy, C.ox, C.cout)) ret = ret.movement_op(MovementOps.PERMUTE, (0,3,1,2)) return ret - -def processed_conv(x, w, C): - x,w,Cn = preprocessing_op(x,w,C) - # precompute the weight - w.realize().image - ret = x.processing_op(ProcessingOps.CONV, w, Cn) - return postprocessing_op(ret, Cn, C) diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index 6226d88afe..8c0ebb0f38 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -19,6 +19,8 @@ class CPUBuffer(np.ndarray): def permute(x, order): return x.transpose(order) def custompad(x, padding): return np.pad(x, padding).view(CPUBuffer) if any(x != 0 or y != 0 for x,y in padding) else x def expand(x, new_shape): return np.broadcast_to(x, new_shape).view(CPUBuffer) + def as_strided(x, size, stride): return np.lib.stride_tricks.as_strided(x, shape=size, strides=[x*4 for x in stride]).view(CPUBuffer) + def contiguous(x): return x.ravel().reshape(x.shape) @staticmethod def fromCPU(x): return x.view(CPUBuffer) @@ -42,6 +44,7 @@ class CPUBuffer(np.ndarray): padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)] return x.custompad(padding)[tuple(slice(p[0] + padding[i][0], p[1] + padding[i][0], None) for i,p in enumerate(arg))] elif op == MovementOps.EXPAND: return x.expand(arg) + elif op == MovementOps.STRIDED: return x.contiguous().as_strided([x[0] for x in arg], [x[1] for x in arg]) def processing_op(x,op,w,C): assert op == ProcessingOps.CONV, f"{op} isn't supported" diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index a1268a201a..0ed6f57055 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -127,7 +127,7 @@ class GPUBuffer: if C is not None: # this is a conv ints = ''.join(f"int {x} = {getattr(C, x)};" for x in ["H", "W", "sy", "sx", "dx", "dy", "px", "py", "groups", "rcout", "cin"]) params = [(f"int {x}", getattr(C, x)) for x in ["oy", "ox", "iy", "ix"]] - global_size = [C.bs*C.cout, C.oy, C.ox] + global_size = [C.bs*C.cout, C.oy, C.ox] # [nGk, h, w] assert ret.shape == C.out_shape, "output shape is wrong (NOTE: you can't reduce and conv together)" # now input and weight can be anywhere in bufs @@ -157,13 +157,13 @@ class GPUBuffer: kernel_name = "conv" if C is not None else ("reduce" if len(loop) > 0 else "elementwise") views = {name:buf.contiguous_view_constant_fold(name) for name, buf in ewbufs} buf_types = [f"__global const float *{name}_g" for name, _ in bufs if name not in views or views[name][1]] - conv_prg = CLProgram(kernel_name, f"""{''.join([x[0] for x in views.values()])} + conv_prg = CLProgram(kernel_name, f"""{chr(13).join([x[0] for x in views.values()])} __kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + [x[0] for x in params])}) {{ {ints} float acc = {start}; int gid = get_global_id(0); {conv_src} int idx = gid; {view.expr.replace('//', '/')}; - {''.join([ls for ls, _ in loop[::-1]])} - {''.join([f'float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in ewbufs])} + {' '.join([ls for ls, _ in loop[::-1]])} +{chr(13).join([f' float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in ewbufs])} acc = {code}; - {''.join([le for _, le in loop])} + {' '.join([le for _, le in loop])} output[gid] = acc; }}""", argdtypes=tuple(None if i < 1+len(buf_types) else np.int32 for i in range(1+len(buf_types)+len(params)))) conv_prg(global_size, None, ret.cl, *[buf.cl for name, buf in bufs if name not in views or views[name][1]], *[x[1] for x in params]) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index b49057d05c..6cda7e318a 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -156,18 +156,10 @@ class Flip(Function): # ************* processing ops ************* class Conv2D(Function): - def _conv(ctx, x, w, C): - # TODO: this does NOT belong here - if x.device == "OPENCL": - from accel.opencl.preprocessing import processed_conv # type: ignore - return processed_conv(x, w, C) - else: - return x.processing_op(ProcessingOps.CONV, w, C) - def forward(ctx, x, w, stride=1, groups=1, dilation=1, padding=0): ctx.C = get_conv_args(x.shape, w.shape, stride, groups, dilation=dilation, padding=padding) ctx.save_for_backward(x,w) - return ctx._conv(x, w, ctx.C) + return x.processing_op(ProcessingOps.CONV, w, ctx.C) def backward(ctx, grad_output): x, w = ctx.saved_tensors @@ -176,7 +168,7 @@ class Conv2D(Function): if ctx.needs_input_grad[0]: # compute derivative of inputs using ProcessingOps.CONV (this is a transposed conv) xt = grad_output - if C.sx > 1 or C.sy > 1: # unstride. NOTE: this is really memory intensive for big strides. + if C.sx > 1 or C.sy > 1: # unstride. NOTE: this is really memory intensive for big strides. (but only when we contiguous it) xt = xt.movement_op(MovementOps.RESHAPE, (grad_output.shape[0], grad_output.shape[1], grad_output.shape[2], 1, grad_output.shape[3], 1)) xt = xt.movement_op(MovementOps.SLICE, ((0,xt.shape[0]), (0,xt.shape[1]), (0,xt.shape[2]), (0,C.sy), (0,xt.shape[4]), (0,C.sx))) xt = xt.movement_op(MovementOps.RESHAPE, (xt.shape[0], xt.shape[1], xt.shape[2]*C.sy, xt.shape[4]*C.sx)) @@ -184,13 +176,13 @@ class Conv2D(Function): wt = wt.movement_op(MovementOps.RESHAPE, (C.groups*C.cin, C.rcout, C.H, C.W)).movement_op(MovementOps.FLIP, (2, 3)) py, px = (C.H-1)*C.dy - C.py, (C.W-1)*C.dx - C.px Cdx = get_conv_args(xt.shape, wt.shape, out_shape=x.shape, dilation=(C.dy, C.dx), padding=(py, px), groups=C.groups) - dx = ctx._conv(xt, wt, Cdx) + dx = xt.processing_op(ProcessingOps.CONV, wt, Cdx) if ctx.needs_input_grad[1]: # compute derivative of weights using ProcessingOps.CONV xdw = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.iy, C.ix)).movement_op(MovementOps.PERMUTE, (2, 1, 0, 3, 4)) xdw = xdw.movement_op(MovementOps.RESHAPE, (C.cin, C.groups*C.bs, C.iy, C.ix)) grad_output_dw = grad_output.movement_op(MovementOps.PERMUTE, (1,0,2,3)) Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, out_shape=(w.shape[1], w.shape[0], w.shape[2], w.shape[3]), padding=(C.py, C.px), stride=(C.dy, C.dx), dilation=(C.sy, C.sx), groups=C.groups) - dw = ctx._conv(xdw, grad_output_dw, Cdw).movement_op(MovementOps.PERMUTE, (1,0,2,3)) + dw = xdw.processing_op(ProcessingOps.CONV, grad_output_dw, Cdw).movement_op(MovementOps.PERMUTE, (1,0,2,3)) return dx, dw diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 5570bc3ef0..94aa6ae527 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -13,7 +13,7 @@ sys.setrecursionlimit(10000) UnaryOps = Enum("UnaryOps", ["NOOP", "NEG", "RELU", "EXP", "LOG", "SIGN"]) BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"]) ReduceOps = Enum("ReduceOps", ["SUM", "MAX"]) -MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLIP"]) +MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLIP", "STRIDED"]) ProcessingOps = Enum("ProcessingOps", ["CONV"]) LoadOps = Enum("LoadOps", ["FROMCPU"]) @@ -23,6 +23,7 @@ OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOp DEBUG = int(os.getenv("DEBUG", "0")) GRAPH = int(os.getenv("GRAPH", "0")) OPT = int(os.getenv("OPT", "1")) +NOCONV = int(os.getenv("NOCONV", "0")) # TODO: movement ops that only change shape are really nops. treat them as such MERGE_MOVEMENT_OPS, REMOVE_MOVEMENT_NOPS, MERGE_UNARY_OPS = OPT>=1, OPT>=1, OPT>=1 @@ -79,7 +80,8 @@ def log_op(optype : OpType, op : List[Op], ret : DeviceBuffer, inp : List[Device G.add_edge(nm(x), nm(ret), label=sop) if 'label' not in G.nodes[nm(x)]: G.nodes[nm(x)]['label'] = str(x.shape) if nm(ret) not in G.nodes: G.add_node(nm(ret)) - if getattr(ret, "st", None) is not None and not ret.st.contiguous: # checked twice to make type checker happy + + if getattr(ret, "st", None) is not None and not ret.st.contiguous: G.nodes[nm(ret)]['label'] = str(ret.shape)+"\n"+str(tuple(x[0] if x[1]!=0 else 0 for x in ret.st.views[-1].shape_strides)) dashed = True else: @@ -245,7 +247,8 @@ class LazyBuffer: ret = LazyBuffer(x.device, ShapeTracker(x.st).movement_op(op, arg), MovementOps, LazyOp(op, (x.op if MERGE_MOVEMENT_OPS and x.optype == MovementOps and x.realized is None else x,), arg)) - if REMOVE_MOVEMENT_NOPS and x.realized is None and ret.st.contiguous: + # NOTE: if ret is in the cache, it can already be realized + if REMOVE_MOVEMENT_NOPS and ret.realized is None and x.realized is None and ret.st.contiguous: root = get_lazybuffers(ret.op)[0] if ret.st.shape == root.shape and root.st.contiguous: return root @@ -253,7 +256,27 @@ class LazyBuffer: return ret def processing_op(x:LazyBuffer, op:ProcessingOps, w:LazyBuffer, C:ConvArgs) -> LazyBuffer: - return LazyBuffer(x.device, C.out_shape, ProcessingOps, LazyOp(op, (x, w), C)) + if NOCONV: + # universal conv, just mul and reduce + x = x.movement_op(MovementOps.SLICE, ((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_))) + # TODO: is there any way to replace strided with other movement ops? + x = x.movement_op(MovementOps.STRIDED, ( + (C.bs, C.groups*C.cin*x.shape[2]*x.shape[3]), (C.groups, C.cin*x.shape[2]*x.shape[3]), + (C.rcout, 0), (C.oy, C.sy*x.shape[3]), (C.ox, C.sx), + (C.cin, x.shape[2]*x.shape[3]), (C.H, C.dy*x.shape[3]), (C.W, C.dx))) + w = w.movement_op(MovementOps.RESHAPE, (1, C.groups, C.rcout, 1, 1, C.cin, C.H, C.W)) \ + .movement_op(MovementOps.EXPAND, (C.bs, C.groups, C.rcout, C.oy, C.ox, C.cin, C.H, C.W)) + #print(x.st.views, w.st.views) + return x.binary_op(BinaryOps.MUL, w).reduce_op(ReduceOps.SUM, (C.bs, C.groups, C.rcout, C.oy, C.ox, 1, 1, 1)) \ + .movement_op(MovementOps.RESHAPE, (C.bs, C.cout, C.oy, C.ox)) + elif x.device == "OPENCL": + # TODO: these can be properties on the device buffer + from accel.opencl.preprocessing import preprocessing_op, postprocessing_op # type: ignore + x,w,Cn = preprocessing_op(x, w, C) + ret = LazyBuffer(x.device, C.out_shape, ProcessingOps, LazyOp(op, (x, w), C)) + return postprocessing_op(ret, Cn, C) + else: + return LazyBuffer(x.device, C.out_shape, ProcessingOps, LazyOp(op, (x, w), C)) def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer) -> LazyBuffer: out_device, out_shape = srcs[0].device, srcs[0].shape diff --git a/tinygrad/shapetracker.py b/tinygrad/shapetracker.py index fff3b6f5c3..af958d828b 100644 --- a/tinygrad/shapetracker.py +++ b/tinygrad/shapetracker.py @@ -95,6 +95,11 @@ class ShapeTracker: exec(self.expr(), None, locals) return locals["idx"] if locals["valid"] else -1 + def strided(self, *arg): + view = View([x[0] for x in arg], [x[1] for x in arg]) + if self.contiguous: self.views[-1] = view + else: self.views.append(view) + def reshape(self, *new_shape): assert all([isinstance(x, int) for x in new_shape]) assert prod(self.shape) == prod(new_shape) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 1e61f1d581..7688c9e0ca 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -120,8 +120,7 @@ class Tensor: self.grad = Tensor.ones(*self.shape, device=self.device, requires_grad=False) for t0 in reversed(self.deepwalk()): - if not any(x.requires_grad for x in t0._ctx.parents): - continue + if not any(x.requires_grad for x in t0._ctx.parents): continue assert (t0.grad is not None) grads = t0._ctx.backward(t0.grad.lazydata) grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None @@ -136,20 +135,12 @@ class Tensor: def __getitem__(self, val): arg = [] - new_shape = [] - if val is not None: - for i, s in enumerate(val if isinstance(val, (list, tuple)) else [val]): - if isinstance(s, int): - arg.append((s, s + 1)) - else: - arg.append((s.start if s.start is not None else 0, - (s.stop if s.stop >=0 else self.shape[i]+s.stop) if s.stop is not None else self.shape[i])) - new_shape.append(arg[-1][1] - arg[-1][0]) - assert s.step is None or s.step == 1 - new_shape += self.shape[len(arg):] - if len(new_shape) == 0: new_shape = (1,) # tinygrad doesn't support len 0 shapes - ret = self.slice(arg = arg + [(0,self.shape[i]) for i in range(len(arg), len(self.shape))]) - return ret.reshape(shape=new_shape) if tuple(ret.shape) != tuple(new_shape) else ret + for i, s in enumerate(val if isinstance(val, (list, tuple)) else [val]) if val is not None else []: + if isinstance(s, int): s = slice(s, s+1, None) + arg.append((s.start if s.start is not None else 0, + (s.stop if s.stop >=0 else self.shape[i]+s.stop) if s.stop is not None else self.shape[i])) + assert s.step is None or s.step == 1 + return self.slice(arg = arg + [(0,self.shape[i]) for i in range(len(arg), len(self.shape))]) # TODO: there has to be a cleaner way to write this def cat(self, *args, dim=0): @@ -209,12 +200,12 @@ class Tensor: def sum(self, axis=None, keepdim=False): axis, out_shape = self._canonicalize_reduce_axis(axis) ret = self._sum(axis=axis) - return ret if keepdim or ret.shape == out_shape else ret.reshape(shape=out_shape) + return ret if keepdim else ret.reshape(shape=out_shape) def max(self, axis=None, keepdim=False): axis, out_shape = self._canonicalize_reduce_axis(axis) ret = self._max(axis=axis) - return ret if keepdim or ret.shape == out_shape else ret.reshape(shape=out_shape) + return ret if keepdim else ret.reshape(shape=out_shape) def mean(self, axis=None, keepdim=False): out = self.sum(axis=axis, keepdim=keepdim) @@ -276,17 +267,10 @@ class Tensor: @staticmethod def broadcasted(fxn, x, y): tt = [arg for arg in [x,y] if isinstance(arg, Tensor)][0] # this is the prototype tensor - if not isinstance(x, Tensor): x = Tensor([x], device=tt.device, requires_grad=False) - if not isinstance(y, Tensor): y = Tensor([y], device=tt.device, requires_grad=False) - - n_dims = max(len(x.shape), len(y.shape)) - if len(x.shape) != n_dims: x = x.reshape(list(x.shape) + [1]*(n_dims-len(x.shape))) - if len(y.shape) != n_dims: y = y.reshape(list(y.shape) + [1]*(n_dims-len(y.shape))) - - shape_ret = tuple([max(sx, sy) for sx,sy in zip(x.shape, y.shape)]) - if x.shape != shape_ret: x = x.expand(shape_ret) - if y.shape != shape_ret: y = y.expand(shape_ret) - return fxn(x, y) + x,y = [Tensor([t], device=tt.device, requires_grad=False) if not isinstance(t, Tensor) else t for t in [x,y]] + x,y = [t.reshape(list(t.shape) + [1]*(max(len(x.shape), len(y.shape))-len(t.shape))) for t in [x,y]] + shape_ret = tuple(max(sx, sy) for sx,sy in zip(x.shape, y.shape)) + return fxn(x.expand(shape_ret), y.expand(shape_ret)) # TODO: are these the only ones that can take number arguments? def add(self, x): return Tensor.broadcasted(Tensor._add, self, x) @@ -301,8 +285,9 @@ class Tensor: # ***** functional nn ops ***** # TODO: fix the kwargs problem, then remove these - def reshape(self, shape): return self._reshape(shape=shape) - def expand(self, shape): return self._expand(shape=shape) + # NOTE: perhaps don't, since they create NOOPs if the shape already matches + def reshape(self, shape): return self._reshape(shape=shape) if tuple(self.shape) != tuple(shape) else self + def expand(self, shape): return self._expand(shape=shape) if tuple(self.shape) != tuple(shape) else self def linear(self, weight:Tensor, bias:Tensor): shp = [1] * (len(self.shape)-1) + [-1]