using image from mad branch saves 1ms on op model

This commit is contained in:
George Hotz
2023-03-05 14:38:42 -08:00
parent 7940ad258e
commit 7989f79820

View File

@@ -27,8 +27,10 @@ def image_conv2d_decorator(normal_conv):
# packed (note: flipping bs and iy would make the auto-padding work)
x = x.permute(0,2,3,1).reshape(bs * iy, ix * groups * cin//4, 4)
cin_last = iy == 1 and ix == 1
if cin == 1: w = w.reshape(cout//4,4,H*W).permute(0,2,1)
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3).reshape(cout//4, H*cin//4*W*4, 4)
elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3).reshape(cout//4, H*cin//4*W*4, 4)
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1).reshape(cout//4, H*cin//4*W*4, 4)
# contiguous creates the image, and early realize static weights (TODO: don't always realize)
x, w = x.contiguous(), w.contiguous().realize()
@@ -37,7 +39,8 @@ def image_conv2d_decorator(normal_conv):
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
# padding
padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])