mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
using image from mad branch saves 1ms on op model
This commit is contained in:
@@ -27,8 +27,10 @@ def image_conv2d_decorator(normal_conv):
|
||||
|
||||
# packed (note: flipping bs and iy would make the auto-padding work)
|
||||
x = x.permute(0,2,3,1).reshape(bs * iy, ix * groups * cin//4, 4)
|
||||
cin_last = iy == 1 and ix == 1
|
||||
if cin == 1: w = w.reshape(cout//4,4,H*W).permute(0,2,1)
|
||||
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3).reshape(cout//4, H*cin//4*W*4, 4)
|
||||
elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3).reshape(cout//4, H*cin//4*W*4, 4)
|
||||
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1).reshape(cout//4, H*cin//4*W*4, 4)
|
||||
|
||||
# contiguous creates the image, and early realize static weights (TODO: don't always realize)
|
||||
x, w = x.contiguous(), w.contiguous().realize()
|
||||
@@ -37,7 +39,8 @@ def image_conv2d_decorator(normal_conv):
|
||||
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
|
||||
cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
|
||||
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
|
||||
w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
|
||||
if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
|
||||
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
|
||||
|
||||
# padding
|
||||
padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
|
||||
|
||||
Reference in New Issue
Block a user