mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
risk -> cherry
This commit is contained in:
@@ -156,11 +156,11 @@ python3 -m pytest
|
|||||||
### TODO (updated)
|
### TODO (updated)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
PYTHONPATH="." DEBUG=1 RISK=1 python3 examples/efficientnet.py https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg
|
PYTHONPATH="." DEBUG=1 CHERRY=1 python3 examples/efficientnet.py https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg
|
||||||
```
|
```
|
||||||
|
|
||||||
* Add reduce ops to RISK, and fully support forward pass. See `extra/ops_risk.py` and `extra/risk.py`
|
* Add reduce ops to CHERRY, and fully support forward pass. See `extra/ops_risk.py` and `extra/risk.py`
|
||||||
* Switch convolution backward pass to RISK instead of the numpy placeholder
|
* Switch convolution backward pass to CHERRY instead of the numpy placeholder
|
||||||
* Confirm EfficientNet backward pass fully uses RISK instructions
|
* Confirm EfficientNet backward pass fully uses CHERRY instructions
|
||||||
* Benchmark that and transformers
|
* Benchmark that and transformers
|
||||||
|
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ def count(func):
|
|||||||
|
|
||||||
import atexit
|
import atexit
|
||||||
@atexit.register
|
@atexit.register
|
||||||
def risk_print_counts():
|
def cherry_print_counts():
|
||||||
print(cnts)
|
print(cnts)
|
||||||
print(tcnts)
|
print(tcnts)
|
||||||
print(utils)
|
print(utils)
|
||||||
@@ -95,12 +95,12 @@ def risk_print_counts():
|
|||||||
print("%.2f GOPS %d maxdma" % ((tcnts['riski_matmul']*SZ*SZ*SZ*2 + tcnts['riski_mulacc']*SZ*SZ*2)*1e-9, maxdma))
|
print("%.2f GOPS %d maxdma" % ((tcnts['riski_matmul']*SZ*SZ*SZ*2 + tcnts['riski_mulacc']*SZ*SZ*2)*1e-9, maxdma))
|
||||||
print("ran in %.2f us with util %.2f%% total %.2f us" % (sum(cnts.values())*1e-3, util_n*100/(util_d+1), sum(tcnts.values())*1e-3))
|
print("ran in %.2f us with util %.2f%% total %.2f us" % (sum(cnts.values())*1e-3, util_n*100/(util_d+1), sum(tcnts.values())*1e-3))
|
||||||
|
|
||||||
def risk_reset_counts():
|
def cherry_reset_counts():
|
||||||
global cnts, utils
|
global cnts, utils
|
||||||
cnts = defaultdict(int)
|
cnts = defaultdict(int)
|
||||||
utils = defaultdict(int)
|
utils = defaultdict(int)
|
||||||
|
|
||||||
def risk_regdump():
|
def cherry_regdump():
|
||||||
print("\n***** regdump *****")
|
print("\n***** regdump *****")
|
||||||
print(regfile[Reg.MATMUL_INPUT])
|
print(regfile[Reg.MATMUL_INPUT])
|
||||||
print(regfile[Reg.MATMUL_WEIGHTS])
|
print(regfile[Reg.MATMUL_WEIGHTS])
|
||||||
@@ -192,8 +192,10 @@ def riski_store(target, address, stride_y=SZ, stride_x=1, len_y=SZ, len_x=SZ):
|
|||||||
sram[address + y*stride_y + x*stride_x] = d[y, x]
|
sram[address + y*stride_y + x*stride_x] = d[y, x]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# *** DMA engine ***
|
||||||
|
|
||||||
@count
|
@count
|
||||||
def riski_dmar(address, arr):
|
def cherry_dmar(address, arr):
|
||||||
global maxdma
|
global maxdma
|
||||||
arr = arr.reshape(-1)
|
arr = arr.reshape(-1)
|
||||||
assert(arr.shape[0] <= SLOTSIZE)
|
assert(arr.shape[0] <= SLOTSIZE)
|
||||||
@@ -202,22 +204,22 @@ def riski_dmar(address, arr):
|
|||||||
sram[address:address+arr.shape[0]] = arr
|
sram[address:address+arr.shape[0]] = arr
|
||||||
|
|
||||||
@count
|
@count
|
||||||
def riski_dmaw(address, shp):
|
def cherry_dmaw(address, shp):
|
||||||
print("DMAW %d elements" % np.prod(shp))
|
print("DMAW %d elements" % np.prod(shp))
|
||||||
return np.copy(sram[address:address+np.prod(shp)].reshape(shp))
|
return np.copy(sram[address:address+np.prod(shp)].reshape(shp))
|
||||||
|
|
||||||
# *** RISK-5 code to be compiled ***
|
# *** CHERRY code to be compiled ***
|
||||||
|
|
||||||
def risk_unop(x, op):
|
def cherry_unop(x, op):
|
||||||
riski_dmar(SLOT(0), x)
|
cherry_dmar(SLOT(0), x)
|
||||||
cnt = np.prod(x.shape)
|
cnt = np.prod(x.shape)
|
||||||
for i in range(0, np.prod(x.shape), SZ*SZ):
|
for i in range(0, np.prod(x.shape), SZ*SZ):
|
||||||
riski_load(Reg.MATMUL_INPUT, SLOT(0)+i)
|
riski_load(Reg.MATMUL_INPUT, SLOT(0)+i)
|
||||||
riski_unop(op)
|
riski_unop(op)
|
||||||
riski_store(Reg.MATMUL_OUTPUT, SLOT(2)+i)
|
riski_store(Reg.MATMUL_OUTPUT, SLOT(2)+i)
|
||||||
return riski_dmaw(SLOT(2), x.shape)
|
return cherry_dmaw(SLOT(2), x.shape)
|
||||||
|
|
||||||
def risk_binop(x, y, op):
|
def cherry_binop(x, y, op):
|
||||||
n_dims = max(len(x.shape), len(y.shape))
|
n_dims = max(len(x.shape), len(y.shape))
|
||||||
shape_x, shape_y = np.ones(n_dims, dtype=np.int32), np.ones(n_dims, dtype=np.int32)
|
shape_x, shape_y = np.ones(n_dims, dtype=np.int32), np.ones(n_dims, dtype=np.int32)
|
||||||
shape_x[:len(x.shape)] = np.array(x.shape, dtype=np.int32)
|
shape_x[:len(x.shape)] = np.array(x.shape, dtype=np.int32)
|
||||||
@@ -238,8 +240,8 @@ def risk_binop(x, y, op):
|
|||||||
|
|
||||||
print(dimlist, complist)
|
print(dimlist, complist)
|
||||||
|
|
||||||
riski_dmar(SLOT(0), x)
|
cherry_dmar(SLOT(0), x)
|
||||||
riski_dmar(SLOT(1), y)
|
cherry_dmar(SLOT(1), y)
|
||||||
if len(dimlist) <= 1:
|
if len(dimlist) <= 1:
|
||||||
if len(complist) == 0:
|
if len(complist) == 0:
|
||||||
complist = [(True, True)]
|
complist = [(True, True)]
|
||||||
@@ -292,15 +294,15 @@ def risk_binop(x, y, op):
|
|||||||
stride_y=dimlist[-1], stride_x=1,
|
stride_y=dimlist[-1], stride_x=1,
|
||||||
len_y=min(SZ, dimlist[-2]-j), len_x=min(SZ, dimlist[-1]-k))
|
len_y=min(SZ, dimlist[-2]-j), len_x=min(SZ, dimlist[-1]-k))
|
||||||
|
|
||||||
return riski_dmaw(SLOT(2), shape_ret)
|
return cherry_dmaw(SLOT(2), shape_ret)
|
||||||
|
|
||||||
def risk_matmul(x, w, transpose_x=False, transpose_w=False):
|
def cherry_matmul(x, w, transpose_x=False, transpose_w=False):
|
||||||
# copy matrices into SRAM
|
# copy matrices into SRAM
|
||||||
# x is M x K
|
# x is M x K
|
||||||
# w is K x N
|
# w is K x N
|
||||||
# out is M x N
|
# out is M x N
|
||||||
riski_dmar(SLOT(0), x)
|
cherry_dmar(SLOT(0), x)
|
||||||
riski_dmar(SLOT(1), w)
|
cherry_dmar(SLOT(1), w)
|
||||||
|
|
||||||
if transpose_x:
|
if transpose_x:
|
||||||
K,M = x.shape[-2], x.shape[-1]
|
K,M = x.shape[-2], x.shape[-1]
|
||||||
@@ -332,42 +334,42 @@ def risk_matmul(x, w, transpose_x=False, transpose_w=False):
|
|||||||
riski_store(Reg.MATMUL_OUTPUT, SLOT(2)+c*M*N + m*N+n, N, 1, min(SZ, M-m), min(SZ, N-n))
|
riski_store(Reg.MATMUL_OUTPUT, SLOT(2)+c*M*N + m*N+n, N, 1, min(SZ, M-m), min(SZ, N-n))
|
||||||
|
|
||||||
# copy back from SRAM
|
# copy back from SRAM
|
||||||
return riski_dmaw(SLOT(2), (*x.shape[0:-2],M,N))
|
return cherry_dmaw(SLOT(2), (*x.shape[0:-2],M,N))
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
class TestRisk(unittest.TestCase):
|
class TestRisk(unittest.TestCase):
|
||||||
def test_matmul_even(self):
|
def test_matmul_even(self):
|
||||||
x = np.random.uniform(size=(SZ*8, SZ*8)).astype(np.float32)
|
x = np.random.uniform(size=(SZ*8, SZ*8)).astype(np.float32)
|
||||||
w = np.random.uniform(size=(SZ*8, SZ*8)).astype(np.float32)
|
w = np.random.uniform(size=(SZ*8, SZ*8)).astype(np.float32)
|
||||||
np.testing.assert_allclose(x @ w, risk_matmul(x, w), rtol=1e-5)
|
np.testing.assert_allclose(x @ w, cherry_matmul(x, w), rtol=1e-5)
|
||||||
|
|
||||||
def test_matmul_small(self):
|
def test_matmul_small(self):
|
||||||
x = np.array([[1,2,3],[4,5,6],[7,8,9]])
|
x = np.array([[1,2,3],[4,5,6],[7,8,9]])
|
||||||
w = np.array([[-1,-2,-3],[-4,-5,-6],[-7,-8,-9]])
|
w = np.array([[-1,-2,-3],[-4,-5,-6],[-7,-8,-9]])
|
||||||
np.testing.assert_allclose(x @ w, risk_matmul(x, w), rtol=1e-5)
|
np.testing.assert_allclose(x @ w, cherry_matmul(x, w), rtol=1e-5)
|
||||||
|
|
||||||
def test_matmul_uneven(self):
|
def test_matmul_uneven(self):
|
||||||
x = np.random.uniform(size=(47, 79)).astype(np.float32)
|
x = np.random.uniform(size=(47, 79)).astype(np.float32)
|
||||||
w = np.random.uniform(size=(79, 42)).astype(np.float32)
|
w = np.random.uniform(size=(79, 42)).astype(np.float32)
|
||||||
np.testing.assert_allclose(x @ w, risk_matmul(x, w), rtol=1e-5)
|
np.testing.assert_allclose(x @ w, cherry_matmul(x, w), rtol=1e-5)
|
||||||
|
|
||||||
def test_matmul_transpose(self):
|
def test_matmul_transpose(self):
|
||||||
x = np.random.uniform(size=(33, 33)).astype(np.float32)
|
x = np.random.uniform(size=(33, 33)).astype(np.float32)
|
||||||
w = np.random.uniform(size=(33, 33)).astype(np.float32)
|
w = np.random.uniform(size=(33, 33)).astype(np.float32)
|
||||||
np.testing.assert_allclose(x @ w, risk_matmul(x, w), rtol=1e-5)
|
np.testing.assert_allclose(x @ w, cherry_matmul(x, w), rtol=1e-5)
|
||||||
np.testing.assert_allclose(x.T @ w, risk_matmul(x, w, True), rtol=1e-5)
|
np.testing.assert_allclose(x.T @ w, cherry_matmul(x, w, True), rtol=1e-5)
|
||||||
np.testing.assert_allclose(x @ w.T, risk_matmul(x, w, False, True), rtol=1e-5)
|
np.testing.assert_allclose(x @ w.T, cherry_matmul(x, w, False, True), rtol=1e-5)
|
||||||
np.testing.assert_allclose(x.T @ w.T, risk_matmul(x, w, True, True), rtol=1e-5)
|
np.testing.assert_allclose(x.T @ w.T, cherry_matmul(x, w, True, True), rtol=1e-5)
|
||||||
|
|
||||||
def test_matmul_transpose_uneven_w(self):
|
def test_matmul_transpose_uneven_w(self):
|
||||||
x = np.random.uniform(size=(47, 79)).astype(np.float32)
|
x = np.random.uniform(size=(47, 79)).astype(np.float32)
|
||||||
w = np.random.uniform(size=(42, 79)).astype(np.float32)
|
w = np.random.uniform(size=(42, 79)).astype(np.float32)
|
||||||
np.testing.assert_allclose(x @ w.T, risk_matmul(x, w, transpose_w=True), rtol=1e-5)
|
np.testing.assert_allclose(x @ w.T, cherry_matmul(x, w, transpose_w=True), rtol=1e-5)
|
||||||
|
|
||||||
def test_matmul_transpose_uneven_x(self):
|
def test_matmul_transpose_uneven_x(self):
|
||||||
x = np.random.uniform(size=(79, 47)).astype(np.float32)
|
x = np.random.uniform(size=(79, 47)).astype(np.float32)
|
||||||
w = np.random.uniform(size=(79, 42)).astype(np.float32)
|
w = np.random.uniform(size=(79, 42)).astype(np.float32)
|
||||||
np.testing.assert_allclose(x.T @ w, risk_matmul(x, w, transpose_x=True), rtol=1e-5)
|
np.testing.assert_allclose(x.T @ w, cherry_matmul(x, w, transpose_x=True), rtol=1e-5)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
np.random.seed(1337)
|
np.random.seed(1337)
|
||||||
@@ -1,36 +1,36 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from tinygrad.tensor import Function
|
from tinygrad.tensor import Function
|
||||||
from extra.risk import *
|
from extra.cherry import *
|
||||||
|
|
||||||
# ************* unary ops *************
|
# ************* unary ops *************
|
||||||
|
|
||||||
class ReLU(Function):
|
class ReLU(Function):
|
||||||
def forward(ctx, input):
|
def forward(ctx, input):
|
||||||
ctx.save_for_backward(input)
|
ctx.save_for_backward(input)
|
||||||
return risk_unop(input, UnaryOps.RELU)
|
return cherry_unop(input, UnaryOps.RELU)
|
||||||
|
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input, = ctx.saved_tensors
|
input, = ctx.saved_tensors
|
||||||
return risk_binop(grad_output, risk_unop(input, UnaryOps.GT0), BinaryOps.MUL)
|
return cherry_binop(grad_output, cherry_unop(input, UnaryOps.GT0), BinaryOps.MUL)
|
||||||
|
|
||||||
class Log(Function):
|
class Log(Function):
|
||||||
def forward(ctx, input):
|
def forward(ctx, input):
|
||||||
ctx.save_for_backward(input)
|
ctx.save_for_backward(input)
|
||||||
return risk_unop(input, UnaryOps.LOG)
|
return cherry_unop(input, UnaryOps.LOG)
|
||||||
|
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input, = ctx.saved_tensors
|
input, = ctx.saved_tensors
|
||||||
return risk_binop(grad_output, input, BinaryOps.DIV)
|
return cherry_binop(grad_output, input, BinaryOps.DIV)
|
||||||
|
|
||||||
class Exp(Function):
|
class Exp(Function):
|
||||||
def forward(ctx, input):
|
def forward(ctx, input):
|
||||||
ret = risk_unop(input, UnaryOps.EXP)
|
ret = cherry_unop(input, UnaryOps.EXP)
|
||||||
ctx.save_for_backward(ret)
|
ctx.save_for_backward(ret)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
ret, = ctx.saved_tensors
|
ret, = ctx.saved_tensors
|
||||||
return risk_binop(grad_output, ret, BinaryOps.MUL)
|
return cherry_binop(grad_output, ret, BinaryOps.MUL)
|
||||||
|
|
||||||
# ************* binary ops *************
|
# ************* binary ops *************
|
||||||
|
|
||||||
@@ -42,7 +42,7 @@ def unbroadcast(out, in_sh):
|
|||||||
class Add(Function):
|
class Add(Function):
|
||||||
def forward(ctx, x, y):
|
def forward(ctx, x, y):
|
||||||
ctx.save_for_backward(x.shape, y.shape)
|
ctx.save_for_backward(x.shape, y.shape)
|
||||||
return risk_binop(x, y, BinaryOps.ADD)
|
return cherry_binop(x, y, BinaryOps.ADD)
|
||||||
|
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
shape_x, shape_y = ctx.saved_tensors
|
shape_x, shape_y = ctx.saved_tensors
|
||||||
@@ -51,7 +51,7 @@ class Add(Function):
|
|||||||
class Sub(Function):
|
class Sub(Function):
|
||||||
def forward(ctx, x, y):
|
def forward(ctx, x, y):
|
||||||
ctx.save_for_backward(x.shape, y.shape)
|
ctx.save_for_backward(x.shape, y.shape)
|
||||||
return risk_binop(x, y, BinaryOps.SUB)
|
return cherry_binop(x, y, BinaryOps.SUB)
|
||||||
|
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
shape_x, shape_y = ctx.saved_tensors
|
shape_x, shape_y = ctx.saved_tensors
|
||||||
@@ -60,7 +60,7 @@ class Sub(Function):
|
|||||||
class Mul(Function):
|
class Mul(Function):
|
||||||
def forward(ctx, x, y):
|
def forward(ctx, x, y):
|
||||||
ctx.save_for_backward(x, y)
|
ctx.save_for_backward(x, y)
|
||||||
return risk_binop(x, y, BinaryOps.MUL)
|
return cherry_binop(x, y, BinaryOps.MUL)
|
||||||
|
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
x,y = ctx.saved_tensors
|
x,y = ctx.saved_tensors
|
||||||
@@ -69,7 +69,7 @@ class Mul(Function):
|
|||||||
class Pow(Function):
|
class Pow(Function):
|
||||||
def forward(ctx, x, y):
|
def forward(ctx, x, y):
|
||||||
ctx.save_for_backward(x, y)
|
ctx.save_for_backward(x, y)
|
||||||
return risk_binop(x, y, BinaryOps.POW)
|
return cherry_binop(x, y, BinaryOps.POW)
|
||||||
|
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
x,y = ctx.saved_tensors
|
x,y = ctx.saved_tensors
|
||||||
@@ -81,12 +81,12 @@ class Pow(Function):
|
|||||||
class Matmul(Function):
|
class Matmul(Function):
|
||||||
def forward(ctx, input, weight):
|
def forward(ctx, input, weight):
|
||||||
ctx.save_for_backward(input, weight)
|
ctx.save_for_backward(input, weight)
|
||||||
return risk_matmul(input, weight)
|
return cherry_matmul(input, weight)
|
||||||
|
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input, weight = ctx.saved_tensors
|
input, weight = ctx.saved_tensors
|
||||||
grad_input = risk_matmul(grad_output, weight, transpose_w=True)
|
grad_input = cherry_matmul(grad_output, weight, transpose_w=True)
|
||||||
grad_weight = risk_matmul(input, grad_output, transpose_x=True)
|
grad_weight = cherry_matmul(input, grad_output, transpose_x=True)
|
||||||
return grad_input, grad_weight
|
return grad_input, grad_weight
|
||||||
|
|
||||||
class Conv2D(Function):
|
class Conv2D(Function):
|
||||||
@@ -125,10 +125,10 @@ class Conv2D(Function):
|
|||||||
return np.moveaxis(ret,4,2).reshape(bs, cout, oy, ox)
|
return np.moveaxis(ret,4,2).reshape(bs, cout, oy, ox)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
riski_dmar(SLOT(0), x) # bs, groups, cin, x.shape[2], x.shape[3]
|
cherry_dmar(SLOT(0), x) # bs, groups, cin, x.shape[2], x.shape[3]
|
||||||
riski_dmar(SLOT(1), w) # groups, rcout, cin, H, W
|
cherry_dmar(SLOT(1), w) # groups, rcout, cin, H, W
|
||||||
|
|
||||||
risk_reset_counts()
|
cherry_reset_counts()
|
||||||
print(bs, ctx.groups, rcout, oy, ox, cin, H, W)
|
print(bs, ctx.groups, rcout, oy, ox, cin, H, W)
|
||||||
|
|
||||||
for B in range(0, bs):
|
for B in range(0, bs):
|
||||||
@@ -217,10 +217,10 @@ class Conv2D(Function):
|
|||||||
riski_store(Reg.MATMUL_OUTPUT,
|
riski_store(Reg.MATMUL_OUTPUT,
|
||||||
SLOT(2) + B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X,
|
SLOT(2) + B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X,
|
||||||
1, oy*ox, min(SZ, ox-X), min(SZ, rcout-c))
|
1, oy*ox, min(SZ, ox-X), min(SZ, rcout-c))
|
||||||
risk_print_counts()
|
cherry_print_counts()
|
||||||
|
|
||||||
#print(x.shape, w.shape, "->", ret.shape)
|
#print(x.shape, w.shape, "->", ret.shape)
|
||||||
return riski_dmaw(SLOT(2), (bs, cout, oy, ox))
|
return cherry_dmaw(SLOT(2), (bs, cout, oy, ox))
|
||||||
|
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
bs,_,oy,ox = grad_output.shape
|
bs,_,oy,ox = grad_output.shape
|
||||||
@@ -1,15 +1,17 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
import pyftdi.serialext
|
import pyftdi.serialext
|
||||||
|
|
||||||
port = pyftdi.serialext.serial_for_url('ftdi://ftdi:2232h/2', baudrate=115200)
|
#port = pyftdi.serialext.serial_for_url('ftdi://ftdi:2232h/2', baudrate=115200)
|
||||||
|
port = pyftdi.serialext.serial_for_url('ftdi://ftdi:2232h/2', baudrate=1000000)
|
||||||
print(port)
|
print(port)
|
||||||
|
|
||||||
while 1:
|
while 1:
|
||||||
#port.write(b'a')
|
#port.write(b'a')
|
||||||
data = port.read(1)
|
data = port.read(1)
|
||||||
print(data)
|
sys.stdout.write(data.decode('utf-8'))
|
||||||
time.sleep(0.01)
|
#time.sleep(0.01)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -355,9 +355,9 @@ def _register_ops(namespace, device=Device.CPU):
|
|||||||
|
|
||||||
from tinygrad import ops_cpu
|
from tinygrad import ops_cpu
|
||||||
_register_ops(ops_cpu)
|
_register_ops(ops_cpu)
|
||||||
if os.getenv("RISK", None) is not None:
|
if os.getenv("CHERRY", None) is not None:
|
||||||
from extra import ops_risk
|
from extra import ops_cherry
|
||||||
_register_ops(ops_risk)
|
_register_ops(ops_cherry)
|
||||||
try:
|
try:
|
||||||
import pyopencl as cl
|
import pyopencl as cl
|
||||||
# TODO: move this import to require_init_gpu?
|
# TODO: move this import to require_init_gpu?
|
||||||
|
|||||||
Reference in New Issue
Block a user