mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
back to 2**10 for bert loss scaler (#6934)
getting 2 NaN for this, revert back to 2**10
This commit is contained in:
@@ -658,7 +658,7 @@ def train_bert():
|
||||
save_ckpt_dir = config["SAVE_CKPT_DIR"] = getenv("SAVE_CKPT_DIR", "./ckpts")
|
||||
init_ckpt = config["INIT_CKPT_DIR"] = getenv("INIT_CKPT_DIR", BASEDIR)
|
||||
|
||||
loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 2.0**13 if dtypes.default_float == dtypes.float16 else 1.0)
|
||||
loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 2.0**10 if dtypes.default_float == dtypes.float16 else 1.0)
|
||||
decay = config["DECAY"] = getenv("DECAY", 0.01)
|
||||
epsilon = config["EPSILON"] = getenv("EPSILON", 1e-6)
|
||||
poly_power = config["POLY_POWER"] = getenv("POLY_POWER", 1.0)
|
||||
|
||||
Reference in New Issue
Block a user