mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
lint mlperf model_train (#10038)
This commit is contained in:
1
.github/workflows/test.yml
vendored
1
.github/workflows/test.yml
vendored
@@ -338,6 +338,7 @@ jobs:
|
||||
run: |
|
||||
pip3 install --upgrade --force-reinstall ruff==0.11.0
|
||||
python3 -m ruff check .
|
||||
python3 -m ruff check examples/mlperf/model_train.py --ignore E501
|
||||
- name: Lint tinygrad with pylint
|
||||
run: python -m pylint tinygrad/
|
||||
- name: Run mypy
|
||||
|
||||
@@ -532,7 +532,7 @@ if __name__ == "__main__":
|
||||
def load_retinanet(val):
|
||||
from extra.datasets.openimages import BASEDIR, download_dataset
|
||||
from pycocotools.coco import COCO
|
||||
dataset = COCO(download_dataset(base_dir:=getenv("BASE_DIR", BASEDIR), "validation" if val else "train"))
|
||||
dataset = COCO(download_dataset(base_dir:=getenv("BASEDIR", BASEDIR), "validation" if val else "train"))
|
||||
with tqdm(total=len(dataset.imgs.keys())) as pbar:
|
||||
for x in batch_load_retinanet(dataset, val, base_dir):
|
||||
pbar.update(x[0].shape[0])
|
||||
|
||||
@@ -347,12 +347,11 @@ def train_retinanet():
|
||||
from examples.mlperf.dataloader import batch_load_retinanet
|
||||
from examples.mlperf.initializers import FrozenBatchNorm2dRetinaNet, Conv2dNormalRetinaNet, Conv2dKaimingUniformRetinaNet, Linear, Conv2dRetinaNet
|
||||
from extra.datasets.openimages import MLPERF_CLASSES, BASEDIR, download_dataset, normalize, get_dataset_count
|
||||
from extra.models import resnet
|
||||
from extra.models import resnet, retinanet
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
from tinygrad.helpers import colored
|
||||
from typing import Iterator
|
||||
import extra.models.retinanet as retinanet
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -399,7 +398,7 @@ def train_retinanet():
|
||||
optim.zero_grad()
|
||||
|
||||
losses = model(normalize(x, GPUS), **kwargs)
|
||||
loss = sum([l for l in losses.values()])
|
||||
loss = sum(losses.values())
|
||||
|
||||
(loss * loss_scaler).backward()
|
||||
for t in optim.params: t.grad = t.grad / loss_scaler
|
||||
@@ -826,7 +825,7 @@ def train_unet3d():
|
||||
|
||||
if mean_dice >= TARGET_METRIC:
|
||||
is_successful = True
|
||||
save_checkpoint(get_state_dict(model), f"./ckpts/unet3d.safe")
|
||||
save_checkpoint(get_state_dict(model), "./ckpts/unet3d.safe")
|
||||
elif mean_dice < 1e-6:
|
||||
print("Model diverging. Aborting.")
|
||||
diverged = True
|
||||
@@ -913,7 +912,7 @@ def train_bert():
|
||||
MLLOGGER.logger.propagate = False
|
||||
|
||||
if INITMLPERF:
|
||||
assert BENCHMARK, f"BENCHMARK must be set for INITMLPERF"
|
||||
assert BENCHMARK, "BENCHMARK must be set for INITMLPERF"
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_ORG, value="tinycorp")
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_PLATFORM, value=getenv("SUBMISSION_PLATFORM", "tinybox"))
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_DIVISION, value=mllog_constants.CLOSED)
|
||||
@@ -1187,7 +1186,7 @@ def train_bert():
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
if previous_step:
|
||||
MLLOGGER.end(key=mllog_constants.BLOCK_STOP, value=None, metadata={"first_epoch_num": 1, "epoch_num": 1, "first_step_num": i, "step_num": i, "step_count": i - previous_step})
|
||||
MLLOGGER.start(key="checkpoint_start", value=None, metadata={"step_num" : i})
|
||||
MLLOGGER.start(key="checkpoint_start", value=None, metadata={"step_num": i})
|
||||
if not os.path.exists(ckpt_dir := save_ckpt_dir): os.mkdir(ckpt_dir)
|
||||
if WANDB and wandb.run is not None:
|
||||
fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}.safe"
|
||||
|
||||
@@ -205,5 +205,5 @@ def get_dataset_count(base_dir:Path, val:bool) -> int:
|
||||
return len(files)
|
||||
|
||||
if __name__ == "__main__":
|
||||
download_dataset(base_dir:=getenv("BASE_DIR", BASEDIR), "train")
|
||||
download_dataset(base_dir:=getenv("BASEDIR", BASEDIR), "train")
|
||||
download_dataset(base_dir, "validation")
|
||||
|
||||
Reference in New Issue
Block a user