mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
test lazy also, make TestMNIST faster
This commit is contained in:
6
.github/workflows/test.yml
vendored
6
.github/workflows/test.yml
vendored
@@ -52,6 +52,8 @@ jobs:
|
||||
run: pip install -e '.[testing]'
|
||||
- name: Run Pytest
|
||||
run: python -m pytest -s -v
|
||||
- name: Run Pytest (lazy)
|
||||
run: LAZY=1 python -m pytest -s -v
|
||||
|
||||
testtorch:
|
||||
name: Torch Tests
|
||||
@@ -70,6 +72,8 @@ jobs:
|
||||
run: pip install -e '.[testing]'
|
||||
- name: Run Pytest
|
||||
run: TORCH=1 python -m pytest -s -v
|
||||
- name: Run Pytest (lazy)
|
||||
run: LAZY=1 TORCH=1 python -m pytest -s -v
|
||||
|
||||
testgpu:
|
||||
name: GPU Tests
|
||||
@@ -94,6 +98,8 @@ jobs:
|
||||
run: pip install -e '.[gpu,testing]'
|
||||
- name: Run Pytest
|
||||
run: GPU=1 python -m pytest -s -v
|
||||
- name: Run Pytest (lazy)
|
||||
run: LAZY=1 GPU=1 python -m pytest -s -v
|
||||
|
||||
testmypy:
|
||||
name: Mypy Tests
|
||||
|
||||
@@ -81,21 +81,21 @@ class TestMNIST(unittest.TestCase):
|
||||
np.random.seed(1337)
|
||||
model = TinyConvNet()
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
train(model, X_train, Y_train, optimizer, steps=200)
|
||||
train(model, X_train, Y_train, optimizer, steps=100)
|
||||
assert evaluate(model, X_test, Y_test) > 0.95
|
||||
|
||||
def test_sgd(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyBobNet()
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.001)
|
||||
train(model, X_train, Y_train, optimizer, steps=1000)
|
||||
train(model, X_train, Y_train, optimizer, steps=600)
|
||||
assert evaluate(model, X_test, Y_test) > 0.95
|
||||
|
||||
def test_rmsprop(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyBobNet()
|
||||
optimizer = optim.RMSprop(model.parameters(), lr=0.0002)
|
||||
train(model, X_train, Y_train, optimizer, steps=1000)
|
||||
train(model, X_train, Y_train, optimizer, steps=400)
|
||||
assert evaluate(model, X_test, Y_test) > 0.95
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user