mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
* add diff * fix edit error * match master * point reference to specific commit * simplify wandb logging * remove lr test, dehardcode device * increase stack size limit
23 lines
1.0 KiB
Python
23 lines
1.0 KiB
Python
import unittest, os
|
|
from tempfile import TemporaryDirectory
|
|
from tinygrad import Tensor
|
|
from tinygrad.helpers import getenv
|
|
from examples.mlperf.model_train import train_stable_diffusion
|
|
|
|
class TestTrain(unittest.TestCase):
|
|
def test_train_to_ckpt(self):
|
|
# train for num_steps, save checkpoint, and stop training
|
|
num_steps = 42
|
|
os.environ.update({"MODEL": "stable_diffusion", "TOTAL_CKPTS": "1", "CKPT_STEP_INTERVAL": str(num_steps), "GPUS": "8", "BS": "304"})
|
|
# NOTE: update these based on where data/checkpoints are on your system
|
|
if not getenv("DATADIR", ""): os.environ["DATADIR"] = "/raid/datasets/stable_diffusion"
|
|
if not getenv("CKPTDIR", ""): os.environ["CKPTDIR"] = "/raid/weights/stable_diffusion"
|
|
with TemporaryDirectory(prefix="test-train") as tmp:
|
|
os.environ["UNET_CKPTDIR"] = tmp
|
|
with Tensor.train():
|
|
saved_ckpts = train_stable_diffusion()
|
|
expected_ckpt = f"{tmp}/{num_steps}.safetensors"
|
|
assert len(saved_ckpts) == 1 and saved_ckpts[0] == expected_ckpt
|
|
|
|
if __name__=="__main__":
|
|
unittest.main() |