mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix some long lines in tests (#3006)
* fix some long lines in tests * better
This commit is contained in:
@@ -1,9 +1,6 @@
|
||||
# ruff: noqa: E501
|
||||
import torch
|
||||
import time
|
||||
import math
|
||||
import time, math, unittest
|
||||
import numpy as np
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI
|
||||
from tinygrad import Device, dtypes
|
||||
@@ -14,7 +11,8 @@ if CI:
|
||||
|
||||
FORWARD_ONLY = getenv("FORWARD_ONLY", 0)
|
||||
PRINT_TENSORS = getenv("PRINT_TENSORS", 0)
|
||||
def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, grad_atol=1e-4, grad_rtol=1e-3, forward_only=False, vals=None, a=-0.5, b=3):
|
||||
def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, grad_atol=1e-4, grad_rtol=1e-3,
|
||||
forward_only=False, vals=None, a=-0.5, b=3):
|
||||
if tinygrad_fxn is None: tinygrad_fxn = torch_fxn
|
||||
ts, tst = prepare_test_op(a, b, shps, vals, forward_only)
|
||||
|
||||
@@ -54,7 +52,9 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
|
||||
for i, (t, tt) in enumerate(zip(ts, tst)):
|
||||
compare(f"backward pass tensor {i}", tt.grad.numpy(), t.grad.detach().numpy(), atol=grad_atol, rtol=grad_rtol)
|
||||
|
||||
if not CI: print("\ntesting %40r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms " % (shps, torch_fp*1000, tinygrad_fp*1000, torch_fbp*1000, tinygrad_fbp*1000), end="")
|
||||
if not CI:
|
||||
print("\ntesting %40r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms " % \
|
||||
(shps, torch_fp*1000, tinygrad_fp*1000, torch_fbp*1000, tinygrad_fbp*1000), end="")
|
||||
|
||||
def prepare_test_op(a, b, shps, vals, forward_only=False):
|
||||
torch.manual_seed(0)
|
||||
@@ -165,11 +165,13 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([], lambda: (-torch.ones(3,3)).sum(axis=1), lambda: (-Tensor.ones(3,3)).sum(axis=1), forward_only=True)
|
||||
|
||||
def test_sum_pad_collapse(self):
|
||||
helper_test_op([], lambda: torch.nn.functional.pad(torch.ones(256,256), pad=(0,64,0,0)).sum(axis=1), lambda: Tensor.ones(256,256).pad(((0,0), (0,64))).sum(axis=1), forward_only=True)
|
||||
helper_test_op([], lambda: torch.nn.functional.pad(torch.ones(256,256), pad=(0,64,0,0)).sum(axis=1),
|
||||
lambda: Tensor.ones(256,256).pad(((0,0), (0,64))).sum(axis=1), forward_only=True)
|
||||
|
||||
# this is more complex and won't fold for a while
|
||||
def test_sum_cat_collapse(self):
|
||||
helper_test_op([], lambda: torch.cat([torch.ones(256,256), torch.zeros(256,64)], dim=1).sum(axis=1), lambda: Tensor.cat(Tensor.ones(256,256), Tensor.zeros(256,64), dim=1).sum(axis=1), forward_only=True)
|
||||
helper_test_op([], lambda: torch.cat([torch.ones(256,256), torch.zeros(256,64)], dim=1).sum(axis=1),
|
||||
lambda: Tensor.cat(Tensor.ones(256,256), Tensor.zeros(256,64), dim=1).sum(axis=1), forward_only=True)
|
||||
|
||||
def test_max_dont_collapse(self):
|
||||
helper_test_op([], lambda: torch.ones(256,256).max(1)[0], lambda: Tensor.ones(256,256).max(1), forward_only=True)
|
||||
@@ -449,7 +451,8 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_multinomial(self):
|
||||
# NOTE: this is random, so it has a very large atol
|
||||
helper_test_op([(1000,)], lambda x: torch.multinomial(x.clip(0,1), num_samples=1), lambda x: Tensor.multinomial(x.clip(0,1)), forward_only=True, atol=1000.)
|
||||
helper_test_op([(1000,)], lambda x: torch.multinomial(x.clip(0,1), num_samples=1),
|
||||
lambda x: Tensor.multinomial(x.clip(0,1)), forward_only=True, atol=1000.)
|
||||
|
||||
def test_small_cumsum(self):
|
||||
helper_test_op([(10)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6)
|
||||
@@ -508,9 +511,11 @@ class TestOps(unittest.TestCase):
|
||||
# batch matrix multiplication, result & input permuted
|
||||
helper_test_op([(20,10,25),(10,25,32)], lambda a,b: torch.einsum('jik,ikl->jil', [a, b]), lambda a,b: Tensor.einsum('jik,ikl->jil', [a, b]))
|
||||
# tensor contraction
|
||||
helper_test_op([(3,5,8,10),(11,13,5,16,8)], lambda a,b: torch.einsum('pqrs,tuqvr->pstuv', a,b), lambda a,b: Tensor.einsum('pqrs,tuqvr->pstuv', a,b), atol=1e-5)
|
||||
helper_test_op([(3,5,8,10),(11,13,5,16,8)], lambda a,b: torch.einsum('pqrs,tuqvr->pstuv', a,b),
|
||||
lambda a,b: Tensor.einsum('pqrs,tuqvr->pstuv', a,b), atol=1e-5)
|
||||
# tensor contraction, input permuted
|
||||
helper_test_op([(3,8,10,5),(11,5,13,16,8)], lambda a,b: torch.einsum('prsq,tquvr->pstuv', a,b), lambda a,b: Tensor.einsum('prsq,tquvr->pstuv', a,b), atol=1e-5)
|
||||
helper_test_op([(3,8,10,5),(11,5,13,16,8)], lambda a,b: torch.einsum('prsq,tquvr->pstuv', a,b),
|
||||
lambda a,b: Tensor.einsum('prsq,tquvr->pstuv', a,b), atol=1e-5)
|
||||
# bilinear transformation
|
||||
helper_test_op([(2,3),(5,3,7),(2,7)], lambda a,b,c: torch.einsum('ik,jkl,il->ij', [a,b,c]), lambda a,b,c: Tensor.einsum('ik,jkl,il->ij', [a,b,c]))
|
||||
|
||||
@@ -611,7 +616,8 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=0, dim=None), lambda x: Tensor.std(x, axis=None, correction=0))
|
||||
def test_std_keepdim(self):
|
||||
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=None, keepdim=True), lambda x: Tensor.std(x, keepdim=True))
|
||||
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=0, keepdim=True, correction=0), lambda x: Tensor.std(x, keepdim=True, correction=0, axis=0))
|
||||
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=0, keepdim=True, correction=0),
|
||||
lambda x: Tensor.std(x, keepdim=True, correction=0, axis=0))
|
||||
def test_log_softmax(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], lambda x: torch.nn.LogSoftmax(dim=0)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
|
||||
@@ -766,8 +772,8 @@ class TestOps(unittest.TestCase):
|
||||
def test_pad(self):
|
||||
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)),lambda x: x.pad(((3,4),(1,2))))
|
||||
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=5), lambda x: x.pad(((3,4), (1,2)), value=5))
|
||||
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=float("inf")), lambda x: x.pad(((3,4), (1,2)), value=float("inf")))
|
||||
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=float("-inf")), lambda x: x.pad(((3,4), (1,2)), value=float("-inf")))
|
||||
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=math.inf), lambda x: x.pad(((3,4), (1,2)), value=math.inf))
|
||||
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=-math.inf), lambda x: x.pad(((3,4), (1,2)), value=-math.inf))
|
||||
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,3,4), value=1), lambda x: x.pad(((3,4), None), value=1))
|
||||
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,0,0), value=1), lambda x: x.pad((None, None), value=1))
|
||||
|
||||
@@ -790,8 +796,10 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(3,4,1,2), value=value)[:,4], lambda x: x.pad(((1,2),(3,4)), value=value)[:,4])
|
||||
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x,(0,3,0,0), value=value)[:,4:6], lambda x: x.pad(((0,0),(0,3)), value=value)[:,4:6])
|
||||
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x,(0,1,3,2), value=value)[0:2,:], lambda x: x.pad(((3,2),(0,1)), value=value)[0:2,:])
|
||||
helper_test_op([(3,3,3)], lambda x: torch.nn.functional.pad(x,(1,1,0,1,3,2), value=value)[0:2,:,:], lambda x: x.pad(((3,2),(0,1),(1,1)), value=value)[0:2,:,:])
|
||||
helper_test_op([(3,3,3)], lambda x: torch.nn.functional.pad(x,(1,1,0,1,3,2), value=value)[2:4,:,:], lambda x: x.pad(((3,2),(0,1),(1,1)), value=value)[2:4,:,:])
|
||||
helper_test_op([(3,3,3)], lambda x: torch.nn.functional.pad(x,(1,1,0,1,3,2), value=value)[0:2,:,:],
|
||||
lambda x: x.pad(((3,2),(0,1),(1,1)), value=value)[0:2,:,:])
|
||||
helper_test_op([(3,3,3)], lambda x: torch.nn.functional.pad(x,(1,1,0,1,3,2), value=value)[2:4,:,:],
|
||||
lambda x: x.pad(((3,2),(0,1),(1,1)), value=value)[2:4,:,:])
|
||||
|
||||
def test_stack_slice(self):
|
||||
helper_test_op([(4)], lambda x: torch.stack([x for i in range(3)])[0,:], lambda x: Tensor.stack([x for i in range(3)])[0,:])
|
||||
@@ -1255,7 +1263,7 @@ class TestOps(unittest.TestCase):
|
||||
x = Tensor.randn(45, 65, 3)
|
||||
|
||||
for dim in range(-1, 3):
|
||||
helper_test_op([(45, 65, 3), (45, 65, 3), (45, 65, 3)], lambda x, y, z: torch.stack((x, y, z), dim=dim), lambda x, y, z: Tensor.stack([x, y, z], dim=dim))
|
||||
helper_test_op([(45,65,3), (45,65,3), (45,65,3)], lambda x, y, z: torch.stack((x, y, z), dim), lambda x, y, z: Tensor.stack([x, y, z], dim))
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
Tensor.stack([x], dim=77)
|
||||
@@ -1353,10 +1361,14 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_slice_fancy_indexing_with_idx(self):
|
||||
# indexing using idx with different dim
|
||||
helper_test_op([(2,3)], lambda x: x[torch.tensor([[0,0,0],[0,0,0]]), torch.tensor(1)], lambda x: x[Tensor([[0,0,0],[0,0,0]]), Tensor(1)])
|
||||
helper_test_op([(2,3)], lambda x: x[torch.tensor([1]), torch.tensor([[0,0,0],[0,0,0]])], lambda x: x[Tensor([1]), Tensor([[0,0,0],[0,0,0]])])
|
||||
helper_test_op([(2,3)], lambda x: x[torch.tensor([[0,0,0],[0,0,0]]), torch.tensor([2,1,1])], lambda x: x[Tensor([[0,0,0],[0,0,0]]), Tensor([2,1,1])])
|
||||
helper_test_op([(2,3)], lambda x: x[torch.tensor([[0,1,-1],[-1,-2,0]]), torch.tensor([2,1,-1])], lambda x: x[Tensor([[0,1,-1],[-1,-2,0]]), Tensor([2,1,-1])])
|
||||
helper_test_op([(2,3)], lambda x: x[torch.tensor([[0,0,0],[0,0,0]]), torch.tensor(1)],
|
||||
lambda x: x[Tensor([[0,0,0],[0,0,0]]), Tensor(1)])
|
||||
helper_test_op([(2,3)], lambda x: x[torch.tensor([1]), torch.tensor([[0,0,0],[0,0,0]])],
|
||||
lambda x: x[Tensor([1]), Tensor([[0,0,0],[0,0,0]])])
|
||||
helper_test_op([(2,3)], lambda x: x[torch.tensor([[0,0,0],[0,0,0]]), torch.tensor([2,1,1])],
|
||||
lambda x: x[Tensor([[0,0,0],[0,0,0]]), Tensor([2,1,1])])
|
||||
helper_test_op([(2,3)], lambda x: x[torch.tensor([[0,1,-1],[-1,-2,0]]), torch.tensor([2,1,-1])],
|
||||
lambda x: x[Tensor([[0,1,-1],[-1,-2,0]]), Tensor([2,1,-1])])
|
||||
|
||||
def test_slice_fancy_indexing_list_indices(self):
|
||||
a,b,c,d,e,i,j,k,o,p = self._get_index_randoms()
|
||||
@@ -1389,8 +1401,10 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_slice_fancy_indexing_tuple_with_tensors(self):
|
||||
a,b,c,d,e,i,j,k,o,p = self._get_index_randoms()
|
||||
# helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,),], lambda x: x[(i,),]) TypeError: only integer tensors of a single element can be converted to an index
|
||||
# helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,1),], lambda x: x[(i,1),]) TypeError: only integer tensors of a single element can be converted to an index
|
||||
# # TypeError: only integer tensors of a single element can be converted to an index
|
||||
# helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,),], lambda x: x[(i,),])
|
||||
# # TypeError: only integer tensors of a single element can be converted to an index
|
||||
# helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,1),], lambda x: x[(i,1),])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,[1,1])], lambda x: x[(i,[1,1])])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,(1,1))], lambda x: x[(i,(1,1))])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,b,c,d,e)], lambda x: x[(i,j,k,o,p)])
|
||||
@@ -1419,21 +1433,32 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=1), lambda x: x.gather(idx=a, dim=1))
|
||||
helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=2), lambda x: x.gather(idx=a, dim=2))
|
||||
helper_test_op([(3,4,5)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0))
|
||||
self.helper_test_exception([(4,5,6)], lambda x: x.gather(index=torch.tensor([1], dtype=torch.int64), dim=0), lambda x: x.gather(idx=Tensor([1], dtype=dtypes.int32), dim=0), expected=(RuntimeError, AssertionError))
|
||||
self.helper_test_exception([(2,1,1)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0), expected=(RuntimeError, AssertionError))
|
||||
self.helper_test_exception([(4,5,6)], lambda x: x.gather(index=torch.tensor([1], dtype=torch.int64), dim=0),
|
||||
lambda x: x.gather(idx=Tensor([1], dtype=dtypes.int32), dim=0), expected=(RuntimeError, AssertionError))
|
||||
self.helper_test_exception([(2,1,1)], lambda x: x.gather(index=b, dim=0),
|
||||
lambda x: x.gather(idx=a, dim=0), expected=(RuntimeError, AssertionError))
|
||||
|
||||
def test_scaled_product_attention(self):
|
||||
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z))
|
||||
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64), (32,8,16,16)], lambda x,y,z,m: torch.nn.functional.scaled_dot_product_attention(x,y,z,attn_mask=m), lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,attn_mask=m))
|
||||
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z),
|
||||
lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z))
|
||||
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64), (32,8,16,16)],
|
||||
lambda x,y,z,m: torch.nn.functional.scaled_dot_product_attention(x,y,z,attn_mask=m),
|
||||
lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,attn_mask=m))
|
||||
|
||||
def test_scaled_product_attention_causal(self):
|
||||
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,is_causal=True), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True))
|
||||
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)],
|
||||
lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,is_causal=True),
|
||||
lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True))
|
||||
|
||||
def test_binary_crossentropy(self):
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),torch.clip(y,0,1)), lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1)))
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,torch.clip(y,0,1)), lambda x,y: x.binary_crossentropy_logits(y.clip(0,1)))
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,torch.clip(y,0,1)), lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1)))
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),torch.clip(y,0,1)), lambda x,y: x.binary_crossentropy_logits(y.clip(0,1)))
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),torch.clip(y,0,1)),
|
||||
lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1)))
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,torch.clip(y,0,1)),
|
||||
lambda x,y: x.binary_crossentropy_logits(y.clip(0,1)))
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,torch.clip(y,0,1)),
|
||||
lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1)))
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),torch.clip(y,0,1)),
|
||||
lambda x,y: x.binary_crossentropy_logits(y.clip(0,1)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
np.random.seed(1337)
|
||||
|
||||
Reference in New Issue
Block a user