diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py new file mode 100644 index 0000000000..8db72fc15c --- /dev/null +++ b/examples/mlperf/model_train.py @@ -0,0 +1,37 @@ +from tinygrad.tensor import Tensor +from tinygrad.helpers import getenv + +def train_resnet(): + # TODO: Resnet50-v1.5 + pass + +def train_retinanet(): + # TODO: Retinanet + pass + +def train_unet3d(): + # TODO: Unet3d + pass + +def train_rnnt(): + # TODO: RNN-T + pass + +def train_bert(): + # TODO: BERT + pass + +def train_maskrcnn(): + # TODO: Mask RCNN + pass + +if __name__ == "__main__": + Tensor.training = True + + for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","): + nm = f"train_{m}" + if nm in globals(): + print(f"training {m}") + globals()[nm]() + +