mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
add support for add
This commit is contained in:
@@ -78,6 +78,15 @@ class Mul(Function):
|
||||
return y*grad_output, x*grad_output
|
||||
register('mul', Mul)
|
||||
|
||||
class Add(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
return x+y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output, grad_output
|
||||
register('add', Add)
|
||||
|
||||
class ReLU(Function):
|
||||
@staticmethod
|
||||
|
||||
6
test.py
6
test.py
@@ -14,7 +14,8 @@ def test_tinygrad():
|
||||
outr = out.relu()
|
||||
outl = outr.logsoftmax()
|
||||
outm = outl.mul(m)
|
||||
outx = outm.sum()
|
||||
outa = outm.add(m)
|
||||
outx = outa.sum()
|
||||
outx.backward()
|
||||
return outx.data, x.grad, W.grad
|
||||
|
||||
@@ -26,7 +27,8 @@ def test_pytorch():
|
||||
outr = out.relu()
|
||||
outl = torch.nn.functional.log_softmax(outr, dim=1)
|
||||
outm = outl.mul(m)
|
||||
outx = outm.sum()
|
||||
outa = outm.add(m)
|
||||
outx = outa.sum()
|
||||
outx.backward()
|
||||
return outx.detach().numpy(), x.grad, W.grad
|
||||
|
||||
|
||||
Reference in New Issue
Block a user