Files
tinygrad/examples/mlperf/model_train.py
Yixiang Gao 094d3d71be with Tensor.train() (#1935)
* add with.train

* remove the rest TODOs

* fix pyflake

* fix pyflake error

* fix mypy
2023-09-28 18:02:31 -07:00

37 lines
597 B
Python

from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
def train_resnet():
# TODO: Resnet50-v1.5
pass
def train_retinanet():
# TODO: Retinanet
pass
def train_unet3d():
# TODO: Unet3d
pass
def train_rnnt():
# TODO: RNN-T
pass
def train_bert():
# TODO: BERT
pass
def train_maskrcnn():
# TODO: Mask RCNN
pass
if __name__ == "__main__":
with Tensor.train():
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
nm = f"train_{m}"
if nm in globals():
print(f"training {m}")
globals()[nm]()