Files
tinygrad/test/external/mlperf_stable_diffusion/external_test_train.py
hooved 1e8945a28c Training loop for Stable Diffusion mlperf (#12315)
* add diff

* fix edit error

* match master

* point reference to specific commit

* simplify wandb logging

* remove lr test, dehardcode device

* increase stack size limit
2025-10-03 02:45:38 -04:00

23 lines
1.0 KiB
Python

import unittest, os
from tempfile import TemporaryDirectory
from tinygrad import Tensor
from tinygrad.helpers import getenv
from examples.mlperf.model_train import train_stable_diffusion
class TestTrain(unittest.TestCase):
def test_train_to_ckpt(self):
# train for num_steps, save checkpoint, and stop training
num_steps = 42
os.environ.update({"MODEL": "stable_diffusion", "TOTAL_CKPTS": "1", "CKPT_STEP_INTERVAL": str(num_steps), "GPUS": "8", "BS": "304"})
# NOTE: update these based on where data/checkpoints are on your system
if not getenv("DATADIR", ""): os.environ["DATADIR"] = "/raid/datasets/stable_diffusion"
if not getenv("CKPTDIR", ""): os.environ["CKPTDIR"] = "/raid/weights/stable_diffusion"
with TemporaryDirectory(prefix="test-train") as tmp:
os.environ["UNET_CKPTDIR"] = tmp
with Tensor.train():
saved_ckpts = train_stable_diffusion()
expected_ckpt = f"{tmp}/{num_steps}.safetensors"
assert len(saved_ckpts) == 1 and saved_ckpts[0] == expected_ckpt
if __name__=="__main__":
unittest.main()