From teeny (#2426)

* changes from teenygrad work

* support not supporting ImageDType/PtrDType

* fixups from teeny
This commit is contained in:
George Hotz
2023-11-24 12:50:56 -08:00
committed by GitHub
parent 9ae83fba04
commit 8ff2e13550
11 changed files with 24 additions and 22 deletions

View File

@@ -4,7 +4,7 @@ import math
import numpy as np
import unittest
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, dtypes, Context, NOOPT
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, dtypes
from tinygrad.ops import Device
if CI:
@@ -271,6 +271,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div)
helper_test_op([(), ()], lambda x,y: x/y, Tensor.div)
helper_test_op(None, lambda x,y: x/y, Tensor.div, forward_only=True, vals=[[5],[1]])
def test_div_int(self):
helper_test_op(None, lambda x: (x/2).to(torch.int), lambda x: x/2, forward_only=True, vals=[[3]])
def test_div_const(self):
helper_test_op([(45,65)], lambda x: x/255, lambda x: x/255)
@@ -700,7 +701,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(1,)], lambda x: torch.reshape(x, []), lambda x: x.reshape([]))
helper_test_op([()], lambda x: torch.reshape(x, [1]), lambda x: x.reshape([1]))
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
x = Tensor.ones((4,3,6,6))
x.reshape([])
@@ -785,11 +786,6 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
def test_simple_conv2d_noopt(self):
# useful with IMAGE enabled
with Context(NOOPT=1):
self.test_simple_conv2d()
@unittest.skipIf(IMAGE>0, "no conv3d on images")
def test_simple_conv3d(self):
helper_test_op([(1,4,9,9,9), (4,4,3,3,3)],
@@ -1166,7 +1162,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(4, 6, 3)], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats))
helper_test_op([()], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats))
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
x.repeat((2, 4))
np.testing.assert_allclose(x.repeat((2, 0, 4)).numpy(), Tensor.zeros(8, 0, 12).numpy())