mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
move hyp out of the train so it can be imported
This commit is contained in:
@@ -82,37 +82,37 @@ class SpeedyResNet:
|
||||
forward = lambda x: x.conv2d(self.whitening).pad2d((1,0,0,1)).sequential(self.net)
|
||||
return forward(x) if training else forward(x)*0.5 + forward(x[..., ::-1])*0.5
|
||||
|
||||
def train_cifar():
|
||||
|
||||
# hyper-parameters were exactly the same as the original repo
|
||||
bias_scaler = 58
|
||||
hyp: Dict[str, Any] = {
|
||||
'seed' : 209,
|
||||
'opt': {
|
||||
'bias_lr': 1.76 * bias_scaler/512,
|
||||
'non_bias_lr': 1.76 / 512,
|
||||
'bias_decay': 1.08 * 6.45e-4 * BS/bias_scaler,
|
||||
'non_bias_decay': 1.08 * 6.45e-4 * BS,
|
||||
'final_lr_ratio': 0.025,
|
||||
'initial_div_factor': 1e6,
|
||||
'label_smoothing': 0.20,
|
||||
'momentum': 0.85,
|
||||
'percent_start': 0.23,
|
||||
'loss_scale_scaler': 1./128 # (range: ~1/512 - 16+, 1/128 w/ FP16)
|
||||
},
|
||||
'net': {
|
||||
'kernel_size': 2, # kernel size for the whitening layer
|
||||
'cutmix_size': 3,
|
||||
'cutmix_steps': 499,
|
||||
'pad_amount': 2
|
||||
},
|
||||
'ema': {
|
||||
'steps': 399,
|
||||
'decay_base': .95,
|
||||
'decay_pow': 1.6,
|
||||
'every_n_steps': 5,
|
||||
}
|
||||
# hyper-parameters were exactly the same as the original repo
|
||||
bias_scaler = 58
|
||||
hyp: Dict[str, Any] = {
|
||||
'seed' : 209,
|
||||
'opt': {
|
||||
'bias_lr': 1.76 * bias_scaler/512,
|
||||
'non_bias_lr': 1.76 / 512,
|
||||
'bias_decay': 1.08 * 6.45e-4 * BS/bias_scaler,
|
||||
'non_bias_decay': 1.08 * 6.45e-4 * BS,
|
||||
'final_lr_ratio': 0.025,
|
||||
'initial_div_factor': 1e6,
|
||||
'label_smoothing': 0.20,
|
||||
'momentum': 0.85,
|
||||
'percent_start': 0.23,
|
||||
'loss_scale_scaler': 1./128 # (range: ~1/512 - 16+, 1/128 w/ FP16)
|
||||
},
|
||||
'net': {
|
||||
'kernel_size': 2, # kernel size for the whitening layer
|
||||
'cutmix_size': 3,
|
||||
'cutmix_steps': 499,
|
||||
'pad_amount': 2
|
||||
},
|
||||
'ema': {
|
||||
'steps': 399,
|
||||
'decay_base': .95,
|
||||
'decay_pow': 1.6,
|
||||
'every_n_steps': 5,
|
||||
}
|
||||
}
|
||||
|
||||
def train_cifar():
|
||||
|
||||
def set_seed(seed):
|
||||
Tensor.manual_seed(getenv('SEED', seed))
|
||||
|
||||
Reference in New Issue
Block a user