mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
less lines and fix default device
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, DEFAULT_DEVICE
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.nn import *
|
||||
from extra.utils import get_parameters
|
||||
import torch
|
||||
|
||||
@unittest.skipUnless(not DEFAULT_DEVICE, "Not Implemented")
|
||||
@unittest.skipUnless(Device.DEFAULT == Device.CPU, "Not Implemented")
|
||||
class TestNN(unittest.TestCase):
|
||||
|
||||
def test_batchnorm2d(self, training=False):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor, DEFAULT_DEVICE
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
||||
|
||||
x_init = np.random.randn(1,3).astype(np.float32)
|
||||
@@ -72,7 +72,7 @@ class TestTinygrad(unittest.TestCase):
|
||||
expected = n * (1 - rate)
|
||||
np.testing.assert_allclose(non_zeros, expected, rtol=1e-3)
|
||||
|
||||
@unittest.skipUnless(not DEFAULT_DEVICE, "float64 not supported on GPU")
|
||||
@unittest.skipUnless(Device.DEFAULT == Device.CPU, "float64 not supported on GPU")
|
||||
def test_jacobian(self):
|
||||
W = np.random.RandomState(1337).random((10, 5))
|
||||
x = np.random.RandomState(7331).random((1, 10)) - 0.5
|
||||
@@ -91,7 +91,7 @@ class TestTinygrad(unittest.TestCase):
|
||||
np.testing.assert_allclose(PJ, J, atol = 1e-5)
|
||||
np.testing.assert_allclose(PJ, NJ, atol = 1e-5)
|
||||
|
||||
@unittest.skipUnless(not DEFAULT_DEVICE, "float64 not supported on GPU")
|
||||
@unittest.skipUnless(Device.DEFAULT == Device.CPU, "float64 not supported on GPU")
|
||||
def test_gradcheck(self):
|
||||
W = np.random.RandomState(1337).random((10, 5))
|
||||
x = np.random.RandomState(7331).random((1, 10)) - 0.5
|
||||
|
||||
@@ -32,21 +32,21 @@ class ProfileOp:
|
||||
debug_times[self.name] += et
|
||||
print(f"{self.name:>20} : {et:>7.2f} ms {str([y.shape for y in self.x]):>40} {'-> '+str(self.output.shape) if self.output is not None else ''}")
|
||||
|
||||
# **** start with two base classes, Tensor and Function ****
|
||||
# **** enumerate supported devices ****
|
||||
|
||||
class Device:
|
||||
buffers = {}
|
||||
imports = {}
|
||||
_ops = sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ops")))
|
||||
imports = dict(enumerate([os.path.splitext(x)[0] for x in _ops if x.startswith("ops_")]))
|
||||
DEFAULT = None
|
||||
for i,o in enumerate([os.path.splitext(x)[0] for x in _ops if x.startswith("ops_")]):
|
||||
name = o[len("ops_"):].upper()
|
||||
if os.environ.get(name, 0) == "1":
|
||||
DEFAULT = i
|
||||
buffers = {}
|
||||
for i,op in imports.items():
|
||||
name = op[len("ops_"):].upper()
|
||||
vars()[name] = i
|
||||
imports[i] = o
|
||||
DEFAULT = i if os.environ.get(name, 0) == "1" else DEFAULT
|
||||
DEFAULT = CPU if DEFAULT is None else DEFAULT
|
||||
|
||||
# **** start with two base classes, Tensor and Function ****
|
||||
|
||||
class Tensor:
|
||||
did_float_warning = False
|
||||
training = True
|
||||
|
||||
Reference in New Issue
Block a user