mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
cleanup pad_reflect and make_square_mask in hlb_cifar (#3206)
removed some complicated looking stuff. no wall time difference
This commit is contained in:
@@ -146,37 +146,20 @@ def train_cifar():
|
||||
raise NotImplementedError(reduction)
|
||||
|
||||
# ========== Preprocessing ==========
|
||||
# TODO currently this only works for RGB in format of NxCxHxW and pads the HxW
|
||||
# implemented in recursive fashion but figuring out how to switch indexing dim
|
||||
# during the loop was a bit tricky
|
||||
# NOTE: this only works for RGB in format of NxCxHxW and pads the HxW
|
||||
def pad_reflect(X, size=2) -> Tensor:
|
||||
padding = ((0,0),(0,0),(size,size),(size,size))
|
||||
p = padding[3]
|
||||
s = X.shape[3]
|
||||
|
||||
X_lr = X[...,:,1:1+p[0]].flip(3).pad(((0,0),(0,0),(0,0),(0,s+p[0]))) + X[...,:,-1-p[1]:-1].flip(3).pad(((0,0),(0,0),(0,0),(s+p[1],0)))
|
||||
X = X.pad(((0,0),(0,0),(0,0),p)) + X_lr
|
||||
|
||||
p = padding[2]
|
||||
s = X.shape[2]
|
||||
X_lr = X[...,1:1+p[0],:].flip(2).pad(((0,0),(0,0),(0,s+p[0]),(0,0))) + X[...,-1-p[1]:-1,:].flip(2).pad(((0,0),(0,0),(s+p[1],0),(0,0)))
|
||||
X = X.pad(((0,0),(0,0),p,(0,0))) + X_lr
|
||||
|
||||
X = X[...,:,1:size+1].flip(-1).cat(X, X[...,:,-(size+1):-1].flip(-1), dim=-1)
|
||||
X = X[...,1:size+1,:].flip(-2).cat(X, X[...,-(size+1):-1,:].flip(-2), dim=-2)
|
||||
return X
|
||||
|
||||
# return a binary mask in the format of BS x C x H x W where H x W contains a random square mask
|
||||
def make_square_mask(shape, mask_size) -> Tensor:
|
||||
is_even = int(mask_size % 2 == 0)
|
||||
center_max = shape[-2]-mask_size//2-is_even
|
||||
center_min = mask_size//2-is_even
|
||||
center_x = Tensor.randint(shape[0], low=center_min, high=center_max)
|
||||
center_y = Tensor.randint(shape[0], low=center_min, high=center_max)
|
||||
d_x = Tensor.arange(0, shape[-1]).reshape((1,1,1,shape[-1])) - center_x.reshape((-1,1,1,1))
|
||||
d_y = Tensor.arange(0, shape[-2]).reshape((1,1,shape[-2],1)) - center_y.reshape((-1,1,1,1))
|
||||
d_x = (d_x >= -(mask_size // 2) + is_even) * (d_x <= mask_size // 2)
|
||||
d_y = (d_y >= -(mask_size // 2) + is_even) * (d_y <= mask_size // 2)
|
||||
mask = d_y * d_x
|
||||
return mask.cast(dtypes.bool)
|
||||
BS, _, H, W = shape
|
||||
low_x = Tensor.randint(BS, low=0, high=W-mask_size).reshape(BS,1,1,1)
|
||||
low_y = Tensor.randint(BS, low=0, high=H-mask_size).reshape(BS,1,1,1)
|
||||
idx_x = Tensor.arange(W).reshape((1,1,1,W))
|
||||
idx_y = Tensor.arange(H).reshape((1,1,H,1))
|
||||
return (idx_x >= low_x) * (idx_x < (low_x + mask_size)) * (idx_y >= low_y) * (idx_y < (low_y + mask_size))
|
||||
|
||||
def random_crop(X:Tensor, crop_size=32):
|
||||
mask = make_square_mask(X.shape, crop_size)
|
||||
|
||||
Reference in New Issue
Block a user