cleanup pad_reflect and make_square_mask in hlb_cifar (#3206)

removed some complicated looking stuff. no wall time difference
This commit is contained in:
chenyu
2024-01-22 11:30:46 -05:00
committed by GitHub
parent 99884f4c98
commit 827b7a3c64

View File

@@ -146,37 +146,20 @@ def train_cifar():
raise NotImplementedError(reduction)
# ========== Preprocessing ==========
# TODO currently this only works for RGB in format of NxCxHxW and pads the HxW
# implemented in recursive fashion but figuring out how to switch indexing dim
# during the loop was a bit tricky
# NOTE: this only works for RGB in format of NxCxHxW and pads the HxW
def pad_reflect(X, size=2) -> Tensor:
padding = ((0,0),(0,0),(size,size),(size,size))
p = padding[3]
s = X.shape[3]
X_lr = X[...,:,1:1+p[0]].flip(3).pad(((0,0),(0,0),(0,0),(0,s+p[0]))) + X[...,:,-1-p[1]:-1].flip(3).pad(((0,0),(0,0),(0,0),(s+p[1],0)))
X = X.pad(((0,0),(0,0),(0,0),p)) + X_lr
p = padding[2]
s = X.shape[2]
X_lr = X[...,1:1+p[0],:].flip(2).pad(((0,0),(0,0),(0,s+p[0]),(0,0))) + X[...,-1-p[1]:-1,:].flip(2).pad(((0,0),(0,0),(s+p[1],0),(0,0)))
X = X.pad(((0,0),(0,0),p,(0,0))) + X_lr
X = X[...,:,1:size+1].flip(-1).cat(X, X[...,:,-(size+1):-1].flip(-1), dim=-1)
X = X[...,1:size+1,:].flip(-2).cat(X, X[...,-(size+1):-1,:].flip(-2), dim=-2)
return X
# return a binary mask in the format of BS x C x H x W where H x W contains a random square mask
def make_square_mask(shape, mask_size) -> Tensor:
is_even = int(mask_size % 2 == 0)
center_max = shape[-2]-mask_size//2-is_even
center_min = mask_size//2-is_even
center_x = Tensor.randint(shape[0], low=center_min, high=center_max)
center_y = Tensor.randint(shape[0], low=center_min, high=center_max)
d_x = Tensor.arange(0, shape[-1]).reshape((1,1,1,shape[-1])) - center_x.reshape((-1,1,1,1))
d_y = Tensor.arange(0, shape[-2]).reshape((1,1,shape[-2],1)) - center_y.reshape((-1,1,1,1))
d_x = (d_x >= -(mask_size // 2) + is_even) * (d_x <= mask_size // 2)
d_y = (d_y >= -(mask_size // 2) + is_even) * (d_y <= mask_size // 2)
mask = d_y * d_x
return mask.cast(dtypes.bool)
BS, _, H, W = shape
low_x = Tensor.randint(BS, low=0, high=W-mask_size).reshape(BS,1,1,1)
low_y = Tensor.randint(BS, low=0, high=H-mask_size).reshape(BS,1,1,1)
idx_x = Tensor.arange(W).reshape((1,1,1,W))
idx_y = Tensor.arange(H).reshape((1,1,H,1))
return (idx_x >= low_x) * (idx_x < (low_x + mask_size)) * (idx_y >= low_y) * (idx_y < (low_y + mask_size))
def random_crop(X:Tensor, crop_size=32):
mask = make_square_mask(X.shape, crop_size)