Remove op decomposition from the v_diffusion.py (#210)

The PyTorch decomposition for the op `aten.upsample_bilinear2d.vec`
is merged in the upstream repo and hence removed from this file.
This commit is contained in:
Vivek Khandelwal
2022-07-25 17:36:26 +05:30
committed by GitHub
parent 921ccdc40b
commit cc4fa96831

View File

@@ -7,7 +7,6 @@ import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from torch._decomp import register_decomposition
import tempfile
import math
@@ -25,117 +24,6 @@ from diffusion import get_model, sampling, utils
import torch_mlir
@register_decomposition(torch.ops.aten.upsample_bilinear2d.vec)
def upsample_bilinear2d_vec(input, output_size, align_corners, scale_factors):
# get dimensions of original image
n_batch, n_channels, in_h, in_w = input.shape
if output_size is not None:
out_h = float(output_size[0])
out_w = float(output_size[1])
elif scale_factors is not None:
out_h = in_h * scale_factors[0]
out_w = in_w * scale_factors[1]
# Calculate horizontal and vertical scaling factor
if out_h > 1:
if align_corners:
h_scale_factor = (in_h - 1) / (int(out_h) - 1)
else:
h_scale_factor = in_h / out_h
else:
h_scale_factor = 0.0
if out_w > 1:
if align_corners:
w_scale_factor = (in_w - 1) / (int(out_w) - 1)
else:
w_scale_factor = in_w / out_w
else:
w_scale_factor = 0.0
i = torch.arange(out_h, dtype=input.dtype, device=input.device)
j = torch.arange(out_w, dtype=input.dtype, device=input.device)
if align_corners:
x = h_scale_factor * i
y = w_scale_factor * j
else:
x = (h_scale_factor * (i + 0.5) - 0.5).clamp(min=0.0)
y = (w_scale_factor * (j + 0.5) - 0.5).clamp(min=0.0)
x_floor = torch.floor(x)
x_ceil = torch.minimum(torch.ceil(x), torch.tensor(in_h - 1))
y_floor = torch.floor(y)
y_ceil = torch.minimum(torch.ceil(y), torch.tensor(in_w - 1))
x_view = x.view(1, 1, len(x), 1)
x_floor_view = x_floor.view(1, 1, len(x_floor), 1)
x_ceil_view = x_ceil.view(1, 1, len(x_ceil), 1)
y_view = y.view(1, 1, 1, len(y))
y_floor_view = y_floor.view(1, 1, 1, len(y_floor))
y_ceil_view = y_ceil.view(1, 1, 1, len(y_ceil))
v1 = input[:, :, x_floor.to(torch.int64), :][
:, :, :, y_floor.to(torch.int64)
]
v2 = input[:, :, x_ceil.to(torch.int64), :][
:, :, :, y_floor.to(torch.int64)
]
v3 = input[:, :, x_floor.to(torch.int64), :][
:, :, :, y_ceil.to(torch.int64)
]
v4 = input[:, :, x_ceil.to(torch.int64), :][
:, :, :, y_ceil.to(torch.int64)
]
q1 = torch.mul(v1, x_ceil_view - x_view) + torch.mul(
v2, x_view - x_floor_view
)
q2 = torch.mul(v3, x_ceil_view - x_view) + torch.mul(
v4, x_view - x_floor_view
)
result = torch.mul(q1, y_ceil_view - y_view) + torch.mul(
q2, y_view - y_floor_view
)
# When (x_ceil == x_floor) and (y_ceil == y_floor).
result_cond1 = input[:, :, x.to(torch.int64), :][
:, :, :, y.to(torch.int64)
]
# When (x_ceil == x_floor).
q1 = input[:, :, x.to(torch.int64), :][:, :, :, y_floor.to(torch.int64)]
q2 = input[:, :, x.to(torch.int64), :][:, :, :, y_ceil.to(torch.int64)]
result_cond2 = torch.mul(q1, y_ceil_view - y_view) + torch.mul(
q2, y_view - y_floor_view
)
# When (y_ceil == y_floor).
q1 = input[:, :, x_floor.to(torch.int64), :][:, :, :, y.to(torch.int64)]
q2 = input[:, :, x_ceil.to(torch.int64), :][:, :, :, y.to(torch.int64)]
result_cond3 = torch.mul(q1, x_ceil_view - x_view) + torch.mul(
q2, x_view - x_floor_view
)
result = torch.where(
torch.eq(x_ceil_view, x_floor_view), result_cond2, result
)
result = torch.where(
torch.eq(y_ceil_view, y_floor_view), result_cond3, result
)
result = torch.where(
torch.logical_and(
torch.eq(x_ceil_view, x_floor_view),
torch.eq(y_ceil_view, y_floor_view),
),
result_cond1,
result,
)
return result
# Load the models
model = get_model("cc12m_1_cfg")()
_, side_y, side_x = model.shape