smaller tests are faster tests [pr] (#10704)

* remove del spam from CI

* more

* preconstruct default buffer spec

* ignore those errors

* check exception

* more exception check

* skip stuff

* smaller tests mean faster tests

* a few more
This commit is contained in:
George Hotz
2025-06-08 10:54:19 -07:00
committed by GitHub
parent 67a1c92fc0
commit 4e2c3560b4
6 changed files with 19 additions and 17 deletions

View File

@@ -356,7 +356,7 @@ jobs:
- name: Test README
run: awk '/```python/{flag=1;next}/```/{flag=0}flag' README.md > README.py && PYTHONPATH=. python README.py
- name: Run unit tests
run: PYTHONPATH="." python -m pytest -n=auto test/unit/
run: PYTHONPATH="." python -m pytest -n=auto test/unit/ --durations=20
- name: Run targetted tests on NULL backend
run: PYTHONPATH="." NULL=1 python3 test/test_multitensor.py TestMultiTensor.test_data_parallel_resnet_train_step
# TODO: support fake weights

View File

@@ -5,7 +5,7 @@ from tinygrad.helpers import Context
class TestConv(unittest.TestCase):
def test_simple(self):
x = Tensor.ones(1,12,128,256).contiguous().realize()
x = Tensor.ones(1,12,16,32).contiguous().realize()
w = Tensor.ones(32,12,3,3).contiguous().realize()
ret = x.conv2d(w, stride=(2,2), padding=(1,1)).numpy()
# it's not 108 around the padding
@@ -14,7 +14,7 @@ class TestConv(unittest.TestCase):
assert ret[0,0,0,1] == 72
def test_simple_rand(self):
x = Tensor.rand(1,12,128,256)
x = Tensor.rand(1,12,16,32)
w = Tensor.rand(32,12,3,3)
x.conv2d(w, stride=(2,2), padding=(1,1)).numpy()
@@ -47,7 +47,7 @@ class TestConv(unittest.TestCase):
np.testing.assert_allclose(out.relu().numpy(), np.maximum(out.numpy(), 0))
def test_two_binops_no_rerun(self):
x = Tensor.randn(1,12,128,256)
x = Tensor.randn(1,12,16,32)
w = Tensor.randn(32,12,3,3)
out = x.conv2d(w, stride=(2,2), padding=(1,1))
r1, r2 = out.relu(), (out-1)
@@ -55,7 +55,7 @@ class TestConv(unittest.TestCase):
np.testing.assert_allclose(r2.numpy(), out.numpy() - 1)
def test_two_overlapping_binops_no_rerun(self):
x = Tensor.randn(1,12,128,256)
x = Tensor.randn(1,12,16,32)
w = Tensor.randn(32,12,3,3)
out = x.conv2d(w, stride=(2,2), padding=(1,1))
r1, r2 = out.relu(), out.elu()
@@ -72,7 +72,7 @@ class TestConv(unittest.TestCase):
np.testing.assert_allclose(r2.numpy(), np.where(out.numpy() > 0, out.numpy(), (np.exp(out.numpy()) - 1)), atol=1e-5)
def test_first_three(self):
x = Tensor.rand(1,12,128,256)
x = Tensor.rand(1,12,16,32)
w = Tensor.rand(32,12,3,3)
x = x.conv2d(w, stride=(2,2), padding=(1,1)).elu()
@@ -87,7 +87,7 @@ class TestConv(unittest.TestCase):
print(x.shape)
def test_elu(self):
x = Tensor.rand(1,12,128,256)
x = Tensor.rand(1,12,16,32)
w = Tensor.rand(32,12,3,3)
x = x.conv2d(w, stride=(2,2), padding=(1,1))
@@ -99,13 +99,13 @@ class TestConv(unittest.TestCase):
x.numpy()
def test_reduce_relu(self):
x = Tensor.rand(1,12,128,256)
x = Tensor.rand(1,12,16,32)
x = x.sum(keepdim=True).relu()
x.numpy()
def test_bias(self):
from tinygrad.nn import Conv2d
x = Tensor.rand(1,12,128,256)
x = Tensor.rand(1,12,16,32)
c = Conv2d(12, 32, 3)
x = c(x).relu()
w = Tensor.uniform(32, 1, 3, 3)
@@ -118,14 +118,14 @@ class TestConv(unittest.TestCase):
(w+x).numpy()
def test_reorder(self):
x = Tensor.rand(1,12,128,256)
x = Tensor.rand(1,12,16,32)
w = Tensor.rand(12,12,3,3)
x = x.conv2d(w, padding=(1,1))
print(x.shape)
x = x.reshape((1, 12, 256, 128))
x = x.reshape((1, 12, 32, 16))
x += 1
x += 1
x = x.reshape((1, 12, 128, 256))
x = x.reshape((1, 12, 16, 32))
x.numpy()
if __name__ == '__main__':

View File

@@ -301,11 +301,11 @@ class TestRandomness(unittest.TestCase):
lambda x: np.random.uniform(-1, 1, size=x) * math.sqrt(6 / (x[0] + math.prod(x[1:])))))
def test_kaiming_uniform(self):
for shape in [(256, 128, 3, 3), (80, 44), (3, 55, 35)]:
for shape in [(32, 128, 3, 3), (80, 44), (3, 55, 35)]:
self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), shape=shape))
def test_kaiming_normal(self):
for shape in [(256, 128, 3, 3), (80, 44), (3, 55, 35)]:
for shape in [(32, 128, 3, 3), (80, 44), (3, 55, 35)]:
self.assertTrue(equal_distribution(Tensor.kaiming_normal, lambda x: torch.nn.init.kaiming_normal_(torch.empty(x)), shape=shape))
def test_multinomial(self):

View File

@@ -1233,8 +1233,8 @@ class TestSchedule(unittest.TestCase):
def test_adam_step_fusion(self):
with Tensor.train():
x = Tensor.empty(4, 64, 768)
layer = nn.Linear(768, 768*4)
x = Tensor.empty(4, 64, 32)
layer = nn.Linear(32, 32*4)
_realize_weights(layer)
opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
layer(x).relu().sum().backward()

View File

@@ -67,7 +67,9 @@ class CLAllocator(LRUAllocator['CLDevice']):
cl.cl_image_format(cl.CL_RGBA, {2: cl.CL_HALF_FLOAT, 4: cl.CL_FLOAT}[options.image.itemsize]),
options.image.shape[1], options.image.shape[0], 0, None, status := ctypes.c_int32()), status), options)
return (checked(cl.clCreateBuffer(self.dev.context, cl.CL_MEM_READ_WRITE, size, None, status := ctypes.c_int32()), status), options)
def _free(self, opaque:tuple[ctypes._CData, BufferSpec], options:BufferSpec): check(cl.clReleaseMemObject(opaque[0]))
def _free(self, opaque:tuple[ctypes._CData, BufferSpec], options:BufferSpec):
try: check(cl.clReleaseMemObject(opaque[0]))
except AttributeError: pass
def _copyin(self, dest:tuple[ctypes._CData, BufferSpec], src:memoryview):
if dest[1].image is not None:
check(cl.clEnqueueWriteImage(self.dev.queue, dest[0], False, (ctypes.c_size_t * 3)(0,0,0),