mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 22:38:01 -05:00
Remove op decomposition from the v_diffusion.py (#210)
The PyTorch decomposition for the op `aten.upsample_bilinear2d.vec` is merged in the upstream repo and hence removed from this file.
This commit is contained in:
@@ -7,7 +7,6 @@ import torch
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from torch._decomp import register_decomposition
|
||||
import tempfile
|
||||
|
||||
import math
|
||||
@@ -25,117 +24,6 @@ from diffusion import get_model, sampling, utils
|
||||
import torch_mlir
|
||||
|
||||
|
||||
@register_decomposition(torch.ops.aten.upsample_bilinear2d.vec)
|
||||
def upsample_bilinear2d_vec(input, output_size, align_corners, scale_factors):
|
||||
# get dimensions of original image
|
||||
n_batch, n_channels, in_h, in_w = input.shape
|
||||
|
||||
if output_size is not None:
|
||||
out_h = float(output_size[0])
|
||||
out_w = float(output_size[1])
|
||||
elif scale_factors is not None:
|
||||
out_h = in_h * scale_factors[0]
|
||||
out_w = in_w * scale_factors[1]
|
||||
|
||||
# Calculate horizontal and vertical scaling factor
|
||||
if out_h > 1:
|
||||
if align_corners:
|
||||
h_scale_factor = (in_h - 1) / (int(out_h) - 1)
|
||||
else:
|
||||
h_scale_factor = in_h / out_h
|
||||
else:
|
||||
h_scale_factor = 0.0
|
||||
|
||||
if out_w > 1:
|
||||
if align_corners:
|
||||
w_scale_factor = (in_w - 1) / (int(out_w) - 1)
|
||||
else:
|
||||
w_scale_factor = in_w / out_w
|
||||
else:
|
||||
w_scale_factor = 0.0
|
||||
|
||||
i = torch.arange(out_h, dtype=input.dtype, device=input.device)
|
||||
j = torch.arange(out_w, dtype=input.dtype, device=input.device)
|
||||
|
||||
if align_corners:
|
||||
x = h_scale_factor * i
|
||||
y = w_scale_factor * j
|
||||
else:
|
||||
x = (h_scale_factor * (i + 0.5) - 0.5).clamp(min=0.0)
|
||||
y = (w_scale_factor * (j + 0.5) - 0.5).clamp(min=0.0)
|
||||
|
||||
x_floor = torch.floor(x)
|
||||
x_ceil = torch.minimum(torch.ceil(x), torch.tensor(in_h - 1))
|
||||
y_floor = torch.floor(y)
|
||||
y_ceil = torch.minimum(torch.ceil(y), torch.tensor(in_w - 1))
|
||||
|
||||
x_view = x.view(1, 1, len(x), 1)
|
||||
x_floor_view = x_floor.view(1, 1, len(x_floor), 1)
|
||||
x_ceil_view = x_ceil.view(1, 1, len(x_ceil), 1)
|
||||
|
||||
y_view = y.view(1, 1, 1, len(y))
|
||||
y_floor_view = y_floor.view(1, 1, 1, len(y_floor))
|
||||
y_ceil_view = y_ceil.view(1, 1, 1, len(y_ceil))
|
||||
|
||||
v1 = input[:, :, x_floor.to(torch.int64), :][
|
||||
:, :, :, y_floor.to(torch.int64)
|
||||
]
|
||||
v2 = input[:, :, x_ceil.to(torch.int64), :][
|
||||
:, :, :, y_floor.to(torch.int64)
|
||||
]
|
||||
v3 = input[:, :, x_floor.to(torch.int64), :][
|
||||
:, :, :, y_ceil.to(torch.int64)
|
||||
]
|
||||
v4 = input[:, :, x_ceil.to(torch.int64), :][
|
||||
:, :, :, y_ceil.to(torch.int64)
|
||||
]
|
||||
q1 = torch.mul(v1, x_ceil_view - x_view) + torch.mul(
|
||||
v2, x_view - x_floor_view
|
||||
)
|
||||
q2 = torch.mul(v3, x_ceil_view - x_view) + torch.mul(
|
||||
v4, x_view - x_floor_view
|
||||
)
|
||||
result = torch.mul(q1, y_ceil_view - y_view) + torch.mul(
|
||||
q2, y_view - y_floor_view
|
||||
)
|
||||
|
||||
# When (x_ceil == x_floor) and (y_ceil == y_floor).
|
||||
result_cond1 = input[:, :, x.to(torch.int64), :][
|
||||
:, :, :, y.to(torch.int64)
|
||||
]
|
||||
|
||||
# When (x_ceil == x_floor).
|
||||
q1 = input[:, :, x.to(torch.int64), :][:, :, :, y_floor.to(torch.int64)]
|
||||
q2 = input[:, :, x.to(torch.int64), :][:, :, :, y_ceil.to(torch.int64)]
|
||||
result_cond2 = torch.mul(q1, y_ceil_view - y_view) + torch.mul(
|
||||
q2, y_view - y_floor_view
|
||||
)
|
||||
|
||||
# When (y_ceil == y_floor).
|
||||
q1 = input[:, :, x_floor.to(torch.int64), :][:, :, :, y.to(torch.int64)]
|
||||
q2 = input[:, :, x_ceil.to(torch.int64), :][:, :, :, y.to(torch.int64)]
|
||||
result_cond3 = torch.mul(q1, x_ceil_view - x_view) + torch.mul(
|
||||
q2, x_view - x_floor_view
|
||||
)
|
||||
|
||||
result = torch.where(
|
||||
torch.eq(x_ceil_view, x_floor_view), result_cond2, result
|
||||
)
|
||||
result = torch.where(
|
||||
torch.eq(y_ceil_view, y_floor_view), result_cond3, result
|
||||
)
|
||||
result = torch.where(
|
||||
torch.logical_and(
|
||||
torch.eq(x_ceil_view, x_floor_view),
|
||||
torch.eq(y_ceil_view, y_floor_view),
|
||||
),
|
||||
result_cond1,
|
||||
result,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Load the models
|
||||
model = get_model("cc12m_1_cfg")()
|
||||
_, side_y, side_x = model.shape
|
||||
|
||||
Reference in New Issue
Block a user