simplify onnx cubic (#11641)

we can drop the double where and abs since we know which ranges the inputs map into
This commit is contained in:
chenyu
2025-08-12 16:57:31 -07:00
committed by GitHub
parent 18cdbec447
commit e9e5a08a04

View File

@@ -819,13 +819,10 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
if mode == "cubic":
A = cubic_coeff_a
def W(x:Tensor):
# Keys weights
# see piecewise function in: https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
x = x.abs()
w0_1 = polyN(x, [A + 2, -(A + 3), 0, 1])
w1_2 = polyN(x, [A, -5 * A, 8 * A, -4 * A])
return (x <= 1).where(w0_1, (x < 2).where(w1_2, 0))
# Keys weights
# see piecewise function in: https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
def W0_1(x:Tensor): return polyN(x, [A + 2, -(A + 3), 0, 1])
def W1_2(x: Tensor): return polyN(x, [A, -5 * A, 8 * A, -4 * A])
expand = list(X.shape)
for i in range(-len(sizes), 0):
@@ -834,12 +831,12 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
reshape[i] = expand[i] = sizes[i]
p = index.floor().int()
ratio = index - p
ratio = index - p # in [0, 1]
# Neighbor indices
idx0, idx1, idx2, idx3 = [p + d for d in [-1, 0, 1, 2]]
# Weights of distance from index and neighbor indices
c0, c1, c2, c3 = [W(ratio - d) for d in [-1, 0, 1, 2]]
c0, c1, c2, c3 = W1_2(ratio+1), W0_1(ratio), W0_1(-(ratio-1)), W1_2(-(ratio-2))
if exclude_outside:
c0 = ((idx0 >= 0) & (idx0 < input_sz)).where(c0, 0)