max op works

This commit is contained in:
George Hotz
2021-06-17 17:01:21 -07:00
parent 9e12c1bbba
commit 2e71ae33f6
3 changed files with 25 additions and 21 deletions

View File

@@ -160,7 +160,6 @@ PYTHONPATH="." DEBUG=1 CHERRY=1 python3 examples/efficientnet.py https://upload.
```
* ~~Add reduce ops to CHERRY, and fully support forward pass. See `extra/ops_risk.py` and `extra/risk.py`~~
* Fix max op
* Switch convolution backward pass to CHERRY instead of the numpy placeholder
* Confirm EfficientNet backward pass fully uses CHERRY instructions
* Benchmark that and transformers

View File

@@ -148,12 +148,12 @@ def riski_pow():
regfile[Reg.MATMUL_OUTPUT] = regfile[Reg.MATMUL_INPUT] ** regfile[Reg.MATMUL_WEIGHTS]
@count
def riski_reduce_sum(out=0, cnt=SZ):
regfile[Reg.MATMUL_OUTPUT][out] = regfile[Reg.MATMUL_INPUT][0:cnt].sum(axis=0)
def riski_reduce_sum(cnt=SZ):
regfile[Reg.MATMUL_OUTPUT][0] = regfile[Reg.MATMUL_INPUT][0:cnt].sum(axis=0)
@count
def riski_reduce_max(out=0, cnt=SZ):
regfile[Reg.MATMUL_OUTPUT][out] = regfile[Reg.MATMUL_INPUT][0:cnt].max(axis=0)
def riski_reduce_max(cnt=SZ):
regfile[Reg.MATMUL_OUTPUT][0] = regfile[Reg.MATMUL_INPUT][0:cnt].max(axis=0)
# TODO: make accumulate a bit in the instruction available to all
binops = {BinaryOps.ADD: riski_add,
@@ -232,7 +232,7 @@ def cherry_dmaw(address, shp):
# *** CHERRY code to be compiled ***
def cherry_reduceop(inp, op, axis):
def cherry_reduceop(inp, op, axis, keepdims=False):
dimlist, redlist = [], []
if type(axis) == int:
axis = [axis]
@@ -265,11 +265,13 @@ def cherry_reduceop(inp, op, axis):
else:
dimlist.append(inp.shape[i])
redlist.append(is_reduce_axis)
nosize = []
for i in range(osize.shape[0]):
if i not in axis:
nosize.append(osize[i])
osize = nosize
if not keepdims:
nosize = []
for i in range(osize.shape[0]):
if i not in axis:
nosize.append(osize[i])
osize = nosize
osize = tuple(osize)
print("reduce", op, inp.shape, axis, "->", osize, dimlist, redlist)
@@ -280,20 +282,21 @@ def cherry_reduceop(inp, op, axis):
# redlist is always [False, True, False, True, ...., True, False]
# special case if redlist ends with True
if redlist[-1] == True:
if len(redlist) > 0 and redlist[-1] == True:
print("special case redlist[-1] == True")
outside = int(np.prod(dimlist[:-1]))
for l in range(0, outside, SZ):
reduce_size = min(SZ, outside-l)
j = 0
while j < dimlist[-1]:
len_y = min(SZ if j == 0 else SZ-1, dimlist[-1]-j)
riski_load(Reg.MATMUL_INPUT,
SLOT(inslot) + l*dimlist[-1] + j,
stride_y=1, stride_x=dimlist[-1],
len_y=min(SZ if j == 0 else SZ-1, dimlist[-1]-j),
len_y=len_y,
len_x=reduce_size,
zero=j==0, skip_first=j!=0)
reduceops[op]()
reduceops[op](len_y+(j!=0))
riski_mov(Reg.MATMUL_INPUT, Reg.MATMUL_OUTPUT) # move the first row
j += SZ if j == 0 else SZ-1
riski_store(Reg.MATMUL_OUTPUT, SLOT(outslot) + l, len_y=1, len_x=reduce_size)
@@ -311,14 +314,15 @@ def cherry_reduceop(inp, op, axis):
reduce_size = min(SZ, dimlist[-1]-k)
j = 0
while j < dimlist[-2]:
len_y = min(SZ if j == 0 else SZ-1, dimlist[-2]-j)
riski_load(Reg.MATMUL_INPUT,
SLOT(inslot) + l*dimlist[-2]*dimlist[-1] + j*dimlist[-1] + k,
stride_y=dimlist[-1], stride_x=1,
len_y=min(SZ if j == 0 else SZ-1, dimlist[-2]-j),
len_y=len_y,
len_x=reduce_size,
zero=j==0, skip_first=j!=0)
#cherry_regdump()
reduceops[op]()
reduceops[op](len_y+(j!=0))
riski_mov(Reg.MATMUL_INPUT, Reg.MATMUL_OUTPUT) # move the first row
j += SZ if j == 0 else SZ-1
riski_store(Reg.MATMUL_OUTPUT, SLOT(outslot) + l*dimlist[-1] + k, len_y=1, len_x=reduce_size)

View File

@@ -45,11 +45,11 @@ class Sum(Function):
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
return cherry_binop(grad_output.reshape(shape), np.zeros_like(input), BinaryOps.ADD)
"""
class Max(Function):
def forward(ctx, inp, axis=None):
if isinstance(axis, int): axis = [axis]
ret = np.amax(inp, axis=None if axis is None else tuple(axis), keepdims=True)
#ret = np.amax(inp, axis=None if axis is None else tuple(axis), keepdims=True)
ret = cherry_reduceop(inp, ReduceOps.MAX, None if axis is None else tuple(axis), keepdims=True)
ctx.save_for_backward(inp, axis, ret)
if axis is not None:
ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis])
@@ -59,9 +59,10 @@ class Max(Function):
input, axis, ret = ctx.saved_tensors
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
ret2 = (input==ret.reshape(shape))
div = ret2.sum(axis=None if axis is None else tuple(axis), keepdims=True)
return ret2*grad_output.reshape(shape)/div
"""
#div = ret2.sum(axis=None if axis is None else tuple(axis), keepdims=True)
#return ret2*grad_output.reshape(shape)/div
div = cherry_reduceop(ret2, ReduceOps.SUM, axis=None if axis is None else tuple(axis), keepdims=True)
return cherry_binop(cherry_binop(ret2, grad_output.reshape(shape), BinaryOps.MUL), div, BinaryOps.DIV)
# ************* binary ops *************