Files
tinygrad/test/models/test_train.py
Diogo a9a1df785f Webgpu support (#1077)
* initial commit

* 81 passing

* 105 passing tests

* 148 passing

* CI tests

* install dep on ci

* try opencl pkgs

* try using vulkan

* down to only 6 failing

* refactor

* cleaning up

* another test skipped due to buffer limit

* linter

* segfault

* indent fix

* another segfault found

* small touchups

* Fix max and maxpool tests

* Add constant folding

* Add javascript export script

* better asserts in codegen

* manual upcasting

* reverted token type change

* skip safetensor test due to unsupported type

* FIx efficientnet and all other model tests

* Remove np copy

* fixed indent and missing import

* manually destroy the buffer

* revert back to length

* linter errors

* removed extra val

* skip broken tests

* skipping more tests

* Make the page pretty

* Save model weights as safetensor

* Fix imagenet to c test

* Fix second imagenet to c bug

* Async and paralel kernel compilation

* workgroup support

* reversed local size

* fixed non local bug

* correct local groups

* ci experiment

* removed typo

* Fix define local by using shared memory

* Refactor

* try running on mac

* match metal tests

* add more workers

* scope down tests

* trying windows runner

* fixed windows env

* see how many it can do

* merged master

* refactor

* missed refactor

* increase test suite coverage

* missing import

* whitespace in test_efficientnet.py

* getting there

* fixed reset

* fixed bufs

* switched to cstyle

* cleanup

* min/max rename

* one more linter issue

* fixed demo

* linter

* testing ci chrome

* add unsafe webgpu arg

* add build step

* remove WEBGPU from cmd line

* use module

* try forcing directx

* trying forced metal backend

* temp disable conv2d for CI

* disable conv_trasnpose2d

---------

Co-authored-by: 0x4d - Martin Loretz <20306567+martinloretzzz@users.noreply.github.com>
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2023-07-12 12:52:06 -07:00

81 lines
2.3 KiB
Python

import unittest
import time
import numpy as np
from tinygrad.state import get_parameters
from tinygrad.nn import optim
from tinygrad.tensor import Device
from tinygrad.helpers import getenv
from extra.training import train
from models.convnext import ConvNeXt
from models.efficientnet import EfficientNet
from models.transformer import Transformer
from models.vit import ViT
from models.resnet import ResNet18
BS = getenv("BS", 2)
def train_one_step(model,X,Y):
params = get_parameters(model)
pcount = 0
for p in params:
pcount += np.prod(p.shape)
optimizer = optim.SGD(params, lr=0.001)
print("stepping %r with %.1fM params bs %d" % (type(model), pcount/1e6, BS))
st = time.time()
train(model, X, Y, optimizer, steps=1, BS=BS)
et = time.time()-st
print("done in %.2f ms" % (et*1000.))
def check_gc():
if Device.DEFAULT == "GPU":
from extra.introspection import print_objects
assert print_objects() == 0
class TestTrain(unittest.TestCase):
def test_convnext(self):
model = ConvNeXt(depths=[1], dims=[16])
X = np.zeros((BS,3,224,224), dtype=np.float32)
Y = np.zeros((BS), dtype=np.int32)
train_one_step(model,X,Y)
check_gc()
def test_efficientnet(self):
model = EfficientNet(0)
X = np.zeros((BS,3,224,224), dtype=np.float32)
Y = np.zeros((BS), dtype=np.int32)
train_one_step(model,X,Y)
check_gc()
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "too many buffers for webgpu")
def test_vit(self):
model = ViT()
X = np.zeros((BS,3,224,224), dtype=np.float32)
Y = np.zeros((BS,), dtype=np.int32)
train_one_step(model,X,Y)
check_gc()
def test_transformer(self):
# this should be small GPT-2, but the param count is wrong
# (real ff_dim is 768*4)
model = Transformer(syms=10, maxlen=6, layers=12, embed_dim=768, num_heads=12, ff_dim=768//4)
X = np.zeros((BS,6), dtype=np.float32)
Y = np.zeros((BS,6), dtype=np.int32)
train_one_step(model,X,Y)
check_gc()
def test_resnet(self):
X = np.zeros((BS, 3, 224, 224), dtype=np.float32)
Y = np.zeros((BS), dtype=np.int32)
for resnet_v in [ResNet18]:
model = resnet_v()
model.load_from_pretrained()
train_one_step(model, X, Y)
check_gc()
def test_bert(self):
# TODO: write this
pass
if __name__ == '__main__':
unittest.main()