mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
max op works
This commit is contained in:
@@ -160,7 +160,6 @@ PYTHONPATH="." DEBUG=1 CHERRY=1 python3 examples/efficientnet.py https://upload.
|
||||
```
|
||||
|
||||
* ~~Add reduce ops to CHERRY, and fully support forward pass. See `extra/ops_risk.py` and `extra/risk.py`~~
|
||||
* Fix max op
|
||||
* Switch convolution backward pass to CHERRY instead of the numpy placeholder
|
||||
* Confirm EfficientNet backward pass fully uses CHERRY instructions
|
||||
* Benchmark that and transformers
|
||||
|
||||
@@ -148,12 +148,12 @@ def riski_pow():
|
||||
regfile[Reg.MATMUL_OUTPUT] = regfile[Reg.MATMUL_INPUT] ** regfile[Reg.MATMUL_WEIGHTS]
|
||||
|
||||
@count
|
||||
def riski_reduce_sum(out=0, cnt=SZ):
|
||||
regfile[Reg.MATMUL_OUTPUT][out] = regfile[Reg.MATMUL_INPUT][0:cnt].sum(axis=0)
|
||||
def riski_reduce_sum(cnt=SZ):
|
||||
regfile[Reg.MATMUL_OUTPUT][0] = regfile[Reg.MATMUL_INPUT][0:cnt].sum(axis=0)
|
||||
|
||||
@count
|
||||
def riski_reduce_max(out=0, cnt=SZ):
|
||||
regfile[Reg.MATMUL_OUTPUT][out] = regfile[Reg.MATMUL_INPUT][0:cnt].max(axis=0)
|
||||
def riski_reduce_max(cnt=SZ):
|
||||
regfile[Reg.MATMUL_OUTPUT][0] = regfile[Reg.MATMUL_INPUT][0:cnt].max(axis=0)
|
||||
|
||||
# TODO: make accumulate a bit in the instruction available to all
|
||||
binops = {BinaryOps.ADD: riski_add,
|
||||
@@ -232,7 +232,7 @@ def cherry_dmaw(address, shp):
|
||||
|
||||
# *** CHERRY code to be compiled ***
|
||||
|
||||
def cherry_reduceop(inp, op, axis):
|
||||
def cherry_reduceop(inp, op, axis, keepdims=False):
|
||||
dimlist, redlist = [], []
|
||||
if type(axis) == int:
|
||||
axis = [axis]
|
||||
@@ -265,6 +265,8 @@ def cherry_reduceop(inp, op, axis):
|
||||
else:
|
||||
dimlist.append(inp.shape[i])
|
||||
redlist.append(is_reduce_axis)
|
||||
|
||||
if not keepdims:
|
||||
nosize = []
|
||||
for i in range(osize.shape[0]):
|
||||
if i not in axis:
|
||||
@@ -280,20 +282,21 @@ def cherry_reduceop(inp, op, axis):
|
||||
# redlist is always [False, True, False, True, ...., True, False]
|
||||
|
||||
# special case if redlist ends with True
|
||||
if redlist[-1] == True:
|
||||
if len(redlist) > 0 and redlist[-1] == True:
|
||||
print("special case redlist[-1] == True")
|
||||
outside = int(np.prod(dimlist[:-1]))
|
||||
for l in range(0, outside, SZ):
|
||||
reduce_size = min(SZ, outside-l)
|
||||
j = 0
|
||||
while j < dimlist[-1]:
|
||||
len_y = min(SZ if j == 0 else SZ-1, dimlist[-1]-j)
|
||||
riski_load(Reg.MATMUL_INPUT,
|
||||
SLOT(inslot) + l*dimlist[-1] + j,
|
||||
stride_y=1, stride_x=dimlist[-1],
|
||||
len_y=min(SZ if j == 0 else SZ-1, dimlist[-1]-j),
|
||||
len_y=len_y,
|
||||
len_x=reduce_size,
|
||||
zero=j==0, skip_first=j!=0)
|
||||
reduceops[op]()
|
||||
reduceops[op](len_y+(j!=0))
|
||||
riski_mov(Reg.MATMUL_INPUT, Reg.MATMUL_OUTPUT) # move the first row
|
||||
j += SZ if j == 0 else SZ-1
|
||||
riski_store(Reg.MATMUL_OUTPUT, SLOT(outslot) + l, len_y=1, len_x=reduce_size)
|
||||
@@ -311,14 +314,15 @@ def cherry_reduceop(inp, op, axis):
|
||||
reduce_size = min(SZ, dimlist[-1]-k)
|
||||
j = 0
|
||||
while j < dimlist[-2]:
|
||||
len_y = min(SZ if j == 0 else SZ-1, dimlist[-2]-j)
|
||||
riski_load(Reg.MATMUL_INPUT,
|
||||
SLOT(inslot) + l*dimlist[-2]*dimlist[-1] + j*dimlist[-1] + k,
|
||||
stride_y=dimlist[-1], stride_x=1,
|
||||
len_y=min(SZ if j == 0 else SZ-1, dimlist[-2]-j),
|
||||
len_y=len_y,
|
||||
len_x=reduce_size,
|
||||
zero=j==0, skip_first=j!=0)
|
||||
#cherry_regdump()
|
||||
reduceops[op]()
|
||||
reduceops[op](len_y+(j!=0))
|
||||
riski_mov(Reg.MATMUL_INPUT, Reg.MATMUL_OUTPUT) # move the first row
|
||||
j += SZ if j == 0 else SZ-1
|
||||
riski_store(Reg.MATMUL_OUTPUT, SLOT(outslot) + l*dimlist[-1] + k, len_y=1, len_x=reduce_size)
|
||||
|
||||
@@ -45,11 +45,11 @@ class Sum(Function):
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
return cherry_binop(grad_output.reshape(shape), np.zeros_like(input), BinaryOps.ADD)
|
||||
|
||||
"""
|
||||
class Max(Function):
|
||||
def forward(ctx, inp, axis=None):
|
||||
if isinstance(axis, int): axis = [axis]
|
||||
ret = np.amax(inp, axis=None if axis is None else tuple(axis), keepdims=True)
|
||||
#ret = np.amax(inp, axis=None if axis is None else tuple(axis), keepdims=True)
|
||||
ret = cherry_reduceop(inp, ReduceOps.MAX, None if axis is None else tuple(axis), keepdims=True)
|
||||
ctx.save_for_backward(inp, axis, ret)
|
||||
if axis is not None:
|
||||
ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis])
|
||||
@@ -59,9 +59,10 @@ class Max(Function):
|
||||
input, axis, ret = ctx.saved_tensors
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
ret2 = (input==ret.reshape(shape))
|
||||
div = ret2.sum(axis=None if axis is None else tuple(axis), keepdims=True)
|
||||
return ret2*grad_output.reshape(shape)/div
|
||||
"""
|
||||
#div = ret2.sum(axis=None if axis is None else tuple(axis), keepdims=True)
|
||||
#return ret2*grad_output.reshape(shape)/div
|
||||
div = cherry_reduceop(ret2, ReduceOps.SUM, axis=None if axis is None else tuple(axis), keepdims=True)
|
||||
return cherry_binop(cherry_binop(ret2, grad_output.reshape(shape), BinaryOps.MUL), div, BinaryOps.DIV)
|
||||
|
||||
# ************* binary ops *************
|
||||
|
||||
|
||||
Reference in New Issue
Block a user