Files
tinygrad/examples/mlperf/model_train.py
2023-05-30 07:10:40 -07:00

38 lines
590 B
Python

from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
def train_resnet():
# TODO: Resnet50-v1.5
pass
def train_retinanet():
# TODO: Retinanet
pass
def train_unet3d():
# TODO: Unet3d
pass
def train_rnnt():
# TODO: RNN-T
pass
def train_bert():
# TODO: BERT
pass
def train_maskrcnn():
# TODO: Mask RCNN
pass
if __name__ == "__main__":
Tensor.training = True
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
nm = f"train_{m}"
if nm in globals():
print(f"training {m}")
globals()[nm]()