diff --git a/README.md b/README.md index 2b329405a8..bd48493bb3 100644 --- a/README.md +++ b/README.md @@ -112,18 +112,18 @@ Warning: do not rely on the ANE port. It segfaults sometimes. So if you were doi ### hlops (in tensor.py) -hlops are syntactic sugar around mlops. +hlops are syntactic sugar around mlops. They support most things torch does. ### mlops mlops are mid level ops, there's 13 of them. They understand memory allocation and derivatives ``` -Relu, Log, Exp # unary ops -Sum, Max # reduce ops (with axis argument) -Add, Sub, Mul, Pow # binary ops (with broadcasting) -Reshape, Permute, Slice # movement ops -Conv2D(NCHW) # processing op (Matmul is also Conv2D) +Relu, Log, Exp # unary ops +Sum, Max # reduce ops (with axis argument) +Add, Sub, Mul, Pow # binary ops (no broadcasting, use expand) +Reshape, Permute, Slice, Expand # movement ops +Conv2D(NCHW) # processing op (Matmul is also Conv2D) ``` You no longer need to write mlops for a new accelerator @@ -136,8 +136,8 @@ The autodiff stuff is all in mlops now so you can focus on the raw operations Buffer # class of memory on this device unary_op (RELU, EXP, LOG, NEG, SIGN) # A -> A reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape) -binary_op (ADD, SUB, MUL, DIV, POW, A, CMPEQ) # A + B -> C (broadcasting supported) -movement_op (RESHAPE, PERMUTE, SLICE) # A -> B (different size) +binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ) # A + B -> C (all the same size) +movement_op (RESHAPE, PERMUTE, SLICE, EXPAND) # A -> B (different size) processing_op (CONV, CONVT, CONVDW) # A + B -> C ``` diff --git a/test/test_ops.py b/test/test_ops.py index 134fc367e2..e840b2e8c7 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -168,6 +168,10 @@ class TestOps(unittest.TestCase): def test_detach(self): helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), forward_only=True) + def test_expand(self): + arg = (4,3,2,6) + helper_test_op([(4,3,1,6)], lambda x: x.expand(arg), lambda x: x.expand(shape=arg)) + def test_simple_conv2d(self): helper_test_op([(1,1,9,9), (1,1,3,3)], lambda x,w: torch.nn.functional.conv2d(x,w).relu(), diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 6537b15774..194d7c9558 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -6,27 +6,6 @@ def prod(x): return int(np.prod(x)) def reduce_shape(shape, axis): return [1 if i in axis else shape[i] for i in range(len(shape))] -def binary_broadcast(x_shape, y_shape, extra=False): - n_dims = max(len(x_shape), len(y_shape)) - shape_x, shape_y = np.ones(n_dims, dtype=np.int32), np.ones(n_dims, dtype=np.int32) - shape_x[:len(x_shape)] = np.array(x_shape, dtype=np.int32) - shape_y[:len(y_shape)] = np.array(y_shape, dtype=np.int32) - if not np.all((shape_x == 1) | (shape_y == 1) | (shape_x == shape_y)): - raise Exception(f"binary op unbroadcastable shape mismatch: {x_shape} vs {y_shape}") - shape_ret = tuple([int(x) for x in np.maximum(shape_x, shape_y)]) - - if extra: - dimlist, complist = [], [] # note: len(dimlist) may be less than n_dims - def push(dim, comp): - if len(complist) > 0 and complist[-1] == comp: - dimlist[-1] *= dim - elif comp != (False, False): - dimlist.append(dim); complist.append(comp) - for i in range(n_dims): # group together any adjacent dimensions that we can to simplify broadcasting - push(np.int32(max(shape_x[i], shape_y[i])), (shape_x[i] > 1, shape_y[i] > 1)) - - return (shape_ret, dimlist, complist) if extra else shape_ret - def get_conv_args(x_shape, w_shape, stride, groups): # TODO: https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html#tensor-layout conv_args = namedtuple('conv_args', diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index a7bb2eafc2..7c20b07836 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -11,6 +11,7 @@ class CPUBuffer(np.ndarray): def amax(x, *args, **kwargs): return np.amax(x, *args, **kwargs) def permute(x, order): return x.transpose(order) def custompad(x, padding): return np.pad(x, padding) + def expand(x, new_shape): return np.broadcast_to(x, new_shape) @staticmethod def fromCPU(x): return x @@ -30,7 +31,6 @@ def binary_op(op, x, y, ret): elif op == BinaryOps.MUL: ret[:] = x*y elif op == BinaryOps.DIV: ret[:] = y/x elif op == BinaryOps.POW: ret[:] = x**y - elif op == BinaryOps.A: ret[:] = x elif op == BinaryOps.CMPEQ: ret[:] = 1.0*(x==y) else: raise Exception(f"{op} isn't supported") @@ -48,13 +48,14 @@ def reduce_op(op, inp, ret): else: raise Exception(f"{op} isn't supported") def movement_op(op, x, ret, arg=None): - if op == MovementOps.RESHAPE: ret[:] = x.reshape(ret.shape) + if op == MovementOps.RESHAPE: ret[:] = x.reshape(arg) elif op == MovementOps.PERMUTE: ret[:] = x.permute(arg) elif op == MovementOps.SLICE: padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)] x = x.custompad(padding) slicee = [(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)] ret[:] = x[tuple([slice(x[0], x[1], None) for x in slicee])] + elif op == MovementOps.EXPAND: ret[:] = x.expand(arg) else: raise Exception(f"{op} isn't supported") def get_tx(x, C): diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 2f29b613f0..d70cff3088 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -1,7 +1,7 @@ import functools import numpy as np import pyopencl as cl -from tinygrad.helpers import prod, binary_broadcast, get_conv_args +from tinygrad.helpers import prod from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps cl_ctx, cl_queue = None, None @@ -66,42 +66,24 @@ def unary_op(op, x, ret): unop([roundup(prod(ret.shape))//4], None, x.cl, ret.cl) return ret -@functools.lru_cache -def get_binop_prg(code, complist): - ndims = len(complist) - args = "".join([f", int d{i}" for i in range(ndims)] + [f", int p{i}" for i in range(ndims-1)]) - compute_idx_rets = "".join([f"\n int idx_ret{i} = (gid0 / {f'p{i}' if i < ndims-1 else '1'}) % d{i};" for i in range(ndims)]) - - idx_exprs = ["0", "0"] # [idx_x, idx_y] - for i in range(ndims): - for j in range(2): - if complist[i][j]: - idx_exprs[j] = "idx_ret%d + d%d*(%s)" % (i, i, idx_exprs[j]) - - dtype = ["float", "float", "float"] - prg = """__kernel void binop(__global const """+dtype[0]+""" *x_g, __global const """+dtype[1]+""" *y_g, __global """+dtype[2]+""" *res_g"""+args+""") { - int gid0 = get_global_id(0);"""+compute_idx_rets+""" - """+dtype[0]+""" a = x_g["""+idx_exprs[0]+"""]; - """+dtype[1]+""" b = y_g["""+idx_exprs[1]+"""]; - res_g[gid0] = """+code+""";\n}""" - return cl.Program(cl_ctx, prg).build(), dtype[2] == "float4" - def binary_op(op, x, y, ret): if op == BinaryOps.ADD: code = "a+b" elif op == BinaryOps.SUB: code = "a-b" elif op == BinaryOps.MUL: code = "a*b" elif op == BinaryOps.DIV: code = "b/a" elif op == BinaryOps.POW: code = "pow(a,b)" - elif op == BinaryOps.A: code = "a" - elif op == BinaryOps.CMPEQ: code = "1.0f*(a==b)" + elif op == BinaryOps.CMPEQ: code = "(float4)(1.0f*(a.x==b.x), 1.0f*(a.y==b.y), 1.0f*(a.z==b.z), 1.0f*(a.w==b.w))" else: raise Exception(f"{op} isn't supported") - - shape_ret, dimlist, complist = binary_broadcast(x.shape, y.shape, True) - assert tuple(shape_ret) == tuple(ret.shape) - prod_list = np.array(dimlist, dtype=i32)[-1::-1].cumprod(dtype=i32)[-1::-1] # take cumprod from back to front - prg, is_float4 = get_binop_prg(code, tuple(complist)) - kernel_size = ((roundup(prod_list[0])//4) if is_float4 else prod_list[0]) if len(dimlist) > 0 else 1 - prg.binop(cl_queue, [kernel_size], None, x.cl, y.cl, ret.cl, *dimlist, *(prod_list[1:])) + assert x.shape == ret.shape and y.shape == ret.shape + binop = clbuild("binop", """ + __kernel void binop(__global const float4 *a_g, __global const float4 *b_g, __global float4 *res_g) { + int gid = get_global_id(0); + float4 a = a_g[gid]; + float4 b = b_g[gid]; + res_g[gid] = """+code+"""; + }""") + binop([roundup(prod(ret.shape))//4], None, x.cl, y.cl, ret.cl) + return ret def reduce_op(op, inp, ret): if op == ReduceOps.SUM: @@ -190,10 +172,39 @@ def inner_slice(x, arg, ret): buffer_np(np.array(ret.shape, dtype=np.int32)), buffer_np(np.array(shift, dtype=np.int32))) +def expand(x, ret): + assert len(x.shape) == len(ret.shape) + + dimlist, complist = [], [] # note: len(dimlist) may be less than n_dims + def push(dim, comp): + if len(complist) > 0 and complist[-1] == comp: + dimlist[-1] *= dim + elif comp != (False, False): + dimlist.append(dim); complist.append(comp) + for i,j in zip(x.shape, ret.shape): # group together any adjacent dimensions that we can to simplify broadcasting + push(np.int32(max(i,j)), (i > 1, j > 1)) + prod_list = np.array(dimlist, dtype=i32)[-1::-1].cumprod(dtype=i32)[-1::-1] # take cumprod from back to front + + ndims = len(complist) + args = "".join([f", int d{i}" for i in range(ndims)] + [f", int p{i}" for i in range(ndims-1)]) + compute_idx_rets = "".join([f"\n int idx_ret{i} = (gid0 / {f'p{i}' if i < ndims-1 else '1'}) % d{i};" for i in range(ndims)]) + + idx_exprs = ["0", "0"] # [idx_x, idx_y] + for i in range(ndims): + for j in range(2): + if complist[i][j]: + idx_exprs[j] = "idx_ret%d + d%d*(%s)" % (i, i, idx_exprs[j]) + + expandop = clbuild("expandop", """__kernel void expandop(__global const float *x_g, __global float *res_g"""+args+""") { + int gid0 = get_global_id(0);"""+compute_idx_rets+""" + res_g[gid0] = x_g["""+idx_exprs[0]+"""];\n}""") + expandop([prod_list[0] if len(dimlist) > 0 else 1], None, x.cl, ret.cl, *dimlist, *(prod_list[1:])) + def movement_op(op, x, ret, arg=None): if op == MovementOps.RESHAPE: reshape(x, ret) elif op == MovementOps.PERMUTE: perm_axis(x, arg, ret) elif op == MovementOps.SLICE: inner_slice(x, arg, ret) + elif op == MovementOps.EXPAND: expand(x, ret) def conv(x,w,ret,C): # input = (bs, groups, cin, iy, ix) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index fd267d048e..17993c8d0f 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -44,9 +44,7 @@ class Sum(Function): def backward(ctx, grad_output): shape_input, = ctx.saved_tensors - # NOTE: the b Buffer isn't used, since this is just for broadcast - ret = ctx.buffer(shape_input) - return ctx.binary_op(BinaryOps.A, grad_output, ret) + return ctx.movement_op(MovementOps.EXPAND, grad_output, shape_input) class Max(Function): def forward(ctx, input, axis=None): @@ -56,36 +54,35 @@ class Max(Function): def backward(ctx, grad_output): input, ret = ctx.saved_tensors - ret2 = ctx.binary_op(BinaryOps.CMPEQ, input, ret) - div = ctx.reduce_op(ReduceOps.SUM, ret2, grad_output.shape) - ret2 = ctx.binary_op(BinaryOps.DIV, div, ret2) - return ctx.binary_op(BinaryOps.MUL, ret2, grad_output) + + # 1s in locations where the max was chosen (can be two locations) + max_is_1s = ctx.binary_op(BinaryOps.CMPEQ, input, ctx.movement_op(MovementOps.EXPAND, ret, input.shape)) + + # sum of locations, averaged + div = ctx.reduce_op(ReduceOps.SUM, max_is_1s, grad_output.shape) + div = ctx.movement_op(MovementOps.EXPAND, div, input.shape) + max_is_amount = ctx.binary_op(BinaryOps.DIV, div, max_is_1s) + + grad_output_expanded = ctx.movement_op(MovementOps.EXPAND, grad_output, input.shape) + return ctx.binary_op(BinaryOps.MUL, max_is_amount, grad_output_expanded) # ************* binary ops ************* -def unbroadcast(ctx, out, in_sh): - return ctx.reduce_op(ReduceOps.SUM, out, in_sh) - class Add(Function): def forward(ctx, x, y): - ctx.save_for_backward(x.shape, y.shape) return ctx.binary_op(BinaryOps.ADD, x, y) def backward(ctx, grad_output): - shape_x, shape_y = ctx.saved_tensors - return unbroadcast(ctx, grad_output, shape_x) if ctx.needs_input_grad[0] else None, \ - unbroadcast(ctx, grad_output, shape_y) if ctx.needs_input_grad[1] else None + return grad_output if ctx.needs_input_grad[0] else None, \ + grad_output if ctx.needs_input_grad[1] else None class Sub(Function): def forward(ctx, x, y): - ctx.save_for_backward(x.shape, y.shape) return ctx.binary_op(BinaryOps.SUB, x, y) def backward(ctx, grad_output): - shape_x, shape_y = ctx.saved_tensors - neg_grad_output = ctx.unary_op(UnaryOps.NEG, grad_output) - return unbroadcast(ctx, grad_output, shape_x) if ctx.needs_input_grad[0] else None, \ - unbroadcast(ctx, neg_grad_output, shape_y) if ctx.needs_input_grad[1] else None + return grad_output if ctx.needs_input_grad[0] else None, \ + ctx.unary_op(UnaryOps.NEG, grad_output) if ctx.needs_input_grad[1] else None class Mul(Function): def forward(ctx, x, y): @@ -94,8 +91,8 @@ class Mul(Function): def backward(ctx, grad_output): x,y = ctx.saved_tensors - grad_x = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, y, grad_output), x.shape) if ctx.needs_input_grad[0] else None - grad_y = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, x, grad_output), y.shape) if ctx.needs_input_grad[1] else None + grad_x = ctx.binary_op(BinaryOps.MUL, y, grad_output) if ctx.needs_input_grad[0] else None + grad_y = ctx.binary_op(BinaryOps.MUL, x, grad_output) if ctx.needs_input_grad[1] else None return grad_x, grad_y class Pow(Function): @@ -106,15 +103,28 @@ class Pow(Function): def backward(ctx, grad_output): x,y,powxy = ctx.saved_tensors - tmp = ctx.binary_op(BinaryOps.DIV, x, powxy) # pow(x,y)/x - tmp = ctx.binary_op(BinaryOps.MUL, y, tmp) # y * pow(x,y)/x - grad_x = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, grad_output, tmp), x.shape) if ctx.needs_input_grad[0] else None - tmp = ctx.binary_op(BinaryOps.MUL, ctx.unary_op(UnaryOps.LOG, x), powxy) # log(x) * pow(x,y) - grad_y = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, grad_output, tmp), y.shape) if ctx.needs_input_grad[1] else None + grad_x, grad_y = None, None + if ctx.needs_input_grad[0]: + tmp = ctx.binary_op(BinaryOps.DIV, x, powxy) # pow(x,y)/x + tmp = ctx.binary_op(BinaryOps.MUL, y, tmp) # y * pow(x,y)/x + grad_x = ctx.binary_op(BinaryOps.MUL, grad_output, tmp) + if ctx.needs_input_grad[1]: + tmp = ctx.binary_op(BinaryOps.MUL, ctx.unary_op(UnaryOps.LOG, x), powxy) # log(x) * pow(x,y) + grad_y = ctx.binary_op(BinaryOps.MUL, grad_output, tmp) return grad_x, grad_y # ************* movement ops ************* +# NOTE: this is sum in reverse +class Expand(Function): + def forward(ctx, x, shape): + ctx.save_for_backward(x.shape) + return ctx.movement_op(MovementOps.EXPAND, x, shape) + + def backward(ctx, grad_output): + in_shape, = ctx.saved_tensors + return ctx.reduce_op(ReduceOps.SUM, grad_output, in_shape) + class Reshape(Function): def forward(ctx, x, shape): ctx.save_for_backward(x.shape) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 1312ed4e77..79a3cee697 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,9 +1,9 @@ # TODO: move Device to here and proxy buffer call from enum import Enum UnaryOps = Enum("UnaryOps", ["RELU", "EXP", "LOG", "NEG", "SIGN"]) -BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "A", "CMPEQ"]) +BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"]) ReduceOps = Enum("ReduceOps", ["SUM", "MAX"]) -MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE"]) +MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND"]) ProcessingOps = Enum("ProcessingOps", ["CONV", "CONVT", "CONVDW"]) import os @@ -45,7 +45,6 @@ def log_op(op, ret, inp): G.nodes[nm(ret)]['fillcolor'] = top_colors[top] G.nodes[nm(ret)]['style'] = 'filled' -from tinygrad.helpers import binary_broadcast class Ops: def unary_op(ctx, op:UnaryOps, x): ret = ctx.buffer(x.shape) @@ -60,13 +59,14 @@ class Ops: return ret def binary_op(ctx, op:ReduceOps, x, y): - ret = ctx.buffer(binary_broadcast(x.shape, y.shape)) + assert x.shape == y.shape + ret = ctx.buffer(x.shape) ctx.op.binary_op(op, x, y, ret) - log_op(op, ret, [x] if op == BinaryOps.A else [x, y]) + log_op(op, ret, [x, y]) return ret def movement_op(ctx, op:MovementOps, x, arg=None): - if op == MovementOps.RESHAPE: new_shape = arg + if op in [MovementOps.RESHAPE, MovementOps.EXPAND]: new_shape = arg if op == MovementOps.PERMUTE: new_shape = [x.shape[i] for i in arg] if op == MovementOps.SLICE: new_shape = [y-x for x,y in arg] ret = ctx.buffer(new_shape) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 21094c9e69..e1dd3dc51a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -349,10 +349,34 @@ class Tensor: ret = x._conv2d(weight, stride=stride, groups=groups) return ret if bias is None else ret.add(bias.reshape(shape=[1, -1, 1, 1])) + # ***** broadcasted binary ops ***** + + @staticmethod + def broadcasted(fxn, x, y): + tt = [arg for arg in [x,y] if isinstance(arg, Tensor)][0] # this is the prototype tensor + if not isinstance(x, Tensor): x = Tensor(np.array([x], dtype=tt.dtype), device=tt.device, requires_grad=False) + if not isinstance(y, Tensor): y = Tensor(np.array([y], dtype=tt.dtype), device=tt.device, requires_grad=False) + + n_dims = max(len(x.shape), len(y.shape)) + if len(x.shape) != n_dims: x = x.reshape(list(x.shape) + [1]*(n_dims-len(x.shape))) + if len(y.shape) != n_dims: y = y.reshape(list(y.shape) + [1]*(n_dims-len(y.shape))) + + shape_ret = tuple([int(x) for x in np.maximum(x.shape, y.shape)]) + if x.shape != shape_ret: x = x.expand(shape_ret) + if y.shape != shape_ret: y = y.expand(shape_ret) + return fxn(x, y) + + # TODO: are these the only ones that can take number arguments? + def add(self, x): return Tensor.broadcasted(Tensor._add, self, x) + def sub(self, x): return Tensor.broadcasted(Tensor._sub, self, x) + def mul(self, x): return Tensor.broadcasted(Tensor._mul, self, x) + def pow(self, x): return Tensor.broadcasted(Tensor._pow, self, x) + # ***** functional nn ops ***** - def reshape(self, shape): - return self._reshape(shape=shape) + # TODO: fix the kwargs problem + def reshape(self, shape): return self._reshape(shape=shape) + def expand(self, shape): return self._expand(shape=shape) def linear(self, weight, bias): shp = [1] * (len(self.shape)-1) + [-1] @@ -391,13 +415,8 @@ class Function(Ops): @classmethod def apply(cls, *x, **kwargs): - tt = [arg for arg in x if isinstance(arg, Tensor)][0] # this is the prototype tensor - - # create tensors from number arguments - x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x] - assert all([tt.device == t.device for t in x]), "All tensors are not on the same device" - - ctx = cls(tt.device, *x) + assert all([isinstance(arg, Tensor) for arg in x]) + ctx = cls(x[0].device, *x) with ProfileOp(ctx, ctx.__class__.__name__, x) as po: ret = Tensor(cls.forward(ctx, *[t.data for t in x], **kwargs), device=ctx.device, requires_grad=ctx.requires_grad) @@ -419,6 +438,7 @@ for name, cls in inspect.getmembers(importlib.import_module('tinygrad.mlops'), i if name[0] != "_" and name != "Function" and not name.endswith("Ops"): register(name.lower(), cls) # register the operators +# TODO: add div def register_op(name, fxn): setattr(Tensor, f"__{name}__", fxn) setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(fxn(self,x)))