From cc4fa96831c262cb89df45c65507c581e28cd8b3 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 25 Jul 2022 17:36:26 +0530 Subject: [PATCH] Remove op decomposition from the v_diffusion.py (#210) The PyTorch decomposition for the op `aten.upsample_bilinear2d.vec` is merged in the upstream repo and hence removed from this file. --- tank/pytorch/v_diffusion/v_diffusion.py | 112 ------------------------ 1 file changed, 112 deletions(-) diff --git a/tank/pytorch/v_diffusion/v_diffusion.py b/tank/pytorch/v_diffusion/v_diffusion.py index 9ae95620..47176190 100644 --- a/tank/pytorch/v_diffusion/v_diffusion.py +++ b/tank/pytorch/v_diffusion/v_diffusion.py @@ -7,7 +7,6 @@ import torch from torch.fx.experimental.proxy_tensor import make_fx from torch._decomp import get_decompositions -from torch._decomp import register_decomposition import tempfile import math @@ -25,117 +24,6 @@ from diffusion import get_model, sampling, utils import torch_mlir -@register_decomposition(torch.ops.aten.upsample_bilinear2d.vec) -def upsample_bilinear2d_vec(input, output_size, align_corners, scale_factors): - # get dimensions of original image - n_batch, n_channels, in_h, in_w = input.shape - - if output_size is not None: - out_h = float(output_size[0]) - out_w = float(output_size[1]) - elif scale_factors is not None: - out_h = in_h * scale_factors[0] - out_w = in_w * scale_factors[1] - - # Calculate horizontal and vertical scaling factor - if out_h > 1: - if align_corners: - h_scale_factor = (in_h - 1) / (int(out_h) - 1) - else: - h_scale_factor = in_h / out_h - else: - h_scale_factor = 0.0 - - if out_w > 1: - if align_corners: - w_scale_factor = (in_w - 1) / (int(out_w) - 1) - else: - w_scale_factor = in_w / out_w - else: - w_scale_factor = 0.0 - - i = torch.arange(out_h, dtype=input.dtype, device=input.device) - j = torch.arange(out_w, dtype=input.dtype, device=input.device) - - if align_corners: - x = h_scale_factor * i - y = w_scale_factor * j - else: - x = (h_scale_factor * (i + 0.5) - 0.5).clamp(min=0.0) - y = (w_scale_factor * (j + 0.5) - 0.5).clamp(min=0.0) - - x_floor = torch.floor(x) - x_ceil = torch.minimum(torch.ceil(x), torch.tensor(in_h - 1)) - y_floor = torch.floor(y) - y_ceil = torch.minimum(torch.ceil(y), torch.tensor(in_w - 1)) - - x_view = x.view(1, 1, len(x), 1) - x_floor_view = x_floor.view(1, 1, len(x_floor), 1) - x_ceil_view = x_ceil.view(1, 1, len(x_ceil), 1) - - y_view = y.view(1, 1, 1, len(y)) - y_floor_view = y_floor.view(1, 1, 1, len(y_floor)) - y_ceil_view = y_ceil.view(1, 1, 1, len(y_ceil)) - - v1 = input[:, :, x_floor.to(torch.int64), :][ - :, :, :, y_floor.to(torch.int64) - ] - v2 = input[:, :, x_ceil.to(torch.int64), :][ - :, :, :, y_floor.to(torch.int64) - ] - v3 = input[:, :, x_floor.to(torch.int64), :][ - :, :, :, y_ceil.to(torch.int64) - ] - v4 = input[:, :, x_ceil.to(torch.int64), :][ - :, :, :, y_ceil.to(torch.int64) - ] - q1 = torch.mul(v1, x_ceil_view - x_view) + torch.mul( - v2, x_view - x_floor_view - ) - q2 = torch.mul(v3, x_ceil_view - x_view) + torch.mul( - v4, x_view - x_floor_view - ) - result = torch.mul(q1, y_ceil_view - y_view) + torch.mul( - q2, y_view - y_floor_view - ) - - # When (x_ceil == x_floor) and (y_ceil == y_floor). - result_cond1 = input[:, :, x.to(torch.int64), :][ - :, :, :, y.to(torch.int64) - ] - - # When (x_ceil == x_floor). - q1 = input[:, :, x.to(torch.int64), :][:, :, :, y_floor.to(torch.int64)] - q2 = input[:, :, x.to(torch.int64), :][:, :, :, y_ceil.to(torch.int64)] - result_cond2 = torch.mul(q1, y_ceil_view - y_view) + torch.mul( - q2, y_view - y_floor_view - ) - - # When (y_ceil == y_floor). - q1 = input[:, :, x_floor.to(torch.int64), :][:, :, :, y.to(torch.int64)] - q2 = input[:, :, x_ceil.to(torch.int64), :][:, :, :, y.to(torch.int64)] - result_cond3 = torch.mul(q1, x_ceil_view - x_view) + torch.mul( - q2, x_view - x_floor_view - ) - - result = torch.where( - torch.eq(x_ceil_view, x_floor_view), result_cond2, result - ) - result = torch.where( - torch.eq(y_ceil_view, y_floor_view), result_cond3, result - ) - result = torch.where( - torch.logical_and( - torch.eq(x_ceil_view, x_floor_view), - torch.eq(y_ceil_view, y_floor_view), - ), - result_cond1, - result, - ) - - return result - - # Load the models model = get_model("cc12m_1_cfg")() _, side_y, side_x = model.shape