explicit dtypes in hlb_cifar (#3707)

prepared bfloat16 change. added float() and cast(default_float) in whiteing, explicitly set dtype in various places that convert between numpy and Tensor
This commit is contained in:
chenyu
2024-03-12 18:20:23 -04:00
committed by GitHub
parent b6e2495fdd
commit b13457e4a7

View File

@@ -23,10 +23,8 @@ for x in GPUS: Device[x]
if getenv("HALF"):
dtypes.default_float = dtypes.float16
np_dtype = np.float16
else:
dtypes.default_float = dtypes.float32
np_dtype = np.float32
class UnsyncedBatchNorm:
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1, num_devices=len(GPUS)):
@@ -158,7 +156,6 @@ def train_cifar():
random.seed(seed)
# ========== Model ==========
# NOTE: np.linalg.eigh only supports float32 so the whitening layer weights need to be converted to float16 manually
def whitening(X, kernel_size=hyp['net']['kernel_size']):
def _cov(X):
return (X.T @ X) / (X.shape[0] - 1)
@@ -175,10 +172,11 @@ def train_cifar():
Λ, V = np.linalg.eigh(Σ, UPLO='U')
return np.flip(Λ, 0), np.flip(V.T.reshape(c*h*w, c, h, w), 0)
Λ, V = _eigens(_patches(X.numpy()))
# NOTE: np.linalg.eigh only supports float32 so the whitening layer weights need to be converted to float16 manually
Λ, V = _eigens(_patches(X.float().numpy()))
W = V/np.sqrt(Λ+1e-2)[:,None,None,None]
return Tensor(W.astype(np_dtype), requires_grad=False)
return Tensor(W.astype(np.float32), requires_grad=False).cast(dtypes.default_float)
# ========== Loss ==========
def cross_entropy(x:Tensor, y:Tensor, reduction:str='mean', label_smoothing:float=0.0) -> Tensor:
@@ -218,8 +216,8 @@ def train_cifar():
mask = make_square_mask(X.shape, mask_size)
order = list(range(0, X.shape[0]))
random.shuffle(order)
X_patch = Tensor(X.numpy()[order], device=X.device)
Y_patch = Tensor(Y.numpy()[order], device=Y.device)
X_patch = Tensor(X.numpy()[order], device=X.device, dtype=X.dtype)
Y_patch = Tensor(Y.numpy()[order], device=Y.device, dtype=Y.dtype)
X_cutmix = mask.where(X_patch, X)
mix_portion = float(mask_size**2)/(X.shape[-2]*X.shape[-1])
Y_cutmix = mix_portion * Y_patch + (1. - mix_portion) * Y
@@ -250,8 +248,8 @@ def train_cifar():
for i in range(0, X.shape[0], BS):
# pad the last batch # TODO: not correct for test
batch_end = min(i+BS, Y.shape[0])
x = Tensor(X[batch_end-BS:batch_end], device=X_in.device)
y = Tensor(Y[batch_end-BS:batch_end], device=Y_in.device)
x = Tensor(X[batch_end-BS:batch_end], device=X_in.device, dtype=X_in.dtype)
y = Tensor(Y[batch_end-BS:batch_end], device=Y_in.device, dtype=Y_in.dtype)
step += 1
yield x, y
epoch += 1
@@ -259,7 +257,8 @@ def train_cifar():
transform = [
lambda x: x / 255.0,
lambda x: (x.reshape((-1,3,32,32)) - Tensor(cifar_mean).reshape((1,3,1,1)))/Tensor(cifar_std).reshape((1,3,1,1))
lambda x: x.reshape((-1,3,32,32)) - Tensor(cifar_mean, device=x.device, dtype=x.dtype).reshape((1,3,1,1)),
lambda x: x / Tensor(cifar_std, device=x.device, dtype=x.dtype).reshape((1,3,1,1)),
]
class modelEMA():