mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
better hlop image conv
This commit is contained in:
@@ -36,16 +36,28 @@ def image_conv2d_decorator(normal_conv):
|
||||
# contiguous creates the image, and early realize static weights (TODO: don't always realize)
|
||||
x, w = x.contiguous(), w.contiguous().realize()
|
||||
|
||||
# put it back for normal conv
|
||||
x = x.reshape(bs, iy, ix, groups*cin).permute(0,3,1,2)
|
||||
w = w.reshape(cout//4, H, cin//4 if cin >= 4 else 1, W, 4, 4 if cin >= 4 else 1).permute(0,4,2,5,1,3).reshape(cout, cin, H, W)
|
||||
# expand out
|
||||
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
|
||||
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
|
||||
w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
|
||||
|
||||
# run normal conv
|
||||
ret = normal_conv(x, w, None, groups, stride, dilation, padding)
|
||||
# padding
|
||||
padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
|
||||
x = x.slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
|
||||
|
||||
# make image sized
|
||||
oy, ox = ret.shape[2:]
|
||||
ret = ret.permute(0,2,3,1).reshape(bs*oy, ox*cout//4, 4)
|
||||
# prepare input
|
||||
x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
|
||||
oy, ox = x.shape[4:6]
|
||||
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, oy, ox, groups, 1, 1, rcin_hi, rcin_lo, H, W)
|
||||
x = x.expand(bs, oy, ox, groups, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1, rcin_hi, rcin_lo, H, W)
|
||||
x = x.reshape(bs, oy, ox, cout//4, 4, rcin_hi, rcin_lo, H, W)
|
||||
|
||||
# prepare weights
|
||||
w = w.permute(0,4,2,5,1,3)
|
||||
w = w.reshape((1, 1, 1, cout//4, 4, rcin_hi, rcin_lo, H, W)) # needed or this is broadcasting?
|
||||
|
||||
# the conv!
|
||||
ret = (x*w).sum((-4, -3, -2, -1)).reshape(bs*oy, ox*cout//4, 4)
|
||||
if IMAGE >= 3: ret = ret.contiguous()
|
||||
|
||||
# undo hack for non multiples of 4 on C.rcout
|
||||
|
||||
@@ -195,7 +195,7 @@ class Tensor:
|
||||
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))))
|
||||
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
|
||||
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=argfix(axis, *args))
|
||||
def slice(self, arg) -> Tensor: return mlops.Slice.apply(self, arg=tuple(arg))
|
||||
def slice(self, arg) -> Tensor: return mlops.Slice.apply(self, arg=tuple(a if a is not None else (0,s) for s,a in zip(self.shape, arg)))
|
||||
|
||||
# ***** movement hlops *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user