mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
cherry pick mlperf5.0 branch to master (#10089)
This commit is contained in:
@@ -357,11 +357,43 @@ def train_retinanet():
|
||||
|
||||
config, target_metric = {}, 0.34
|
||||
|
||||
config["SEED"] = SEED = getenv("SEED", random.SystemRandom().randint(0, 2**32 - 1))
|
||||
Tensor.manual_seed(SEED)
|
||||
|
||||
NUM_CLASSES = len(MLPERF_CLASSES)
|
||||
BASEDIR = getenv("BASEDIR", BASEDIR)
|
||||
BENCHMARK = getenv("BENCHMARK")
|
||||
# INITMLPERF = getenv("INITMLPERF")
|
||||
INITMLPERF = getenv("INITMLPERF")
|
||||
RUNMLPERF = getenv("RUNMLPERF")
|
||||
|
||||
if getenv("LOGMLPERF"):
|
||||
from mlperf_logging import mllog
|
||||
import mlperf_logging.mllog.constants as mllog_constants
|
||||
|
||||
mllog.config(filename=f"result_retinanet_{SEED}.log")
|
||||
mllog.config(root_dir=Path(__file__).parents[3].as_posix())
|
||||
MLLOGGER = mllog.get_mllogger()
|
||||
MLLOGGER.logger.propagate = False
|
||||
|
||||
if INITMLPERF:
|
||||
assert BENCHMARK, "BENCHMARK must be set for INITMLPERF"
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_ORG, value="tinycorp")
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_PLATFORM, value=getenv("SUBMISSION_PLATFORM", "tinybox"))
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_DIVISION, value=mllog_constants.CLOSED)
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_STATUS, value=mllog_constants.ONPREM)
|
||||
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_BENCHMARK, value=mllog_constants.RETINANET)
|
||||
|
||||
diskcache_clear()
|
||||
MLLOGGER.event(key=mllog_constants.CACHE_CLEAR, value=True)
|
||||
MLLOGGER.start(key=mllog_constants.INIT_START)
|
||||
|
||||
if RUNMLPERF:
|
||||
MLLOGGER.start(key=mllog_constants.RUN_START)
|
||||
MLLOGGER.event(key=mllog_constants.SEED, value=SEED)
|
||||
else:
|
||||
MLLOGGER = None
|
||||
|
||||
config["gpus"] = GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 6))]
|
||||
|
||||
for x in GPUS: Device[x]
|
||||
@@ -414,24 +446,21 @@ def train_retinanet():
|
||||
return out.to(GPUS[0]).realize()
|
||||
|
||||
# ** hyperparameters **
|
||||
config["seed"] = SEED = getenv("SEED", random.SystemRandom().randint(0, 2**32 - 1))
|
||||
config["bs"] = BS = getenv("BS", 16 * len(GPUS) if dtypes.default_float == dtypes.float16 else 12 * len(GPUS))
|
||||
config["eval_bs"] = EVAL_BS = getenv("EVAL_BS", BS)
|
||||
config["epochs"] = EPOCHS = getenv("EPOCHS", 4)
|
||||
config["train_beam"] = TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
|
||||
config["eval_beam"] = EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
|
||||
config["lr"] = lr = getenv("LR", 9.5e-5 * (BS / 96))
|
||||
config["loss_scaler"] = loss_scaler = getenv("LOSS_SCALER", 2**11 if dtypes.default_float == dtypes.float16 else 1.0)
|
||||
config["default_float"] = dtypes.default_float.name
|
||||
config["eval_freq"] = eval_freq = getenv("EVAL_FREQ", 1)
|
||||
config["BS"] = BS = getenv("BS", 16 * len(GPUS) if dtypes.default_float == dtypes.float16 else 12 * len(GPUS))
|
||||
config["EVAL_BS"] = EVAL_BS = getenv("EVAL_BS", BS)
|
||||
config["EPOCHS"] = EPOCHS = getenv("EPOCHS", 4)
|
||||
config["TRAIN_BEAM"] = TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
|
||||
config["EVAL_BEAM"] = EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
|
||||
config["LR"] = lr = getenv("LR", 9.5e-5 * (BS / 96))
|
||||
config["LOSS_SCALER"] = loss_scaler = getenv("LOSS_SCALER", 2**11 if dtypes.default_float == dtypes.float16 else 1.0)
|
||||
config["DEFAULT_FLOAT"] = dtypes.default_float.name
|
||||
config["EVAL_FREQ"] = eval_freq = getenv("EVAL_FREQ", 1)
|
||||
|
||||
# ** initialize wandb **
|
||||
if (WANDB:=getenv("WANDB")):
|
||||
import wandb
|
||||
wandb.init(config=config, project="MLPerf-RetinaNet")
|
||||
|
||||
if SEED: Tensor.manual_seed(SEED)
|
||||
|
||||
# ** model initializers **
|
||||
resnet.BatchNorm = FrozenBatchNorm2dRetinaNet
|
||||
resnet.Linear = Linear
|
||||
@@ -464,8 +493,24 @@ def train_retinanet():
|
||||
optim = Adam(params, lr=lr)
|
||||
|
||||
# ** dataset **
|
||||
config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(get_dataset_count((base_dir_path:=Path(BASEDIR)), False), BS) // BS
|
||||
config["steps_in_val_epoch"] = steps_in_val_epoch = (round_up(get_dataset_count(base_dir_path, True), EVAL_BS) // EVAL_BS)
|
||||
config["STEPS_IN_TRAIN_EPOCH"] = steps_in_train_epoch = round_up(get_dataset_count((base_dir_path:=Path(BASEDIR)), False), BS) // BS
|
||||
config["STEPS_IN_VAL_EPOCH"] = steps_in_val_epoch = (round_up(get_dataset_count(base_dir_path, True), EVAL_BS) // EVAL_BS)
|
||||
|
||||
# log mlperf hparams
|
||||
if MLLOGGER:
|
||||
if RUNMLPERF:
|
||||
MLLOGGER.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=config["BS"])
|
||||
MLLOGGER.event(key=mllog_constants.TRAIN_SAMPLES, value=config["STEPS_IN_TRAIN_EPOCH"])
|
||||
MLLOGGER.event(key=mllog_constants.EVAL_SAMPLES, value=config["STEPS_IN_VAL_EPOCH"])
|
||||
MLLOGGER.event(key=mllog_constants.EPOCH_COUNT, value=config["EPOCHS"])
|
||||
MLLOGGER.event(key=mllog_constants.FIRST_EPOCH_NUM, value=start_epoch)
|
||||
|
||||
MLLOGGER.event(key=mllog_constants.OPT_NAME, value=mllog_constants.ADAM)
|
||||
MLLOGGER.event(key=mllog_constants.OPT_BASE_LR, value=config["LR"])
|
||||
MLLOGGER.event(key=mllog_constants.OPT_WEIGHT_DECAY, value=0)
|
||||
MLLOGGER.event(key=mllog_constants.OPT_LR_WARMUP_EPOCHS, value=0)
|
||||
MLLOGGER.event(key=mllog_constants.OPT_LR_WARMUP_FACTOR, value=0)
|
||||
MLLOGGER.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=1)
|
||||
|
||||
if RUNMLPERF:
|
||||
train_dataset = COCO(download_dataset(BASEDIR, "train"))
|
||||
@@ -476,13 +521,16 @@ def train_retinanet():
|
||||
|
||||
for e in range(start_epoch, EPOCHS):
|
||||
# ** training loop **
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.start(key=mllog_constants.EPOCH_START, value=e + 1, metadata={"epoch_num": e + 1})
|
||||
|
||||
BEAM.value = TRAIN_BEAM
|
||||
|
||||
if not RUNMLPERF:
|
||||
i, proc = 0, _fake_data_get(BS)
|
||||
else:
|
||||
train_dataloader = batch_load_retinanet(train_dataset, False, base_dir_path, batch_size=BS, seed=SEED)
|
||||
it = iter(tqdm(train_dataloader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
|
||||
it = iter(tqdm(train_dataloader, total=steps_in_train_epoch, desc=f"epoch {e + 1}", disable=BENCHMARK))
|
||||
i, proc = 0, _data_get(it)
|
||||
|
||||
prev_cookies = []
|
||||
@@ -544,8 +592,14 @@ def train_retinanet():
|
||||
if (TRAIN_BEAM or EVAL_BEAM) and e == start_epoch: break
|
||||
return
|
||||
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.event(key=mllog_constants.EPOCH_STOP, value=e + 1, metadata={"epoch_num": e + 1})
|
||||
|
||||
# ** eval loop **
|
||||
if (e + 1) % eval_freq == 0:
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.start(key=mllog_constants.EVAL_START, value=e + 1, metadata={"epoch_num": e + 1})
|
||||
|
||||
BEAM.value = EVAL_BEAM
|
||||
|
||||
if getenv("RESET_STEP", 1): _train_step.reset()
|
||||
@@ -593,12 +647,15 @@ def train_retinanet():
|
||||
proc, next_proc = next_proc, None
|
||||
i += 1
|
||||
|
||||
if i == BENCHMARK:
|
||||
return
|
||||
|
||||
et = time.time()
|
||||
eval_times.append(et - st)
|
||||
|
||||
if i == BENCHMARK:
|
||||
# assume INITMLPERF has BENCHMARK set
|
||||
if MLLOGGER and INITMLPERF:
|
||||
MLLOGGER.event(key=mllog_constants.INIT_STOP)
|
||||
return
|
||||
|
||||
if getenv("RESET_STEP", 1): _eval_step.reset()
|
||||
total_fw_time = sum(eval_times) / len(eval_times)
|
||||
|
||||
@@ -616,8 +673,16 @@ def train_retinanet():
|
||||
if WANDB:
|
||||
wandb.log({"eval/forward_time": total_fw_time, "eval/metric": val_metric, "epoch": e + 1})
|
||||
|
||||
if MLLOGGER:
|
||||
MLLOGGER.event(key=mllog_constants.EVAL_ACCURACY, value=val_metric, metadata={"epoch_num": e + 1}, clear_line=True)
|
||||
MLLOGGER.end(key=mllog_constants.EVAL_STOP, value=e + 1, metadata={"epoch_num": e + 1})
|
||||
|
||||
if val_metric >= target_metric:
|
||||
print(colored(f"target metric reached: {val_metric:.2f}/{target_metric:.2f}", color="green"))
|
||||
|
||||
if MLLOGGER:
|
||||
MLLOGGER.end(key=mllog_constants.RUN_STOP, metadata={"status": mllog_constants.SUCCESS})
|
||||
|
||||
break
|
||||
|
||||
def train_unet3d():
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
# 1. Problem
|
||||
|
||||
This problem uses BERT for NLP.
|
||||
|
||||
## Requirements
|
||||
|
||||
Install tinygrad and mlperf-logging (uncomment mlperf from setup.py) from branch mlperf_training_v5.0.
|
||||
```
|
||||
git clone https://github.com/tinygrad/tinygrad.git
|
||||
python3 -m pip install -e ".[mlperf]"
|
||||
```
|
||||
Also install gdown (for dataset), numpy, tqdm and tensorflow.
|
||||
```
|
||||
pip install gdown numpy tqdm tensorflow
|
||||
```
|
||||
|
||||
### tinybox_green
|
||||
Install the p2p driver per [README](https://github.com/tinygrad/open-gpu-kernel-modules/blob/550.54.15-p2p/README.md)
|
||||
This is the default on production tinybox green.
|
||||
|
||||
# 2. Directions
|
||||
|
||||
## Steps to download and verify data
|
||||
|
||||
### 1. Download raw data
|
||||
|
||||
```
|
||||
BASEDIR="/raid/datasets/wiki" WIKI_TRAIN=1 VERIFY_CHECKSUM=1 python3 extra/datasets/wikipedia_download.py
|
||||
```
|
||||
|
||||
### 2. Preprocess train and validation data
|
||||
|
||||
Note: The number of threads used for preprocessing is limited by available memory. With 128GB of RAM, a maximum of 16 threads is recommended.
|
||||
|
||||
#### Training:
|
||||
```
|
||||
BASEDIR="/raid/datasets/wiki" NUM_WORKERS=16 python3 extra/datasets/wikipedia.py pre-train all
|
||||
```
|
||||
|
||||
Generating a specific topic (Between 0 and 499)
|
||||
```
|
||||
BASEDIR="/raid/datasets/wiki" python3 extra/datasets/wikipedia.py pre-train 42
|
||||
```
|
||||
|
||||
#### Validation:
|
||||
```
|
||||
BASEDIR="/raid/datasets/wiki" python3 extra/datasets/wikipedia.py pre-eval
|
||||
```
|
||||
## Running
|
||||
|
||||
### tinybox_green
|
||||
|
||||
#### Steps to run benchmark
|
||||
```
|
||||
examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh
|
||||
```
|
||||
|
||||
### tinybox_red
|
||||
|
||||
#### Steps to run benchmark
|
||||
```
|
||||
examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh
|
||||
```
|
||||
### tinybox_8xMI300X
|
||||
|
||||
#### Steps to run benchmark
|
||||
```
|
||||
examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_8xMI300X/run_and_time.sh
|
||||
```
|
||||
@@ -4,14 +4,14 @@ This problem uses BERT for NLP.
|
||||
|
||||
## Requirements
|
||||
|
||||
Install tinygrad and mlperf-logging from master.
|
||||
Install tinygrad and mlperf-logging (uncomment mlperf from setup.py) from branch mlperf_training_v5.0.
|
||||
```
|
||||
git clone https://github.com/tinygrad/tinygrad.git
|
||||
python3 -m pip install -e ".[mlperf]"
|
||||
```
|
||||
Also install tqdm and tensorflow.
|
||||
Also install gdown (for dataset), numpy, tqdm and tensorflow.
|
||||
```
|
||||
pip install tqdm tensorflow
|
||||
pip install gdown numpy tqdm tensorflow
|
||||
```
|
||||
|
||||
### tinybox_green
|
||||
@@ -52,12 +52,18 @@ BASEDIR="/raid/datasets/wiki" python3 extra/datasets/wikipedia.py pre-eval
|
||||
|
||||
#### Steps to run benchmark
|
||||
```
|
||||
examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh
|
||||
examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh
|
||||
```
|
||||
|
||||
### tinybox_red
|
||||
|
||||
#### Steps to run benchmark
|
||||
```
|
||||
examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh
|
||||
examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh
|
||||
```
|
||||
### tinybox_8xMI300X
|
||||
|
||||
#### Steps to run benchmark
|
||||
```
|
||||
examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_8xMI300X/run_and_time.sh
|
||||
```
|
||||
@@ -4,14 +4,14 @@ This problem uses BERT for NLP.
|
||||
|
||||
## Requirements
|
||||
|
||||
Install tinygrad and mlperf-logging from master.
|
||||
Install tinygrad and mlperf-logging (uncomment mlperf from setup.py) from branch mlperf_training_v5.0.
|
||||
```
|
||||
git clone https://github.com/tinygrad/tinygrad.git
|
||||
python3 -m pip install -e ".[mlperf]"
|
||||
```
|
||||
Also install tqdm and tensorflow.
|
||||
Also install gdown (for dataset), numpy, tqdm and tensorflow.
|
||||
```
|
||||
pip install tqdm tensorflow
|
||||
pip install gdown numpy tqdm tensorflow
|
||||
```
|
||||
|
||||
### tinybox_green
|
||||
@@ -52,12 +52,18 @@ BASEDIR="/raid/datasets/wiki" python3 extra/datasets/wikipedia.py pre-eval
|
||||
|
||||
#### Steps to run benchmark
|
||||
```
|
||||
examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh
|
||||
examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh
|
||||
```
|
||||
|
||||
### tinybox_red
|
||||
|
||||
#### Steps to run benchmark
|
||||
```
|
||||
examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh
|
||||
examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh
|
||||
```
|
||||
### tinybox_8xMI300X
|
||||
|
||||
#### Steps to run benchmark
|
||||
```
|
||||
examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_8xMI300X/run_and_time.sh
|
||||
```
|
||||
@@ -6,7 +6,7 @@ export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
|
||||
|
||||
export FUSE_ARANGE=1 FUSE_ARANGE_UINT=0
|
||||
|
||||
export BEAM=5 BEAM_UOPS_MAX=10000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
|
||||
export BEAM=5 BEAM_UOPS_MAX=8000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BEAM_LOG_SURPASS_MAX=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
@@ -6,7 +6,7 @@ export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
|
||||
|
||||
export FUSE_ARANGE=1 FUSE_ARANGE_UINT=0
|
||||
|
||||
export BEAM=5 BEAM_UOPS_MAX=10000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
|
||||
export BEAM=5 BEAM_UOPS_MAX=8000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
|
||||
|
||||
export FUSE_ARANGE=1 FUSE_ARANGE_UINT=0
|
||||
|
||||
export BEAM=5 BEAM_UOPS_MAX=10000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
|
||||
export BEAM=5 BEAM_UOPS_MAX=8000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
# 1. Problem
|
||||
|
||||
This problem uses RetinaNet for SSD.
|
||||
|
||||
## Requirements
|
||||
|
||||
Install tinygrad and mlperf-logging (uncomment mlperf from setup.py) from branch mlperf_training_v5.0.
|
||||
```
|
||||
git clone https://github.com/tinygrad/tinygrad.git
|
||||
python3 -m pip install -e ".[mlperf]"
|
||||
```
|
||||
|
||||
Also install the following dependencies:
|
||||
```
|
||||
pip install tqdm numpy pycocotools boto3 pandas torch torchvision
|
||||
```
|
||||
|
||||
### tinybox_green
|
||||
Install the p2p driver per [README](https://github.com/tinygrad/open-gpu-kernel-modules/blob/550.54.15-p2p/README.md)
|
||||
This is the default on production tinybox green.
|
||||
|
||||
# 2. Directions
|
||||
|
||||
## Steps to download data
|
||||
|
||||
Run the following:
|
||||
```
|
||||
BASEDIR=/raid/datasets/openimages python3 extra/datasets/openimages.py
|
||||
```
|
||||
|
||||
## Running
|
||||
|
||||
### tinybox_green
|
||||
|
||||
#### Steps to run benchmark
|
||||
```
|
||||
examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/retinanet/implementations/tinybox_green/run_and_time.sh
|
||||
```
|
||||
@@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
|
||||
export PYTHONPATH="." NV=1
|
||||
export MODEL="retinanet"
|
||||
export SUBMISSION_PLATFORM="tinybox_green"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=96 EVAL_BS=96
|
||||
|
||||
export TRAIN_BEAM=2 BEAM_UOPS_MAX=1500 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5 BEAM_PADTO=0
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/openimages"
|
||||
|
||||
# pip install -e ".[mlperf]"
|
||||
export LOGMLPERF=1
|
||||
|
||||
export SEED=$RANDOM
|
||||
DATETIME=$(date "+%m%d%H%M")
|
||||
LOGFILE="retinanet_green_${DATETIME}_${SEED}.log"
|
||||
|
||||
# init
|
||||
BENCHMARK=10 INITMLPERF=1 python3 examples/mlperf/model_train.py | tee $LOGFILE
|
||||
|
||||
# run
|
||||
PARALLEL=0 RUNMLPERF=1 python3 examples/mlperf/model_train.py | tee -a $LOGFILE
|
||||
@@ -28,7 +28,7 @@
|
||||
"accelerator_interconnect_topology": "",
|
||||
"cooling": "air",
|
||||
"hw_notes": "",
|
||||
"framework": "tinygrad, commit TBD",
|
||||
"framework": "tinygrad, branch mlperf_training_v5.0",
|
||||
"other_software_stack": {
|
||||
"python": "3.10.16",
|
||||
"ROCm": "3.0.0+94441cb"
|
||||
|
||||
@@ -28,7 +28,7 @@
|
||||
"accelerator_interconnect_topology": "",
|
||||
"cooling": "air",
|
||||
"hw_notes": "",
|
||||
"framework": "tinygrad, commit b5546912e24e0a864b35924da4efa5d71cfe368b",
|
||||
"framework": "tinygrad, branch mlperf_training_v5.0",
|
||||
"other_software_stack": {
|
||||
"python": "3.10.12",
|
||||
"CUDA": "12.4"
|
||||
|
||||
@@ -28,7 +28,7 @@
|
||||
"accelerator_interconnect_topology": "",
|
||||
"cooling": "air",
|
||||
"hw_notes": "",
|
||||
"framework": "tinygrad, commit b5546912e24e0a864b35924da4efa5d71cfe368b",
|
||||
"framework": "tinygrad, branch mlperf_training_v5.0",
|
||||
"other_software_stack": {
|
||||
"python": "3.10.12"
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user