This commit is contained in:
Christopher Milan
2025-12-18 03:04:11 +00:00
parent b23eaa8b01
commit 8ef81d0eb9

View File

@@ -3892,7 +3892,8 @@ class Tensor(OpMixin):
w = w.pad_to(None, None, cin, None, None)
x = x.pad_to(None, None, cin, None, None).reshape(bs, groups*cin, iy, ix)
# hack for pitch alignment
# hacks for pitch alignment
assert isinstance(ix, int) and isinstance(H, int)
added_width = 0
if (ix*groups*cin) % (64 // dtsz):
added_width = round_up(ix, 64 // (dtsz * math.gcd(groups * cin, 64 // dtsz))) - ix
@@ -3941,10 +3942,10 @@ class Tensor(OpMixin):
w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *group_shape, *rcout_expand, rcin_hi, rcin_lo, H, W))
added_ox = 0
assert isinstance(ox, int) and isinstance(cout, int)
if (ox * cout) % (64 // dtsz):
added_ox = round_up(ox, 64 // (dtsz * math.gcd(cout, 64 // dtsz))) - ox
ox = ox + added_ox
# bs oy ox *group_shape 1 1, *rcin, H, W
x = x.pad_to(None, None, ox, None, None, None, None, None, None, None, None)
# the conv!