Refactor to class style (#4804)

This commit is contained in:
Elias Wahl
2024-06-04 23:08:31 +02:00
committed by GitHub
parent 1b8bed4a26
commit 04e237328b
5 changed files with 144 additions and 157 deletions

View File

@@ -12,7 +12,7 @@ from tinygrad.device import Device
from tinygrad.helpers import getenv
from tinygrad.nn.state import get_state_dict
from examples.mlperf.helpers import get_mlperf_bert_model, init_bert_from_checkpoint, get_data_bert
from examples.mlperf.helpers import get_mlperf_bert_model, get_data_bert
from examples.mlperf.dataloader import batch_load_val_bert
from examples.mlperf.model_train import eval_step_bert
@@ -27,14 +27,13 @@ if __name__ == "__main__":
assert os.path.exists(os.path.join(BASEDIR, "eval", f"{i}.pkl")), \
f"File {i}.pkl does not exist in {os.path.join(BASEDIR, 'eval')}"
required_files = ["checkpoint", "model.ckpt-28252.data-00000-of-00001", "model.ckpt-28252.index"]
required_files = ["checkpoint", "model.ckpt-28252.data-00000-of-00001", "model.ckpt-28252.index", "model.ckpt-28252.meta"]
assert all(os.path.exists(os.path.join(INIT_CKPT_DIR, f)) for f in required_files), \
f"Missing checkpoint files in INIT_CKPT_DIR: {required_files}"
Tensor.training = False
model = get_mlperf_bert_model()
init_bert_from_checkpoint(model, INIT_CKPT_DIR) # Test the actual loading of the checkpoint
model = get_mlperf_bert_model(INIT_CKPT_DIR)
for _, x in get_state_dict(model).items():
x.realize().to_(GPUS)