test_train cleanup (#12140)

* test_train cleanup

remove skipIf due to buffer sizes, runs locally

* those are slow
This commit is contained in:
chenyu
2025-09-12 13:21:30 -04:00
committed by GitHub
parent 0fad07c684
commit 647965fb09
3 changed files with 4 additions and 9 deletions

View File

@@ -53,8 +53,8 @@ class TestRealWorld(unittest.TestCase):
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
def test_stable_diffusion(self):
params = unet_params
params["model_ch"] = 16
params["ctx_dim"] = 16
params["model_ch"] = 8
params["ctx_dim"] = 8
params["num_res_blocks"] = 1
params["n_heads"] = 2
model = UNetModel(**params)
@@ -144,6 +144,7 @@ class TestRealWorld(unittest.TestCase):
final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=4)
assert not np.isnan(lr_scheduler.min_lr), "lr too small or initial_div_facotr too big for half"
@unittest.skipIf(CI and Device.DEFAULT == "CPU", "slow")
def test_bert(self):
with Tensor.train():
args_tiny = {"attention_probs_dropout_prob": 0.0, "hidden_dropout_prob": 0.0, "vocab_size": 30522, "type_vocab_size": 2,

View File

@@ -40,7 +40,6 @@ class TestTrain(unittest.TestCase):
check_gc()
@unittest.skipIf(CI, "slow")
@unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal")
def test_efficientnet(self):
model = EfficientNet(0)
X = np.zeros((BS,3,224,224), dtype=np.float32)
@@ -49,7 +48,6 @@ class TestTrain(unittest.TestCase):
check_gc()
@unittest.skipIf(CI, "slow")
@unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal")
def test_vit(self):
model = ViT()
X = np.zeros((BS,3,224,224), dtype=np.float32)
@@ -57,7 +55,7 @@ class TestTrain(unittest.TestCase):
train_one_step(model,X,Y)
check_gc()
@unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal")
@unittest.skipIf(CI, "slow")
def test_transformer(self):
# this should be small GPT-2, but the param count is wrong
# (real ff_dim is 768*4)

View File

@@ -516,10 +516,6 @@ class TestTinygrad(unittest.TestCase):
print(c)
def test_env_overwrite_default_device(self):
subprocess.run(['DISK=1 python3 -c "from tinygrad import Device; assert Device.DEFAULT != \\"DISK\\""'],
shell=True, check=True)
subprocess.run(['NPY=1 python3 -c "from tinygrad import Device; assert Device.DEFAULT != \\"NPY\\""'],
shell=True, check=True)
subprocess.run([f'{Device.DEFAULT}=1 python3 -c "from tinygrad import Device; assert Device.DEFAULT == \\"{Device.DEFAULT}\\""'],
shell=True, check=True)
subprocess.run([f'DISK=1 {Device.DEFAULT}=1 python3 -c "from tinygrad import Device; assert Device.DEFAULT == \\"{Device.DEFAULT}\\""'],