From 2e71ae33f608be625df1126d68b716bc512567a1 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 17 Jun 2021 17:01:21 -0700 Subject: [PATCH] max op works --- README.md | 1 - extra/cherry.py | 34 +++++++++++++++++++--------------- extra/ops_cherry.py | 11 ++++++----- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 98a27146e4..31ec64f2d2 100644 --- a/README.md +++ b/README.md @@ -160,7 +160,6 @@ PYTHONPATH="." DEBUG=1 CHERRY=1 python3 examples/efficientnet.py https://upload. ``` * ~~Add reduce ops to CHERRY, and fully support forward pass. See `extra/ops_risk.py` and `extra/risk.py`~~ -* Fix max op * Switch convolution backward pass to CHERRY instead of the numpy placeholder * Confirm EfficientNet backward pass fully uses CHERRY instructions * Benchmark that and transformers diff --git a/extra/cherry.py b/extra/cherry.py index c2a6a001eb..5156ce818c 100755 --- a/extra/cherry.py +++ b/extra/cherry.py @@ -148,12 +148,12 @@ def riski_pow(): regfile[Reg.MATMUL_OUTPUT] = regfile[Reg.MATMUL_INPUT] ** regfile[Reg.MATMUL_WEIGHTS] @count -def riski_reduce_sum(out=0, cnt=SZ): - regfile[Reg.MATMUL_OUTPUT][out] = regfile[Reg.MATMUL_INPUT][0:cnt].sum(axis=0) +def riski_reduce_sum(cnt=SZ): + regfile[Reg.MATMUL_OUTPUT][0] = regfile[Reg.MATMUL_INPUT][0:cnt].sum(axis=0) @count -def riski_reduce_max(out=0, cnt=SZ): - regfile[Reg.MATMUL_OUTPUT][out] = regfile[Reg.MATMUL_INPUT][0:cnt].max(axis=0) +def riski_reduce_max(cnt=SZ): + regfile[Reg.MATMUL_OUTPUT][0] = regfile[Reg.MATMUL_INPUT][0:cnt].max(axis=0) # TODO: make accumulate a bit in the instruction available to all binops = {BinaryOps.ADD: riski_add, @@ -232,7 +232,7 @@ def cherry_dmaw(address, shp): # *** CHERRY code to be compiled *** -def cherry_reduceop(inp, op, axis): +def cherry_reduceop(inp, op, axis, keepdims=False): dimlist, redlist = [], [] if type(axis) == int: axis = [axis] @@ -265,11 +265,13 @@ def cherry_reduceop(inp, op, axis): else: dimlist.append(inp.shape[i]) redlist.append(is_reduce_axis) - nosize = [] - for i in range(osize.shape[0]): - if i not in axis: - nosize.append(osize[i]) - osize = nosize + + if not keepdims: + nosize = [] + for i in range(osize.shape[0]): + if i not in axis: + nosize.append(osize[i]) + osize = nosize osize = tuple(osize) print("reduce", op, inp.shape, axis, "->", osize, dimlist, redlist) @@ -280,20 +282,21 @@ def cherry_reduceop(inp, op, axis): # redlist is always [False, True, False, True, ...., True, False] # special case if redlist ends with True - if redlist[-1] == True: + if len(redlist) > 0 and redlist[-1] == True: print("special case redlist[-1] == True") outside = int(np.prod(dimlist[:-1])) for l in range(0, outside, SZ): reduce_size = min(SZ, outside-l) j = 0 while j < dimlist[-1]: + len_y = min(SZ if j == 0 else SZ-1, dimlist[-1]-j) riski_load(Reg.MATMUL_INPUT, SLOT(inslot) + l*dimlist[-1] + j, stride_y=1, stride_x=dimlist[-1], - len_y=min(SZ if j == 0 else SZ-1, dimlist[-1]-j), + len_y=len_y, len_x=reduce_size, zero=j==0, skip_first=j!=0) - reduceops[op]() + reduceops[op](len_y+(j!=0)) riski_mov(Reg.MATMUL_INPUT, Reg.MATMUL_OUTPUT) # move the first row j += SZ if j == 0 else SZ-1 riski_store(Reg.MATMUL_OUTPUT, SLOT(outslot) + l, len_y=1, len_x=reduce_size) @@ -311,14 +314,15 @@ def cherry_reduceop(inp, op, axis): reduce_size = min(SZ, dimlist[-1]-k) j = 0 while j < dimlist[-2]: + len_y = min(SZ if j == 0 else SZ-1, dimlist[-2]-j) riski_load(Reg.MATMUL_INPUT, SLOT(inslot) + l*dimlist[-2]*dimlist[-1] + j*dimlist[-1] + k, stride_y=dimlist[-1], stride_x=1, - len_y=min(SZ if j == 0 else SZ-1, dimlist[-2]-j), + len_y=len_y, len_x=reduce_size, zero=j==0, skip_first=j!=0) #cherry_regdump() - reduceops[op]() + reduceops[op](len_y+(j!=0)) riski_mov(Reg.MATMUL_INPUT, Reg.MATMUL_OUTPUT) # move the first row j += SZ if j == 0 else SZ-1 riski_store(Reg.MATMUL_OUTPUT, SLOT(outslot) + l*dimlist[-1] + k, len_y=1, len_x=reduce_size) diff --git a/extra/ops_cherry.py b/extra/ops_cherry.py index 0a5e957198..7a06eed4be 100644 --- a/extra/ops_cherry.py +++ b/extra/ops_cherry.py @@ -45,11 +45,11 @@ class Sum(Function): shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))] return cherry_binop(grad_output.reshape(shape), np.zeros_like(input), BinaryOps.ADD) -""" class Max(Function): def forward(ctx, inp, axis=None): if isinstance(axis, int): axis = [axis] - ret = np.amax(inp, axis=None if axis is None else tuple(axis), keepdims=True) + #ret = np.amax(inp, axis=None if axis is None else tuple(axis), keepdims=True) + ret = cherry_reduceop(inp, ReduceOps.MAX, None if axis is None else tuple(axis), keepdims=True) ctx.save_for_backward(inp, axis, ret) if axis is not None: ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis]) @@ -59,9 +59,10 @@ class Max(Function): input, axis, ret = ctx.saved_tensors shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))] ret2 = (input==ret.reshape(shape)) - div = ret2.sum(axis=None if axis is None else tuple(axis), keepdims=True) - return ret2*grad_output.reshape(shape)/div -""" + #div = ret2.sum(axis=None if axis is None else tuple(axis), keepdims=True) + #return ret2*grad_output.reshape(shape)/div + div = cherry_reduceop(ret2, ReduceOps.SUM, axis=None if axis is None else tuple(axis), keepdims=True) + return cherry_binop(cherry_binop(ret2, grad_output.reshape(shape), BinaryOps.MUL), div, BinaryOps.DIV) # ************* binary ops *************