mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
test_train cleanup (#12140)
* test_train cleanup remove skipIf due to buffer sizes, runs locally * those are slow
This commit is contained in:
@@ -53,8 +53,8 @@ class TestRealWorld(unittest.TestCase):
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
|
||||
def test_stable_diffusion(self):
|
||||
params = unet_params
|
||||
params["model_ch"] = 16
|
||||
params["ctx_dim"] = 16
|
||||
params["model_ch"] = 8
|
||||
params["ctx_dim"] = 8
|
||||
params["num_res_blocks"] = 1
|
||||
params["n_heads"] = 2
|
||||
model = UNetModel(**params)
|
||||
@@ -144,6 +144,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=4)
|
||||
assert not np.isnan(lr_scheduler.min_lr), "lr too small or initial_div_facotr too big for half"
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "CPU", "slow")
|
||||
def test_bert(self):
|
||||
with Tensor.train():
|
||||
args_tiny = {"attention_probs_dropout_prob": 0.0, "hidden_dropout_prob": 0.0, "vocab_size": 30522, "type_vocab_size": 2,
|
||||
|
||||
@@ -40,7 +40,6 @@ class TestTrain(unittest.TestCase):
|
||||
check_gc()
|
||||
|
||||
@unittest.skipIf(CI, "slow")
|
||||
@unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal")
|
||||
def test_efficientnet(self):
|
||||
model = EfficientNet(0)
|
||||
X = np.zeros((BS,3,224,224), dtype=np.float32)
|
||||
@@ -49,7 +48,6 @@ class TestTrain(unittest.TestCase):
|
||||
check_gc()
|
||||
|
||||
@unittest.skipIf(CI, "slow")
|
||||
@unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal")
|
||||
def test_vit(self):
|
||||
model = ViT()
|
||||
X = np.zeros((BS,3,224,224), dtype=np.float32)
|
||||
@@ -57,7 +55,7 @@ class TestTrain(unittest.TestCase):
|
||||
train_one_step(model,X,Y)
|
||||
check_gc()
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal")
|
||||
@unittest.skipIf(CI, "slow")
|
||||
def test_transformer(self):
|
||||
# this should be small GPT-2, but the param count is wrong
|
||||
# (real ff_dim is 768*4)
|
||||
|
||||
@@ -516,10 +516,6 @@ class TestTinygrad(unittest.TestCase):
|
||||
print(c)
|
||||
|
||||
def test_env_overwrite_default_device(self):
|
||||
subprocess.run(['DISK=1 python3 -c "from tinygrad import Device; assert Device.DEFAULT != \\"DISK\\""'],
|
||||
shell=True, check=True)
|
||||
subprocess.run(['NPY=1 python3 -c "from tinygrad import Device; assert Device.DEFAULT != \\"NPY\\""'],
|
||||
shell=True, check=True)
|
||||
subprocess.run([f'{Device.DEFAULT}=1 python3 -c "from tinygrad import Device; assert Device.DEFAULT == \\"{Device.DEFAULT}\\""'],
|
||||
shell=True, check=True)
|
||||
subprocess.run([f'DISK=1 {Device.DEFAULT}=1 python3 -c "from tinygrad import Device; assert Device.DEFAULT == \\"{Device.DEFAULT}\\""'],
|
||||
|
||||
Reference in New Issue
Block a user