risk -> cherry

This commit is contained in:
George Hotz
2021-06-16 09:59:48 -07:00
parent 2f91c012eb
commit ff3fdc58e5
5 changed files with 59 additions and 55 deletions

View File

@@ -156,11 +156,11 @@ python3 -m pytest
### TODO (updated)
```bash
PYTHONPATH="." DEBUG=1 RISK=1 python3 examples/efficientnet.py https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg
PYTHONPATH="." DEBUG=1 CHERRY=1 python3 examples/efficientnet.py https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg
```
* Add reduce ops to RISK, and fully support forward pass. See `extra/ops_risk.py` and `extra/risk.py`
* Switch convolution backward pass to RISK instead of the numpy placeholder
* Confirm EfficientNet backward pass fully uses RISK instructions
* Add reduce ops to CHERRY, and fully support forward pass. See `extra/ops_risk.py` and `extra/risk.py`
* Switch convolution backward pass to CHERRY instead of the numpy placeholder
* Confirm EfficientNet backward pass fully uses CHERRY instructions
* Benchmark that and transformers

View File

@@ -86,7 +86,7 @@ def count(func):
import atexit
@atexit.register
def risk_print_counts():
def cherry_print_counts():
print(cnts)
print(tcnts)
print(utils)
@@ -95,12 +95,12 @@ def risk_print_counts():
print("%.2f GOPS %d maxdma" % ((tcnts['riski_matmul']*SZ*SZ*SZ*2 + tcnts['riski_mulacc']*SZ*SZ*2)*1e-9, maxdma))
print("ran in %.2f us with util %.2f%% total %.2f us" % (sum(cnts.values())*1e-3, util_n*100/(util_d+1), sum(tcnts.values())*1e-3))
def risk_reset_counts():
def cherry_reset_counts():
global cnts, utils
cnts = defaultdict(int)
utils = defaultdict(int)
def risk_regdump():
def cherry_regdump():
print("\n***** regdump *****")
print(regfile[Reg.MATMUL_INPUT])
print(regfile[Reg.MATMUL_WEIGHTS])
@@ -192,8 +192,10 @@ def riski_store(target, address, stride_y=SZ, stride_x=1, len_y=SZ, len_x=SZ):
sram[address + y*stride_y + x*stride_x] = d[y, x]
"""
# *** DMA engine ***
@count
def riski_dmar(address, arr):
def cherry_dmar(address, arr):
global maxdma
arr = arr.reshape(-1)
assert(arr.shape[0] <= SLOTSIZE)
@@ -202,22 +204,22 @@ def riski_dmar(address, arr):
sram[address:address+arr.shape[0]] = arr
@count
def riski_dmaw(address, shp):
def cherry_dmaw(address, shp):
print("DMAW %d elements" % np.prod(shp))
return np.copy(sram[address:address+np.prod(shp)].reshape(shp))
# *** RISK-5 code to be compiled ***
# *** CHERRY code to be compiled ***
def risk_unop(x, op):
riski_dmar(SLOT(0), x)
def cherry_unop(x, op):
cherry_dmar(SLOT(0), x)
cnt = np.prod(x.shape)
for i in range(0, np.prod(x.shape), SZ*SZ):
riski_load(Reg.MATMUL_INPUT, SLOT(0)+i)
riski_unop(op)
riski_store(Reg.MATMUL_OUTPUT, SLOT(2)+i)
return riski_dmaw(SLOT(2), x.shape)
return cherry_dmaw(SLOT(2), x.shape)
def risk_binop(x, y, op):
def cherry_binop(x, y, op):
n_dims = max(len(x.shape), len(y.shape))
shape_x, shape_y = np.ones(n_dims, dtype=np.int32), np.ones(n_dims, dtype=np.int32)
shape_x[:len(x.shape)] = np.array(x.shape, dtype=np.int32)
@@ -238,8 +240,8 @@ def risk_binop(x, y, op):
print(dimlist, complist)
riski_dmar(SLOT(0), x)
riski_dmar(SLOT(1), y)
cherry_dmar(SLOT(0), x)
cherry_dmar(SLOT(1), y)
if len(dimlist) <= 1:
if len(complist) == 0:
complist = [(True, True)]
@@ -292,15 +294,15 @@ def risk_binop(x, y, op):
stride_y=dimlist[-1], stride_x=1,
len_y=min(SZ, dimlist[-2]-j), len_x=min(SZ, dimlist[-1]-k))
return riski_dmaw(SLOT(2), shape_ret)
return cherry_dmaw(SLOT(2), shape_ret)
def risk_matmul(x, w, transpose_x=False, transpose_w=False):
def cherry_matmul(x, w, transpose_x=False, transpose_w=False):
# copy matrices into SRAM
# x is M x K
# w is K x N
# out is M x N
riski_dmar(SLOT(0), x)
riski_dmar(SLOT(1), w)
cherry_dmar(SLOT(0), x)
cherry_dmar(SLOT(1), w)
if transpose_x:
K,M = x.shape[-2], x.shape[-1]
@@ -332,42 +334,42 @@ def risk_matmul(x, w, transpose_x=False, transpose_w=False):
riski_store(Reg.MATMUL_OUTPUT, SLOT(2)+c*M*N + m*N+n, N, 1, min(SZ, M-m), min(SZ, N-n))
# copy back from SRAM
return riski_dmaw(SLOT(2), (*x.shape[0:-2],M,N))
return cherry_dmaw(SLOT(2), (*x.shape[0:-2],M,N))
import unittest
class TestRisk(unittest.TestCase):
def test_matmul_even(self):
x = np.random.uniform(size=(SZ*8, SZ*8)).astype(np.float32)
w = np.random.uniform(size=(SZ*8, SZ*8)).astype(np.float32)
np.testing.assert_allclose(x @ w, risk_matmul(x, w), rtol=1e-5)
np.testing.assert_allclose(x @ w, cherry_matmul(x, w), rtol=1e-5)
def test_matmul_small(self):
x = np.array([[1,2,3],[4,5,6],[7,8,9]])
w = np.array([[-1,-2,-3],[-4,-5,-6],[-7,-8,-9]])
np.testing.assert_allclose(x @ w, risk_matmul(x, w), rtol=1e-5)
np.testing.assert_allclose(x @ w, cherry_matmul(x, w), rtol=1e-5)
def test_matmul_uneven(self):
x = np.random.uniform(size=(47, 79)).astype(np.float32)
w = np.random.uniform(size=(79, 42)).astype(np.float32)
np.testing.assert_allclose(x @ w, risk_matmul(x, w), rtol=1e-5)
np.testing.assert_allclose(x @ w, cherry_matmul(x, w), rtol=1e-5)
def test_matmul_transpose(self):
x = np.random.uniform(size=(33, 33)).astype(np.float32)
w = np.random.uniform(size=(33, 33)).astype(np.float32)
np.testing.assert_allclose(x @ w, risk_matmul(x, w), rtol=1e-5)
np.testing.assert_allclose(x.T @ w, risk_matmul(x, w, True), rtol=1e-5)
np.testing.assert_allclose(x @ w.T, risk_matmul(x, w, False, True), rtol=1e-5)
np.testing.assert_allclose(x.T @ w.T, risk_matmul(x, w, True, True), rtol=1e-5)
np.testing.assert_allclose(x @ w, cherry_matmul(x, w), rtol=1e-5)
np.testing.assert_allclose(x.T @ w, cherry_matmul(x, w, True), rtol=1e-5)
np.testing.assert_allclose(x @ w.T, cherry_matmul(x, w, False, True), rtol=1e-5)
np.testing.assert_allclose(x.T @ w.T, cherry_matmul(x, w, True, True), rtol=1e-5)
def test_matmul_transpose_uneven_w(self):
x = np.random.uniform(size=(47, 79)).astype(np.float32)
w = np.random.uniform(size=(42, 79)).astype(np.float32)
np.testing.assert_allclose(x @ w.T, risk_matmul(x, w, transpose_w=True), rtol=1e-5)
np.testing.assert_allclose(x @ w.T, cherry_matmul(x, w, transpose_w=True), rtol=1e-5)
def test_matmul_transpose_uneven_x(self):
x = np.random.uniform(size=(79, 47)).astype(np.float32)
w = np.random.uniform(size=(79, 42)).astype(np.float32)
np.testing.assert_allclose(x.T @ w, risk_matmul(x, w, transpose_x=True), rtol=1e-5)
np.testing.assert_allclose(x.T @ w, cherry_matmul(x, w, transpose_x=True), rtol=1e-5)
if __name__ == "__main__":
np.random.seed(1337)

View File

@@ -1,36 +1,36 @@
import numpy as np
from tinygrad.tensor import Function
from extra.risk import *
from extra.cherry import *
# ************* unary ops *************
class ReLU(Function):
def forward(ctx, input):
ctx.save_for_backward(input)
return risk_unop(input, UnaryOps.RELU)
return cherry_unop(input, UnaryOps.RELU)
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return risk_binop(grad_output, risk_unop(input, UnaryOps.GT0), BinaryOps.MUL)
return cherry_binop(grad_output, cherry_unop(input, UnaryOps.GT0), BinaryOps.MUL)
class Log(Function):
def forward(ctx, input):
ctx.save_for_backward(input)
return risk_unop(input, UnaryOps.LOG)
return cherry_unop(input, UnaryOps.LOG)
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return risk_binop(grad_output, input, BinaryOps.DIV)
return cherry_binop(grad_output, input, BinaryOps.DIV)
class Exp(Function):
def forward(ctx, input):
ret = risk_unop(input, UnaryOps.EXP)
ret = cherry_unop(input, UnaryOps.EXP)
ctx.save_for_backward(ret)
return ret
def backward(ctx, grad_output):
ret, = ctx.saved_tensors
return risk_binop(grad_output, ret, BinaryOps.MUL)
return cherry_binop(grad_output, ret, BinaryOps.MUL)
# ************* binary ops *************
@@ -42,7 +42,7 @@ def unbroadcast(out, in_sh):
class Add(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x.shape, y.shape)
return risk_binop(x, y, BinaryOps.ADD)
return cherry_binop(x, y, BinaryOps.ADD)
def backward(ctx, grad_output):
shape_x, shape_y = ctx.saved_tensors
@@ -51,7 +51,7 @@ class Add(Function):
class Sub(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x.shape, y.shape)
return risk_binop(x, y, BinaryOps.SUB)
return cherry_binop(x, y, BinaryOps.SUB)
def backward(ctx, grad_output):
shape_x, shape_y = ctx.saved_tensors
@@ -60,7 +60,7 @@ class Sub(Function):
class Mul(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return risk_binop(x, y, BinaryOps.MUL)
return cherry_binop(x, y, BinaryOps.MUL)
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
@@ -69,7 +69,7 @@ class Mul(Function):
class Pow(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return risk_binop(x, y, BinaryOps.POW)
return cherry_binop(x, y, BinaryOps.POW)
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
@@ -81,12 +81,12 @@ class Pow(Function):
class Matmul(Function):
def forward(ctx, input, weight):
ctx.save_for_backward(input, weight)
return risk_matmul(input, weight)
return cherry_matmul(input, weight)
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input = risk_matmul(grad_output, weight, transpose_w=True)
grad_weight = risk_matmul(input, grad_output, transpose_x=True)
grad_input = cherry_matmul(grad_output, weight, transpose_w=True)
grad_weight = cherry_matmul(input, grad_output, transpose_x=True)
return grad_input, grad_weight
class Conv2D(Function):
@@ -125,10 +125,10 @@ class Conv2D(Function):
return np.moveaxis(ret,4,2).reshape(bs, cout, oy, ox)
"""
riski_dmar(SLOT(0), x) # bs, groups, cin, x.shape[2], x.shape[3]
riski_dmar(SLOT(1), w) # groups, rcout, cin, H, W
cherry_dmar(SLOT(0), x) # bs, groups, cin, x.shape[2], x.shape[3]
cherry_dmar(SLOT(1), w) # groups, rcout, cin, H, W
risk_reset_counts()
cherry_reset_counts()
print(bs, ctx.groups, rcout, oy, ox, cin, H, W)
for B in range(0, bs):
@@ -217,10 +217,10 @@ class Conv2D(Function):
riski_store(Reg.MATMUL_OUTPUT,
SLOT(2) + B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X,
1, oy*ox, min(SZ, ox-X), min(SZ, rcout-c))
risk_print_counts()
cherry_print_counts()
#print(x.shape, w.shape, "->", ret.shape)
return riski_dmaw(SLOT(2), (bs, cout, oy, ox))
return cherry_dmaw(SLOT(2), (bs, cout, oy, ox))
def backward(ctx, grad_output):
bs,_,oy,ox = grad_output.shape

View File

@@ -1,15 +1,17 @@
#!/usr/bin/env python3
import sys
import time
import pyftdi.serialext
port = pyftdi.serialext.serial_for_url('ftdi://ftdi:2232h/2', baudrate=115200)
#port = pyftdi.serialext.serial_for_url('ftdi://ftdi:2232h/2', baudrate=115200)
port = pyftdi.serialext.serial_for_url('ftdi://ftdi:2232h/2', baudrate=1000000)
print(port)
while 1:
#port.write(b'a')
data = port.read(1)
print(data)
time.sleep(0.01)
sys.stdout.write(data.decode('utf-8'))
#time.sleep(0.01)

View File

@@ -355,9 +355,9 @@ def _register_ops(namespace, device=Device.CPU):
from tinygrad import ops_cpu
_register_ops(ops_cpu)
if os.getenv("RISK", None) is not None:
from extra import ops_risk
_register_ops(ops_risk)
if os.getenv("CHERRY", None) is not None:
from extra import ops_cherry
_register_ops(ops_cherry)
try:
import pyopencl as cl
# TODO: move this import to require_init_gpu?