mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-04 19:55:18 -05:00
cleanup test_ops.py (#3192)
- removed exact duplicated tests - only kept one function if torch_fxn is the same as tinygrad_fxn - used tensor method instead of class method style - replaced unneeded `lamdba f: f(x)` with just `f` - re-enabled commented tests that work now - removed some forward_only now 0 shape tensor can backward
This commit is contained in:
620
test/test_ops.py
620
test/test_ops.py
@@ -102,6 +102,7 @@ class TestOps(unittest.TestCase):
|
||||
def test_eye(self):
|
||||
helper_test_op([], lambda: torch.eye(10), lambda: Tensor.eye(10), forward_only=True)
|
||||
helper_test_op([], lambda: torch.eye(1), lambda: Tensor.eye(1), forward_only=True)
|
||||
helper_test_op([], lambda: torch.eye(0), lambda: Tensor.eye(0), forward_only=True)
|
||||
|
||||
def test_split(self):
|
||||
test_cases = [
|
||||
@@ -149,8 +150,6 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([], lambda: torch.arange(5, 10, 3), lambda: Tensor.arange(5, 10, 3), forward_only=True)
|
||||
helper_test_op([], lambda: torch.arange(10, 5, -3), lambda: Tensor.arange(10, 5, -3), forward_only=True)
|
||||
helper_test_op([], lambda: torch.arange(11, 5, -3), lambda: Tensor.arange(11, 5, -3), forward_only=True)
|
||||
def test_arange_simple(self):
|
||||
helper_test_op([], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True)
|
||||
def test_arange_big(self):
|
||||
helper_test_op([], lambda: torch.arange(256), lambda: Tensor.arange(256), forward_only=True)
|
||||
|
||||
@@ -234,33 +233,30 @@ class TestOps(unittest.TestCase):
|
||||
(t*(t < 0)).sum().backward()
|
||||
np.testing.assert_allclose(t.grad.numpy(), tt.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
#@unittest.skip("this is broken with contiguous")
|
||||
def test_trunc(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.trunc(x), lambda x: x.trunc(), forward_only=True)
|
||||
helper_test_op([(45,65)], lambda x: x.trunc(), forward_only=True)
|
||||
a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5])
|
||||
helper_test_op([], lambda: torch.trunc(b), lambda: Tensor.trunc(a), forward_only=True)
|
||||
#@unittest.skip("this is broken with contiguous")
|
||||
def test_floor(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.floor(x), lambda x: x.floor(), forward_only=True)
|
||||
helper_test_op([(45,65)], lambda x: x.floor(), forward_only=True)
|
||||
a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5])
|
||||
helper_test_op([], lambda: torch.floor(b), lambda: Tensor.floor(a), forward_only=True)
|
||||
#@unittest.skip("this is broken with contiguous")
|
||||
def test_ceil(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.ceil(x), lambda x: x.ceil(), forward_only=True)
|
||||
helper_test_op([(45,65)], lambda x: x.ceil(), forward_only=True)
|
||||
a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5])
|
||||
helper_test_op([], lambda: torch.ceil(b), lambda: Tensor.ceil(a), forward_only=True)
|
||||
def test_tril(self):
|
||||
helper_test_op([(3,3)], lambda x: x.tril(), lambda x: x.tril())
|
||||
helper_test_op([(3,3)], lambda x: x.tril(1), lambda x: x.tril(1))
|
||||
helper_test_op([(3,3)], lambda x: x.tril(-1), lambda x: x.tril(-1))
|
||||
helper_test_op([(5,3,3)], lambda x: x.tril(), lambda x: x.tril())
|
||||
helper_test_op([(5,3,3)], lambda x: x.tril(1), lambda x: x.tril(1))
|
||||
helper_test_op([(3,3)], lambda x: x.tril())
|
||||
helper_test_op([(3,3)], lambda x: x.tril(1))
|
||||
helper_test_op([(3,3)], lambda x: x.tril(-1))
|
||||
helper_test_op([(5,3,3)], lambda x: x.tril())
|
||||
helper_test_op([(5,3,3)], lambda x: x.tril(1))
|
||||
def test_triu(self):
|
||||
helper_test_op([(3,3)], lambda x: x.triu(), lambda x: x.triu())
|
||||
helper_test_op([(3,3)], lambda x: x.triu(1), lambda x: x.triu(1))
|
||||
helper_test_op([(3,3)], lambda x: x.triu(-1), lambda x: x.triu(-1))
|
||||
helper_test_op([(5,3,3)], lambda x: x.triu(), lambda x: x.triu())
|
||||
helper_test_op([(5,3,3)], lambda x: x.triu(1), lambda x: x.triu(1))
|
||||
helper_test_op([(3,3)], lambda x: x.triu())
|
||||
helper_test_op([(3,3)], lambda x: x.triu(1))
|
||||
helper_test_op([(3,3)], lambda x: x.triu(-1))
|
||||
helper_test_op([(5,3,3)], lambda x: x.triu())
|
||||
helper_test_op([(5,3,3)], lambda x: x.triu(1))
|
||||
def test_maximum(self):
|
||||
helper_test_op([(45,65), (45,65)], torch.maximum, Tensor.maximum)
|
||||
helper_test_op([(), ()], torch.maximum, Tensor.maximum)
|
||||
@@ -269,81 +265,97 @@ class TestOps(unittest.TestCase):
|
||||
def test_minimum(self):
|
||||
helper_test_op([(45,65), (45,65)], torch.minimum, Tensor.minimum)
|
||||
helper_test_op([(), ()], torch.minimum, Tensor.minimum)
|
||||
|
||||
def test_add(self):
|
||||
helper_test_op([(45,68), (45,68)], lambda x,y: x+y)
|
||||
helper_test_op([(45,68), (45,68)], lambda x,y: x+y, Tensor.add)
|
||||
def test_add_number(self):
|
||||
helper_test_op([(), ()], lambda x,y: x+y, Tensor.add)
|
||||
helper_test_op([(), ()], lambda x,y: x+y)
|
||||
def test_add3(self):
|
||||
helper_test_op([(45,65), (45,65), (45,65)], lambda x,y,z: x+y+z)
|
||||
def test_add_simple(self):
|
||||
helper_test_op([(256), (256)], lambda x,y: x+y, Tensor.add, forward_only=True)
|
||||
def test_broadcasted_add(self):
|
||||
helper_test_op([(45,65), (45,1)], lambda x,y: x+y, lambda x,y: x+y)
|
||||
helper_test_op([(45,65), ()], lambda x,y: x+y, lambda x,y: x+y)
|
||||
helper_test_op([(45,65), (45,1)], lambda x,y: x+y)
|
||||
helper_test_op([(45,65), ()], lambda x,y: x+y)
|
||||
def test_broadcasted_add_2(self):
|
||||
helper_test_op([(45,65), (65,)], lambda x,y: x+y, lambda x,y: x+y)
|
||||
helper_test_op([(45,65), (65,)], lambda x,y: x+y)
|
||||
|
||||
def test_sub(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub)
|
||||
helper_test_op([(), ()], lambda x,y: x-y, Tensor.sub)
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x-y)
|
||||
helper_test_op([(), ()], lambda x,y: x-y)
|
||||
def test_scalar_sub(self):
|
||||
helper_test_op([(45,65)], lambda x: x-2)
|
||||
helper_test_op([()], lambda x: x-2)
|
||||
def test_scalar_rsub(self):
|
||||
helper_test_op([(45,65)], lambda x: 2-x)
|
||||
helper_test_op([()], lambda x: 2-x)
|
||||
|
||||
def test_neg(self):
|
||||
helper_test_op([(45,65)], lambda x: -x)
|
||||
helper_test_op([()], lambda x: -x)
|
||||
|
||||
def test_mul(self):
|
||||
helper_test_op([(64,64), (64,64)], lambda x,y: x*y, Tensor.mul)
|
||||
def test_mul_number(self):
|
||||
helper_test_op([(), ()], lambda x,y: x*y, Tensor.mul)
|
||||
def test_mul_const(self):
|
||||
helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2)
|
||||
helper_test_op([(45,65)], lambda x: x*-1, lambda x: x*-1)
|
||||
helper_test_op([(45,65)], lambda x: 255*x, lambda x: 255*x)
|
||||
helper_test_op([(64,64), (64,64)], lambda x,y: x*y)
|
||||
helper_test_op([(), ()], lambda x,y: x*y)
|
||||
def test_scalar_mul(self):
|
||||
helper_test_op([(45,65)], lambda x: x*2)
|
||||
helper_test_op([(45,65)], lambda x: x*-1)
|
||||
helper_test_op([(45,65)], lambda x: 255*x)
|
||||
helper_test_op([()], lambda x: x*2)
|
||||
def test_scalar_rmul(self):
|
||||
helper_test_op([(45,65)], lambda x: 2*x)
|
||||
helper_test_op([()], lambda x: 2*x)
|
||||
|
||||
def test_div(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div)
|
||||
helper_test_op([(), ()], lambda x,y: x/y, Tensor.div)
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x/y)
|
||||
helper_test_op([(), ()], lambda x,y: x/y)
|
||||
helper_test_op(None, lambda x,y: x/y, Tensor.div, forward_only=True, vals=np.array([[5],[1]], dtype=np.int32))
|
||||
def test_div_int(self):
|
||||
helper_test_op(None, lambda x: (x/2).to(torch.int), lambda x: x/2, forward_only=True, vals=np.array([[3]], dtype=np.int32))
|
||||
def test_div_const(self):
|
||||
helper_test_op([(45,65)], lambda x: x/255, lambda x: x/255)
|
||||
helper_test_op([(45,65)], lambda x: x/1, lambda x: x/1)
|
||||
helper_test_op([(45,65)], lambda x: 1/x, lambda x: 1/x)
|
||||
helper_test_op([(45,65)], lambda x: x/2, lambda x: x/2)
|
||||
helper_test_op([(45,65)], lambda x: 2/x, lambda x: 2/x)
|
||||
helper_test_op([()], lambda x: x/2, lambda x: x/2)
|
||||
helper_test_op([()], lambda x: 2/x, lambda x: 2/x)
|
||||
helper_test_op([(45,65)], lambda x: x/255)
|
||||
helper_test_op([(45,65)], lambda x: x/1)
|
||||
helper_test_op([(45,65)], lambda x: 1/x)
|
||||
helper_test_op([(45,65)], lambda x: x/2)
|
||||
helper_test_op([(45,65)], lambda x: 2/x)
|
||||
helper_test_op([()], lambda x: x/2)
|
||||
helper_test_op([()], lambda x: 2/x)
|
||||
@unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "METAL has issues with -inf")
|
||||
def test_mul_const_naninf(self):
|
||||
helper_test_op([(45,65)], lambda x: x*float("inf"), lambda x: x*float("inf"))
|
||||
helper_test_op([(45,65)], lambda x: x*-float("inf"), lambda x: x*-float("inf"))
|
||||
helper_test_op([(45,65)], lambda x: x*float("nan"), lambda x: x*float("nan"))
|
||||
helper_test_op([(45,65)], lambda x: x*math.inf)
|
||||
helper_test_op([(45,65)], lambda x: x*-math.inf)
|
||||
helper_test_op([(45,65)], lambda x: x*math.nan)
|
||||
@unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "METAL has issues with -inf")
|
||||
def test_div_const_naninf(self):
|
||||
helper_test_op([(45,65)], lambda x: x/float("inf"), lambda x: x/float("inf"))
|
||||
helper_test_op([(45,65)], lambda x: x/-float("inf"), lambda x: x/-float("inf"))
|
||||
helper_test_op([(45,65)], lambda x: x/float("nan"), lambda x: x/float("nan"))
|
||||
helper_test_op([(45,65)], lambda x: float("inf")/x, lambda x: float("inf")/x)
|
||||
helper_test_op([(45,65)], lambda x: (-float("inf"))/x, lambda x: (-float("inf"))/x)
|
||||
helper_test_op([(45,65)], lambda x: float("nan")/x, lambda x: float("nan")/x)
|
||||
helper_test_op([(45,65)], lambda x: x/math.inf)
|
||||
helper_test_op([(45,65)], lambda x: x/-math.inf)
|
||||
helper_test_op([(45,65)], lambda x: x/math.nan)
|
||||
helper_test_op([(45,65)], lambda x: math.inf/x)
|
||||
helper_test_op([(45,65)], lambda x: (-math.inf)/x)
|
||||
helper_test_op([(45,65)], lambda x: math.nan/x)
|
||||
|
||||
def test_pow_full(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, a=0)
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x**y, a=0)
|
||||
def test_pow(self):
|
||||
# TODO: why is a=0 for these tests?
|
||||
helper_test_op([(45,65)], lambda x: x**0, lambda x: Tensor.pow(x,0), a=0)
|
||||
helper_test_op([(45,65)], lambda x: x**1, lambda x: Tensor.pow(x,1), a=0)
|
||||
helper_test_op([(45,65)], lambda x: x**2, lambda x: Tensor.pow(x,2), a=0)
|
||||
helper_test_op([(45,65)], lambda x: x**3, lambda x: Tensor.pow(x,3), a=0)
|
||||
helper_test_op([(45,65)], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0)
|
||||
helper_test_op([()], lambda x: x**2, lambda x: Tensor.pow(x,2), a=0)
|
||||
helper_test_op([()], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0)
|
||||
helper_test_op([(45,65)], lambda x: x**0, a=0)
|
||||
helper_test_op([(45,65)], lambda x: x**1, a=0)
|
||||
helper_test_op([(45,65)], lambda x: x**2, a=0)
|
||||
helper_test_op([(45,65)], lambda x: x**3, a=0)
|
||||
helper_test_op([(45,65)], lambda x: x**-2, a=0)
|
||||
helper_test_op([()], lambda x: x**2, a=0)
|
||||
helper_test_op([()], lambda x: x**-2, a=0)
|
||||
# Regression tests for https://github.com/tinygrad/tinygrad/issues/1151
|
||||
helper_test_op([(45,65)], lambda x: x**3, lambda x: Tensor.pow(x,3), a=-10)
|
||||
helper_test_op([()], lambda x: x**3, lambda x: Tensor.pow(x,3), a=-10)
|
||||
helper_test_op([(45,65)], lambda x: x**3, a=-10)
|
||||
helper_test_op([()], lambda x: x**3, a=-10)
|
||||
# Regression tests for https://github.com/tinygrad/tinygrad/issues/1251
|
||||
helper_test_op([(45,65)], lambda x: x**0.2, lambda x: Tensor.pow(x,0.2), a=-10)
|
||||
helper_test_op([(45,65)], lambda x: x**1.2, lambda x: Tensor.pow(x,1.2), a=-10)
|
||||
helper_test_op([()], lambda x: x**0.2, lambda x: Tensor.pow(x,0.2), a=-10)
|
||||
helper_test_op([()], lambda x: x**1.2, lambda x: Tensor.pow(x,1.2), a=-10)
|
||||
helper_test_op([(45,65)], lambda x: x**0.2, a=-10)
|
||||
helper_test_op([(45,65)], lambda x: x**1.2, a=-10)
|
||||
helper_test_op([()], lambda x: x**0.2, a=-10)
|
||||
helper_test_op([()], lambda x: x**1.2, a=-10)
|
||||
a, b = Tensor([0.0], requires_grad=True), torch.tensor([0.0], requires_grad=True)
|
||||
helper_test_op([], lambda: b**1.1, lambda: a**1.1, )
|
||||
helper_test_op([], lambda: b**1.1, lambda: a**1.1)
|
||||
def test_pow_const(self):
|
||||
helper_test_op([(45,65)], lambda x: x**1.0, lambda x: x**1.0)
|
||||
helper_test_op([(45,65)], lambda x: x**-1.0, lambda x: x**-1.0)
|
||||
@@ -352,12 +364,14 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65)], lambda x: 2.0**x, lambda x: 2.0**x)
|
||||
helper_test_op([()], lambda x: x**2.0, lambda x: x**2.0)
|
||||
helper_test_op([()], lambda x: 2.0**x, lambda x: 2.0**x)
|
||||
|
||||
def test_sqrt(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, a=0)
|
||||
helper_test_op([()], lambda x: x.sqrt(), Tensor.sqrt, a=0)
|
||||
helper_test_op([(45,65)], lambda x: x.sqrt(), a=0)
|
||||
helper_test_op([()], lambda x: x.sqrt(), a=0)
|
||||
def test_rsqrt(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.rsqrt(x), Tensor.rsqrt, a=0)
|
||||
helper_test_op([()], lambda x: torch.rsqrt(x), Tensor.rsqrt, a=0)
|
||||
helper_test_op([(45,65)], lambda x: x.rsqrt(), a=0)
|
||||
helper_test_op([()], lambda x: x.rsqrt(), a=0)
|
||||
|
||||
def test_xor(self):
|
||||
tor = torch.tensor([[1,-8,1],[32,1,6]], dtype=torch.int)
|
||||
ten = Tensor([[1,-8,1],[32,1,6]], dtype=dtypes.int32)
|
||||
@@ -366,17 +380,20 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([], lambda: 0x1337^tor, lambda: 0x1337^ten, forward_only=True)
|
||||
|
||||
def test_sin(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sin(), Tensor.sin, a=0)
|
||||
helper_test_op([(45,65)], lambda x: x.sin(), a=0)
|
||||
helper_test_op([()], lambda x: x.sin(), a=0)
|
||||
def test_cos(self):
|
||||
helper_test_op([(45,65)], lambda x: x.cos(), Tensor.cos, a=0)
|
||||
helper_test_op([(45,65)], lambda x: x.cos(), a=0)
|
||||
helper_test_op([()], lambda x: x.cos(), a=0)
|
||||
def test_tan(self):
|
||||
helper_test_op([(45,65)], lambda x: x.tan(), Tensor.tan, a=0)
|
||||
helper_test_op([(45,65)], lambda x: x.tan(), a=0)
|
||||
helper_test_op([()], lambda x: x.tan(), a=0)
|
||||
|
||||
def test_relu(self):
|
||||
helper_test_op([(64,64)], lambda x: x.relu(), Tensor.relu)
|
||||
helper_test_op([()], lambda x: x.relu(), Tensor.relu)
|
||||
helper_test_op([(64,64)], lambda x: x.relu())
|
||||
helper_test_op([()], lambda x: x.relu())
|
||||
def test_relu_exact(self):
|
||||
helper_test_op(None, lambda x: x.relu(), Tensor.relu, vals=[[-1.,0,1]])
|
||||
helper_test_op(None, lambda x: x.relu(), vals=[[-1.,0,1]])
|
||||
def test_relu_maximum_exact(self):
|
||||
helper_test_op(None, lambda x: torch.maximum(x, torch.zeros_like(x, requires_grad=False)), lambda x: Tensor.maximum(x, 0), vals=[[-1.,0,1]])
|
||||
def test_leakyrelu(self):
|
||||
@@ -386,38 +403,40 @@ class TestOps(unittest.TestCase):
|
||||
for val in range(1, 5):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val))
|
||||
helper_test_op([()], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val))
|
||||
|
||||
def test_abs(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.abs(x), Tensor.abs)
|
||||
helper_test_op([()], lambda x: torch.abs(x), Tensor.abs)
|
||||
helper_test_op([(45,65)], torch.abs, Tensor.abs)
|
||||
helper_test_op([()], torch.abs, Tensor.abs)
|
||||
def test_log(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.log(x), Tensor.log)
|
||||
helper_test_op([()], lambda x: torch.log(x), Tensor.log)
|
||||
helper_test_op([(45,65)], torch.log, Tensor.log)
|
||||
helper_test_op([()], torch.log, Tensor.log)
|
||||
def test_log2(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.log2(x), Tensor.log2)
|
||||
helper_test_op([()], lambda x: torch.log2(x), Tensor.log2)
|
||||
helper_test_op([(45,65)], torch.log2, Tensor.log2)
|
||||
helper_test_op([()], torch.log2, Tensor.log2)
|
||||
def test_exp(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.exp(x), Tensor.exp)
|
||||
helper_test_op([()], lambda x: torch.exp(x), Tensor.exp)
|
||||
helper_test_op([(45,65)], torch.exp, Tensor.exp)
|
||||
helper_test_op([()], torch.exp, Tensor.exp)
|
||||
def test_exp2(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.exp2(x), Tensor.exp2)
|
||||
helper_test_op([()], lambda x: torch.exp2(x), Tensor.exp2)
|
||||
helper_test_op([(45,65)], torch.exp2, Tensor.exp2)
|
||||
helper_test_op([()], torch.exp2, Tensor.exp2)
|
||||
def test_sign(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.sign(x), Tensor.sign)
|
||||
helper_test_op([()], lambda x: torch.sign(x), Tensor.sign)
|
||||
helper_test_op([(45,65)], torch.sign, Tensor.sign)
|
||||
helper_test_op([()], torch.sign, Tensor.sign)
|
||||
def test_softsign(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.softsign(x), Tensor.softsign)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.softsign(x), Tensor.softsign)
|
||||
helper_test_op([(45,65)], torch.nn.functional.softsign, Tensor.softsign)
|
||||
helper_test_op([()], torch.nn.functional.softsign, Tensor.softsign)
|
||||
def test_sigmoid(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid)
|
||||
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, a=100)
|
||||
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, a=-100)
|
||||
helper_test_op([()], lambda x: x.sigmoid(), Tensor.sigmoid, forward_only=True)
|
||||
helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid)
|
||||
helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, a=100)
|
||||
helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, a=-100)
|
||||
helper_test_op([()], torch.sigmoid, Tensor.sigmoid)
|
||||
def test_softplus(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([(45,65)], torch.nn.functional.softplus, Tensor.softplus, atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], torch.nn.functional.softplus, Tensor.softplus, atol=1e-6, grad_atol=1e-6)
|
||||
def test_gelu(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu)
|
||||
#helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, a=100)
|
||||
if not (CI and Device.DEFAULT == "METAL"):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, a=100)
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, a=-100)
|
||||
def test_quick_gelu(self):
|
||||
helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu)
|
||||
@@ -425,38 +444,18 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, a=-100)
|
||||
helper_test_op([()], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu)
|
||||
def test_elu(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.elu(x), Tensor.elu)
|
||||
helper_test_op([(45,65)], torch.nn.functional.elu, Tensor.elu)
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.elu(x, alpha=0.1), lambda x: Tensor.elu(x, alpha=0.1))
|
||||
helper_test_op([()], lambda x: torch.nn.functional.elu(x), Tensor.elu)
|
||||
helper_test_op([()], torch.nn.functional.elu, Tensor.elu)
|
||||
def test_relu6(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.relu6(x), Tensor.relu6)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.relu6(x), Tensor.relu6)
|
||||
helper_test_op([(45,65)], torch.nn.functional.relu6, Tensor.relu6)
|
||||
helper_test_op([()], torch.nn.functional.relu6, Tensor.relu6)
|
||||
def test_hardswish(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.hardswish(x), Tensor.hardswish, atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.hardswish(x), Tensor.hardswish, atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([(45,65)], torch.nn.functional.hardswish, Tensor.hardswish, atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], torch.nn.functional.hardswish, Tensor.hardswish, atol=1e-6, grad_atol=1e-6)
|
||||
def test_mish(self):
|
||||
def _mish_pytorch(x):
|
||||
return x*torch.tanh(torch.nn.functional.softplus(x))
|
||||
helper_test_op([(45,65)], _mish_pytorch, Tensor.mish, atol=1e-4)
|
||||
helper_test_op([()], _mish_pytorch, Tensor.mish, atol=1e-4)
|
||||
@unittest.skipIf(IMAGE>0, "no 1d dot for images")
|
||||
def test_dot_1d(self):
|
||||
helper_test_op([(65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(65), (65,45)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(8,45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(65), (8,65,45)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
self.helper_test_exception([(4), (1,2)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
|
||||
self.helper_test_exception([(2,1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
|
||||
self.helper_test_exception([(1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
|
||||
def test_dot(self):
|
||||
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(8,45,65), (8,65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
self.helper_test_exception([(2, 4), (1, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
|
||||
self.helper_test_exception([(2, 1), (4, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
|
||||
with self.assertRaises(AssertionError):
|
||||
a = Tensor(3.14)
|
||||
a.matmul(a)
|
||||
helper_test_op([(45,65)], torch.nn.functional.mish, Tensor.mish, atol=1e-4)
|
||||
helper_test_op([()], torch.nn.functional.mish, Tensor.mish, atol=1e-4)
|
||||
|
||||
def test_multinomial(self):
|
||||
# NOTE: this is random, so it has a very large atol
|
||||
@@ -477,17 +476,17 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_argmax(self):
|
||||
self.assertEqual(torch.Tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy()) # check if returns first index for same max
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(), lambda x: x.argmax(), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(0, False), lambda x: x.argmax(0, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(1, False), lambda x: x.argmax(1, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(1, True), lambda x: x.argmax(1, True), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(0, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(1, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(1, True), forward_only=True)
|
||||
|
||||
def test_argmin(self):
|
||||
self.assertEqual(torch.Tensor([2, 2]).argmin().numpy(), Tensor([2, 2]).argmin().numpy())
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(), lambda x: x.argmin(), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(0, False), lambda x: x.argmin(0, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(1, False), lambda x: x.argmin(1, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(1, True), lambda x: x.argmin(1, True), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(0, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(1, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(1, True), forward_only=True)
|
||||
|
||||
def test_einsum(self):
|
||||
# matrix transpose
|
||||
@@ -545,6 +544,25 @@ class TestOps(unittest.TestCase):
|
||||
with self.assertRaises(AssertionError):
|
||||
Tensor.einsum('ij,jk->ij', a)
|
||||
|
||||
@unittest.skipIf(IMAGE>0, "no 1d dot for images")
|
||||
def test_dot_1d(self):
|
||||
helper_test_op([(65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(65), (65,45)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(8,45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(65), (8,65,45)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
self.helper_test_exception([(4), (1,2)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
|
||||
self.helper_test_exception([(2,1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
|
||||
self.helper_test_exception([(1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
|
||||
def test_dot(self):
|
||||
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(8,45,65), (8,65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
self.helper_test_exception([(2, 4), (1, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
|
||||
self.helper_test_exception([(2, 1), (4, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
|
||||
with self.assertRaises(AssertionError):
|
||||
a = Tensor(3.14)
|
||||
a.matmul(a)
|
||||
|
||||
def test_matmul_simple(self):
|
||||
helper_test_op([(4), (4,4)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
def test_matmul(self):
|
||||
@@ -583,126 +601,107 @@ class TestOps(unittest.TestCase):
|
||||
def test_multidot(self):
|
||||
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
|
||||
def test_sum_simple(self):
|
||||
helper_test_op(None, lambda x: x.sum(), Tensor.sum, vals=[[1.,1.]])
|
||||
helper_test_op(None, lambda x: x.sum(), vals=[[1.,1.]])
|
||||
def test_sum_full(self):
|
||||
helper_test_op([(16384)], lambda x: x.sum(), lambda x: x.sum())
|
||||
def test_sum_small_full(self):
|
||||
helper_test_op([(45,5)], lambda x: x.sum(), Tensor.sum)
|
||||
helper_test_op([(16384)], lambda x: x.sum())
|
||||
def test_sum_relu(self):
|
||||
helper_test_op([(3,4,5)], lambda x: x.relu().sum().relu(), lambda x: x.relu().sum().relu())
|
||||
helper_test_op([(3,4,5)], lambda x: x.relu().sum().relu())
|
||||
def test_sum(self):
|
||||
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum)
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=3), lambda x: Tensor.sum(x, axis=3))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,3)), lambda x: Tensor.sum(x, axis=(1,3)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2)), lambda x: Tensor.sum(x, axis=(0,2)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1))
|
||||
helper_test_op([(45,3)], lambda x: x.sum())
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=3))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,3)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1))
|
||||
helper_test_op([()], lambda x: x.sum(), Tensor.sum)
|
||||
def test_sum_with_zeros_shape(self):
|
||||
helper_test_op([(4, 0)], lambda x: x.sum(axis=(0,)), lambda x: Tensor.sum(x, axis=(0,)))
|
||||
helper_test_op([(4, 0)], lambda x: x.sum(axis=(1,)), lambda x: Tensor.sum(x, axis=(1,)))
|
||||
helper_test_op([(4, 0)], lambda x: x.sum(axis=(0,1)), lambda x: Tensor.sum(x, axis=(0,1)))
|
||||
helper_test_op([(4, 0)], lambda x: x.sum(axis=(0,)))
|
||||
helper_test_op([(4, 0)], lambda x: x.sum(axis=(1,)))
|
||||
helper_test_op([(4, 0)], lambda x: x.sum(axis=(0,1)))
|
||||
def test_min(self):
|
||||
helper_test_op([(3,3)], lambda x: x.min(), Tensor.min)
|
||||
helper_test_op([(45,3)], lambda x: x.min(), Tensor.min)
|
||||
helper_test_op([(45,3)], lambda x: x.min().mul(0.5), lambda x: Tensor.min(x).mul(0.5))
|
||||
helper_test_op([()], lambda x: x.min(), Tensor.min)
|
||||
helper_test_op([(3,3)], lambda x: x.min())
|
||||
helper_test_op([(45,3)], lambda x: x.min())
|
||||
helper_test_op([(45,3)], lambda x: x.min().mul(0.5))
|
||||
helper_test_op([()], lambda x: x.min())
|
||||
def test_max(self):
|
||||
helper_test_op([(45,3)], lambda x: x.max(), Tensor.max)
|
||||
helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5))
|
||||
helper_test_op(None, lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), vals=[[[1.0,1.0,0.0,1.0]],])
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1))
|
||||
helper_test_op([()], lambda x: x.max(), Tensor.max)
|
||||
helper_test_op([(45,3)], lambda x: x.max())
|
||||
helper_test_op([(45,3)], lambda x: x.max().mul(0.5))
|
||||
helper_test_op(None, lambda x: x.max().mul(0.5), vals=[[[1.0,1.0,0.0,1.0]],])
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: x.max(axis=1))
|
||||
helper_test_op([()], lambda x: x.max())
|
||||
def test_mean(self):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.mean())
|
||||
helper_test_op([()], lambda x: x.mean())
|
||||
def test_mean_axis(self):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)))
|
||||
def test_var(self):
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.var(x), lambda x: Tensor.var(x))
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.var(x, dim=None, correction=0), lambda x: Tensor.var(x, correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.var(x, dim=None, correction=5), lambda x: Tensor.var(x, correction=5))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var())
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var(correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var(correction=5))
|
||||
def test_var_axis(self):
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.var(x, dim=0), lambda x: Tensor.var(x, axis=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.var(x, dim=2), lambda x: Tensor.var(x, axis=2))
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.var(x, dim=[1, 2]), lambda x: Tensor.var(x, axis=[1, 2]))
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.var(x, dim=None), lambda x: Tensor.var(x, axis=None))
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.var(x, correction=0, dim=0), lambda x: Tensor.var(x, axis=0, correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.var(x, correction=0, dim=2), lambda x: Tensor.var(x, axis=2, correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.var(x, correction=0, dim=[1, 2]), lambda x: Tensor.var(x, axis=[1, 2], correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.var(x, correction=0, dim=None), lambda x: Tensor.var(x, axis=None, correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var(0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var(2))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var([1, 2]))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var(0, correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var(2, correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var([1, 2], correction=0))
|
||||
def test_var_keepdim(self):
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.var(x, dim=None, keepdim=True), lambda x: Tensor.var(x, keepdim=True))
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.var(x, dim=0, keepdim=True, correction=0),
|
||||
lambda x: Tensor.var(x, keepdim=True, correction=0, axis=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var(keepdim=True))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var(0, keepdim=True, correction=0))
|
||||
def test_std(self):
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.std(x), lambda x: Tensor.std(x))
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.std(x, dim=None, correction=0), lambda x: Tensor.std(x, correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.std(x, dim=None, correction=5), lambda x: Tensor.std(x, correction=5))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.std())
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.std(correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.std(correction=5))
|
||||
def test_std_axis(self):
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.std(x, dim=0), lambda x: Tensor.std(x, axis=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.std(x, dim=2), lambda x: Tensor.std(x, axis=2))
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.std(x, dim=[1, 2]), lambda x: Tensor.std(x, axis=[1, 2]))
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.std(x, dim=None), lambda x: Tensor.std(x, axis=None))
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.std(x, correction=0, dim=0), lambda x: Tensor.std(x, axis=0, correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.std(x, correction=0, dim=2), lambda x: Tensor.std(x, axis=2, correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.std(x, correction=0, dim=[1, 2]), lambda x: Tensor.std(x, axis=[1, 2], correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.std(x, correction=0, dim=None), lambda x: Tensor.std(x, axis=None, correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.std(0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.std(2))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.std([1, 2]))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.std(0, correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.std(2, correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.std([1, 2], correction=0))
|
||||
def test_std_keepdim(self):
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.std(x, dim=None, keepdim=True), lambda x: Tensor.std(x, keepdim=True))
|
||||
helper_test_op([(15, 25, 35)], lambda x: torch.std(x, dim=0, keepdim=True, correction=0),
|
||||
lambda x: Tensor.std(x, keepdim=True, correction=0, axis=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.std(keepdim=True))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.std(0, keepdim=True, correction=0))
|
||||
def test_softmax(self):
|
||||
# exceed per kernel buffer limit with backward
|
||||
forward_only = (Device.DEFAULT == "WEBGPU")
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.Softmax(dim=1)(x), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
|
||||
helper_test_op([()], lambda x: torch.nn.Softmax(dim=0)(x), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
|
||||
helper_test_op([(45,65)], torch.nn.Softmax(dim=1), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
|
||||
helper_test_op([()], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
|
||||
def test_log_softmax(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], lambda x: torch.nn.LogSoftmax(dim=0)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45,65)], torch.nn.LogSoftmax(dim=1), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], torch.nn.LogSoftmax(dim=0), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
|
||||
def test_log_softmax_other_axis(self):
|
||||
helper_test_op([(10,10,10)], lambda x: x.log_softmax(0), lambda x: x.log_softmax(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(10,10,10)], lambda x: x.log_softmax(1), lambda x: x.log_softmax(1), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(10,10,10)], lambda x: x.log_softmax(2), lambda x: x.log_softmax(2), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(10,10,10)], lambda x: x.log_softmax(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(10,10,10)], lambda x: x.log_softmax(1), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(10,10,10)], lambda x: x.log_softmax(2), atol=1e-7, grad_atol=1e-7)
|
||||
def test_tanh(self):
|
||||
helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6, a=-100)
|
||||
helper_test_op([()], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([(45,65)], lambda x: x.tanh(), atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([(45,65)], lambda x: x.tanh(), atol=1e-6, grad_atol=1e-6, a=-100)
|
||||
helper_test_op([()], lambda x: x.tanh(), atol=1e-6, grad_atol=1e-6)
|
||||
def test_hardtanh(self):
|
||||
for val in range(10, 30, 5):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.hardtanh(x,-val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.hardtanh(x,-val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.hardtanh(x, -val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.hardtanh(x, -val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6)
|
||||
def test_topo_sort(self):
|
||||
helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([(45,65)], lambda x: (x+x)*x, atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], lambda x: (x+x)*x, atol=1e-6, grad_atol=1e-6)
|
||||
|
||||
def test_scalar_mul(self):
|
||||
helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2)
|
||||
helper_test_op([()], lambda x: x*2, lambda x: x*2)
|
||||
def test_scalar_rmul(self):
|
||||
helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x)
|
||||
helper_test_op([()], lambda x: 2*x, lambda x: 2*x)
|
||||
def test_scalar_sub(self):
|
||||
helper_test_op([(45,65)], lambda x: x-2, lambda x: x-2)
|
||||
helper_test_op([()], lambda x: x-2, lambda x: x-2)
|
||||
def test_scalar_rsub(self):
|
||||
helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x)
|
||||
helper_test_op([()], lambda x: 2-x, lambda x: 2-x)
|
||||
def test_flip_eye_crash(self):
|
||||
helper_test_op([], lambda: (torch.eye(10)@torch.eye(10).flip(0)),
|
||||
lambda: (Tensor.eye(10)@Tensor.eye(10).flip(0)), forward_only=True)
|
||||
|
||||
def test_broadcast_full(self):
|
||||
for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul),
|
||||
(torch.div, Tensor.div)]: #, (torch.pow, Tensor.pow)]:
|
||||
(torch.div, Tensor.div), (torch.pow, Tensor.pow)]:
|
||||
for shapes in [((5,13,24,16), (5,1,24,1)), ((1,3,1,7,1), (2,1,5,1,8))]:
|
||||
with self.subTest(op=torch_op.__name__, shapes=shapes):
|
||||
helper_test_op(shapes, torch_op, tinygrad_op, a=-0.5 if tinygrad_op != Tensor.pow else 0.0)
|
||||
|
||||
def test_broadcast_simple(self):
|
||||
helper_test_op([(45,65), (45,1)], lambda x,y: x/y, lambda x,y: x/y)
|
||||
helper_test_op([(45,65), ()], lambda x,y: x/y, lambda x,y: x/y)
|
||||
helper_test_op([(45,65), (45,1)], lambda x,y: x/y)
|
||||
helper_test_op([(45,65), ()], lambda x,y: x/y)
|
||||
|
||||
def test_broadcast_partial(self):
|
||||
for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul),
|
||||
@@ -714,49 +713,49 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op(shapes, torch_op, tinygrad_op, a=-0.5 if tinygrad_op != Tensor.pow else 0.0)
|
||||
|
||||
def test_slice_in_bounds_1dim(self):
|
||||
helper_test_op([(3)], lambda x: x[1:3], lambda x: x[1:3])
|
||||
helper_test_op([(3)], lambda x: x[0:2], lambda x: x[0:2])
|
||||
helper_test_op([(3)], lambda x: x[-2:2], lambda x: x[-2:2])
|
||||
helper_test_op([(3)], lambda x: x[1:3])
|
||||
helper_test_op([(3)], lambda x: x[0:2])
|
||||
helper_test_op([(3)], lambda x: x[-2:2])
|
||||
|
||||
def test_slice_on_0dim_tensor(self):
|
||||
helper_test_op([()], lambda x: x[None], lambda x: x[None])
|
||||
helper_test_op([()], lambda x: x[None])
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
a = Tensor(3.14)
|
||||
a[0]
|
||||
|
||||
def test_slice_int_indexing(self):
|
||||
helper_test_op([(3)], lambda x: x[0], lambda x: x[0])
|
||||
helper_test_op([(3)], lambda x: x[2], lambda x: x[2])
|
||||
helper_test_op([(3)], lambda x: x[-1], lambda x: x[-1])
|
||||
helper_test_op([(3)], lambda x: x[-3], lambda x: x[-3])
|
||||
helper_test_op([(10,10)], lambda x: x[1], lambda x: x[1])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1,1,1], lambda x: x[1,1,1])
|
||||
helper_test_op([(3)], lambda x: x[0])
|
||||
helper_test_op([(3)], lambda x: x[2])
|
||||
helper_test_op([(3)], lambda x: x[-1])
|
||||
helper_test_op([(3)], lambda x: x[-3])
|
||||
helper_test_op([(10,10)], lambda x: x[1])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1,1,1])
|
||||
|
||||
def test_slice_in_bounds_multidim(self):
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2], lambda x: x[1:2])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, 2], lambda x: x[1:2, 2])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, 0:-1], lambda x: x[1:2, 1:2, 0:-1])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, 2])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, 0:-1])
|
||||
|
||||
def test_slice_with_none(self):
|
||||
helper_test_op([(3,3,3)], lambda x: x[None], lambda x: x[None])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, None], lambda x: x[1:2, None])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, None, 1:2], lambda x: x[1:2, None, 1:2])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, None, -1], lambda x: x[1:2, 1:2, None, -1])
|
||||
helper_test_op([(3,3,3)], lambda x: x[None, None, 1, None, 2, 0:2], lambda x: x[None, None, 1, None, 2, 0:2])
|
||||
helper_test_op([(3,3,3)], lambda x: x[None])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, None])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, None, 1:2])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, None, -1])
|
||||
helper_test_op([(3,3,3)], lambda x: x[None, None, 1, None, 2, 0:2])
|
||||
|
||||
def test_slice_one_endpoint_out_of_bounds(self):
|
||||
helper_test_op([(3,3,3)], lambda x: x[0:4], lambda x: x[0:4])
|
||||
helper_test_op([(3,3,3)], lambda x: x[-6:4], lambda x: x[-6:4])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:50], lambda x: x[1:50])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:50, 1:2, -1], lambda x: x[1:50, 1:2, -1])
|
||||
helper_test_op([(3,3,3)], lambda x: x[0:4])
|
||||
helper_test_op([(3,3,3)], lambda x: x[-6:4])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:50])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:50, 1:2, -1])
|
||||
|
||||
def test_slice_stride_gt_one(self):
|
||||
helper_test_op([(7,5,10)], lambda x: x[::2, ::3, ::4], lambda x: x[::2, ::3, ::4])
|
||||
helper_test_op([(7,5,10)], lambda x: x[1:5:2, ::3, ::4], lambda x: x[1:5:2, ::3, ::4])
|
||||
helper_test_op([(7,5,10)], lambda x: x[1:5:2, 3, ::4], lambda x: x[1:5:2, 3, ::4])
|
||||
helper_test_op([(7,5,10)], lambda x: x[1:5:2, None, None, 3, None, ::4], lambda x: x[1:5:2, None, None, 3, None, ::4])
|
||||
helper_test_op([(7,5,10)], lambda x: x[::2, ::3, ::4])
|
||||
helper_test_op([(7,5,10)], lambda x: x[1:5:2, ::3, ::4])
|
||||
helper_test_op([(7,5,10)], lambda x: x[1:5:2, 3, ::4])
|
||||
helper_test_op([(7,5,10)], lambda x: x[1:5:2, None, None, 3, None, ::4])
|
||||
|
||||
def test_slice_negative_strides(self):
|
||||
# Torch doesn't support slicing with negative steps
|
||||
@@ -767,26 +766,24 @@ class TestOps(unittest.TestCase):
|
||||
np.testing.assert_allclose(a[:, 2:0:-1], t[:, 2:0:-1].numpy())
|
||||
np.testing.assert_allclose(a[:, 2:0:-1, 3:1:-2], t[:, 2:0:-1, 3:1:-2].numpy())
|
||||
np.testing.assert_allclose(a[4:0:-3, 2:0:-1, -1:-5:-2], t[4:0:-3, 2:0:-1, -1:-5:-2].numpy())
|
||||
if Device.DEFAULT not in ["CPU"]:
|
||||
# broken
|
||||
np.testing.assert_allclose(a[2:5:-1, :, :], t[2:5:-1, :, :].numpy()) # shape = (0, 10, 10)
|
||||
np.testing.assert_allclose(a[:, 2:5:-1, :], t[:, 2:5:-1, :].numpy()) # shape = (0, 10, 10)
|
||||
np.testing.assert_allclose(a[:, :, 2:5:-1], t[:, :, 2:5:-1].numpy()) # shape = (0, 10, 10)
|
||||
np.testing.assert_allclose(a[2:5:-1, :, :], t[2:5:-1, :, :].numpy()) # shape = (0, 10, 10)
|
||||
np.testing.assert_allclose(a[:, 2:5:-1, :], t[:, 2:5:-1, :].numpy()) # shape = (0, 10, 10)
|
||||
np.testing.assert_allclose(a[:, :, 2:5:-1], t[:, :, 2:5:-1].numpy()) # shape = (0, 10, 10)
|
||||
|
||||
def test_slice_both_endpoints_out_of_bounds(self):
|
||||
helper_test_op([(3,3,3)], lambda x: x[5:10], lambda x: x[5:10], forward_only=True)
|
||||
helper_test_op([(3,3,3)], lambda x: x[-15:-7], lambda x: x[-15:-7], forward_only=True)
|
||||
helper_test_op([(3,3,3)], lambda x: x[5:10])
|
||||
helper_test_op([(3,3,3)], lambda x: x[-15:-7])
|
||||
|
||||
def test_slice_start_gt_end(self):
|
||||
helper_test_op([(3,3,3)], lambda x: x[-2:2], lambda x: x[-2:2], forward_only=True)
|
||||
helper_test_op([(3,3,3)], lambda x: x[-2:-5], lambda x: x[-2:-5], forward_only=True)
|
||||
helper_test_op([(3,3,3)], lambda x: x[-2:2])
|
||||
helper_test_op([(3,3,3)], lambda x: x[-2:-5])
|
||||
|
||||
def test_slice_empty(self):
|
||||
helper_test_op([(10,10)], lambda x: x[1:1], lambda x: x[1:1], forward_only=True)
|
||||
helper_test_op([(10,10)], lambda x: x[1:1])
|
||||
|
||||
def test_slice_zero_in_shape(self):
|
||||
helper_test_op([(10,10)], lambda x: x[1:1], lambda x: x[1:1], forward_only=True) # x.shape = (0, 10)
|
||||
helper_test_op([(3,3,3)], lambda x: x[-2:-5], lambda x: x[-2:-5], forward_only=True) # x.shape = (0, 3, 3)
|
||||
helper_test_op([(10,10)], lambda x: x[1:1]) # x.shape = (0, 10)
|
||||
helper_test_op([(3,3,3)], lambda x: x[-2:-5]) # x.shape = (0, 3, 3)
|
||||
|
||||
def test_slice_errors(self):
|
||||
a = Tensor.ones(4, 3)
|
||||
@@ -799,11 +796,11 @@ class TestOps(unittest.TestCase):
|
||||
with self.assertRaises(IndexError): b[:] # slice cannot be applied to a 0-dim tensor
|
||||
|
||||
def test_slice_ellipsis(self):
|
||||
helper_test_op([(3,3,3,3)], lambda x: x[..., 0], lambda x: x[..., 0])
|
||||
helper_test_op([(3,3,3,3)], lambda x: x[0, ...], lambda x: x[0, ...])
|
||||
helper_test_op([(3,3,3,3)], lambda x: x[0, ..., 0], lambda x: x[0, ..., 0])
|
||||
helper_test_op([(3,3,3,3)], lambda x: x[0:3, ..., 2:3], lambda x: x[0:3, ..., 2:3])
|
||||
helper_test_op([(3,3,3,3)], lambda x: x[None, 0:3, ..., 0, None], lambda x: x[None, 0:3, ..., 0, None])
|
||||
helper_test_op([(3,3,3,3)], lambda x: x[..., 0])
|
||||
helper_test_op([(3,3,3,3)], lambda x: x[0, ...])
|
||||
helper_test_op([(3,3,3,3)], lambda x: x[0, ..., 0])
|
||||
helper_test_op([(3,3,3,3)], lambda x: x[0:3, ..., 2:3])
|
||||
helper_test_op([(3,3,3,3)], lambda x: x[None, 0:3, ..., 0, None])
|
||||
|
||||
def test_pad2d(self):
|
||||
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)))
|
||||
@@ -850,80 +847,82 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(4,4)], lambda x: torch.stack([x for i in range(4)])[3], lambda x: Tensor.stack([x for i in range(4)])[3])
|
||||
|
||||
def test_transpose(self):
|
||||
helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(1,2))
|
||||
helper_test_op([(3,3,3)], lambda x: x.transpose(0,2), lambda x: x.transpose(0,2))
|
||||
helper_test_op([(1,2,3,4)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.permute(order=(3,0,2,1)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.permute(order=(3,2,1,0)))
|
||||
helper_test_op([()], lambda x: x.permute(()), lambda x: x.permute(()))
|
||||
helper_test_op([(3,3)], lambda x: x.T)
|
||||
helper_test_op([(3,3,3)], lambda x: x.transpose(1,2))
|
||||
helper_test_op([(3,3,3)], lambda x: x.transpose(0,2))
|
||||
helper_test_op([(1,2,3,4)], lambda x: x.permute((3,0,2,1)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.permute((3,2,1,0)))
|
||||
helper_test_op([()], lambda x: x.permute(()))
|
||||
|
||||
def test_reshape(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6)))
|
||||
helper_test_op([()], lambda x: torch.reshape(x, []), lambda x: x.reshape([]))
|
||||
helper_test_op([(1,)], lambda x: torch.reshape(x, []), lambda x: x.reshape([]))
|
||||
helper_test_op([()], lambda x: torch.reshape(x, [1]), lambda x: x.reshape([1]))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.reshape((-1,3,6,6)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.reshape((-1,1,6,6)))
|
||||
helper_test_op([()], lambda x: x.reshape([]))
|
||||
helper_test_op([(1,)], lambda x: x.reshape([]))
|
||||
helper_test_op([()], lambda x: x.reshape([1]))
|
||||
helper_test_op([()], lambda x: x.reshape([1, 1, 1]))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
x = Tensor.ones((4,3,6,6))
|
||||
x.reshape([])
|
||||
|
||||
def test_flip(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,)), lambda x: x.flip(axis=(0,)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1)), lambda x: x.flip(axis=(0,1)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1,3)), lambda x: x.flip(axis=(0,1,3)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(3,)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1,3)).flip((0,)), lambda x: x.flip(axis=(0,1,3)).flip(0))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(-1,)))
|
||||
helper_test_op([()], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=()))
|
||||
helper_test_op([(1,)], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=()))
|
||||
helper_test_op([(4, 3, 6, 6)], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=()))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.flip((0,)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.flip((0,1)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.flip((0,1,3)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.flip((3,)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.flip((0,1,3)).flip(0))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.flip((-1,)))
|
||||
helper_test_op([()], lambda x: x.flip(()))
|
||||
helper_test_op([(1,)], lambda x: x.flip(()))
|
||||
helper_test_op([(4, 3, 6, 6)], lambda x: x.flip(()))
|
||||
|
||||
def test_squeeze(self):
|
||||
helper_test_op([(1,3,6,6)], lambda x: torch.squeeze(x, 0), lambda x: x.squeeze(dim=0))
|
||||
helper_test_op([(4,3,1,6)], lambda x: torch.squeeze(x, 1), lambda x: x.squeeze(dim=1))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.squeeze(x, 3), lambda x: x.squeeze(dim=3))
|
||||
helper_test_op([(1,3,6,6)], lambda x: x.squeeze(0))
|
||||
helper_test_op([(4,3,1,6)], lambda x: x.squeeze(1))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.squeeze(3))
|
||||
self.helper_test_exception([(4,3,6,6)], lambda x: torch.squeeze(x, 50), lambda x: x.squeeze(dim=50), expected=IndexError)
|
||||
self.helper_test_exception([(4,3,6,6)], lambda x: torch.squeeze(x, -50), lambda x: x.squeeze(dim=-50), expected=IndexError)
|
||||
helper_test_op([(4,3,6,1)], lambda x: torch.squeeze(x, -1), lambda x: x.squeeze(dim=-1))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.squeeze(x), lambda x: x.squeeze())
|
||||
helper_test_op([(1,3,6,6)], lambda x: torch.squeeze(x), lambda x: x.squeeze())
|
||||
helper_test_op([(2,3,1)], lambda x: torch.squeeze(x), lambda x: x.squeeze())
|
||||
helper_test_op([()], lambda x: torch.squeeze(x, -1), lambda x: x.squeeze(dim=-1))
|
||||
helper_test_op([()], lambda x: torch.squeeze(x, 0), lambda x: x.squeeze(dim=0))
|
||||
helper_test_op([()], lambda x: torch.squeeze(x), lambda x: x.squeeze())
|
||||
helper_test_op([(4,3,6,1)], lambda x: x.squeeze(-1))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.squeeze())
|
||||
helper_test_op([(1,3,6,6)], lambda x: x.squeeze())
|
||||
helper_test_op([(2,3,1)], lambda x: x.squeeze())
|
||||
helper_test_op([()], lambda x: x.squeeze(-1))
|
||||
helper_test_op([()], lambda x: x.squeeze(0))
|
||||
helper_test_op([()], lambda x: x.squeeze())
|
||||
self.helper_test_exception([()], lambda x: torch.squeeze(x, 10), lambda x: x.squeeze(dim=10), expected=IndexError)
|
||||
self.helper_test_exception([()], lambda x: torch.squeeze(x, 1), lambda x: x.squeeze(dim=1), expected=IndexError)
|
||||
self.helper_test_exception([()], lambda x: torch.squeeze(x, -2), lambda x: x.squeeze(dim=-2), expected=IndexError)
|
||||
|
||||
def test_unsqueeze(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 0), lambda x: x.unsqueeze(dim=0))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 4), lambda x: x.unsqueeze(dim=4))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -1), lambda x: x.unsqueeze(dim=-1))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -3), lambda x: x.unsqueeze(dim=-3))
|
||||
helper_test_op([()], lambda x: torch.unsqueeze(x, 0), lambda x: x.unsqueeze(dim=0))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.unsqueeze(0))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.unsqueeze(4))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.unsqueeze(-1))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.unsqueeze(-3))
|
||||
helper_test_op([()], lambda x: x.unsqueeze(0))
|
||||
|
||||
def test_flatten(self):
|
||||
for axis in range(3):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flatten(x, start_dim=axis), lambda x: x.flatten(start_dim=axis))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.flatten(start_dim=axis))
|
||||
for axis in range(3):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flatten(x, end_dim=axis), lambda x: x.flatten(end_dim=axis))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flatten(x, start_dim=1, end_dim=3), lambda x: x.flatten(start_dim=1, end_dim=3))
|
||||
helper_test_op([()], lambda x: x.flatten(), lambda x: x.flatten())
|
||||
helper_test_op([(1,)], lambda x: x.flatten(), lambda x: x.flatten())
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.flatten(end_dim=axis))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.flatten(start_dim=1, end_dim=3))
|
||||
helper_test_op([()], lambda x: x.flatten())
|
||||
helper_test_op([(1,)], lambda x: x.flatten())
|
||||
|
||||
def test_unflatten(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.unflatten(x, 0, (2, 2)), lambda x: x.unflatten(0, (2, 2)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.unflatten(x, 3, (3, 2)), lambda x: x.unflatten(3, (3, 2)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.unflatten(x, -1, (3, 2, 1)), lambda x: x.unflatten(-1, (3, 2, 1)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.unflatten(0, (2, 2)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.unflatten(3, (3, 2)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.unflatten(-1, (3, 2, 1)))
|
||||
|
||||
def test_detach(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), forward_only=True)
|
||||
helper_test_op([()], lambda x: x.detach(), lambda x: x.detach(), forward_only=True)
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.detach(), forward_only=True)
|
||||
helper_test_op([()], lambda x: x.detach(), forward_only=True)
|
||||
|
||||
def test_expand(self):
|
||||
arg = (4,3,2,6)
|
||||
helper_test_op([(4,3,1,6)], lambda x: x.expand(arg), lambda x: x.expand(shape=arg))
|
||||
helper_test_op([()], lambda x: x.expand([]), lambda x: x.expand(shape=[]))
|
||||
helper_test_op([(4,3,1,6)], lambda x: x.expand((4,3,2,6)))
|
||||
helper_test_op([(1,1,1,1)], lambda x: x.expand((4,3,2,6)))
|
||||
helper_test_op([()], lambda x: x.expand([]))
|
||||
|
||||
@unittest.skip("very slow")
|
||||
def test_sd_big_conv(self):
|
||||
@@ -1348,12 +1347,12 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(3, 3)], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats))
|
||||
|
||||
def test_clip(self):
|
||||
helper_test_op([(45,65)], lambda x: x.clip(-2.3, 1.2), lambda x: x.clip(-2.3, 1.2))
|
||||
helper_test_op([(45,65)], lambda x: x.clip(0, 0), lambda x: x.clip(0, 0))
|
||||
helper_test_op([(45,65)], lambda x: x.clip(10, 100), lambda x: x.clip(10, 100))
|
||||
helper_test_op([(45,65)], lambda x: x.clip(0, 0.1), lambda x: x.clip(0, 0.1))
|
||||
helper_test_op([(45,65)], lambda x: x.clip(-0.3, -0.2), lambda x: x.clip(-0.3, -0.2))
|
||||
helper_test_op([(45,65)], lambda x: x.clip(3, 0), lambda x: x.clip(3, 0))
|
||||
helper_test_op([(45,65)], lambda x: x.clip(-2.3, 1.2))
|
||||
helper_test_op([(45,65)], lambda x: x.clip(0, 0))
|
||||
helper_test_op([(45,65)], lambda x: x.clip(10, 100))
|
||||
helper_test_op([(45,65)], lambda x: x.clip(0, 0.1))
|
||||
helper_test_op([(45,65)], lambda x: x.clip(-0.3, -0.2))
|
||||
helper_test_op([(45,65)], lambda x: x.clip(3, 0))
|
||||
|
||||
def test_matvecmat(self):
|
||||
helper_test_op([(1,128), (128,128), (128,128)], lambda x,y,z: (x@y).relu()@z, atol=1e-4)
|
||||
@@ -1484,8 +1483,7 @@ class TestOps(unittest.TestCase):
|
||||
lambda x: x.gather(idx=a, dim=0), expected=(RuntimeError, AssertionError))
|
||||
|
||||
def test_scaled_product_attention(self):
|
||||
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z),
|
||||
lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z))
|
||||
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], torch.nn.functional.scaled_dot_product_attention, Tensor.scaled_dot_product_attention)
|
||||
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64), (32,8,16,16)],
|
||||
lambda x,y,z,m: torch.nn.functional.scaled_dot_product_attention(x,y,z,attn_mask=m),
|
||||
lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,attn_mask=m))
|
||||
|
||||
Reference in New Issue
Block a user