From 827b7a3c6450a6ea3dfd449852b038344ec310a5 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 22 Jan 2024 11:30:46 -0500 Subject: [PATCH] cleanup pad_reflect and make_square_mask in hlb_cifar (#3206) removed some complicated looking stuff. no wall time difference --- examples/hlb_cifar10.py | 35 +++++++++-------------------------- 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index 847557e1ce..dcb2ef1e42 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -146,37 +146,20 @@ def train_cifar(): raise NotImplementedError(reduction) # ========== Preprocessing ========== - # TODO currently this only works for RGB in format of NxCxHxW and pads the HxW - # implemented in recursive fashion but figuring out how to switch indexing dim - # during the loop was a bit tricky + # NOTE: this only works for RGB in format of NxCxHxW and pads the HxW def pad_reflect(X, size=2) -> Tensor: - padding = ((0,0),(0,0),(size,size),(size,size)) - p = padding[3] - s = X.shape[3] - - X_lr = X[...,:,1:1+p[0]].flip(3).pad(((0,0),(0,0),(0,0),(0,s+p[0]))) + X[...,:,-1-p[1]:-1].flip(3).pad(((0,0),(0,0),(0,0),(s+p[1],0))) - X = X.pad(((0,0),(0,0),(0,0),p)) + X_lr - - p = padding[2] - s = X.shape[2] - X_lr = X[...,1:1+p[0],:].flip(2).pad(((0,0),(0,0),(0,s+p[0]),(0,0))) + X[...,-1-p[1]:-1,:].flip(2).pad(((0,0),(0,0),(s+p[1],0),(0,0))) - X = X.pad(((0,0),(0,0),p,(0,0))) + X_lr - + X = X[...,:,1:size+1].flip(-1).cat(X, X[...,:,-(size+1):-1].flip(-1), dim=-1) + X = X[...,1:size+1,:].flip(-2).cat(X, X[...,-(size+1):-1,:].flip(-2), dim=-2) return X # return a binary mask in the format of BS x C x H x W where H x W contains a random square mask def make_square_mask(shape, mask_size) -> Tensor: - is_even = int(mask_size % 2 == 0) - center_max = shape[-2]-mask_size//2-is_even - center_min = mask_size//2-is_even - center_x = Tensor.randint(shape[0], low=center_min, high=center_max) - center_y = Tensor.randint(shape[0], low=center_min, high=center_max) - d_x = Tensor.arange(0, shape[-1]).reshape((1,1,1,shape[-1])) - center_x.reshape((-1,1,1,1)) - d_y = Tensor.arange(0, shape[-2]).reshape((1,1,shape[-2],1)) - center_y.reshape((-1,1,1,1)) - d_x = (d_x >= -(mask_size // 2) + is_even) * (d_x <= mask_size // 2) - d_y = (d_y >= -(mask_size // 2) + is_even) * (d_y <= mask_size // 2) - mask = d_y * d_x - return mask.cast(dtypes.bool) + BS, _, H, W = shape + low_x = Tensor.randint(BS, low=0, high=W-mask_size).reshape(BS,1,1,1) + low_y = Tensor.randint(BS, low=0, high=H-mask_size).reshape(BS,1,1,1) + idx_x = Tensor.arange(W).reshape((1,1,1,W)) + idx_y = Tensor.arange(H).reshape((1,1,H,1)) + return (idx_x >= low_x) * (idx_x < (low_x + mask_size)) * (idx_y >= low_y) * (idx_y < (low_y + mask_size)) def random_crop(X:Tensor, crop_size=32): mask = make_square_mask(X.shape, crop_size)