mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
redo faster sparse_categorical_crossentropy (#4461)
update LR and DECAY in resnet default that help convergence too
This commit is contained in:
@@ -69,9 +69,9 @@ def train_resnet():
|
||||
epochs = config["epochs"] = getenv("EPOCHS", 37)
|
||||
BS = config["BS"] = getenv("BS", 104 * len(GPUS)) # fp32 GPUS<=6 7900xtx can fit BS=112
|
||||
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", BS)
|
||||
base_lr = config["base_lr"] = getenv("LR", 7 * (BS/1536))
|
||||
base_lr = config["base_lr"] = getenv("LR", 7.2 * (BS/1536))
|
||||
lr_warmup_epochs = config["lr_warmup_epochs"] = getenv("WARMUP_EPOCHS", 2)
|
||||
decay = config["decay"] = getenv("DECAY", 5e-5)
|
||||
decay = config["decay"] = getenv("DECAY", 2e-4)
|
||||
|
||||
loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 128.0 if dtypes.default_float == dtypes.float16 else 1.0)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="resnet"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=1536 EVAL_BS=48 LR=7
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=1536 EVAL_BS=48
|
||||
|
||||
export LAZYCACHE=0 RESET_STEP=0
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="resnet"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=1536 EVAL_BS=48 LR=7
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=1536 EVAL_BS=48
|
||||
|
||||
export LAZYCACHE=0 RESET_STEP=0
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
export PYTHONPATH="."
|
||||
export MODEL="resnet"
|
||||
export SUBMISSION_PLATFORM="tinybox_green"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=1536 EVAL_BS=48 LR=7
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=1536 EVAL_BS=48
|
||||
|
||||
export LAZYCACHE=0 RESET_STEP=0
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="resnet"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=1536 EVAL_BS=48 LR=7
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=1536 EVAL_BS=48
|
||||
|
||||
export LAZYCACHE=0 RESET_STEP=0
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="resnet"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=1536 EVAL_BS=48 LR=7
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=1536 EVAL_BS=48
|
||||
|
||||
export LAZYCACHE=0 RESET_STEP=0
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
export PYTHONPATH="."
|
||||
export MODEL="resnet"
|
||||
export SUBMISSION_PLATFORM="tinybox_red"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=1536 EVAL_BS=48 LR=7
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=1536 EVAL_BS=48
|
||||
|
||||
export LAZYCACHE=0 RESET_STEP=0
|
||||
|
||||
|
||||
@@ -1327,9 +1327,9 @@ class Tensor:
|
||||
# NOTE: self is a logits input
|
||||
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index)
|
||||
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
|
||||
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
||||
smoothing = -1 * label_smoothing * (log_probs.mean(-1) * loss_mask).sum() / loss_mask.sum()
|
||||
return (1 - label_smoothing) * (log_probs * y).sum() / loss_mask.sum() + smoothing
|
||||
y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
||||
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask).sum()
|
||||
return -((1 - label_smoothing) * (log_probs * y).sum() + smoothing) / loss_mask.sum()
|
||||
|
||||
# ***** cast ops *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user