back to 2**10 for bert loss scaler (#6934)

getting 2 NaN for this, revert back to 2**10
This commit is contained in:
chenyu
2024-10-07 10:17:21 -04:00
committed by GitHub
parent 9250452da4
commit 102dfe5510

View File

@@ -658,7 +658,7 @@ def train_bert():
save_ckpt_dir = config["SAVE_CKPT_DIR"] = getenv("SAVE_CKPT_DIR", "./ckpts")
init_ckpt = config["INIT_CKPT_DIR"] = getenv("INIT_CKPT_DIR", BASEDIR)
loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 2.0**13 if dtypes.default_float == dtypes.float16 else 1.0)
loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 2.0**10 if dtypes.default_float == dtypes.float16 else 1.0)
decay = config["DECAY"] = getenv("DECAY", 0.01)
epsilon = config["EPSILON"] = getenv("EPSILON", 1e-6)
poly_power = config["POLY_POWER"] = getenv("POLY_POWER", 1.0)