mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
From teeny (#2426)
* changes from teenygrad work * support not supporting ImageDType/PtrDType * fixups from teeny
This commit is contained in:
@@ -4,7 +4,7 @@ import math
|
||||
import numpy as np
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, dtypes, Context, NOOPT
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, dtypes
|
||||
from tinygrad.ops import Device
|
||||
|
||||
if CI:
|
||||
@@ -271,6 +271,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div)
|
||||
helper_test_op([(), ()], lambda x,y: x/y, Tensor.div)
|
||||
helper_test_op(None, lambda x,y: x/y, Tensor.div, forward_only=True, vals=[[5],[1]])
|
||||
def test_div_int(self):
|
||||
helper_test_op(None, lambda x: (x/2).to(torch.int), lambda x: x/2, forward_only=True, vals=[[3]])
|
||||
def test_div_const(self):
|
||||
helper_test_op([(45,65)], lambda x: x/255, lambda x: x/255)
|
||||
@@ -700,7 +701,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(1,)], lambda x: torch.reshape(x, []), lambda x: x.reshape([]))
|
||||
helper_test_op([()], lambda x: torch.reshape(x, [1]), lambda x: x.reshape([1]))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(ValueError):
|
||||
x = Tensor.ones((4,3,6,6))
|
||||
x.reshape([])
|
||||
|
||||
@@ -785,11 +786,6 @@ class TestOps(unittest.TestCase):
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
def test_simple_conv2d_noopt(self):
|
||||
# useful with IMAGE enabled
|
||||
with Context(NOOPT=1):
|
||||
self.test_simple_conv2d()
|
||||
|
||||
@unittest.skipIf(IMAGE>0, "no conv3d on images")
|
||||
def test_simple_conv3d(self):
|
||||
helper_test_op([(1,4,9,9,9), (4,4,3,3,3)],
|
||||
@@ -1166,7 +1162,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(4, 6, 3)], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats))
|
||||
helper_test_op([()], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(ValueError):
|
||||
x.repeat((2, 4))
|
||||
|
||||
np.testing.assert_allclose(x.repeat((2, 0, 4)).numpy(), Tensor.zeros(8, 0, 12).numpy())
|
||||
|
||||
Reference in New Issue
Block a user