mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
simplify onnx cubic (#11641)
we can drop the double where and abs since we know which ranges the inputs map into
This commit is contained in:
@@ -819,13 +819,10 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
if mode == "cubic":
|
||||
A = cubic_coeff_a
|
||||
|
||||
def W(x:Tensor):
|
||||
# Keys weights
|
||||
# see piecewise function in: https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
|
||||
x = x.abs()
|
||||
w0_1 = polyN(x, [A + 2, -(A + 3), 0, 1])
|
||||
w1_2 = polyN(x, [A, -5 * A, 8 * A, -4 * A])
|
||||
return (x <= 1).where(w0_1, (x < 2).where(w1_2, 0))
|
||||
# Keys weights
|
||||
# see piecewise function in: https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
|
||||
def W0_1(x:Tensor): return polyN(x, [A + 2, -(A + 3), 0, 1])
|
||||
def W1_2(x: Tensor): return polyN(x, [A, -5 * A, 8 * A, -4 * A])
|
||||
|
||||
expand = list(X.shape)
|
||||
for i in range(-len(sizes), 0):
|
||||
@@ -834,12 +831,12 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
reshape[i] = expand[i] = sizes[i]
|
||||
|
||||
p = index.floor().int()
|
||||
ratio = index - p
|
||||
ratio = index - p # in [0, 1]
|
||||
|
||||
# Neighbor indices
|
||||
idx0, idx1, idx2, idx3 = [p + d for d in [-1, 0, 1, 2]]
|
||||
# Weights of distance from index and neighbor indices
|
||||
c0, c1, c2, c3 = [W(ratio - d) for d in [-1, 0, 1, 2]]
|
||||
c0, c1, c2, c3 = W1_2(ratio+1), W0_1(ratio), W0_1(-(ratio-1)), W1_2(-(ratio-2))
|
||||
|
||||
if exclude_outside:
|
||||
c0 = ((idx0 >= 0) & (idx0 < input_sz)).where(c0, 0)
|
||||
|
||||
Reference in New Issue
Block a user