mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
max to behave on ties like torch (#229)
* checkpoint * fixing pow * undo pow * backward max on GPU and CPU rewrite * indentation * changing seed for curiosity * max replaced equality * undo seed * rebase * fixed tests * merge error
This commit is contained in:
@@ -6,9 +6,12 @@ import timeit
|
||||
import functools
|
||||
from tinygrad.tensor import Tensor, GPU, ANE, Device
|
||||
|
||||
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, device=Device.CPU, forward_only=False):
|
||||
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, device=Device.CPU, forward_only=False, vals=None):
|
||||
torch.manual_seed(0)
|
||||
ts = [torch.rand(x, requires_grad=True) for x in shps]
|
||||
if shps is None:
|
||||
ts = [torch.tensor(x, requires_grad=True) for x in vals]
|
||||
else:
|
||||
ts = [torch.rand(x, requires_grad=True) for x in shps]
|
||||
tst = [Tensor(x.detach().numpy()) for x in ts]
|
||||
if device==Device.GPU:
|
||||
tst = [x.gpu() for x in tst]
|
||||
@@ -85,8 +88,10 @@ class TestOps(unittest.TestCase):
|
||||
def test_max(self):
|
||||
helper_test_op([(45,3)], lambda x: x.max(), Tensor.max, device=self.device)
|
||||
helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), device=self.device)
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1), device=self.device)
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0].mul(0.5), lambda x: Tensor.max(x, axis=1).mul(0.5), device=self.device)
|
||||
helper_test_op(None, lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), device=self.device,
|
||||
vals=[
|
||||
[[1.0,1.0,0.0,1.0]],
|
||||
])
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1), device=self.device)
|
||||
def test_mean_axis(self):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)), device=self.device)
|
||||
|
||||
Reference in New Issue
Block a user