mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 14:43:57 -05:00
nah, no sign, it's not what you want. use relu
This commit is contained in:
@@ -105,17 +105,17 @@ Warning: do not rely on the ANE port. It segfaults sometimes. So if you were doi
|
||||
|
||||
### Adding an accelerator
|
||||
|
||||
You need to support 15 first class ops:
|
||||
You need to support 14 first class ops:
|
||||
|
||||
```
|
||||
Relu, Log, Exp, Sign # unary ops
|
||||
Relu, Log, Exp # unary ops
|
||||
Sum, Max # reduce ops (with axis argument)
|
||||
Add, Sub, Mul, Pow # binary ops (with broadcasting)
|
||||
Reshape, Transpose, Slice # movement ops
|
||||
Matmul, Conv2D # processing ops
|
||||
```
|
||||
|
||||
While more ops may be added, I think these base 15 are stable.
|
||||
While more ops may be added, I think this base is stable.
|
||||
|
||||
## ImageNet inference
|
||||
|
||||
|
||||
@@ -6,12 +6,12 @@ import timeit
|
||||
import functools
|
||||
from tinygrad.tensor import Tensor, DEFAULT_DEVICE, Device
|
||||
|
||||
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, forward_only=False, vals=None):
|
||||
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-6, rtol=1e-3, grad_atol=1e-6, grad_rtol=1e-3, forward_only=False, vals=None):
|
||||
torch.manual_seed(0)
|
||||
if shps is None:
|
||||
ts = [torch.tensor(x, requires_grad=True) for x in vals]
|
||||
else:
|
||||
ts = [torch.rand(x, requires_grad=True) for x in shps]
|
||||
ts = [torch.tensor((np.random.random(size=x).astype(np.float32)-0.5)*20, requires_grad=True) for x in shps]
|
||||
|
||||
tst = [Tensor(x.detach().numpy()) for x in ts]
|
||||
out = torch_fxn(*ts)
|
||||
@@ -79,8 +79,8 @@ class TestOps(unittest.TestCase):
|
||||
def test_dot(self):
|
||||
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot)
|
||||
def test_multidot(self):
|
||||
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot)
|
||||
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot)
|
||||
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
def test_sum(self):
|
||||
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum)
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)))
|
||||
@@ -159,7 +159,7 @@ class TestOps(unittest.TestCase):
|
||||
with self.subTest(batch_size=bs, channels=cin, groups=groups, height=H, width=W):
|
||||
helper_test_op([(bs,cin,11,28), (6,cin//groups,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), grad_rtol=1e-5)
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
def test_strided_conv2d(self):
|
||||
bs = 4
|
||||
@@ -168,18 +168,19 @@ class TestOps(unittest.TestCase):
|
||||
with self.subTest(stride := 2):
|
||||
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=stride).relu())
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=stride).relu(), atol=1e-4)
|
||||
with self.subTest(stride := (2,1)):
|
||||
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,stride=stride).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu())
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), atol=1e-4)
|
||||
|
||||
def test_maxpool2d(self):
|
||||
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
|
||||
with self.subTest(kernel_size=ksz):
|
||||
helper_test_op([(32,2,110,28)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=ksz))
|
||||
# TODO: why is this tolerance so high?
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=ksz), grad_atol=1e-4)
|
||||
|
||||
def test_avgpool2d(self):
|
||||
shape = (32,2,111,28)
|
||||
|
||||
@@ -37,15 +37,6 @@ class Exp(Function):
|
||||
ret, = ctx.saved_tensors
|
||||
return grad_output * ret
|
||||
|
||||
class Sign(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
return np.sign(input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return np.zeros_like(grad_output)
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
|
||||
@@ -64,18 +64,9 @@ class Exp(Function):
|
||||
ret, = ctx.saved_tensors
|
||||
return binary_op(ctx, 'a * b', grad_output, ret)
|
||||
|
||||
class Sign(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
return unary_op(ctx, '(a > 0) - (a < 0)', input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return buffer_new(ctx, grad_output.shape, zero=True)
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
def reduce_op(ctx, code, code2, inp, axis=None):
|
||||
def reduce_op(ctx, code, code2, inp, axis=None, start="0.0"):
|
||||
if axis is None:
|
||||
# full reduce
|
||||
osize = [1]*len(inp.shape)
|
||||
@@ -92,7 +83,7 @@ def reduce_op(ctx, code, code2, inp, axis=None):
|
||||
__global const int *shape_x, __global const int *shape_ret) {
|
||||
int gid = get_global_id(0);
|
||||
|
||||
float out = 0.0;
|
||||
float out = """+start+""";
|
||||
for (int x = 0; x < sz; x++) {
|
||||
int idx = 0; // compute index into a_g
|
||||
int tprod = prod;
|
||||
@@ -140,7 +131,7 @@ class Max(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, axis=None):
|
||||
axis = [axis] if type(axis) == int else axis
|
||||
ret = reduce_op(ctx, "out = max(a,out)", "out", input, axis=axis)
|
||||
ret = reduce_op(ctx, "out = max(a,out)", "out", input, axis=axis, start="-INFINITY")
|
||||
ctx.save_for_backward(input, axis, ret)
|
||||
if axis is not None:
|
||||
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
|
||||
|
||||
@@ -232,10 +232,10 @@ class Tensor:
|
||||
return self * self.sigmoid()
|
||||
|
||||
def relu6(self):
|
||||
return self.relu() * (6-self).sign()
|
||||
return self.relu() - (self-6).relu()
|
||||
|
||||
def hardswish(self):
|
||||
return self * (self+3).relu6()/6
|
||||
return self * (self+3).relu6() * (1/6)
|
||||
|
||||
def tanh(self):
|
||||
return 2.0 * ((2.0 * self).sigmoid()) - 1.0
|
||||
@@ -274,7 +274,7 @@ class Tensor:
|
||||
return self * (self.softplus().tanh()) # x*tanh(softplus(x))
|
||||
|
||||
def abs(self):
|
||||
return self * self.sign()
|
||||
return self.relu() + (-1.0*self).relu()
|
||||
|
||||
def _pool2d(self, py, px):
|
||||
xup = self[:, :, :self.shape[2]-self.shape[2]%py, :self.shape[3]-self.shape[3]%px]
|
||||
|
||||
Reference in New Issue
Block a user