mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Refactor to class style (#4804)
This commit is contained in:
@@ -12,7 +12,7 @@ from tinygrad.device import Device
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
|
||||
from examples.mlperf.helpers import get_mlperf_bert_model, init_bert_from_checkpoint, get_data_bert
|
||||
from examples.mlperf.helpers import get_mlperf_bert_model, get_data_bert
|
||||
from examples.mlperf.dataloader import batch_load_val_bert
|
||||
from examples.mlperf.model_train import eval_step_bert
|
||||
|
||||
@@ -27,14 +27,13 @@ if __name__ == "__main__":
|
||||
assert os.path.exists(os.path.join(BASEDIR, "eval", f"{i}.pkl")), \
|
||||
f"File {i}.pkl does not exist in {os.path.join(BASEDIR, 'eval')}"
|
||||
|
||||
required_files = ["checkpoint", "model.ckpt-28252.data-00000-of-00001", "model.ckpt-28252.index"]
|
||||
required_files = ["checkpoint", "model.ckpt-28252.data-00000-of-00001", "model.ckpt-28252.index", "model.ckpt-28252.meta"]
|
||||
assert all(os.path.exists(os.path.join(INIT_CKPT_DIR, f)) for f in required_files), \
|
||||
f"Missing checkpoint files in INIT_CKPT_DIR: {required_files}"
|
||||
|
||||
Tensor.training = False
|
||||
|
||||
model = get_mlperf_bert_model()
|
||||
init_bert_from_checkpoint(model, INIT_CKPT_DIR) # Test the actual loading of the checkpoint
|
||||
model = get_mlperf_bert_model(INIT_CKPT_DIR)
|
||||
|
||||
for _, x in get_state_dict(model).items():
|
||||
x.realize().to_(GPUS)
|
||||
|
||||
Reference in New Issue
Block a user