lint mlperf model_train (#10038)

This commit is contained in:
chenyu
2025-04-24 16:19:44 -04:00
committed by GitHub
parent 51ca19d061
commit 74c6cf8be3
4 changed files with 8 additions and 8 deletions

View File

@@ -338,6 +338,7 @@ jobs:
run: |
pip3 install --upgrade --force-reinstall ruff==0.11.0
python3 -m ruff check .
python3 -m ruff check examples/mlperf/model_train.py --ignore E501
- name: Lint tinygrad with pylint
run: python -m pylint tinygrad/
- name: Run mypy

View File

@@ -532,7 +532,7 @@ if __name__ == "__main__":
def load_retinanet(val):
from extra.datasets.openimages import BASEDIR, download_dataset
from pycocotools.coco import COCO
dataset = COCO(download_dataset(base_dir:=getenv("BASE_DIR", BASEDIR), "validation" if val else "train"))
dataset = COCO(download_dataset(base_dir:=getenv("BASEDIR", BASEDIR), "validation" if val else "train"))
with tqdm(total=len(dataset.imgs.keys())) as pbar:
for x in batch_load_retinanet(dataset, val, base_dir):
pbar.update(x[0].shape[0])

View File

@@ -347,12 +347,11 @@ def train_retinanet():
from examples.mlperf.dataloader import batch_load_retinanet
from examples.mlperf.initializers import FrozenBatchNorm2dRetinaNet, Conv2dNormalRetinaNet, Conv2dKaimingUniformRetinaNet, Linear, Conv2dRetinaNet
from extra.datasets.openimages import MLPERF_CLASSES, BASEDIR, download_dataset, normalize, get_dataset_count
from extra.models import resnet
from extra.models import resnet, retinanet
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from tinygrad.helpers import colored
from typing import Iterator
import extra.models.retinanet as retinanet
import numpy as np
@@ -399,7 +398,7 @@ def train_retinanet():
optim.zero_grad()
losses = model(normalize(x, GPUS), **kwargs)
loss = sum([l for l in losses.values()])
loss = sum(losses.values())
(loss * loss_scaler).backward()
for t in optim.params: t.grad = t.grad / loss_scaler
@@ -826,7 +825,7 @@ def train_unet3d():
if mean_dice >= TARGET_METRIC:
is_successful = True
save_checkpoint(get_state_dict(model), f"./ckpts/unet3d.safe")
save_checkpoint(get_state_dict(model), "./ckpts/unet3d.safe")
elif mean_dice < 1e-6:
print("Model diverging. Aborting.")
diverged = True
@@ -913,7 +912,7 @@ def train_bert():
MLLOGGER.logger.propagate = False
if INITMLPERF:
assert BENCHMARK, f"BENCHMARK must be set for INITMLPERF"
assert BENCHMARK, "BENCHMARK must be set for INITMLPERF"
MLLOGGER.event(key=mllog_constants.SUBMISSION_ORG, value="tinycorp")
MLLOGGER.event(key=mllog_constants.SUBMISSION_PLATFORM, value=getenv("SUBMISSION_PLATFORM", "tinybox"))
MLLOGGER.event(key=mllog_constants.SUBMISSION_DIVISION, value=mllog_constants.CLOSED)
@@ -1187,7 +1186,7 @@ def train_bert():
if MLLOGGER and RUNMLPERF:
if previous_step:
MLLOGGER.end(key=mllog_constants.BLOCK_STOP, value=None, metadata={"first_epoch_num": 1, "epoch_num": 1, "first_step_num": i, "step_num": i, "step_count": i - previous_step})
MLLOGGER.start(key="checkpoint_start", value=None, metadata={"step_num" : i})
MLLOGGER.start(key="checkpoint_start", value=None, metadata={"step_num": i})
if not os.path.exists(ckpt_dir := save_ckpt_dir): os.mkdir(ckpt_dir)
if WANDB and wandb.run is not None:
fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}.safe"

View File

@@ -205,5 +205,5 @@ def get_dataset_count(base_dir:Path, val:bool) -> int:
return len(files)
if __name__ == "__main__":
download_dataset(base_dir:=getenv("BASE_DIR", BASEDIR), "train")
download_dataset(base_dir:=getenv("BASEDIR", BASEDIR), "train")
download_dataset(base_dir, "validation")