mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-02 07:05:07 -05:00
feat(nodes): add NormalMapInvocation
Similar to the existing node, but without any resizing and with a revised model loading API that uses the model manager. All code related to the invocation now lives in the Invoke repo. Unfortunately, this includes a whole git repo for EfficientNet. I believe we could use the package `timm` instead of this, but it's beyond me.
This commit is contained in:
committed by
Kent Keirsey
parent
fd42da5a36
commit
b3d60bd56a
31
invokeai/app/invocations/normal_bae.py
Normal file
31
invokeai/app/invocations/normal_bae.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.normal_bae import NormalMapDetector
|
||||
from invokeai.backend.image_util.normal_bae.nets.NNET import NNET
|
||||
|
||||
|
||||
@invocation(
|
||||
"normal_map",
|
||||
title="Normal Map",
|
||||
tags=["controlnet", "normal"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class NormalMapInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates a normal map."""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
loaded_model = context.models.load_remote_model(NormalMapDetector.get_model_url(), NormalMapDetector.load_model)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, NNET)
|
||||
detector = NormalMapDetector(model)
|
||||
normal_map = detector.run(image=image)
|
||||
|
||||
image_dto = context.images.save(image=normal_map)
|
||||
return ImageOutput.build(image_dto)
|
||||
21
invokeai/backend/image_util/normal_bae/LICENSE
Normal file
21
invokeai/backend/image_util/normal_bae/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 Caroline Chan
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
93
invokeai/backend/image_util/normal_bae/__init__.py
Normal file
93
invokeai/backend/image_util/normal_bae/__init__.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Adapted from https://github.com/huggingface/controlnet_aux
|
||||
|
||||
import pathlib
|
||||
import types
|
||||
|
||||
import cv2
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.normal_bae.nets.NNET import NNET
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np, resize_to_multiple
|
||||
|
||||
|
||||
class NormalMapDetector:
|
||||
"""Simple wrapper around the Normal BAE model for normal map generation."""
|
||||
|
||||
hf_repo_id = "lllyasviel/Annotators"
|
||||
hf_filename = "scannet.pt"
|
||||
|
||||
@classmethod
|
||||
def get_model_url(cls) -> str:
|
||||
"""Get the URL to download the model from the Hugging Face Hub."""
|
||||
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename)
|
||||
|
||||
@classmethod
|
||||
def load_model(cls, model_path: pathlib.Path) -> NNET:
|
||||
"""Load the model from a file."""
|
||||
|
||||
args = types.SimpleNamespace()
|
||||
args.mode = "client"
|
||||
args.architecture = "BN"
|
||||
args.pretrained = "scannet"
|
||||
args.sampling_ratio = 0.4
|
||||
args.importance_ratio = 0.7
|
||||
|
||||
model = NNET(args)
|
||||
|
||||
ckpt = torch.load(model_path, map_location="cpu")["model"]
|
||||
load_dict = {}
|
||||
for k, v in ckpt.items():
|
||||
if k.startswith("module."):
|
||||
k_ = k.replace("module.", "")
|
||||
load_dict[k_] = v
|
||||
else:
|
||||
load_dict[k] = v
|
||||
|
||||
model.load_state_dict(load_dict)
|
||||
model.eval()
|
||||
|
||||
return model
|
||||
|
||||
def __init__(self, model: NNET) -> None:
|
||||
self.model = model
|
||||
self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
def to(self, device: torch.device):
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
def run(self, image: Image.Image):
|
||||
"""Processes an image and returns the detected normal map."""
|
||||
|
||||
device = next(iter(self.model.parameters())).device
|
||||
np_image = pil_to_np(image)
|
||||
|
||||
height, width, _channels = np_image.shape
|
||||
|
||||
# The model requires the image to be a multiple of 8
|
||||
np_image = resize_to_multiple(np_image, 8)
|
||||
|
||||
image_normal = np_image
|
||||
|
||||
with torch.no_grad():
|
||||
image_normal = torch.from_numpy(image_normal).float().to(device)
|
||||
image_normal = image_normal / 255.0
|
||||
image_normal = rearrange(image_normal, "h w c -> 1 c h w")
|
||||
image_normal = self.norm(image_normal)
|
||||
|
||||
normal = self.model(image_normal)
|
||||
normal = normal[0][-1][:, :3]
|
||||
normal = ((normal + 1) * 0.5).clip(0, 1)
|
||||
|
||||
normal = rearrange(normal[0], "c h w -> h w c").cpu().numpy()
|
||||
normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8)
|
||||
|
||||
# Back to the original size
|
||||
output_image = cv2.resize(normal_image, (width, height), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
return np_to_pil(output_image)
|
||||
22
invokeai/backend/image_util/normal_bae/nets/NNET.py
Normal file
22
invokeai/backend/image_util/normal_bae/nets/NNET.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .submodules.encoder import Encoder
|
||||
from .submodules.decoder import Decoder
|
||||
|
||||
|
||||
class NNET(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(NNET, self).__init__()
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder(args)
|
||||
|
||||
def get_1x_lr_params(self): # lr/10 learning rate
|
||||
return self.encoder.parameters()
|
||||
|
||||
def get_10x_lr_params(self): # lr learning rate
|
||||
return self.decoder.parameters()
|
||||
|
||||
def forward(self, img, **kwargs):
|
||||
return self.decoder(self.encoder(img), **kwargs)
|
||||
85
invokeai/backend/image_util/normal_bae/nets/baseline.py
Normal file
85
invokeai/backend/image_util/normal_bae/nets/baseline.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .submodules.submodules import UpSampleBN, norm_normalize
|
||||
|
||||
|
||||
# This is the baseline encoder-decoder we used in the ablation study
|
||||
class NNET(nn.Module):
|
||||
def __init__(self, args=None):
|
||||
super(NNET, self).__init__()
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder(num_classes=4)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
out = self.decoder(self.encoder(x), **kwargs)
|
||||
|
||||
# Bilinearly upsample the output to match the input resolution
|
||||
up_out = F.interpolate(out, size=[x.size(2), x.size(3)], mode='bilinear', align_corners=False)
|
||||
|
||||
# L2-normalize the first three channels / ensure positive value for concentration parameters (kappa)
|
||||
up_out = norm_normalize(up_out)
|
||||
return up_out
|
||||
|
||||
def get_1x_lr_params(self): # lr/10 learning rate
|
||||
return self.encoder.parameters()
|
||||
|
||||
def get_10x_lr_params(self): # lr learning rate
|
||||
modules = [self.decoder]
|
||||
for m in modules:
|
||||
yield from m.parameters()
|
||||
|
||||
|
||||
# Encoder
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self):
|
||||
super(Encoder, self).__init__()
|
||||
|
||||
basemodel_name = 'tf_efficientnet_b5_ap'
|
||||
basemodel = torch.hub.load('rwightman/gen-efficientnet-pytorch', basemodel_name, pretrained=True)
|
||||
|
||||
# Remove last layer
|
||||
basemodel.global_pool = nn.Identity()
|
||||
basemodel.classifier = nn.Identity()
|
||||
|
||||
self.original_model = basemodel
|
||||
|
||||
def forward(self, x):
|
||||
features = [x]
|
||||
for k, v in self.original_model._modules.items():
|
||||
if (k == 'blocks'):
|
||||
for ki, vi in v._modules.items():
|
||||
features.append(vi(features[-1]))
|
||||
else:
|
||||
features.append(v(features[-1]))
|
||||
return features
|
||||
|
||||
|
||||
# Decoder (no pixel-wise MLP, no uncertainty-guided sampling)
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, num_classes=4):
|
||||
super(Decoder, self).__init__()
|
||||
self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0)
|
||||
self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024)
|
||||
self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512)
|
||||
self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256)
|
||||
self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128)
|
||||
self.conv3 = nn.Conv2d(128, num_classes, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, features):
|
||||
x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11]
|
||||
x_d0 = self.conv2(x_block4)
|
||||
x_d1 = self.up1(x_d0, x_block3)
|
||||
x_d2 = self.up2(x_d1, x_block2)
|
||||
x_d3 = self.up3(x_d2, x_block1)
|
||||
x_d4 = self.up4(x_d3, x_block0)
|
||||
out = self.conv3(x_d4)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = Baseline()
|
||||
x = torch.rand(2, 3, 480, 640)
|
||||
out = model(x)
|
||||
print(out.shape)
|
||||
@@ -0,0 +1,202 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .submodules import UpSampleBN, UpSampleGN, norm_normalize, sample_points
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(Decoder, self).__init__()
|
||||
|
||||
# hyper-parameter for sampling
|
||||
self.sampling_ratio = args.sampling_ratio
|
||||
self.importance_ratio = args.importance_ratio
|
||||
|
||||
# feature-map
|
||||
self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0)
|
||||
if args.architecture == 'BN':
|
||||
self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024)
|
||||
self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512)
|
||||
self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256)
|
||||
self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128)
|
||||
|
||||
elif args.architecture == 'GN':
|
||||
self.up1 = UpSampleGN(skip_input=2048 + 176, output_features=1024)
|
||||
self.up2 = UpSampleGN(skip_input=1024 + 64, output_features=512)
|
||||
self.up3 = UpSampleGN(skip_input=512 + 40, output_features=256)
|
||||
self.up4 = UpSampleGN(skip_input=256 + 24, output_features=128)
|
||||
|
||||
else:
|
||||
raise Exception('invalid architecture')
|
||||
|
||||
# produces 1/8 res output
|
||||
self.out_conv_res8 = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# produces 1/4 res output
|
||||
self.out_conv_res4 = nn.Sequential(
|
||||
nn.Conv1d(512 + 4, 128, kernel_size=1), nn.ReLU(),
|
||||
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
|
||||
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
|
||||
nn.Conv1d(128, 4, kernel_size=1),
|
||||
)
|
||||
|
||||
# produces 1/2 res output
|
||||
self.out_conv_res2 = nn.Sequential(
|
||||
nn.Conv1d(256 + 4, 128, kernel_size=1), nn.ReLU(),
|
||||
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
|
||||
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
|
||||
nn.Conv1d(128, 4, kernel_size=1),
|
||||
)
|
||||
|
||||
# produces 1/1 res output
|
||||
self.out_conv_res1 = nn.Sequential(
|
||||
nn.Conv1d(128 + 4, 128, kernel_size=1), nn.ReLU(),
|
||||
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
|
||||
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
|
||||
nn.Conv1d(128, 4, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, features, gt_norm_mask=None, mode='test'):
|
||||
x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11]
|
||||
|
||||
# generate feature-map
|
||||
|
||||
x_d0 = self.conv2(x_block4) # x_d0 : [2, 2048, 15, 20] 1/32 res
|
||||
x_d1 = self.up1(x_d0, x_block3) # x_d1 : [2, 1024, 30, 40] 1/16 res
|
||||
x_d2 = self.up2(x_d1, x_block2) # x_d2 : [2, 512, 60, 80] 1/8 res
|
||||
x_d3 = self.up3(x_d2, x_block1) # x_d3: [2, 256, 120, 160] 1/4 res
|
||||
x_d4 = self.up4(x_d3, x_block0) # x_d4: [2, 128, 240, 320] 1/2 res
|
||||
|
||||
# 1/8 res output
|
||||
out_res8 = self.out_conv_res8(x_d2) # out_res8: [2, 4, 60, 80] 1/8 res output
|
||||
out_res8 = norm_normalize(out_res8) # out_res8: [2, 4, 60, 80] 1/8 res output
|
||||
|
||||
################################################################################################################
|
||||
# out_res4
|
||||
################################################################################################################
|
||||
|
||||
if mode == 'train':
|
||||
# upsampling ... out_res8: [2, 4, 60, 80] -> out_res8_res4: [2, 4, 120, 160]
|
||||
out_res8_res4 = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True)
|
||||
B, _, H, W = out_res8_res4.shape
|
||||
|
||||
# samples: [B, 1, N, 2]
|
||||
point_coords_res4, rows_int, cols_int = sample_points(out_res8_res4.detach(), gt_norm_mask,
|
||||
sampling_ratio=self.sampling_ratio,
|
||||
beta=self.importance_ratio)
|
||||
|
||||
# output (needed for evaluation / visualization)
|
||||
out_res4 = out_res8_res4
|
||||
|
||||
# grid_sample feature-map
|
||||
feat_res4 = F.grid_sample(x_d2, point_coords_res4, mode='bilinear', align_corners=True) # (B, 512, 1, N)
|
||||
init_pred = F.grid_sample(out_res8, point_coords_res4, mode='bilinear', align_corners=True) # (B, 4, 1, N)
|
||||
feat_res4 = torch.cat([feat_res4, init_pred], dim=1) # (B, 512+4, 1, N)
|
||||
|
||||
# prediction (needed to compute loss)
|
||||
samples_pred_res4 = self.out_conv_res4(feat_res4[:, :, 0, :]) # (B, 4, N)
|
||||
samples_pred_res4 = norm_normalize(samples_pred_res4) # (B, 4, N) - normalized
|
||||
|
||||
for i in range(B):
|
||||
out_res4[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res4[i, :, :]
|
||||
|
||||
else:
|
||||
# grid_sample feature-map
|
||||
feat_map = F.interpolate(x_d2, scale_factor=2, mode='bilinear', align_corners=True)
|
||||
init_pred = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True)
|
||||
feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
|
||||
B, _, H, W = feat_map.shape
|
||||
|
||||
# try all pixels
|
||||
out_res4 = self.out_conv_res4(feat_map.view(B, 512 + 4, -1)) # (B, 4, N)
|
||||
out_res4 = norm_normalize(out_res4) # (B, 4, N) - normalized
|
||||
out_res4 = out_res4.view(B, 4, H, W)
|
||||
samples_pred_res4 = point_coords_res4 = None
|
||||
|
||||
################################################################################################################
|
||||
# out_res2
|
||||
################################################################################################################
|
||||
|
||||
if mode == 'train':
|
||||
|
||||
# upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320]
|
||||
out_res4_res2 = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True)
|
||||
B, _, H, W = out_res4_res2.shape
|
||||
|
||||
# samples: [B, 1, N, 2]
|
||||
point_coords_res2, rows_int, cols_int = sample_points(out_res4_res2.detach(), gt_norm_mask,
|
||||
sampling_ratio=self.sampling_ratio,
|
||||
beta=self.importance_ratio)
|
||||
|
||||
# output (needed for evaluation / visualization)
|
||||
out_res2 = out_res4_res2
|
||||
|
||||
# grid_sample feature-map
|
||||
feat_res2 = F.grid_sample(x_d3, point_coords_res2, mode='bilinear', align_corners=True) # (B, 256, 1, N)
|
||||
init_pred = F.grid_sample(out_res4, point_coords_res2, mode='bilinear', align_corners=True) # (B, 4, 1, N)
|
||||
feat_res2 = torch.cat([feat_res2, init_pred], dim=1) # (B, 256+4, 1, N)
|
||||
|
||||
# prediction (needed to compute loss)
|
||||
samples_pred_res2 = self.out_conv_res2(feat_res2[:, :, 0, :]) # (B, 4, N)
|
||||
samples_pred_res2 = norm_normalize(samples_pred_res2) # (B, 4, N) - normalized
|
||||
|
||||
for i in range(B):
|
||||
out_res2[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res2[i, :, :]
|
||||
|
||||
else:
|
||||
# grid_sample feature-map
|
||||
feat_map = F.interpolate(x_d3, scale_factor=2, mode='bilinear', align_corners=True)
|
||||
init_pred = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True)
|
||||
feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
|
||||
B, _, H, W = feat_map.shape
|
||||
|
||||
out_res2 = self.out_conv_res2(feat_map.view(B, 256 + 4, -1)) # (B, 4, N)
|
||||
out_res2 = norm_normalize(out_res2) # (B, 4, N) - normalized
|
||||
out_res2 = out_res2.view(B, 4, H, W)
|
||||
samples_pred_res2 = point_coords_res2 = None
|
||||
|
||||
################################################################################################################
|
||||
# out_res1
|
||||
################################################################################################################
|
||||
|
||||
if mode == 'train':
|
||||
# upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320]
|
||||
out_res2_res1 = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True)
|
||||
B, _, H, W = out_res2_res1.shape
|
||||
|
||||
# samples: [B, 1, N, 2]
|
||||
point_coords_res1, rows_int, cols_int = sample_points(out_res2_res1.detach(), gt_norm_mask,
|
||||
sampling_ratio=self.sampling_ratio,
|
||||
beta=self.importance_ratio)
|
||||
|
||||
# output (needed for evaluation / visualization)
|
||||
out_res1 = out_res2_res1
|
||||
|
||||
# grid_sample feature-map
|
||||
feat_res1 = F.grid_sample(x_d4, point_coords_res1, mode='bilinear', align_corners=True) # (B, 128, 1, N)
|
||||
init_pred = F.grid_sample(out_res2, point_coords_res1, mode='bilinear', align_corners=True) # (B, 4, 1, N)
|
||||
feat_res1 = torch.cat([feat_res1, init_pred], dim=1) # (B, 128+4, 1, N)
|
||||
|
||||
# prediction (needed to compute loss)
|
||||
samples_pred_res1 = self.out_conv_res1(feat_res1[:, :, 0, :]) # (B, 4, N)
|
||||
samples_pred_res1 = norm_normalize(samples_pred_res1) # (B, 4, N) - normalized
|
||||
|
||||
for i in range(B):
|
||||
out_res1[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res1[i, :, :]
|
||||
|
||||
else:
|
||||
# grid_sample feature-map
|
||||
feat_map = F.interpolate(x_d4, scale_factor=2, mode='bilinear', align_corners=True)
|
||||
init_pred = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True)
|
||||
feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
|
||||
B, _, H, W = feat_map.shape
|
||||
|
||||
out_res1 = self.out_conv_res1(feat_map.view(B, 128 + 4, -1)) # (B, 4, N)
|
||||
out_res1 = norm_normalize(out_res1) # (B, 4, N) - normalized
|
||||
out_res1 = out_res1.view(B, 4, H, W)
|
||||
samples_pred_res1 = point_coords_res1 = None
|
||||
|
||||
return [out_res8, out_res4, out_res2, out_res1], \
|
||||
[out_res8, samples_pred_res4, samples_pred_res2, samples_pred_res1], \
|
||||
[None, point_coords_res4, point_coords_res2, point_coords_res1]
|
||||
|
||||
109
invokeai/backend/image_util/normal_bae/nets/submodules/efficientnet_repo/.gitignore
vendored
Normal file
109
invokeai/backend/image_util/normal_bae/nets/submodules/efficientnet_repo/.gitignore
vendored
Normal file
@@ -0,0 +1,109 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# pytorch stuff
|
||||
*.pth
|
||||
*.onnx
|
||||
*.pb
|
||||
|
||||
trained_models/
|
||||
.fuse_hidden*
|
||||
@@ -0,0 +1,555 @@
|
||||
# Model Performance Benchmarks
|
||||
|
||||
All benchmarks run as per:
|
||||
|
||||
```
|
||||
python onnx_export.py --model mobilenetv3_100 ./mobilenetv3_100.onnx
|
||||
python onnx_optimize.py ./mobilenetv3_100.onnx --output mobilenetv3_100-opt.onnx
|
||||
python onnx_to_caffe.py ./mobilenetv3_100.onnx --c2-prefix mobilenetv3
|
||||
python onnx_to_caffe.py ./mobilenetv3_100-opt.onnx --c2-prefix mobilenetv3-opt
|
||||
python caffe2_benchmark.py --c2-init ./mobilenetv3.init.pb --c2-predict ./mobilenetv3.predict.pb
|
||||
python caffe2_benchmark.py --c2-init ./mobilenetv3-opt.init.pb --c2-predict ./mobilenetv3-opt.predict.pb
|
||||
```
|
||||
|
||||
## EfficientNet-B0
|
||||
|
||||
### Unoptimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 49.2862. Iters per second: 20.2897
|
||||
Time per operator type:
|
||||
29.7378 ms. 60.5145%. Conv
|
||||
12.1785 ms. 24.7824%. Sigmoid
|
||||
3.62811 ms. 7.38297%. SpatialBN
|
||||
2.98444 ms. 6.07314%. Mul
|
||||
0.326902 ms. 0.665225%. AveragePool
|
||||
0.197317 ms. 0.401528%. FC
|
||||
0.0852877 ms. 0.173555%. Add
|
||||
0.0032607 ms. 0.00663532%. Squeeze
|
||||
49.1416 ms in Total
|
||||
FLOP per operator type:
|
||||
0.76907 GFLOP. 95.2696%. Conv
|
||||
0.0269508 GFLOP. 3.33857%. SpatialBN
|
||||
0.00846444 GFLOP. 1.04855%. Mul
|
||||
0.002561 GFLOP. 0.317248%. FC
|
||||
0.000210112 GFLOP. 0.0260279%. Add
|
||||
0.807256 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
58.5253 MB. 43.0891%. Mul
|
||||
43.2015 MB. 31.807%. Conv
|
||||
27.2869 MB. 20.0899%. SpatialBN
|
||||
5.12912 MB. 3.77631%. FC
|
||||
1.6809 MB. 1.23756%. Add
|
||||
135.824 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
33.8578 MB. 38.1965%. Mul
|
||||
26.9881 MB. 30.4465%. Conv
|
||||
26.9508 MB. 30.4044%. SpatialBN
|
||||
0.840448 MB. 0.948147%. Add
|
||||
0.004 MB. 0.00451258%. FC
|
||||
88.6412 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
15.8248 MB. 74.9391%. Conv
|
||||
5.124 MB. 24.265%. FC
|
||||
0.168064 MB. 0.795877%. SpatialBN
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Mul
|
||||
21.1168 MB in Total
|
||||
```
|
||||
### Optimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 46.0838. Iters per second: 21.6996
|
||||
Time per operator type:
|
||||
29.776 ms. 65.002%. Conv
|
||||
12.2803 ms. 26.8084%. Sigmoid
|
||||
3.15073 ms. 6.87815%. Mul
|
||||
0.328651 ms. 0.717456%. AveragePool
|
||||
0.186237 ms. 0.406563%. FC
|
||||
0.0832429 ms. 0.181722%. Add
|
||||
0.0026184 ms. 0.00571606%. Squeeze
|
||||
45.8078 ms in Total
|
||||
FLOP per operator type:
|
||||
0.76907 GFLOP. 98.5601%. Conv
|
||||
0.00846444 GFLOP. 1.08476%. Mul
|
||||
0.002561 GFLOP. 0.328205%. FC
|
||||
0.000210112 GFLOP. 0.0269269%. Add
|
||||
0.780305 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
58.5253 MB. 53.8803%. Mul
|
||||
43.2855 MB. 39.8501%. Conv
|
||||
5.12912 MB. 4.72204%. FC
|
||||
1.6809 MB. 1.54749%. Add
|
||||
108.621 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
33.8578 MB. 54.8834%. Mul
|
||||
26.9881 MB. 43.7477%. Conv
|
||||
0.840448 MB. 1.36237%. Add
|
||||
0.004 MB. 0.00648399%. FC
|
||||
61.6904 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
15.8248 MB. 75.5403%. Conv
|
||||
5.124 MB. 24.4597%. FC
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Mul
|
||||
20.9488 MB in Total
|
||||
```
|
||||
|
||||
## EfficientNet-B1
|
||||
### Optimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 71.8102. Iters per second: 13.9256
|
||||
Time per operator type:
|
||||
45.7915 ms. 66.3206%. Conv
|
||||
17.8718 ms. 25.8841%. Sigmoid
|
||||
4.44132 ms. 6.43244%. Mul
|
||||
0.51001 ms. 0.738658%. AveragePool
|
||||
0.233283 ms. 0.337868%. Add
|
||||
0.194986 ms. 0.282402%. FC
|
||||
0.00268255 ms. 0.00388519%. Squeeze
|
||||
69.0456 ms in Total
|
||||
FLOP per operator type:
|
||||
1.37105 GFLOP. 98.7673%. Conv
|
||||
0.0138759 GFLOP. 0.99959%. Mul
|
||||
0.002561 GFLOP. 0.184489%. FC
|
||||
0.000674432 GFLOP. 0.0485847%. Add
|
||||
1.38816 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
94.624 MB. 54.0789%. Mul
|
||||
69.8255 MB. 39.9062%. Conv
|
||||
5.39546 MB. 3.08357%. Add
|
||||
5.12912 MB. 2.93136%. FC
|
||||
174.974 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
55.5035 MB. 54.555%. Mul
|
||||
43.5333 MB. 42.7894%. Conv
|
||||
2.69773 MB. 2.65163%. Add
|
||||
0.004 MB. 0.00393165%. FC
|
||||
101.739 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
25.7479 MB. 83.4024%. Conv
|
||||
5.124 MB. 16.5976%. FC
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Mul
|
||||
30.8719 MB in Total
|
||||
```
|
||||
|
||||
## EfficientNet-B2
|
||||
### Optimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 92.28. Iters per second: 10.8366
|
||||
Time per operator type:
|
||||
61.4627 ms. 67.5845%. Conv
|
||||
22.7458 ms. 25.0113%. Sigmoid
|
||||
5.59931 ms. 6.15701%. Mul
|
||||
0.642567 ms. 0.706568%. AveragePool
|
||||
0.272795 ms. 0.299965%. Add
|
||||
0.216178 ms. 0.237709%. FC
|
||||
0.00268895 ms. 0.00295677%. Squeeze
|
||||
90.942 ms in Total
|
||||
FLOP per operator type:
|
||||
1.98431 GFLOP. 98.9343%. Conv
|
||||
0.0177039 GFLOP. 0.882686%. Mul
|
||||
0.002817 GFLOP. 0.140451%. FC
|
||||
0.000853984 GFLOP. 0.0425782%. Add
|
||||
2.00568 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
120.609 MB. 54.9637%. Mul
|
||||
86.3512 MB. 39.3519%. Conv
|
||||
6.83187 MB. 3.11341%. Add
|
||||
5.64163 MB. 2.571%. FC
|
||||
219.433 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
70.8155 MB. 54.6573%. Mul
|
||||
55.3273 MB. 42.7031%. Conv
|
||||
3.41594 MB. 2.63651%. Add
|
||||
0.004 MB. 0.00308731%. FC
|
||||
129.563 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
30.4721 MB. 84.3913%. Conv
|
||||
5.636 MB. 15.6087%. FC
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Mul
|
||||
36.1081 MB in Total
|
||||
```
|
||||
|
||||
## MixNet-M
|
||||
### Optimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 63.1122. Iters per second: 15.8448
|
||||
Time per operator type:
|
||||
48.1139 ms. 75.2052%. Conv
|
||||
7.1341 ms. 11.1511%. Sigmoid
|
||||
2.63706 ms. 4.12189%. SpatialBN
|
||||
1.73186 ms. 2.70701%. Mul
|
||||
1.38707 ms. 2.16809%. Split
|
||||
1.29322 ms. 2.02139%. Concat
|
||||
1.00093 ms. 1.56452%. Relu
|
||||
0.235309 ms. 0.367803%. Add
|
||||
0.221579 ms. 0.346343%. FC
|
||||
0.219315 ms. 0.342803%. AveragePool
|
||||
0.00250145 ms. 0.00390993%. Squeeze
|
||||
63.9768 ms in Total
|
||||
FLOP per operator type:
|
||||
0.675273 GFLOP. 95.5827%. Conv
|
||||
0.0221072 GFLOP. 3.12921%. SpatialBN
|
||||
0.00538445 GFLOP. 0.762152%. Mul
|
||||
0.003073 GFLOP. 0.434973%. FC
|
||||
0.000642488 GFLOP. 0.0909421%. Add
|
||||
0 GFLOP. 0%. Concat
|
||||
0 GFLOP. 0%. Relu
|
||||
0.70648 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
46.8424 MB. 30.502%. Conv
|
||||
36.8626 MB. 24.0036%. Mul
|
||||
22.3152 MB. 14.5309%. SpatialBN
|
||||
22.1074 MB. 14.3955%. Concat
|
||||
14.1496 MB. 9.21372%. Relu
|
||||
6.15414 MB. 4.00735%. FC
|
||||
5.1399 MB. 3.34692%. Add
|
||||
153.571 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
32.7672 MB. 28.4331%. Conv
|
||||
22.1072 MB. 19.1831%. Concat
|
||||
22.1072 MB. 19.1831%. SpatialBN
|
||||
21.5378 MB. 18.689%. Mul
|
||||
14.1496 MB. 12.2781%. Relu
|
||||
2.56995 MB. 2.23003%. Add
|
||||
0.004 MB. 0.00347092%. FC
|
||||
115.243 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
13.7059 MB. 68.674%. Conv
|
||||
6.148 MB. 30.8049%. FC
|
||||
0.104 MB. 0.521097%. SpatialBN
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Concat
|
||||
0 MB. 0%. Mul
|
||||
0 MB. 0%. Relu
|
||||
19.9579 MB in Total
|
||||
```
|
||||
|
||||
## TF MobileNet-V3 Large 1.0
|
||||
|
||||
### Optimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 22.0495. Iters per second: 45.3525
|
||||
Time per operator type:
|
||||
17.437 ms. 80.0087%. Conv
|
||||
1.27662 ms. 5.8577%. Add
|
||||
1.12759 ms. 5.17387%. Div
|
||||
0.701155 ms. 3.21721%. Mul
|
||||
0.562654 ms. 2.58171%. Relu
|
||||
0.431144 ms. 1.97828%. Clip
|
||||
0.156902 ms. 0.719936%. FC
|
||||
0.0996858 ms. 0.457402%. AveragePool
|
||||
0.00112455 ms. 0.00515993%. Flatten
|
||||
21.7939 ms in Total
|
||||
FLOP per operator type:
|
||||
0.43062 GFLOP. 98.1484%. Conv
|
||||
0.002561 GFLOP. 0.583713%. FC
|
||||
0.00210867 GFLOP. 0.480616%. Mul
|
||||
0.00193868 GFLOP. 0.441871%. Add
|
||||
0.00151532 GFLOP. 0.345377%. Div
|
||||
0 GFLOP. 0%. Relu
|
||||
0.438743 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
34.7967 MB. 43.9391%. Conv
|
||||
14.496 MB. 18.3046%. Mul
|
||||
9.44828 MB. 11.9307%. Add
|
||||
9.26157 MB. 11.6949%. Relu
|
||||
6.0614 MB. 7.65395%. Div
|
||||
5.12912 MB. 6.47673%. FC
|
||||
79.193 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
17.6247 MB. 35.8656%. Conv
|
||||
9.26157 MB. 18.847%. Relu
|
||||
8.43469 MB. 17.1643%. Mul
|
||||
7.75472 MB. 15.7806%. Add
|
||||
6.06128 MB. 12.3345%. Div
|
||||
0.004 MB. 0.00813985%. FC
|
||||
49.1409 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
16.6851 MB. 76.5052%. Conv
|
||||
5.124 MB. 23.4948%. FC
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Div
|
||||
0 MB. 0%. Mul
|
||||
0 MB. 0%. Relu
|
||||
21.8091 MB in Total
|
||||
```
|
||||
|
||||
## MobileNet-V3 (RW)
|
||||
|
||||
### Unoptimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 24.8316. Iters per second: 40.2712
|
||||
Time per operator type:
|
||||
15.9266 ms. 69.2624%. Conv
|
||||
2.36551 ms. 10.2873%. SpatialBN
|
||||
1.39102 ms. 6.04936%. Add
|
||||
1.30327 ms. 5.66773%. Div
|
||||
0.737014 ms. 3.20517%. Mul
|
||||
0.639697 ms. 2.78195%. Relu
|
||||
0.375681 ms. 1.63378%. Clip
|
||||
0.153126 ms. 0.665921%. FC
|
||||
0.0993787 ms. 0.432184%. AveragePool
|
||||
0.0032632 ms. 0.0141912%. Squeeze
|
||||
22.9946 ms in Total
|
||||
FLOP per operator type:
|
||||
0.430616 GFLOP. 94.4041%. Conv
|
||||
0.0175992 GFLOP. 3.85829%. SpatialBN
|
||||
0.002561 GFLOP. 0.561449%. FC
|
||||
0.00210961 GFLOP. 0.46249%. Mul
|
||||
0.00173891 GFLOP. 0.381223%. Add
|
||||
0.00151626 GFLOP. 0.33241%. Div
|
||||
0 GFLOP. 0%. Relu
|
||||
0.456141 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
34.7354 MB. 36.4363%. Conv
|
||||
17.7944 MB. 18.6658%. SpatialBN
|
||||
14.5035 MB. 15.2137%. Mul
|
||||
9.25778 MB. 9.71113%. Relu
|
||||
7.84641 MB. 8.23064%. Add
|
||||
6.06516 MB. 6.36216%. Div
|
||||
5.12912 MB. 5.38029%. FC
|
||||
95.3317 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
17.6246 MB. 26.7264%. Conv
|
||||
17.5992 MB. 26.6878%. SpatialBN
|
||||
9.25778 MB. 14.0387%. Relu
|
||||
8.43843 MB. 12.7962%. Mul
|
||||
6.95565 MB. 10.5477%. Add
|
||||
6.06502 MB. 9.19713%. Div
|
||||
0.004 MB. 0.00606568%. FC
|
||||
65.9447 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
16.6778 MB. 76.1564%. Conv
|
||||
5.124 MB. 23.3979%. FC
|
||||
0.0976 MB. 0.445674%. SpatialBN
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Div
|
||||
0 MB. 0%. Mul
|
||||
0 MB. 0%. Relu
|
||||
21.8994 MB in Total
|
||||
|
||||
```
|
||||
### Optimized
|
||||
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 22.0981. Iters per second: 45.2527
|
||||
Time per operator type:
|
||||
17.146 ms. 78.8965%. Conv
|
||||
1.38453 ms. 6.37084%. Add
|
||||
1.30991 ms. 6.02749%. Div
|
||||
0.685417 ms. 3.15391%. Mul
|
||||
0.532589 ms. 2.45068%. Relu
|
||||
0.418263 ms. 1.92461%. Clip
|
||||
0.15128 ms. 0.696106%. FC
|
||||
0.102065 ms. 0.469648%. AveragePool
|
||||
0.0022143 ms. 0.010189%. Squeeze
|
||||
21.7323 ms in Total
|
||||
FLOP per operator type:
|
||||
0.430616 GFLOP. 98.1927%. Conv
|
||||
0.002561 GFLOP. 0.583981%. FC
|
||||
0.00210961 GFLOP. 0.481051%. Mul
|
||||
0.00173891 GFLOP. 0.396522%. Add
|
||||
0.00151626 GFLOP. 0.34575%. Div
|
||||
0 GFLOP. 0%. Relu
|
||||
0.438542 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
34.7842 MB. 44.833%. Conv
|
||||
14.5035 MB. 18.6934%. Mul
|
||||
9.25778 MB. 11.9323%. Relu
|
||||
7.84641 MB. 10.1132%. Add
|
||||
6.06516 MB. 7.81733%. Div
|
||||
5.12912 MB. 6.61087%. FC
|
||||
77.5861 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
17.6246 MB. 36.4556%. Conv
|
||||
9.25778 MB. 19.1492%. Relu
|
||||
8.43843 MB. 17.4544%. Mul
|
||||
6.95565 MB. 14.3874%. Add
|
||||
6.06502 MB. 12.5452%. Div
|
||||
0.004 MB. 0.00827378%. FC
|
||||
48.3455 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
16.6778 MB. 76.4973%. Conv
|
||||
5.124 MB. 23.5027%. FC
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Div
|
||||
0 MB. 0%. Mul
|
||||
0 MB. 0%. Relu
|
||||
21.8018 MB in Total
|
||||
|
||||
```
|
||||
|
||||
## MnasNet-A1
|
||||
|
||||
### Unoptimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 30.0892. Iters per second: 33.2345
|
||||
Time per operator type:
|
||||
24.4656 ms. 79.0905%. Conv
|
||||
4.14958 ms. 13.4144%. SpatialBN
|
||||
1.60598 ms. 5.19169%. Relu
|
||||
0.295219 ms. 0.95436%. Mul
|
||||
0.187609 ms. 0.606486%. FC
|
||||
0.120556 ms. 0.389724%. AveragePool
|
||||
0.09036 ms. 0.292109%. Add
|
||||
0.015727 ms. 0.050841%. Sigmoid
|
||||
0.00306205 ms. 0.00989875%. Squeeze
|
||||
30.9337 ms in Total
|
||||
FLOP per operator type:
|
||||
0.620598 GFLOP. 95.6434%. Conv
|
||||
0.0248873 GFLOP. 3.8355%. SpatialBN
|
||||
0.002561 GFLOP. 0.394688%. FC
|
||||
0.000597408 GFLOP. 0.0920695%. Mul
|
||||
0.000222656 GFLOP. 0.0343146%. Add
|
||||
0 GFLOP. 0%. Relu
|
||||
0.648867 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
35.5457 MB. 38.4109%. Conv
|
||||
25.1552 MB. 27.1829%. SpatialBN
|
||||
22.5235 MB. 24.339%. Relu
|
||||
5.12912 MB. 5.54256%. FC
|
||||
2.40586 MB. 2.59978%. Mul
|
||||
1.78125 MB. 1.92483%. Add
|
||||
92.5406 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
24.9042 MB. 32.9424%. Conv
|
||||
24.8873 MB. 32.92%. SpatialBN
|
||||
22.5235 MB. 29.7932%. Relu
|
||||
2.38963 MB. 3.16092%. Mul
|
||||
0.890624 MB. 1.17809%. Add
|
||||
0.004 MB. 0.00529106%. FC
|
||||
75.5993 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
10.2732 MB. 66.1459%. Conv
|
||||
5.124 MB. 32.9917%. FC
|
||||
0.133952 MB. 0.86247%. SpatialBN
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Mul
|
||||
0 MB. 0%. Relu
|
||||
15.5312 MB in Total
|
||||
```
|
||||
|
||||
### Optimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 24.2367. Iters per second: 41.2597
|
||||
Time per operator type:
|
||||
22.0547 ms. 91.1375%. Conv
|
||||
1.49096 ms. 6.16116%. Relu
|
||||
0.253417 ms. 1.0472%. Mul
|
||||
0.18506 ms. 0.76473%. FC
|
||||
0.112942 ms. 0.466717%. AveragePool
|
||||
0.086769 ms. 0.358559%. Add
|
||||
0.0127889 ms. 0.0528479%. Sigmoid
|
||||
0.0027346 ms. 0.0113003%. Squeeze
|
||||
24.1994 ms in Total
|
||||
FLOP per operator type:
|
||||
0.620598 GFLOP. 99.4581%. Conv
|
||||
0.002561 GFLOP. 0.41043%. FC
|
||||
0.000597408 GFLOP. 0.0957417%. Mul
|
||||
0.000222656 GFLOP. 0.0356832%. Add
|
||||
0 GFLOP. 0%. Relu
|
||||
0.623979 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
35.6127 MB. 52.7968%. Conv
|
||||
22.5235 MB. 33.3917%. Relu
|
||||
5.12912 MB. 7.60406%. FC
|
||||
2.40586 MB. 3.56675%. Mul
|
||||
1.78125 MB. 2.64075%. Add
|
||||
67.4524 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
24.9042 MB. 49.1092%. Conv
|
||||
22.5235 MB. 44.4145%. Relu
|
||||
2.38963 MB. 4.71216%. Mul
|
||||
0.890624 MB. 1.75624%. Add
|
||||
0.004 MB. 0.00788768%. FC
|
||||
50.712 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
10.2732 MB. 66.7213%. Conv
|
||||
5.124 MB. 33.2787%. FC
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Mul
|
||||
0 MB. 0%. Relu
|
||||
15.3972 MB in Total
|
||||
```
|
||||
## MnasNet-B1
|
||||
|
||||
### Unoptimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 28.3109. Iters per second: 35.322
|
||||
Time per operator type:
|
||||
29.1121 ms. 83.3081%. Conv
|
||||
4.14959 ms. 11.8746%. SpatialBN
|
||||
1.35823 ms. 3.88675%. Relu
|
||||
0.186188 ms. 0.532802%. FC
|
||||
0.116244 ms. 0.332647%. Add
|
||||
0.018641 ms. 0.0533437%. AveragePool
|
||||
0.0040904 ms. 0.0117052%. Squeeze
|
||||
34.9451 ms in Total
|
||||
FLOP per operator type:
|
||||
0.626272 GFLOP. 96.2088%. Conv
|
||||
0.0218266 GFLOP. 3.35303%. SpatialBN
|
||||
0.002561 GFLOP. 0.393424%. FC
|
||||
0.000291648 GFLOP. 0.0448034%. Add
|
||||
0 GFLOP. 0%. Relu
|
||||
0.650951 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
34.4354 MB. 41.3788%. Conv
|
||||
22.1299 MB. 26.5921%. SpatialBN
|
||||
19.1923 MB. 23.0622%. Relu
|
||||
5.12912 MB. 6.16333%. FC
|
||||
2.33318 MB. 2.80364%. Add
|
||||
83.2199 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
21.8266 MB. 34.0955%. Conv
|
||||
21.8266 MB. 34.0955%. SpatialBN
|
||||
19.1923 MB. 29.9805%. Relu
|
||||
1.16659 MB. 1.82234%. Add
|
||||
0.004 MB. 0.00624844%. FC
|
||||
64.016 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
12.2576 MB. 69.9104%. Conv
|
||||
5.124 MB. 29.2245%. FC
|
||||
0.15168 MB. 0.865099%. SpatialBN
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Relu
|
||||
17.5332 MB in Total
|
||||
```
|
||||
|
||||
### Optimized
|
||||
```
|
||||
Main run finished. Milliseconds per iter: 26.6364. Iters per second: 37.5426
|
||||
Time per operator type:
|
||||
24.9888 ms. 94.0962%. Conv
|
||||
1.26147 ms. 4.75011%. Relu
|
||||
0.176234 ms. 0.663619%. FC
|
||||
0.113309 ms. 0.426672%. Add
|
||||
0.0138708 ms. 0.0522311%. AveragePool
|
||||
0.00295685 ms. 0.0111341%. Squeeze
|
||||
26.5566 ms in Total
|
||||
FLOP per operator type:
|
||||
0.626272 GFLOP. 99.5466%. Conv
|
||||
0.002561 GFLOP. 0.407074%. FC
|
||||
0.000291648 GFLOP. 0.0463578%. Add
|
||||
0 GFLOP. 0%. Relu
|
||||
0.629124 GFLOP in Total
|
||||
Feature Memory Read per operator type:
|
||||
34.5112 MB. 56.4224%. Conv
|
||||
19.1923 MB. 31.3775%. Relu
|
||||
5.12912 MB. 8.3856%. FC
|
||||
2.33318 MB. 3.81452%. Add
|
||||
61.1658 MB in Total
|
||||
Feature Memory Written per operator type:
|
||||
21.8266 MB. 51.7346%. Conv
|
||||
19.1923 MB. 45.4908%. Relu
|
||||
1.16659 MB. 2.76513%. Add
|
||||
0.004 MB. 0.00948104%. FC
|
||||
42.1895 MB in Total
|
||||
Parameter Memory per operator type:
|
||||
12.2576 MB. 70.5205%. Conv
|
||||
5.124 MB. 29.4795%. FC
|
||||
0 MB. 0%. Add
|
||||
0 MB. 0%. Relu
|
||||
17.3816 MB in Total
|
||||
```
|
||||
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@@ -0,0 +1,323 @@
|
||||
# (Generic) EfficientNets for PyTorch
|
||||
|
||||
A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter efficient architectures derived from the MobileNet V1/V2 block sequence, including those found via automated neural architecture search.
|
||||
|
||||
All models are implemented by GenEfficientNet or MobileNetV3 classes, with string based architecture definitions to configure the block layouts (idea from [here](https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py))
|
||||
|
||||
## What's New
|
||||
|
||||
### Aug 19, 2020
|
||||
* Add updated PyTorch trained EfficientNet-B3 weights trained by myself with `timm` (82.1 top-1)
|
||||
* Add PyTorch trained EfficientNet-Lite0 contributed by [@hal-314](https://github.com/hal-314) (75.5 top-1)
|
||||
* Update ONNX and Caffe2 export / utility scripts to work with latest PyTorch / ONNX
|
||||
* ONNX runtime based validation script added
|
||||
* activations (mostly) brought in sync with `timm` equivalents
|
||||
|
||||
|
||||
### April 5, 2020
|
||||
* Add some newly trained MobileNet-V2 models trained with latest h-params, rand augment. They compare quite favourably to EfficientNet-Lite
|
||||
* 3.5M param MobileNet-V2 100 @ 73%
|
||||
* 4.5M param MobileNet-V2 110d @ 75%
|
||||
* 6.1M param MobileNet-V2 140 @ 76.5%
|
||||
* 5.8M param MobileNet-V2 120d @ 77.3%
|
||||
|
||||
### March 23, 2020
|
||||
* Add EfficientNet-Lite models w/ weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite)
|
||||
* Add PyTorch trained MobileNet-V3 Large weights with 75.77% top-1
|
||||
* IMPORTANT CHANGE (if training from scratch) - weight init changed to better match Tensorflow impl, set `fix_group_fanout=False` in `initialize_weight_goog` for old behavior
|
||||
|
||||
### Feb 12, 2020
|
||||
* Add EfficientNet-L2 and B0-B7 NoisyStudent weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet)
|
||||
* Port new EfficientNet-B8 (RandAugment) weights from TF TPU, these are different than the B8 AdvProp, different input normalization.
|
||||
* Add RandAugment PyTorch trained EfficientNet-ES (EdgeTPU-Small) weights with 78.1 top-1. Trained by [Andrew Lavin](https://github.com/andravin)
|
||||
|
||||
### Jan 22, 2020
|
||||
* Update weights for EfficientNet B0, B2, B3 and MixNet-XL with latest RandAugment trained weights. Trained with (https://github.com/rwightman/pytorch-image-models)
|
||||
* Fix torchscript compatibility for PyTorch 1.4, add torchscript support for MixedConv2d using ModuleDict
|
||||
* Test models, torchscript, onnx export with PyTorch 1.4 -- no issues
|
||||
|
||||
### Nov 22, 2019
|
||||
* New top-1 high! Ported official TF EfficientNet AdvProp (https://arxiv.org/abs/1911.09665) weights and B8 model spec. Created a new set of `ap` models since they use a different
|
||||
preprocessing (Inception mean/std) from the original EfficientNet base/AA/RA weights.
|
||||
|
||||
### Nov 15, 2019
|
||||
* Ported official TF MobileNet-V3 float32 large/small/minimalistic weights
|
||||
* Modifications to MobileNet-V3 model and components to support some additional config needed for differences between TF MobileNet-V3 and mine
|
||||
|
||||
### Oct 30, 2019
|
||||
* Many of the models will now work with torch.jit.script, MixNet being the biggest exception
|
||||
* Improved interface for enabling torchscript or ONNX export compatible modes (via config)
|
||||
* Add JIT optimized mem-efficient Swish/Mish autograd.fn in addition to memory-efficient autgrad.fn
|
||||
* Activation factory to select best version of activation by name or override one globally
|
||||
* Add pretrained checkpoint load helper that handles input conv and classifier changes
|
||||
|
||||
### Oct 27, 2019
|
||||
* Add CondConv EfficientNet variants ported from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv
|
||||
* Add RandAug weights for TF EfficientNet B5 and B7 from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
|
||||
* Bring over MixNet-XL model and depth scaling algo from my pytorch-image-models code base
|
||||
* Switch activations and global pooling to modules
|
||||
* Add memory-efficient Swish/Mish impl
|
||||
* Add as_sequential() method to all models and allow as an argument in entrypoint fns
|
||||
* Move MobileNetV3 into own file since it has a different head
|
||||
* Remove ChamNet, MobileNet V2/V1 since they will likely never be used here
|
||||
|
||||
## Models
|
||||
|
||||
Implemented models include:
|
||||
* EfficientNet NoisyStudent (B0-B7, L2) (https://arxiv.org/abs/1911.04252)
|
||||
* EfficientNet AdvProp (B0-B8) (https://arxiv.org/abs/1911.09665)
|
||||
* EfficientNet (B0-B8) (https://arxiv.org/abs/1905.11946)
|
||||
* EfficientNet-EdgeTPU (S, M, L) (https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html)
|
||||
* EfficientNet-CondConv (https://arxiv.org/abs/1904.04971)
|
||||
* EfficientNet-Lite (https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite)
|
||||
* MixNet (https://arxiv.org/abs/1907.09595)
|
||||
* MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626)
|
||||
* MobileNet-V3 (https://arxiv.org/abs/1905.02244)
|
||||
* FBNet-C (https://arxiv.org/abs/1812.03443)
|
||||
* Single-Path NAS (https://arxiv.org/abs/1904.02877)
|
||||
|
||||
I originally implemented and trained some these models with code [here](https://github.com/rwightman/pytorch-image-models), this repository contains just the GenEfficientNet models, validation, and associated ONNX/Caffe2 export code.
|
||||
|
||||
## Pretrained
|
||||
|
||||
I've managed to train several of the models to accuracies close to or above the originating papers and official impl. My training code is here: https://github.com/rwightman/pytorch-image-models
|
||||
|
||||
|
||||
|Model | Prec@1 (Err) | Prec@5 (Err) | Param#(M) | MAdds(M) | Image Scaling | Resolution | Crop |
|
||||
|---|---|---|---|---|---|---|---|
|
||||
| efficientnet_b3 | 82.240 (17.760) | 96.116 (3.884) | 12.23 | TBD | bicubic | 320 | 1.0 |
|
||||
| efficientnet_b3 | 82.076 (17.924) | 96.020 (3.980) | 12.23 | TBD | bicubic | 300 | 0.904 |
|
||||
| mixnet_xl | 81.074 (18.926) | 95.282 (4.718) | 11.90 | TBD | bicubic | 256 | 1.0 |
|
||||
| efficientnet_b2 | 80.612 (19.388) | 95.318 (4.682) | 9.1 | TBD | bicubic | 288 | 1.0 |
|
||||
| mixnet_xl | 80.476 (19.524) | 94.936 (5.064) | 11.90 | TBD | bicubic | 224 | 0.875 |
|
||||
| efficientnet_b2 | 80.288 (19.712) | 95.166 (4.834) | 9.1 | 1003 | bicubic | 260 | 0.890 |
|
||||
| mixnet_l | 78.976 (21.024 | 94.184 (5.816) | 7.33 | TBD | bicubic | 224 | 0.875 |
|
||||
| efficientnet_b1 | 78.692 (21.308) | 94.086 (5.914) | 7.8 | 694 | bicubic | 240 | 0.882 |
|
||||
| efficientnet_es | 78.066 (21.934) | 93.926 (6.074) | 5.44 | TBD | bicubic | 224 | 0.875 |
|
||||
| efficientnet_b0 | 77.698 (22.302) | 93.532 (6.468) | 5.3 | 390 | bicubic | 224 | 0.875 |
|
||||
| mobilenetv2_120d | 77.294 (22.706 | 93.502 (6.498) | 5.8 | TBD | bicubic | 224 | 0.875 |
|
||||
| mixnet_m | 77.256 (22.744) | 93.418 (6.582) | 5.01 | 353 | bicubic | 224 | 0.875 |
|
||||
| mobilenetv2_140 | 76.524 (23.476) | 92.990 (7.010) | 6.1 | TBD | bicubic | 224 | 0.875 |
|
||||
| mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13 | TBD | bicubic | 224 | 0.875 |
|
||||
| mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5 | TBD | bicubic | 224 | 0.875 |
|
||||
| mobilenetv3_rw | 75.634 (24.366) | 92.708 (7.292) | 5.5 | 219 | bicubic | 224 | 0.875 |
|
||||
| efficientnet_lite0 | 75.472 (24.528) | 92.520 (7.480) | 4.65 | TBD | bicubic | 224 | 0.875 |
|
||||
| mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.9 | 312 | bicubic | 224 | 0.875 |
|
||||
| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6 | 385 | bilinear | 224 | 0.875 |
|
||||
| mobilenetv2_110d | 75.052 (24.948) | 92.180 (7.820) | 4.5 | TBD | bicubic | 224 | 0.875 |
|
||||
| mnasnet_b1 | 74.658 (25.342) | 92.114 (7.886) | 4.4 | 315 | bicubic | 224 | 0.875 |
|
||||
| spnasnet_100 | 74.084 (25.916) | 91.818 (8.182) | 4.4 | TBD | bilinear | 224 | 0.875 |
|
||||
| mobilenetv2_100 | 72.978 (27.022) | 91.016 (8.984) | 3.5 | TBD | bicubic | 224 | 0.875 |
|
||||
|
||||
|
||||
More pretrained models to come...
|
||||
|
||||
|
||||
## Ported Weights
|
||||
|
||||
The weights ported from Tensorflow checkpoints for the EfficientNet models do pretty much match accuracy in Tensorflow once a SAME convolution padding equivalent is added, and the same crop factors, image scaling, etc (see table) are used via cmd line args.
|
||||
|
||||
**IMPORTANT:**
|
||||
* Tensorflow ported weights for EfficientNet AdvProp (AP), EfficientNet EdgeTPU, EfficientNet-CondConv, EfficientNet-Lite, and MobileNet-V3 models use Inception style (0.5, 0.5, 0.5) for mean and std.
|
||||
* Enabling the Tensorflow preprocessing pipeline with `--tf-preprocessing` at validation time will improve scores by 0.1-0.5%, very close to original TF impl.
|
||||
|
||||
To run validation for tf_efficientnet_b5:
|
||||
`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --crop-pct 0.934 --interpolation bicubic`
|
||||
|
||||
To run validation w/ TF preprocessing for tf_efficientnet_b5:
|
||||
`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --tf-preprocessing`
|
||||
|
||||
To run validation for a model with Inception preprocessing, ie EfficientNet-B8 AdvProp:
|
||||
`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b8_ap -b 48 --num-gpu 2 --img-size 672 --crop-pct 0.954 --mean 0.5 --std 0.5`
|
||||
|
||||
|Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Image Size | Crop |
|
||||
|---|---|---|---|---|---|---|
|
||||
| tf_efficientnet_l2_ns *tfp | 88.352 (11.648) | 98.652 (1.348) | 480 | bicubic | 800 | N/A |
|
||||
| tf_efficientnet_l2_ns | TBD | TBD | 480 | bicubic | 800 | 0.961 |
|
||||
| tf_efficientnet_l2_ns_475 | 88.234 (11.766) | 98.546 (1.454) | 480 | bicubic | 475 | 0.936 |
|
||||
| tf_efficientnet_l2_ns_475 *tfp | 88.172 (11.828) | 98.566 (1.434) | 480 | bicubic | 475 | N/A |
|
||||
| tf_efficientnet_b7_ns *tfp | 86.844 (13.156) | 98.084 (1.916) | 66.35 | bicubic | 600 | N/A |
|
||||
| tf_efficientnet_b7_ns | 86.840 (13.160) | 98.094 (1.906) | 66.35 | bicubic | 600 | N/A |
|
||||
| tf_efficientnet_b6_ns | 86.452 (13.548) | 97.882 (2.118) | 43.04 | bicubic | 528 | N/A |
|
||||
| tf_efficientnet_b6_ns *tfp | 86.444 (13.556) | 97.880 (2.120) | 43.04 | bicubic | 528 | N/A |
|
||||
| tf_efficientnet_b5_ns *tfp | 86.064 (13.936) | 97.746 (2.254) | 30.39 | bicubic | 456 | N/A |
|
||||
| tf_efficientnet_b5_ns | 86.088 (13.912) | 97.752 (2.248) | 30.39 | bicubic | 456 | N/A |
|
||||
| tf_efficientnet_b8_ap *tfp | 85.436 (14.564) | 97.272 (2.728) | 87.4 | bicubic | 672 | N/A |
|
||||
| tf_efficientnet_b8 *tfp | 85.384 (14.616) | 97.394 (2.606) | 87.4 | bicubic | 672 | N/A |
|
||||
| tf_efficientnet_b8 | 85.370 (14.630) | 97.390 (2.610) | 87.4 | bicubic | 672 | 0.954 |
|
||||
| tf_efficientnet_b8_ap | 85.368 (14.632) | 97.294 (2.706) | 87.4 | bicubic | 672 | 0.954 |
|
||||
| tf_efficientnet_b4_ns *tfp | 85.298 (14.702) | 97.504 (2.496) | 19.34 | bicubic | 380 | N/A |
|
||||
| tf_efficientnet_b4_ns | 85.162 (14.838) | 97.470 (2.530) | 19.34 | bicubic | 380 | 0.922 |
|
||||
| tf_efficientnet_b7_ap *tfp | 85.154 (14.846) | 97.244 (2.756) | 66.35 | bicubic | 600 | N/A |
|
||||
| tf_efficientnet_b7_ap | 85.118 (14.882) | 97.252 (2.748) | 66.35 | bicubic | 600 | 0.949 |
|
||||
| tf_efficientnet_b7 *tfp | 84.940 (15.060) | 97.214 (2.786) | 66.35 | bicubic | 600 | N/A |
|
||||
| tf_efficientnet_b7 | 84.932 (15.068) | 97.208 (2.792) | 66.35 | bicubic | 600 | 0.949 |
|
||||
| tf_efficientnet_b6_ap | 84.786 (15.214) | 97.138 (2.862) | 43.04 | bicubic | 528 | 0.942 |
|
||||
| tf_efficientnet_b6_ap *tfp | 84.760 (15.240) | 97.124 (2.876) | 43.04 | bicubic | 528 | N/A |
|
||||
| tf_efficientnet_b5_ap *tfp | 84.276 (15.724) | 96.932 (3.068) | 30.39 | bicubic | 456 | N/A |
|
||||
| tf_efficientnet_b5_ap | 84.254 (15.746) | 96.976 (3.024) | 30.39 | bicubic | 456 | 0.934 |
|
||||
| tf_efficientnet_b6 *tfp | 84.140 (15.860) | 96.852 (3.148) | 43.04 | bicubic | 528 | N/A |
|
||||
| tf_efficientnet_b6 | 84.110 (15.890) | 96.886 (3.114) | 43.04 | bicubic | 528 | 0.942 |
|
||||
| tf_efficientnet_b3_ns *tfp | 84.054 (15.946) | 96.918 (3.082) | 12.23 | bicubic | 300 | N/A |
|
||||
| tf_efficientnet_b3_ns | 84.048 (15.952) | 96.910 (3.090) | 12.23 | bicubic | 300 | .904 |
|
||||
| tf_efficientnet_b5 *tfp | 83.822 (16.178) | 96.756 (3.244) | 30.39 | bicubic | 456 | N/A |
|
||||
| tf_efficientnet_b5 | 83.812 (16.188) | 96.748 (3.252) | 30.39 | bicubic | 456 | 0.934 |
|
||||
| tf_efficientnet_b4_ap *tfp | 83.278 (16.722) | 96.376 (3.624) | 19.34 | bicubic | 380 | N/A |
|
||||
| tf_efficientnet_b4_ap | 83.248 (16.752) | 96.388 (3.612) | 19.34 | bicubic | 380 | 0.922 |
|
||||
| tf_efficientnet_b4 | 83.022 (16.978) | 96.300 (3.700) | 19.34 | bicubic | 380 | 0.922 |
|
||||
| tf_efficientnet_b4 *tfp | 82.948 (17.052) | 96.308 (3.692) | 19.34 | bicubic | 380 | N/A |
|
||||
| tf_efficientnet_b2_ns *tfp | 82.436 (17.564) | 96.268 (3.732) | 9.11 | bicubic | 260 | N/A |
|
||||
| tf_efficientnet_b2_ns | 82.380 (17.620) | 96.248 (3.752) | 9.11 | bicubic | 260 | 0.89 |
|
||||
| tf_efficientnet_b3_ap *tfp | 81.882 (18.118) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A |
|
||||
| tf_efficientnet_b3_ap | 81.828 (18.172) | 95.624 (4.376) | 12.23 | bicubic | 300 | 0.904 |
|
||||
| tf_efficientnet_b3 | 81.636 (18.364) | 95.718 (4.282) | 12.23 | bicubic | 300 | 0.904 |
|
||||
| tf_efficientnet_b3 *tfp | 81.576 (18.424) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A |
|
||||
| tf_efficientnet_lite4 | 81.528 (18.472) | 95.668 (4.332) | 13.00 | bilinear | 380 | 0.92 |
|
||||
| tf_efficientnet_b1_ns *tfp | 81.514 (18.486) | 95.776 (4.224) | 7.79 | bicubic | 240 | N/A |
|
||||
| tf_efficientnet_lite4 *tfp | 81.502 (18.498) | 95.676 (4.324) | 13.00 | bilinear | 380 | N/A |
|
||||
| tf_efficientnet_b1_ns | 81.388 (18.612) | 95.738 (4.262) | 7.79 | bicubic | 240 | 0.88 |
|
||||
| tf_efficientnet_el | 80.534 (19.466) | 95.190 (4.810) | 10.59 | bicubic | 300 | 0.904 |
|
||||
| tf_efficientnet_el *tfp | 80.476 (19.524) | 95.200 (4.800) | 10.59 | bicubic | 300 | N/A |
|
||||
| tf_efficientnet_b2_ap *tfp | 80.420 (19.580) | 95.040 (4.960) | 9.11 | bicubic | 260 | N/A |
|
||||
| tf_efficientnet_b2_ap | 80.306 (19.694) | 95.028 (4.972) | 9.11 | bicubic | 260 | 0.890 |
|
||||
| tf_efficientnet_b2 *tfp | 80.188 (19.812) | 94.974 (5.026) | 9.11 | bicubic | 260 | N/A |
|
||||
| tf_efficientnet_b2 | 80.086 (19.914) | 94.908 (5.092) | 9.11 | bicubic | 260 | 0.890 |
|
||||
| tf_efficientnet_lite3 | 79.812 (20.188) | 94.914 (5.086) | 8.20 | bilinear | 300 | 0.904 |
|
||||
| tf_efficientnet_lite3 *tfp | 79.734 (20.266) | 94.838 (5.162) | 8.20 | bilinear | 300 | N/A |
|
||||
| tf_efficientnet_b1_ap *tfp | 79.532 (20.468) | 94.378 (5.622) | 7.79 | bicubic | 240 | N/A |
|
||||
| tf_efficientnet_cc_b1_8e *tfp | 79.464 (20.536)| 94.492 (5.508) | 39.7 | bicubic | 240 | 0.88 |
|
||||
| tf_efficientnet_cc_b1_8e | 79.298 (20.702) | 94.364 (5.636) | 39.7 | bicubic | 240 | 0.88 |
|
||||
| tf_efficientnet_b1_ap | 79.278 (20.722) | 94.308 (5.692) | 7.79 | bicubic | 240 | 0.88 |
|
||||
| tf_efficientnet_b1 *tfp | 79.172 (20.828) | 94.450 (5.550) | 7.79 | bicubic | 240 | N/A |
|
||||
| tf_efficientnet_em *tfp | 78.958 (21.042) | 94.458 (5.542) | 6.90 | bicubic | 240 | N/A |
|
||||
| tf_efficientnet_b0_ns *tfp | 78.806 (21.194) | 94.496 (5.504) | 5.29 | bicubic | 224 | N/A |
|
||||
| tf_mixnet_l *tfp | 78.846 (21.154) | 94.212 (5.788) | 7.33 | bilinear | 224 | N/A |
|
||||
| tf_efficientnet_b1 | 78.826 (21.174) | 94.198 (5.802) | 7.79 | bicubic | 240 | 0.88 |
|
||||
| tf_mixnet_l | 78.770 (21.230) | 94.004 (5.996) | 7.33 | bicubic | 224 | 0.875 |
|
||||
| tf_efficientnet_em | 78.742 (21.258) | 94.332 (5.668) | 6.90 | bicubic | 240 | 0.875 |
|
||||
| tf_efficientnet_b0_ns | 78.658 (21.342) | 94.376 (5.624) | 5.29 | bicubic | 224 | 0.875 |
|
||||
| tf_efficientnet_cc_b0_8e *tfp | 78.314 (21.686) | 93.790 (6.210) | 24.0 | bicubic | 224 | 0.875 |
|
||||
| tf_efficientnet_cc_b0_8e | 77.908 (22.092) | 93.656 (6.344) | 24.0 | bicubic | 224 | 0.875 |
|
||||
| tf_efficientnet_cc_b0_4e *tfp | 77.746 (22.254) | 93.552 (6.448) | 13.3 | bicubic | 224 | 0.875 |
|
||||
| tf_efficientnet_cc_b0_4e | 77.304 (22.696) | 93.332 (6.668) | 13.3 | bicubic | 224 | 0.875 |
|
||||
| tf_efficientnet_es *tfp | 77.616 (22.384) | 93.750 (6.250) | 5.44 | bicubic | 224 | N/A |
|
||||
| tf_efficientnet_lite2 *tfp | 77.544 (22.456) | 93.800 (6.200) | 6.09 | bilinear | 260 | N/A |
|
||||
| tf_efficientnet_lite2 | 77.460 (22.540) | 93.746 (6.254) | 6.09 | bicubic | 260 | 0.89 |
|
||||
| tf_efficientnet_b0_ap *tfp | 77.514 (22.486) | 93.576 (6.424) | 5.29 | bicubic | 224 | N/A |
|
||||
| tf_efficientnet_es | 77.264 (22.736) | 93.600 (6.400) | 5.44 | bicubic | 224 | N/A |
|
||||
| tf_efficientnet_b0 *tfp | 77.258 (22.742) | 93.478 (6.522) | 5.29 | bicubic | 224 | N/A |
|
||||
| tf_efficientnet_b0_ap | 77.084 (22.916) | 93.254 (6.746) | 5.29 | bicubic | 224 | 0.875 |
|
||||
| tf_mixnet_m *tfp | 77.072 (22.928) | 93.368 (6.632) | 5.01 | bilinear | 224 | N/A |
|
||||
| tf_mixnet_m | 76.950 (23.050) | 93.156 (6.844) | 5.01 | bicubic | 224 | 0.875 |
|
||||
| tf_efficientnet_b0 | 76.848 (23.152) | 93.228 (6.772) | 5.29 | bicubic | 224 | 0.875 |
|
||||
| tf_efficientnet_lite1 *tfp | 76.764 (23.236) | 93.326 (6.674) | 5.42 | bilinear | 240 | N/A |
|
||||
| tf_efficientnet_lite1 | 76.638 (23.362) | 93.232 (6.768) | 5.42 | bicubic | 240 | 0.882 |
|
||||
| tf_mixnet_s *tfp | 75.800 (24.200) | 92.788 (7.212) | 4.13 | bilinear | 224 | N/A |
|
||||
| tf_mobilenetv3_large_100 *tfp | 75.768 (24.232) | 92.710 (7.290) | 5.48 | bilinear | 224 | N/A |
|
||||
| tf_mixnet_s | 75.648 (24.352) | 92.636 (7.364) | 4.13 | bicubic | 224 | 0.875 |
|
||||
| tf_mobilenetv3_large_100 | 75.516 (24.484) | 92.600 (7.400) | 5.48 | bilinear | 224 | 0.875 |
|
||||
| tf_efficientnet_lite0 *tfp | 75.074 (24.926) | 92.314 (7.686) | 4.65 | bilinear | 224 | N/A |
|
||||
| tf_efficientnet_lite0 | 74.842 (25.158) | 92.170 (7.830) | 4.65 | bicubic | 224 | 0.875 |
|
||||
| tf_mobilenetv3_large_075 *tfp | 73.730 (26.270) | 91.616 (8.384) | 3.99 | bilinear | 224 |N/A |
|
||||
| tf_mobilenetv3_large_075 | 73.442 (26.558) | 91.352 (8.648) | 3.99 | bilinear | 224 | 0.875 |
|
||||
| tf_mobilenetv3_large_minimal_100 *tfp | 72.678 (27.322) | 90.860 (9.140) | 3.92 | bilinear | 224 | N/A |
|
||||
| tf_mobilenetv3_large_minimal_100 | 72.244 (27.756) | 90.636 (9.364) | 3.92 | bilinear | 224 | 0.875 |
|
||||
| tf_mobilenetv3_small_100 *tfp | 67.918 (32.082) | 87.958 (12.042 | 2.54 | bilinear | 224 | N/A |
|
||||
| tf_mobilenetv3_small_100 | 67.918 (32.082) | 87.662 (12.338) | 2.54 | bilinear | 224 | 0.875 |
|
||||
| tf_mobilenetv3_small_075 *tfp | 66.142 (33.858) | 86.498 (13.502) | 2.04 | bilinear | 224 | N/A |
|
||||
| tf_mobilenetv3_small_075 | 65.718 (34.282) | 86.136 (13.864) | 2.04 | bilinear | 224 | 0.875 |
|
||||
| tf_mobilenetv3_small_minimal_100 *tfp | 63.378 (36.622) | 84.802 (15.198) | 2.04 | bilinear | 224 | N/A |
|
||||
| tf_mobilenetv3_small_minimal_100 | 62.898 (37.102) | 84.230 (15.770) | 2.04 | bilinear | 224 | 0.875 |
|
||||
|
||||
|
||||
*tfp models validated with `tf-preprocessing` pipeline
|
||||
|
||||
Google tf and tflite weights ported from official Tensorflow repositories
|
||||
* https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
|
||||
* https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
|
||||
* https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet
|
||||
|
||||
## Usage
|
||||
|
||||
### Environment
|
||||
|
||||
All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically Python 3.6.x, 3.7.x, 3.8.x.
|
||||
|
||||
Users have reported that a Python 3 Anaconda install in Windows works. I have not verified this myself.
|
||||
|
||||
PyTorch versions 1.4, 1.5, 1.6 have been tested with this code.
|
||||
|
||||
I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda:
|
||||
```
|
||||
conda create -n torch-env
|
||||
conda activate torch-env
|
||||
conda install -c pytorch pytorch torchvision cudatoolkit=10.2
|
||||
```
|
||||
|
||||
### PyTorch Hub
|
||||
|
||||
Models can be accessed via the PyTorch Hub API
|
||||
|
||||
```
|
||||
>>> torch.hub.list('rwightman/gen-efficientnet-pytorch')
|
||||
['efficientnet_b0', ...]
|
||||
>>> model = torch.hub.load('rwightman/gen-efficientnet-pytorch', 'efficientnet_b0', pretrained=True)
|
||||
>>> model.eval()
|
||||
>>> output = model(torch.randn(1,3,224,224))
|
||||
```
|
||||
|
||||
### Pip
|
||||
This package can be installed via pip.
|
||||
|
||||
Install (after conda env/install):
|
||||
```
|
||||
pip install geffnet
|
||||
```
|
||||
|
||||
Eval use:
|
||||
```
|
||||
>>> import geffnet
|
||||
>>> m = geffnet.create_model('mobilenetv3_large_100', pretrained=True)
|
||||
>>> m.eval()
|
||||
```
|
||||
|
||||
Train use:
|
||||
```
|
||||
>>> import geffnet
|
||||
>>> # models can also be created by using the entrypoint directly
|
||||
>>> m = geffnet.efficientnet_b2(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2)
|
||||
>>> m.train()
|
||||
```
|
||||
|
||||
Create in a nn.Sequential container, for fast.ai, etc:
|
||||
```
|
||||
>>> import geffnet
|
||||
>>> m = geffnet.mixnet_l(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2, as_sequential=True)
|
||||
```
|
||||
|
||||
### Exporting
|
||||
|
||||
Scripts are included to
|
||||
* export models to ONNX (`onnx_export.py`)
|
||||
* optimized ONNX graph (`onnx_optimize.py` or `onnx_validate.py` w/ `--onnx-output-opt` arg)
|
||||
* validate with ONNX runtime (`onnx_validate.py`)
|
||||
* convert ONNX model to Caffe2 (`onnx_to_caffe.py`)
|
||||
* validate in Caffe2 (`caffe2_validate.py`)
|
||||
* benchmark in Caffe2 w/ FLOPs, parameters output (`caffe2_benchmark.py`)
|
||||
|
||||
As an example, to export the MobileNet-V3 pretrained model and then run an Imagenet validation:
|
||||
```
|
||||
python onnx_export.py --model mobilenetv3_large_100 ./mobilenetv3_100.onnx
|
||||
python onnx_validate.py /imagenet/validation/ --onnx-input ./mobilenetv3_100.onnx
|
||||
```
|
||||
|
||||
These scripts were tested to be working as of PyTorch 1.6 and ONNX 1.7 w/ ONNX runtime 1.4. Caffe2 compatible
|
||||
export now requires additional args mentioned in the export script (not needed in earlier versions).
|
||||
|
||||
#### Export Notes
|
||||
1. The TF ported weights with the 'SAME' conv padding activated cannot be exported to ONNX unless `_EXPORTABLE` flag in `config.py` is set to True. Use `config.set_exportable(True)` as in the `onnx_export.py` script.
|
||||
2. TF ported models with 'SAME' padding will have the padding fixed at export time to the resolution used for export. Even though dynamic padding is supported in opset >= 11, I can't get it working.
|
||||
3. ONNX optimize facility doesn't work reliably in PyTorch 1.6 / ONNX 1.7. Fortunately, the onnxruntime based inference is working very well now and includes on the fly optimization.
|
||||
3. ONNX / Caffe2 export/import frequently breaks with different PyTorch and ONNX version releases. Please check their respective issue trackers before filing issues here.
|
||||
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
""" Caffe2 validation script
|
||||
|
||||
This script runs Caffe2 benchmark on exported ONNX model.
|
||||
It is a useful tool for reporting model FLOPS.
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import argparse
|
||||
from caffe2.python import core, workspace, model_helper
|
||||
from caffe2.proto import caffe2_pb2
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Caffe2 Model Benchmark')
|
||||
parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME',
|
||||
help='caffe2 model pb name prefix')
|
||||
parser.add_argument('--c2-init', default='', type=str, metavar='PATH',
|
||||
help='caffe2 model init .pb')
|
||||
parser.add_argument('--c2-predict', default='', type=str, metavar='PATH',
|
||||
help='caffe2 model predict .pb')
|
||||
parser.add_argument('-b', '--batch-size', default=1, type=int,
|
||||
metavar='N', help='mini-batch size (default: 1)')
|
||||
parser.add_argument('--img-size', default=224, type=int,
|
||||
metavar='N', help='Input image dimension, uses model default if empty')
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
args.gpu_id = 0
|
||||
if args.c2_prefix:
|
||||
args.c2_init = args.c2_prefix + '.init.pb'
|
||||
args.c2_predict = args.c2_prefix + '.predict.pb'
|
||||
|
||||
model = model_helper.ModelHelper(name="le_net", init_params=False)
|
||||
|
||||
# Bring in the init net from init_net.pb
|
||||
init_net_proto = caffe2_pb2.NetDef()
|
||||
with open(args.c2_init, "rb") as f:
|
||||
init_net_proto.ParseFromString(f.read())
|
||||
model.param_init_net = core.Net(init_net_proto)
|
||||
|
||||
# bring in the predict net from predict_net.pb
|
||||
predict_net_proto = caffe2_pb2.NetDef()
|
||||
with open(args.c2_predict, "rb") as f:
|
||||
predict_net_proto.ParseFromString(f.read())
|
||||
model.net = core.Net(predict_net_proto)
|
||||
|
||||
# CUDA performance not impressive
|
||||
#device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id)
|
||||
#model.net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)
|
||||
#model.param_init_net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)
|
||||
|
||||
input_blob = model.net.external_inputs[0]
|
||||
model.param_init_net.GaussianFill(
|
||||
[],
|
||||
input_blob.GetUnscopedName(),
|
||||
shape=(args.batch_size, 3, args.img_size, args.img_size),
|
||||
mean=0.0,
|
||||
std=1.0)
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.CreateNet(model.net, overwrite=True)
|
||||
workspace.BenchmarkNet(model.net.Proto().name, 5, 20, True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,138 @@
|
||||
""" Caffe2 validation script
|
||||
|
||||
This script is created to verify exported ONNX models running in Caffe2
|
||||
It utilizes the same PyTorch dataloader/processing pipeline for a
|
||||
fair comparison against the originals.
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
from caffe2.python import core, workspace, model_helper
|
||||
from caffe2.proto import caffe2_pb2
|
||||
from data import create_loader, resolve_data_config, Dataset
|
||||
from utils import AverageMeter
|
||||
import time
|
||||
|
||||
parser = argparse.ArgumentParser(description='Caffe2 ImageNet Validation')
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME',
|
||||
help='caffe2 model pb name prefix')
|
||||
parser.add_argument('--c2-init', default='', type=str, metavar='PATH',
|
||||
help='caffe2 model init .pb')
|
||||
parser.add_argument('--c2-predict', default='', type=str, metavar='PATH',
|
||||
help='caffe2 model predict .pb')
|
||||
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||
metavar='N', help='mini-batch size (default: 256)')
|
||||
parser.add_argument('--img-size', default=None, type=int,
|
||||
metavar='N', help='Input image dimension, uses model default if empty')
|
||||
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||
help='Override mean pixel value of dataset')
|
||||
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
||||
help='Override std deviation of of dataset')
|
||||
parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT',
|
||||
help='Override default crop pct of 0.875')
|
||||
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
||||
help='Image resize interpolation type (overrides model)')
|
||||
parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
|
||||
help='use tensorflow mnasnet preporcessing')
|
||||
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
||||
metavar='N', help='print frequency (default: 10)')
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
args.gpu_id = 0
|
||||
if args.c2_prefix:
|
||||
args.c2_init = args.c2_prefix + '.init.pb'
|
||||
args.c2_predict = args.c2_prefix + '.predict.pb'
|
||||
|
||||
model = model_helper.ModelHelper(name="validation_net", init_params=False)
|
||||
|
||||
# Bring in the init net from init_net.pb
|
||||
init_net_proto = caffe2_pb2.NetDef()
|
||||
with open(args.c2_init, "rb") as f:
|
||||
init_net_proto.ParseFromString(f.read())
|
||||
model.param_init_net = core.Net(init_net_proto)
|
||||
|
||||
# bring in the predict net from predict_net.pb
|
||||
predict_net_proto = caffe2_pb2.NetDef()
|
||||
with open(args.c2_predict, "rb") as f:
|
||||
predict_net_proto.ParseFromString(f.read())
|
||||
model.net = core.Net(predict_net_proto)
|
||||
|
||||
data_config = resolve_data_config(None, args)
|
||||
loader = create_loader(
|
||||
Dataset(args.data, load_bytes=args.tf_preprocessing),
|
||||
input_size=data_config['input_size'],
|
||||
batch_size=args.batch_size,
|
||||
use_prefetcher=False,
|
||||
interpolation=data_config['interpolation'],
|
||||
mean=data_config['mean'],
|
||||
std=data_config['std'],
|
||||
num_workers=args.workers,
|
||||
crop_pct=data_config['crop_pct'],
|
||||
tensorflow_preprocessing=args.tf_preprocessing)
|
||||
|
||||
# this is so obvious, wonderful interface </sarcasm>
|
||||
input_blob = model.net.external_inputs[0]
|
||||
output_blob = model.net.external_outputs[0]
|
||||
|
||||
if True:
|
||||
device_opts = None
|
||||
else:
|
||||
# CUDA is crashing, no idea why, awesome error message, give it a try for kicks
|
||||
device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id)
|
||||
model.net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)
|
||||
model.param_init_net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)
|
||||
|
||||
model.param_init_net.GaussianFill(
|
||||
[], input_blob.GetUnscopedName(),
|
||||
shape=(1,) + data_config['input_size'], mean=0.0, std=1.0)
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.CreateNet(model.net, overwrite=True)
|
||||
|
||||
batch_time = AverageMeter()
|
||||
top1 = AverageMeter()
|
||||
top5 = AverageMeter()
|
||||
end = time.time()
|
||||
for i, (input, target) in enumerate(loader):
|
||||
# run the net and return prediction
|
||||
caffe2_in = input.data.numpy()
|
||||
workspace.FeedBlob(input_blob, caffe2_in, device_opts)
|
||||
workspace.RunNet(model.net, num_iter=1)
|
||||
output = workspace.FetchBlob(output_blob)
|
||||
|
||||
# measure accuracy and record loss
|
||||
prec1, prec5 = accuracy_np(output.data, target.numpy())
|
||||
top1.update(prec1.item(), input.size(0))
|
||||
top5.update(prec5.item(), input.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
print('Test: [{0}/{1}]\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s, {ms_avg:.3f} ms/sample) \t'
|
||||
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
|
||||
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
|
||||
i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg,
|
||||
ms_avg=100 * batch_time.avg / input.size(0), top1=top1, top5=top5))
|
||||
|
||||
print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format(
|
||||
top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg))
|
||||
|
||||
|
||||
def accuracy_np(output, target):
|
||||
max_indices = np.argsort(output, axis=1)[:, ::-1]
|
||||
top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean()
|
||||
top1 = 100 * np.equal(max_indices[:, 0], target).mean()
|
||||
return top1, top5
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,5 @@
|
||||
from .gen_efficientnet import *
|
||||
from .mobilenetv3 import *
|
||||
from .model_factory import create_model
|
||||
from .config import is_exportable, is_scriptable, set_exportable, set_scriptable
|
||||
from .activations import *
|
||||
@@ -0,0 +1,137 @@
|
||||
from geffnet import config
|
||||
from geffnet.activations.activations_me import *
|
||||
from geffnet.activations.activations_jit import *
|
||||
from geffnet.activations.activations import *
|
||||
import torch
|
||||
|
||||
_has_silu = 'silu' in dir(torch.nn.functional)
|
||||
|
||||
_ACT_FN_DEFAULT = dict(
|
||||
silu=F.silu if _has_silu else swish,
|
||||
swish=F.silu if _has_silu else swish,
|
||||
mish=mish,
|
||||
relu=F.relu,
|
||||
relu6=F.relu6,
|
||||
sigmoid=sigmoid,
|
||||
tanh=tanh,
|
||||
hard_sigmoid=hard_sigmoid,
|
||||
hard_swish=hard_swish,
|
||||
)
|
||||
|
||||
_ACT_FN_JIT = dict(
|
||||
silu=F.silu if _has_silu else swish_jit,
|
||||
swish=F.silu if _has_silu else swish_jit,
|
||||
mish=mish_jit,
|
||||
)
|
||||
|
||||
_ACT_FN_ME = dict(
|
||||
silu=F.silu if _has_silu else swish_me,
|
||||
swish=F.silu if _has_silu else swish_me,
|
||||
mish=mish_me,
|
||||
hard_swish=hard_swish_me,
|
||||
hard_sigmoid_jit=hard_sigmoid_me,
|
||||
)
|
||||
|
||||
_ACT_LAYER_DEFAULT = dict(
|
||||
silu=nn.SiLU if _has_silu else Swish,
|
||||
swish=nn.SiLU if _has_silu else Swish,
|
||||
mish=Mish,
|
||||
relu=nn.ReLU,
|
||||
relu6=nn.ReLU6,
|
||||
sigmoid=Sigmoid,
|
||||
tanh=Tanh,
|
||||
hard_sigmoid=HardSigmoid,
|
||||
hard_swish=HardSwish,
|
||||
)
|
||||
|
||||
_ACT_LAYER_JIT = dict(
|
||||
silu=nn.SiLU if _has_silu else SwishJit,
|
||||
swish=nn.SiLU if _has_silu else SwishJit,
|
||||
mish=MishJit,
|
||||
)
|
||||
|
||||
_ACT_LAYER_ME = dict(
|
||||
silu=nn.SiLU if _has_silu else SwishMe,
|
||||
swish=nn.SiLU if _has_silu else SwishMe,
|
||||
mish=MishMe,
|
||||
hard_swish=HardSwishMe,
|
||||
hard_sigmoid=HardSigmoidMe
|
||||
)
|
||||
|
||||
_OVERRIDE_FN = dict()
|
||||
_OVERRIDE_LAYER = dict()
|
||||
|
||||
|
||||
def add_override_act_fn(name, fn):
|
||||
global _OVERRIDE_FN
|
||||
_OVERRIDE_FN[name] = fn
|
||||
|
||||
|
||||
def update_override_act_fn(overrides):
|
||||
assert isinstance(overrides, dict)
|
||||
global _OVERRIDE_FN
|
||||
_OVERRIDE_FN.update(overrides)
|
||||
|
||||
|
||||
def clear_override_act_fn():
|
||||
global _OVERRIDE_FN
|
||||
_OVERRIDE_FN = dict()
|
||||
|
||||
|
||||
def add_override_act_layer(name, fn):
|
||||
_OVERRIDE_LAYER[name] = fn
|
||||
|
||||
|
||||
def update_override_act_layer(overrides):
|
||||
assert isinstance(overrides, dict)
|
||||
global _OVERRIDE_LAYER
|
||||
_OVERRIDE_LAYER.update(overrides)
|
||||
|
||||
|
||||
def clear_override_act_layer():
|
||||
global _OVERRIDE_LAYER
|
||||
_OVERRIDE_LAYER = dict()
|
||||
|
||||
|
||||
def get_act_fn(name='relu'):
|
||||
""" Activation Function Factory
|
||||
Fetching activation fns by name with this function allows export or torch script friendly
|
||||
functions to be returned dynamically based on current config.
|
||||
"""
|
||||
if name in _OVERRIDE_FN:
|
||||
return _OVERRIDE_FN[name]
|
||||
use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit())
|
||||
if use_me and name in _ACT_FN_ME:
|
||||
# If not exporting or scripting the model, first look for a memory optimized version
|
||||
# activation with custom autograd, then fallback to jit scripted, then a Python or Torch builtin
|
||||
return _ACT_FN_ME[name]
|
||||
if config.is_exportable() and name in ('silu', 'swish'):
|
||||
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
|
||||
return swish
|
||||
use_jit = not (config.is_exportable() or config.is_no_jit())
|
||||
# NOTE: export tracing should work with jit scripted components, but I keep running into issues
|
||||
if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting
|
||||
return _ACT_FN_JIT[name]
|
||||
return _ACT_FN_DEFAULT[name]
|
||||
|
||||
|
||||
def get_act_layer(name='relu'):
|
||||
""" Activation Layer Factory
|
||||
Fetching activation layers by name with this function allows export or torch script friendly
|
||||
functions to be returned dynamically based on current config.
|
||||
"""
|
||||
if name in _OVERRIDE_LAYER:
|
||||
return _OVERRIDE_LAYER[name]
|
||||
use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit())
|
||||
if use_me and name in _ACT_LAYER_ME:
|
||||
return _ACT_LAYER_ME[name]
|
||||
if config.is_exportable() and name in ('silu', 'swish'):
|
||||
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
|
||||
return Swish
|
||||
use_jit = not (config.is_exportable() or config.is_no_jit())
|
||||
# NOTE: export tracing should work with jit scripted components, but I keep running into issues
|
||||
if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting
|
||||
return _ACT_LAYER_JIT[name]
|
||||
return _ACT_LAYER_DEFAULT[name]
|
||||
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
""" Activations
|
||||
|
||||
A collection of activations fn and modules with a common interface so that they can
|
||||
easily be swapped. All have an `inplace` arg even if not used.
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def swish(x, inplace: bool = False):
|
||||
"""Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
|
||||
and also as Swish (https://arxiv.org/abs/1710.05941).
|
||||
|
||||
TODO Rename to SiLU with addition to PyTorch
|
||||
"""
|
||||
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(Swish, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return swish(x, self.inplace)
|
||||
|
||||
|
||||
def mish(x, inplace: bool = False):
|
||||
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
||||
"""
|
||||
return x.mul(F.softplus(x).tanh())
|
||||
|
||||
|
||||
class Mish(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(Mish, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return mish(x, self.inplace)
|
||||
|
||||
|
||||
def sigmoid(x, inplace: bool = False):
|
||||
return x.sigmoid_() if inplace else x.sigmoid()
|
||||
|
||||
|
||||
# PyTorch has this, but not with a consistent inplace argmument interface
|
||||
class Sigmoid(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(Sigmoid, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return x.sigmoid_() if self.inplace else x.sigmoid()
|
||||
|
||||
|
||||
def tanh(x, inplace: bool = False):
|
||||
return x.tanh_() if inplace else x.tanh()
|
||||
|
||||
|
||||
# PyTorch has this, but not with a consistent inplace argmument interface
|
||||
class Tanh(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(Tanh, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return x.tanh_() if self.inplace else x.tanh()
|
||||
|
||||
|
||||
def hard_swish(x, inplace: bool = False):
|
||||
inner = F.relu6(x + 3.).div_(6.)
|
||||
return x.mul_(inner) if inplace else x.mul(inner)
|
||||
|
||||
|
||||
class HardSwish(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(HardSwish, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return hard_swish(x, self.inplace)
|
||||
|
||||
|
||||
def hard_sigmoid(x, inplace: bool = False):
|
||||
if inplace:
|
||||
return x.add_(3.).clamp_(0., 6.).div_(6.)
|
||||
else:
|
||||
return F.relu6(x + 3.) / 6.
|
||||
|
||||
|
||||
class HardSigmoid(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(HardSigmoid, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return hard_sigmoid(x, self.inplace)
|
||||
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
""" Activations (jit)
|
||||
|
||||
A collection of jit-scripted activations fn and modules with a common interface so that they can
|
||||
easily be swapped. All have an `inplace` arg even if not used.
|
||||
|
||||
All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
|
||||
currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
|
||||
versions if they contain in-place ops.
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
__all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit',
|
||||
'hard_sigmoid_jit', 'HardSigmoidJit', 'hard_swish_jit', 'HardSwishJit']
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish_jit(x, inplace: bool = False):
|
||||
"""Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
|
||||
and also as Swish (https://arxiv.org/abs/1710.05941).
|
||||
|
||||
TODO Rename to SiLU with addition to PyTorch
|
||||
"""
|
||||
return x.mul(x.sigmoid())
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def mish_jit(x, _inplace: bool = False):
|
||||
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
||||
"""
|
||||
return x.mul(F.softplus(x).tanh())
|
||||
|
||||
|
||||
class SwishJit(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(SwishJit, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return swish_jit(x)
|
||||
|
||||
|
||||
class MishJit(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(MishJit, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return mish_jit(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_sigmoid_jit(x, inplace: bool = False):
|
||||
# return F.relu6(x + 3.) / 6.
|
||||
return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
|
||||
|
||||
|
||||
class HardSigmoidJit(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(HardSigmoidJit, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return hard_sigmoid_jit(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_swish_jit(x, inplace: bool = False):
|
||||
# return x * (F.relu6(x + 3.) / 6)
|
||||
return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
|
||||
|
||||
|
||||
class HardSwishJit(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(HardSwishJit, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return hard_swish_jit(x)
|
||||
@@ -0,0 +1,174 @@
|
||||
""" Activations (memory-efficient w/ custom autograd)
|
||||
|
||||
A collection of activations fn and modules with a common interface so that they can
|
||||
easily be swapped. All have an `inplace` arg even if not used.
|
||||
|
||||
These activations are not compatible with jit scripting or ONNX export of the model, please use either
|
||||
the JIT or basic versions of the activations.
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
__all__ = ['swish_me', 'SwishMe', 'mish_me', 'MishMe',
|
||||
'hard_sigmoid_me', 'HardSigmoidMe', 'hard_swish_me', 'HardSwishMe']
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish_jit_fwd(x):
|
||||
return x.mul(torch.sigmoid(x))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish_jit_bwd(x, grad_output):
|
||||
x_sigmoid = torch.sigmoid(x)
|
||||
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
|
||||
|
||||
|
||||
class SwishJitAutoFn(torch.autograd.Function):
|
||||
""" torch.jit.script optimised Swish w/ memory-efficient checkpoint
|
||||
Inspired by conversation btw Jeremy Howard & Adam Pazske
|
||||
https://twitter.com/jeremyphoward/status/1188251041835315200
|
||||
|
||||
Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
|
||||
and also as Swish (https://arxiv.org/abs/1710.05941).
|
||||
|
||||
TODO Rename to SiLU with addition to PyTorch
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return swish_jit_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors[0]
|
||||
return swish_jit_bwd(x, grad_output)
|
||||
|
||||
|
||||
def swish_me(x, inplace=False):
|
||||
return SwishJitAutoFn.apply(x)
|
||||
|
||||
|
||||
class SwishMe(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(SwishMe, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return SwishJitAutoFn.apply(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def mish_jit_fwd(x):
|
||||
return x.mul(torch.tanh(F.softplus(x)))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def mish_jit_bwd(x, grad_output):
|
||||
x_sigmoid = torch.sigmoid(x)
|
||||
x_tanh_sp = F.softplus(x).tanh()
|
||||
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
|
||||
|
||||
|
||||
class MishJitAutoFn(torch.autograd.Function):
|
||||
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
||||
A memory efficient, jit scripted variant of Mish
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return mish_jit_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors[0]
|
||||
return mish_jit_bwd(x, grad_output)
|
||||
|
||||
|
||||
def mish_me(x, inplace=False):
|
||||
return MishJitAutoFn.apply(x)
|
||||
|
||||
|
||||
class MishMe(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(MishMe, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return MishJitAutoFn.apply(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_sigmoid_jit_fwd(x, inplace: bool = False):
|
||||
return (x + 3).clamp(min=0, max=6).div(6.)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_sigmoid_jit_bwd(x, grad_output):
|
||||
m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
|
||||
return grad_output * m
|
||||
|
||||
|
||||
class HardSigmoidJitAutoFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return hard_sigmoid_jit_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors[0]
|
||||
return hard_sigmoid_jit_bwd(x, grad_output)
|
||||
|
||||
|
||||
def hard_sigmoid_me(x, inplace: bool = False):
|
||||
return HardSigmoidJitAutoFn.apply(x)
|
||||
|
||||
|
||||
class HardSigmoidMe(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(HardSigmoidMe, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return HardSigmoidJitAutoFn.apply(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_swish_jit_fwd(x):
|
||||
return x * (x + 3).clamp(min=0, max=6).div(6.)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_swish_jit_bwd(x, grad_output):
|
||||
m = torch.ones_like(x) * (x >= 3.)
|
||||
m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
|
||||
return grad_output * m
|
||||
|
||||
|
||||
class HardSwishJitAutoFn(torch.autograd.Function):
|
||||
"""A memory efficient, jit-scripted HardSwish activation"""
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return hard_swish_jit_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors[0]
|
||||
return hard_swish_jit_bwd(x, grad_output)
|
||||
|
||||
|
||||
def hard_swish_me(x, inplace=False):
|
||||
return HardSwishJitAutoFn.apply(x)
|
||||
|
||||
|
||||
class HardSwishMe(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(HardSwishMe, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return HardSwishJitAutoFn.apply(x)
|
||||
@@ -0,0 +1,123 @@
|
||||
""" Global layer config state
|
||||
"""
|
||||
from typing import Any, Optional
|
||||
|
||||
__all__ = [
|
||||
'is_exportable', 'is_scriptable', 'is_no_jit', 'layer_config_kwargs',
|
||||
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config'
|
||||
]
|
||||
|
||||
# Set to True if prefer to have layers with no jit optimization (includes activations)
|
||||
_NO_JIT = False
|
||||
|
||||
# Set to True if prefer to have activation layers with no jit optimization
|
||||
# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
|
||||
# the jit flags so far are activations. This will change as more layers are updated and/or added.
|
||||
_NO_ACTIVATION_JIT = False
|
||||
|
||||
# Set to True if exporting a model with Same padding via ONNX
|
||||
_EXPORTABLE = False
|
||||
|
||||
# Set to True if wanting to use torch.jit.script on a model
|
||||
_SCRIPTABLE = False
|
||||
|
||||
|
||||
def is_no_jit():
|
||||
return _NO_JIT
|
||||
|
||||
|
||||
class set_no_jit:
|
||||
def __init__(self, mode: bool) -> None:
|
||||
global _NO_JIT
|
||||
self.prev = _NO_JIT
|
||||
_NO_JIT = mode
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, *args: Any) -> bool:
|
||||
global _NO_JIT
|
||||
_NO_JIT = self.prev
|
||||
return False
|
||||
|
||||
|
||||
def is_exportable():
|
||||
return _EXPORTABLE
|
||||
|
||||
|
||||
class set_exportable:
|
||||
def __init__(self, mode: bool) -> None:
|
||||
global _EXPORTABLE
|
||||
self.prev = _EXPORTABLE
|
||||
_EXPORTABLE = mode
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, *args: Any) -> bool:
|
||||
global _EXPORTABLE
|
||||
_EXPORTABLE = self.prev
|
||||
return False
|
||||
|
||||
|
||||
def is_scriptable():
|
||||
return _SCRIPTABLE
|
||||
|
||||
|
||||
class set_scriptable:
|
||||
def __init__(self, mode: bool) -> None:
|
||||
global _SCRIPTABLE
|
||||
self.prev = _SCRIPTABLE
|
||||
_SCRIPTABLE = mode
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, *args: Any) -> bool:
|
||||
global _SCRIPTABLE
|
||||
_SCRIPTABLE = self.prev
|
||||
return False
|
||||
|
||||
|
||||
class set_layer_config:
|
||||
""" Layer config context manager that allows setting all layer config flags at once.
|
||||
If a flag arg is None, it will not change the current value.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
scriptable: Optional[bool] = None,
|
||||
exportable: Optional[bool] = None,
|
||||
no_jit: Optional[bool] = None,
|
||||
no_activation_jit: Optional[bool] = None):
|
||||
global _SCRIPTABLE
|
||||
global _EXPORTABLE
|
||||
global _NO_JIT
|
||||
global _NO_ACTIVATION_JIT
|
||||
self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT
|
||||
if scriptable is not None:
|
||||
_SCRIPTABLE = scriptable
|
||||
if exportable is not None:
|
||||
_EXPORTABLE = exportable
|
||||
if no_jit is not None:
|
||||
_NO_JIT = no_jit
|
||||
if no_activation_jit is not None:
|
||||
_NO_ACTIVATION_JIT = no_activation_jit
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, *args: Any) -> bool:
|
||||
global _SCRIPTABLE
|
||||
global _EXPORTABLE
|
||||
global _NO_JIT
|
||||
global _NO_ACTIVATION_JIT
|
||||
_SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
|
||||
return False
|
||||
|
||||
|
||||
def layer_config_kwargs(kwargs):
|
||||
""" Consume config kwargs and return contextmgr obj """
|
||||
return set_layer_config(
|
||||
scriptable=kwargs.pop('scriptable', None),
|
||||
exportable=kwargs.pop('exportable', None),
|
||||
no_jit=kwargs.pop('no_jit', None))
|
||||
@@ -0,0 +1,304 @@
|
||||
""" Conv2D w/ SAME padding, CondConv, MixedConv
|
||||
|
||||
A collection of conv layers and padding helpers needed by EfficientNet, MixNet, and
|
||||
MobileNetV3 models that maintain weight compatibility with original Tensorflow models.
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import collections.abc
|
||||
import math
|
||||
from functools import partial
|
||||
from itertools import repeat
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .config import *
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
return parse
|
||||
|
||||
|
||||
_single = _ntuple(1)
|
||||
_pair = _ntuple(2)
|
||||
_triple = _ntuple(3)
|
||||
_quadruple = _ntuple(4)
|
||||
|
||||
|
||||
def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
|
||||
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
|
||||
|
||||
|
||||
def _get_padding(kernel_size, stride=1, dilation=1, **_):
|
||||
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
||||
return padding
|
||||
|
||||
|
||||
def _calc_same_pad(i: int, k: int, s: int, d: int):
|
||||
return max((-(i // -s) - 1) * s + (k - 1) * d + 1 - i, 0)
|
||||
|
||||
|
||||
def _same_pad_arg(input_size, kernel_size, stride, dilation):
|
||||
ih, iw = input_size
|
||||
kh, kw = kernel_size
|
||||
pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
|
||||
pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
|
||||
return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
|
||||
|
||||
|
||||
def _split_channels(num_chan, num_groups):
|
||||
split = [num_chan // num_groups for _ in range(num_groups)]
|
||||
split[0] += num_chan - sum(split)
|
||||
return split
|
||||
|
||||
|
||||
def conv2d_same(
|
||||
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
|
||||
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
|
||||
ih, iw = x.size()[-2:]
|
||||
kh, kw = weight.size()[-2:]
|
||||
pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
|
||||
pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
|
||||
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
||||
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
|
||||
|
||||
|
||||
class Conv2dSame(nn.Conv2d):
|
||||
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1, bias=True):
|
||||
super(Conv2dSame, self).__init__(
|
||||
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
|
||||
|
||||
def forward(self, x):
|
||||
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
class Conv2dSameExport(nn.Conv2d):
|
||||
""" ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions
|
||||
|
||||
NOTE: This does not currently work with torch.jit.script
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1, bias=True):
|
||||
super(Conv2dSameExport, self).__init__(
|
||||
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
|
||||
self.pad = None
|
||||
self.pad_input_size = (0, 0)
|
||||
|
||||
def forward(self, x):
|
||||
input_size = x.size()[-2:]
|
||||
if self.pad is None:
|
||||
pad_arg = _same_pad_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation)
|
||||
self.pad = nn.ZeroPad2d(pad_arg)
|
||||
self.pad_input_size = input_size
|
||||
|
||||
if self.pad is not None:
|
||||
x = self.pad(x)
|
||||
return F.conv2d(
|
||||
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
def get_padding_value(padding, kernel_size, **kwargs):
|
||||
dynamic = False
|
||||
if isinstance(padding, str):
|
||||
# for any string padding, the padding will be calculated for you, one of three ways
|
||||
padding = padding.lower()
|
||||
if padding == 'same':
|
||||
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
|
||||
if _is_static_pad(kernel_size, **kwargs):
|
||||
# static case, no extra overhead
|
||||
padding = _get_padding(kernel_size, **kwargs)
|
||||
else:
|
||||
# dynamic padding
|
||||
padding = 0
|
||||
dynamic = True
|
||||
elif padding == 'valid':
|
||||
# 'VALID' padding, same as padding=0
|
||||
padding = 0
|
||||
else:
|
||||
# Default to PyTorch style 'same'-ish symmetric padding
|
||||
padding = _get_padding(kernel_size, **kwargs)
|
||||
return padding, dynamic
|
||||
|
||||
|
||||
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
|
||||
padding = kwargs.pop('padding', '')
|
||||
kwargs.setdefault('bias', False)
|
||||
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
|
||||
if is_dynamic:
|
||||
if is_exportable():
|
||||
assert not is_scriptable()
|
||||
return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs)
|
||||
else:
|
||||
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
|
||||
else:
|
||||
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
||||
|
||||
|
||||
class MixedConv2d(nn.ModuleDict):
|
||||
""" Mixed Grouped Convolution
|
||||
Based on MDConv and GroupedConv in MixNet impl:
|
||||
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3,
|
||||
stride=1, padding='', dilation=1, depthwise=False, **kwargs):
|
||||
super(MixedConv2d, self).__init__()
|
||||
|
||||
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
|
||||
num_groups = len(kernel_size)
|
||||
in_splits = _split_channels(in_channels, num_groups)
|
||||
out_splits = _split_channels(out_channels, num_groups)
|
||||
self.in_channels = sum(in_splits)
|
||||
self.out_channels = sum(out_splits)
|
||||
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
|
||||
conv_groups = out_ch if depthwise else 1
|
||||
self.add_module(
|
||||
str(idx),
|
||||
create_conv2d_pad(
|
||||
in_ch, out_ch, k, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
|
||||
)
|
||||
self.splits = in_splits
|
||||
|
||||
def forward(self, x):
|
||||
x_split = torch.split(x, self.splits, 1)
|
||||
x_out = [conv(x_split[i]) for i, conv in enumerate(self.values())]
|
||||
x = torch.cat(x_out, 1)
|
||||
return x
|
||||
|
||||
|
||||
def get_condconv_initializer(initializer, num_experts, expert_shape):
|
||||
def condconv_initializer(weight):
|
||||
"""CondConv initializer function."""
|
||||
num_params = np.prod(expert_shape)
|
||||
if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
|
||||
weight.shape[1] != num_params):
|
||||
raise (ValueError(
|
||||
'CondConv variables must have shape [num_experts, num_params]'))
|
||||
for i in range(num_experts):
|
||||
initializer(weight[i].view(expert_shape))
|
||||
return condconv_initializer
|
||||
|
||||
|
||||
class CondConv2d(nn.Module):
|
||||
""" Conditional Convolution
|
||||
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
|
||||
|
||||
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
|
||||
https://github.com/pytorch/pytorch/issues/17983
|
||||
"""
|
||||
__constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3,
|
||||
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
|
||||
super(CondConv2d, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = _pair(kernel_size)
|
||||
self.stride = _pair(stride)
|
||||
padding_val, is_padding_dynamic = get_padding_value(
|
||||
padding, kernel_size, stride=stride, dilation=dilation)
|
||||
self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
|
||||
self.padding = _pair(padding_val)
|
||||
self.dilation = _pair(dilation)
|
||||
self.groups = groups
|
||||
self.num_experts = num_experts
|
||||
|
||||
self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
|
||||
weight_num_param = 1
|
||||
for wd in self.weight_shape:
|
||||
weight_num_param *= wd
|
||||
self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
|
||||
|
||||
if bias:
|
||||
self.bias_shape = (self.out_channels,)
|
||||
self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
init_weight = get_condconv_initializer(
|
||||
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
|
||||
init_weight(self.weight)
|
||||
if self.bias is not None:
|
||||
fan_in = np.prod(self.weight_shape[1:])
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
init_bias = get_condconv_initializer(
|
||||
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
|
||||
init_bias(self.bias)
|
||||
|
||||
def forward(self, x, routing_weights):
|
||||
B, C, H, W = x.shape
|
||||
weight = torch.matmul(routing_weights, self.weight)
|
||||
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
|
||||
weight = weight.view(new_weight_shape)
|
||||
bias = None
|
||||
if self.bias is not None:
|
||||
bias = torch.matmul(routing_weights, self.bias)
|
||||
bias = bias.view(B * self.out_channels)
|
||||
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel
|
||||
x = x.view(1, B * C, H, W)
|
||||
if self.dynamic_padding:
|
||||
out = conv2d_same(
|
||||
x, weight, bias, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups * B)
|
||||
else:
|
||||
out = F.conv2d(
|
||||
x, weight, bias, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups * B)
|
||||
out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
|
||||
|
||||
# Literal port (from TF definition)
|
||||
# x = torch.split(x, 1, 0)
|
||||
# weight = torch.split(weight, 1, 0)
|
||||
# if self.bias is not None:
|
||||
# bias = torch.matmul(routing_weights, self.bias)
|
||||
# bias = torch.split(bias, 1, 0)
|
||||
# else:
|
||||
# bias = [None] * B
|
||||
# out = []
|
||||
# for xi, wi, bi in zip(x, weight, bias):
|
||||
# wi = wi.view(*self.weight_shape)
|
||||
# if bi is not None:
|
||||
# bi = bi.view(*self.bias_shape)
|
||||
# out.append(self.conv_fn(
|
||||
# xi, wi, bi, stride=self.stride, padding=self.padding,
|
||||
# dilation=self.dilation, groups=self.groups))
|
||||
# out = torch.cat(out, 0)
|
||||
return out
|
||||
|
||||
|
||||
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
|
||||
assert 'groups' not in kwargs # only use 'depthwise' bool arg
|
||||
if isinstance(kernel_size, list):
|
||||
assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
|
||||
# We're going to use only lists for defining the MixedConv2d kernel groups,
|
||||
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
|
||||
m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
|
||||
else:
|
||||
depthwise = kwargs.pop('depthwise', False)
|
||||
groups = out_chs if depthwise else 1
|
||||
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
|
||||
m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
|
||||
else:
|
||||
m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
|
||||
return m
|
||||
@@ -0,0 +1,683 @@
|
||||
""" EfficientNet / MobileNetV3 Blocks and Builder
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import re
|
||||
from copy import deepcopy
|
||||
|
||||
from .conv2d_layers import *
|
||||
from geffnet.activations import *
|
||||
|
||||
__all__ = ['get_bn_args_tf', 'resolve_bn_args', 'resolve_se_args', 'resolve_act_layer', 'make_divisible',
|
||||
'round_channels', 'drop_connect', 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv',
|
||||
'InvertedResidual', 'CondConvResidual', 'EdgeResidual', 'EfficientNetBuilder', 'decode_arch_def',
|
||||
'initialize_weight_default', 'initialize_weight_goog', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT'
|
||||
]
|
||||
|
||||
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
|
||||
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
|
||||
# NOTE: momentum varies btw .99 and .9997 depending on source
|
||||
# .99 in official TF TPU impl
|
||||
# .9997 (/w .999 in search space) for paper
|
||||
#
|
||||
# PyTorch defaults are momentum = .1, eps = 1e-5
|
||||
#
|
||||
BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
|
||||
BN_EPS_TF_DEFAULT = 1e-3
|
||||
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
|
||||
|
||||
|
||||
def get_bn_args_tf():
|
||||
return _BN_ARGS_TF.copy()
|
||||
|
||||
|
||||
def resolve_bn_args(kwargs):
|
||||
bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
|
||||
bn_momentum = kwargs.pop('bn_momentum', None)
|
||||
if bn_momentum is not None:
|
||||
bn_args['momentum'] = bn_momentum
|
||||
bn_eps = kwargs.pop('bn_eps', None)
|
||||
if bn_eps is not None:
|
||||
bn_args['eps'] = bn_eps
|
||||
return bn_args
|
||||
|
||||
|
||||
_SE_ARGS_DEFAULT = dict(
|
||||
gate_fn=sigmoid,
|
||||
act_layer=None, # None == use containing block's activation layer
|
||||
reduce_mid=False,
|
||||
divisor=1)
|
||||
|
||||
|
||||
def resolve_se_args(kwargs, in_chs, act_layer=None):
|
||||
se_kwargs = kwargs.copy() if kwargs is not None else {}
|
||||
# fill in args that aren't specified with the defaults
|
||||
for k, v in _SE_ARGS_DEFAULT.items():
|
||||
se_kwargs.setdefault(k, v)
|
||||
# some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch
|
||||
if not se_kwargs.pop('reduce_mid'):
|
||||
se_kwargs['reduced_base_chs'] = in_chs
|
||||
# act_layer override, if it remains None, the containing block's act_layer will be used
|
||||
if se_kwargs['act_layer'] is None:
|
||||
assert act_layer is not None
|
||||
se_kwargs['act_layer'] = act_layer
|
||||
return se_kwargs
|
||||
|
||||
|
||||
def resolve_act_layer(kwargs, default='relu'):
|
||||
act_layer = kwargs.pop('act_layer', default)
|
||||
if isinstance(act_layer, str):
|
||||
act_layer = get_act_layer(act_layer)
|
||||
return act_layer
|
||||
|
||||
|
||||
def make_divisible(v: int, divisor: int = 8, min_value: int = None):
|
||||
min_value = min_value or divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v: # ensure round down does not go down by more than 10%.
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
|
||||
"""Round number of filters based on depth multiplier."""
|
||||
if not multiplier:
|
||||
return channels
|
||||
channels *= multiplier
|
||||
return make_divisible(channels, divisor, channel_min)
|
||||
|
||||
|
||||
def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.):
|
||||
"""Apply drop connect."""
|
||||
if not training:
|
||||
return inputs
|
||||
|
||||
keep_prob = 1 - drop_connect_rate
|
||||
random_tensor = keep_prob + torch.rand(
|
||||
(inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = inputs.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class SqueezeExcite(nn.Module):
|
||||
|
||||
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1):
|
||||
super(SqueezeExcite, self).__init__()
|
||||
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
|
||||
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
|
||||
self.gate_fn = gate_fn
|
||||
|
||||
def forward(self, x):
|
||||
x_se = x.mean((2, 3), keepdim=True)
|
||||
x_se = self.conv_reduce(x_se)
|
||||
x_se = self.act1(x_se)
|
||||
x_se = self.conv_expand(x_se)
|
||||
x = x * self.gate_fn(x_se)
|
||||
return x
|
||||
|
||||
|
||||
class ConvBnAct(nn.Module):
|
||||
def __init__(self, in_chs, out_chs, kernel_size,
|
||||
stride=1, pad_type='', act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||
super(ConvBnAct, self).__init__()
|
||||
assert stride in [1, 2]
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type)
|
||||
self.bn1 = norm_layer(out_chs, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
return x
|
||||
|
||||
|
||||
class DepthwiseSeparableConv(nn.Module):
|
||||
""" DepthwiseSeparable block
|
||||
Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion
|
||||
factor of 1.0. This is an alternative to having a IR with optional first pw conv.
|
||||
"""
|
||||
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
||||
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
||||
pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None,
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
|
||||
super(DepthwiseSeparableConv, self).__init__()
|
||||
assert stride in [1, 2]
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
|
||||
self.conv_dw = select_conv2d(
|
||||
in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True)
|
||||
self.bn1 = norm_layer(in_chs, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if se_ratio is not None and se_ratio > 0.:
|
||||
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
||||
self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs)
|
||||
else:
|
||||
self.se = nn.Identity()
|
||||
|
||||
self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
|
||||
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
||||
self.act2 = act_layer(inplace=True) if pw_act else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
x = self.conv_dw(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
|
||||
x = self.se(x)
|
||||
|
||||
x = self.conv_pw(x)
|
||||
x = self.bn2(x)
|
||||
x = self.act2(x)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0.:
|
||||
x = drop_connect(x, self.training, self.drop_connect_rate)
|
||||
x += residual
|
||||
return x
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
""" Inverted residual block w/ optional SE"""
|
||||
|
||||
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
||||
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
||||
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
|
||||
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
conv_kwargs=None, drop_connect_rate=0.):
|
||||
super(InvertedResidual, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
conv_kwargs = conv_kwargs or {}
|
||||
mid_chs: int = make_divisible(in_chs * exp_ratio)
|
||||
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
|
||||
# Point-wise expansion
|
||||
self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
|
||||
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
# Depth-wise convolution
|
||||
self.conv_dw = select_conv2d(
|
||||
mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True, **conv_kwargs)
|
||||
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
|
||||
self.act2 = act_layer(inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if se_ratio is not None and se_ratio > 0.:
|
||||
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
||||
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
|
||||
else:
|
||||
self.se = nn.Identity() # for jit.script compat
|
||||
|
||||
# Point-wise linear projection
|
||||
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
|
||||
self.bn3 = norm_layer(out_chs, **norm_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
# Point-wise expansion
|
||||
x = self.conv_pw(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
|
||||
# Depth-wise convolution
|
||||
x = self.conv_dw(x)
|
||||
x = self.bn2(x)
|
||||
x = self.act2(x)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
x = self.se(x)
|
||||
|
||||
# Point-wise linear projection
|
||||
x = self.conv_pwl(x)
|
||||
x = self.bn3(x)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0.:
|
||||
x = drop_connect(x, self.training, self.drop_connect_rate)
|
||||
x += residual
|
||||
return x
|
||||
|
||||
|
||||
class CondConvResidual(InvertedResidual):
|
||||
""" Inverted residual block w/ CondConv routing"""
|
||||
|
||||
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
||||
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
||||
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
|
||||
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
num_experts=0, drop_connect_rate=0.):
|
||||
|
||||
self.num_experts = num_experts
|
||||
conv_kwargs = dict(num_experts=self.num_experts)
|
||||
|
||||
super(CondConvResidual, self).__init__(
|
||||
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, pad_type=pad_type,
|
||||
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
|
||||
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs,
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs,
|
||||
drop_connect_rate=drop_connect_rate)
|
||||
|
||||
self.routing_fn = nn.Linear(in_chs, self.num_experts)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
# CondConv routing
|
||||
pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1)
|
||||
routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs))
|
||||
|
||||
# Point-wise expansion
|
||||
x = self.conv_pw(x, routing_weights)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
|
||||
# Depth-wise convolution
|
||||
x = self.conv_dw(x, routing_weights)
|
||||
x = self.bn2(x)
|
||||
x = self.act2(x)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
x = self.se(x)
|
||||
|
||||
# Point-wise linear projection
|
||||
x = self.conv_pwl(x, routing_weights)
|
||||
x = self.bn3(x)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0.:
|
||||
x = drop_connect(x, self.training, self.drop_connect_rate)
|
||||
x += residual
|
||||
return x
|
||||
|
||||
|
||||
class EdgeResidual(nn.Module):
|
||||
""" EdgeTPU Residual block with expansion convolution followed by pointwise-linear w/ stride"""
|
||||
|
||||
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
|
||||
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1,
|
||||
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
|
||||
super(EdgeResidual, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
mid_chs = make_divisible(fake_in_chs * exp_ratio) if fake_in_chs > 0 else make_divisible(in_chs * exp_ratio)
|
||||
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
|
||||
# Expansion convolution
|
||||
self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
|
||||
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if se_ratio is not None and se_ratio > 0.:
|
||||
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
||||
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
|
||||
else:
|
||||
self.se = nn.Identity()
|
||||
|
||||
# Point-wise linear projection
|
||||
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type)
|
||||
self.bn2 = nn.BatchNorm2d(out_chs, **norm_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
# Expansion convolution
|
||||
x = self.conv_exp(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
x = self.se(x)
|
||||
|
||||
# Point-wise linear projection
|
||||
x = self.conv_pwl(x)
|
||||
x = self.bn2(x)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0.:
|
||||
x = drop_connect(x, self.training, self.drop_connect_rate)
|
||||
x += residual
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class EfficientNetBuilder:
|
||||
""" Build Trunk Blocks for Efficient/Mobile Networks
|
||||
|
||||
This ended up being somewhat of a cross between
|
||||
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
|
||||
and
|
||||
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||
pad_type='', act_layer=None, se_kwargs=None,
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
|
||||
self.channel_multiplier = channel_multiplier
|
||||
self.channel_divisor = channel_divisor
|
||||
self.channel_min = channel_min
|
||||
self.pad_type = pad_type
|
||||
self.act_layer = act_layer
|
||||
self.se_kwargs = se_kwargs
|
||||
self.norm_layer = norm_layer
|
||||
self.norm_kwargs = norm_kwargs
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
|
||||
# updated during build
|
||||
self.in_chs = None
|
||||
self.block_idx = 0
|
||||
self.block_count = 0
|
||||
|
||||
def _round_channels(self, chs):
|
||||
return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
|
||||
|
||||
def _make_block(self, ba):
|
||||
bt = ba.pop('block_type')
|
||||
ba['in_chs'] = self.in_chs
|
||||
ba['out_chs'] = self._round_channels(ba['out_chs'])
|
||||
if 'fake_in_chs' in ba and ba['fake_in_chs']:
|
||||
# FIXME this is a hack to work around mismatch in origin impl input filters for EdgeTPU
|
||||
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
|
||||
ba['norm_layer'] = self.norm_layer
|
||||
ba['norm_kwargs'] = self.norm_kwargs
|
||||
ba['pad_type'] = self.pad_type
|
||||
# block act fn overrides the model default
|
||||
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
|
||||
assert ba['act_layer'] is not None
|
||||
if bt == 'ir':
|
||||
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
|
||||
ba['se_kwargs'] = self.se_kwargs
|
||||
if ba.get('num_experts', 0) > 0:
|
||||
block = CondConvResidual(**ba)
|
||||
else:
|
||||
block = InvertedResidual(**ba)
|
||||
elif bt == 'ds' or bt == 'dsa':
|
||||
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
|
||||
ba['se_kwargs'] = self.se_kwargs
|
||||
block = DepthwiseSeparableConv(**ba)
|
||||
elif bt == 'er':
|
||||
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
|
||||
ba['se_kwargs'] = self.se_kwargs
|
||||
block = EdgeResidual(**ba)
|
||||
elif bt == 'cn':
|
||||
block = ConvBnAct(**ba)
|
||||
else:
|
||||
assert False, 'Uknkown block type (%s) while building model.' % bt
|
||||
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
|
||||
return block
|
||||
|
||||
def _make_stack(self, stack_args):
|
||||
blocks = []
|
||||
# each stack (stage) contains a list of block arguments
|
||||
for i, ba in enumerate(stack_args):
|
||||
if i >= 1:
|
||||
# only the first block in any stack can have a stride > 1
|
||||
ba['stride'] = 1
|
||||
block = self._make_block(ba)
|
||||
blocks.append(block)
|
||||
self.block_idx += 1 # incr global idx (across all stacks)
|
||||
return nn.Sequential(*blocks)
|
||||
|
||||
def __call__(self, in_chs, block_args):
|
||||
""" Build the blocks
|
||||
Args:
|
||||
in_chs: Number of input-channels passed to first block
|
||||
block_args: A list of lists, outer list defines stages, inner
|
||||
list contains strings defining block configuration(s)
|
||||
Return:
|
||||
List of block stacks (each stack wrapped in nn.Sequential)
|
||||
"""
|
||||
self.in_chs = in_chs
|
||||
self.block_count = sum([len(x) for x in block_args])
|
||||
self.block_idx = 0
|
||||
blocks = []
|
||||
# outer list of block_args defines the stacks ('stages' by some conventions)
|
||||
for stack_idx, stack in enumerate(block_args):
|
||||
assert isinstance(stack, list)
|
||||
stack = self._make_stack(stack)
|
||||
blocks.append(stack)
|
||||
return blocks
|
||||
|
||||
|
||||
def _parse_ksize(ss):
|
||||
if ss.isdigit():
|
||||
return int(ss)
|
||||
else:
|
||||
return [int(k) for k in ss.split('.')]
|
||||
|
||||
|
||||
def _decode_block_str(block_str):
|
||||
""" Decode block definition string
|
||||
|
||||
Gets a list of block arg (dicts) through a string notation of arguments.
|
||||
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
|
||||
|
||||
All args can exist in any order with the exception of the leading string which
|
||||
is assumed to indicate the block type.
|
||||
|
||||
leading string - block type (
|
||||
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
|
||||
r - number of repeat blocks,
|
||||
k - kernel size,
|
||||
s - strides (1-9),
|
||||
e - expansion ratio,
|
||||
c - output channels,
|
||||
se - squeeze/excitation ratio
|
||||
n - activation fn ('re', 'r6', 'hs', or 'sw')
|
||||
Args:
|
||||
block_str: a string representation of block arguments.
|
||||
Returns:
|
||||
A list of block args (dicts)
|
||||
Raises:
|
||||
ValueError: if the string def not properly specified (TODO)
|
||||
"""
|
||||
assert isinstance(block_str, str)
|
||||
ops = block_str.split('_')
|
||||
block_type = ops[0] # take the block type off the front
|
||||
ops = ops[1:]
|
||||
options = {}
|
||||
noskip = False
|
||||
for op in ops:
|
||||
# string options being checked on individual basis, combine if they grow
|
||||
if op == 'noskip':
|
||||
noskip = True
|
||||
elif op.startswith('n'):
|
||||
# activation fn
|
||||
key = op[0]
|
||||
v = op[1:]
|
||||
if v == 're':
|
||||
value = get_act_layer('relu')
|
||||
elif v == 'r6':
|
||||
value = get_act_layer('relu6')
|
||||
elif v == 'hs':
|
||||
value = get_act_layer('hard_swish')
|
||||
elif v == 'sw':
|
||||
value = get_act_layer('swish')
|
||||
else:
|
||||
continue
|
||||
options[key] = value
|
||||
else:
|
||||
# all numeric options
|
||||
splits = re.split(r'(\d.*)', op)
|
||||
if len(splits) >= 2:
|
||||
key, value = splits[:2]
|
||||
options[key] = value
|
||||
|
||||
# if act_layer is None, the model default (passed to model init) will be used
|
||||
act_layer = options['n'] if 'n' in options else None
|
||||
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
|
||||
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
|
||||
fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
|
||||
|
||||
num_repeat = int(options['r'])
|
||||
# each type of block has different valid arguments, fill accordingly
|
||||
if block_type == 'ir':
|
||||
block_args = dict(
|
||||
block_type=block_type,
|
||||
dw_kernel_size=_parse_ksize(options['k']),
|
||||
exp_kernel_size=exp_kernel_size,
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
out_chs=int(options['c']),
|
||||
exp_ratio=float(options['e']),
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
stride=int(options['s']),
|
||||
act_layer=act_layer,
|
||||
noskip=noskip,
|
||||
)
|
||||
if 'cc' in options:
|
||||
block_args['num_experts'] = int(options['cc'])
|
||||
elif block_type == 'ds' or block_type == 'dsa':
|
||||
block_args = dict(
|
||||
block_type=block_type,
|
||||
dw_kernel_size=_parse_ksize(options['k']),
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
out_chs=int(options['c']),
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
stride=int(options['s']),
|
||||
act_layer=act_layer,
|
||||
pw_act=block_type == 'dsa',
|
||||
noskip=block_type == 'dsa' or noskip,
|
||||
)
|
||||
elif block_type == 'er':
|
||||
block_args = dict(
|
||||
block_type=block_type,
|
||||
exp_kernel_size=_parse_ksize(options['k']),
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
out_chs=int(options['c']),
|
||||
exp_ratio=float(options['e']),
|
||||
fake_in_chs=fake_in_chs,
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
stride=int(options['s']),
|
||||
act_layer=act_layer,
|
||||
noskip=noskip,
|
||||
)
|
||||
elif block_type == 'cn':
|
||||
block_args = dict(
|
||||
block_type=block_type,
|
||||
kernel_size=int(options['k']),
|
||||
out_chs=int(options['c']),
|
||||
stride=int(options['s']),
|
||||
act_layer=act_layer,
|
||||
)
|
||||
else:
|
||||
assert False, 'Unknown block type (%s)' % block_type
|
||||
|
||||
return block_args, num_repeat
|
||||
|
||||
|
||||
def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
|
||||
""" Per-stage depth scaling
|
||||
Scales the block repeats in each stage. This depth scaling impl maintains
|
||||
compatibility with the EfficientNet scaling method, while allowing sensible
|
||||
scaling for other models that may have multiple block arg definitions in each stage.
|
||||
"""
|
||||
|
||||
# We scale the total repeat count for each stage, there may be multiple
|
||||
# block arg defs per stage so we need to sum.
|
||||
num_repeat = sum(repeats)
|
||||
if depth_trunc == 'round':
|
||||
# Truncating to int by rounding allows stages with few repeats to remain
|
||||
# proportionally smaller for longer. This is a good choice when stage definitions
|
||||
# include single repeat stages that we'd prefer to keep that way as long as possible
|
||||
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
|
||||
else:
|
||||
# The default for EfficientNet truncates repeats to int via 'ceil'.
|
||||
# Any multiplier > 1.0 will result in an increased depth for every stage.
|
||||
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
|
||||
|
||||
# Proportionally distribute repeat count scaling to each block definition in the stage.
|
||||
# Allocation is done in reverse as it results in the first block being less likely to be scaled.
|
||||
# The first block makes less sense to repeat in most of the arch definitions.
|
||||
repeats_scaled = []
|
||||
for r in repeats[::-1]:
|
||||
rs = max(1, round((r / num_repeat * num_repeat_scaled)))
|
||||
repeats_scaled.append(rs)
|
||||
num_repeat -= r
|
||||
num_repeat_scaled -= rs
|
||||
repeats_scaled = repeats_scaled[::-1]
|
||||
|
||||
# Apply the calculated scaling to each block arg in the stage
|
||||
sa_scaled = []
|
||||
for ba, rep in zip(stack_args, repeats_scaled):
|
||||
sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
|
||||
return sa_scaled
|
||||
|
||||
|
||||
def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False):
|
||||
arch_args = []
|
||||
for stack_idx, block_strings in enumerate(arch_def):
|
||||
assert isinstance(block_strings, list)
|
||||
stack_args = []
|
||||
repeats = []
|
||||
for block_str in block_strings:
|
||||
assert isinstance(block_str, str)
|
||||
ba, rep = _decode_block_str(block_str)
|
||||
if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
|
||||
ba['num_experts'] *= experts_multiplier
|
||||
stack_args.append(ba)
|
||||
repeats.append(rep)
|
||||
if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
|
||||
arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
|
||||
else:
|
||||
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
|
||||
return arch_args
|
||||
|
||||
|
||||
def initialize_weight_goog(m, n='', fix_group_fanout=True):
|
||||
# weight init as per Tensorflow Official impl
|
||||
# https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
|
||||
if isinstance(m, CondConv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
if fix_group_fanout:
|
||||
fan_out //= m.groups
|
||||
init_weight_fn = get_condconv_initializer(
|
||||
lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
|
||||
init_weight_fn(m.weight)
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
if fix_group_fanout:
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1.0)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
fan_out = m.weight.size(0) # fan-out
|
||||
fan_in = 0
|
||||
if 'routing_fn' in n:
|
||||
fan_in = m.weight.size(1)
|
||||
init_range = 1.0 / math.sqrt(fan_in + fan_out)
|
||||
m.weight.data.uniform_(-init_range, init_range)
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
def initialize_weight_default(m, n=''):
|
||||
if isinstance(m, CondConv2d):
|
||||
init_fn = get_condconv_initializer(partial(
|
||||
nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape)
|
||||
init_fn(m.weight)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1.0)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,71 @@
|
||||
""" Checkpoint loading / state_dict helpers
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
try:
|
||||
from torch.hub import load_state_dict_from_url
|
||||
except ImportError:
|
||||
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path):
|
||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||
print("=> Loading checkpoint '{}'".format(checkpoint_path))
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in checkpoint['state_dict'].items():
|
||||
if k.startswith('module'):
|
||||
name = k[7:] # remove `module.`
|
||||
else:
|
||||
name = k
|
||||
new_state_dict[name] = v
|
||||
model.load_state_dict(new_state_dict)
|
||||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
|
||||
else:
|
||||
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
|
||||
raise FileNotFoundError()
|
||||
|
||||
|
||||
def load_pretrained(model, url, filter_fn=None, strict=True):
|
||||
if not url:
|
||||
print("=> Warning: Pretrained model URL is empty, using random initialization.")
|
||||
return
|
||||
|
||||
state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu')
|
||||
|
||||
input_conv = 'conv_stem'
|
||||
classifier = 'classifier'
|
||||
in_chans = getattr(model, input_conv).weight.shape[1]
|
||||
num_classes = getattr(model, classifier).weight.shape[0]
|
||||
|
||||
input_conv_weight = input_conv + '.weight'
|
||||
pretrained_in_chans = state_dict[input_conv_weight].shape[1]
|
||||
if in_chans != pretrained_in_chans:
|
||||
if in_chans == 1:
|
||||
print('=> Converting pretrained input conv {} from {} to 1 channel'.format(
|
||||
input_conv_weight, pretrained_in_chans))
|
||||
conv1_weight = state_dict[input_conv_weight]
|
||||
state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True)
|
||||
else:
|
||||
print('=> Discarding pretrained input conv {} since input channel count != {}'.format(
|
||||
input_conv_weight, pretrained_in_chans))
|
||||
del state_dict[input_conv_weight]
|
||||
strict = False
|
||||
|
||||
classifier_weight = classifier + '.weight'
|
||||
pretrained_num_classes = state_dict[classifier_weight].shape[0]
|
||||
if num_classes != pretrained_num_classes:
|
||||
print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes))
|
||||
del state_dict[classifier_weight]
|
||||
del state_dict[classifier + '.bias']
|
||||
strict = False
|
||||
|
||||
if filter_fn is not None:
|
||||
state_dict = filter_fn(state_dict)
|
||||
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
@@ -0,0 +1,364 @@
|
||||
""" MobileNet-V3
|
||||
|
||||
A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl.
|
||||
|
||||
Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .activations import get_act_fn, get_act_layer, HardSwish
|
||||
from .config import layer_config_kwargs
|
||||
from .conv2d_layers import select_conv2d
|
||||
from .helpers import load_pretrained
|
||||
from .efficientnet_builder import *
|
||||
|
||||
__all__ = ['mobilenetv3_rw', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_minimal_100',
|
||||
'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilenetv3_small_minimal_100',
|
||||
'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100',
|
||||
'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100']
|
||||
|
||||
model_urls = {
|
||||
'mobilenetv3_rw':
|
||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
|
||||
'mobilenetv3_large_075': None,
|
||||
'mobilenetv3_large_100':
|
||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth',
|
||||
'mobilenetv3_large_minimal_100': None,
|
||||
'mobilenetv3_small_075': None,
|
||||
'mobilenetv3_small_100': None,
|
||||
'mobilenetv3_small_minimal_100': None,
|
||||
'tf_mobilenetv3_large_075':
|
||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
|
||||
'tf_mobilenetv3_large_100':
|
||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',
|
||||
'tf_mobilenetv3_large_minimal_100':
|
||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',
|
||||
'tf_mobilenetv3_small_075':
|
||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',
|
||||
'tf_mobilenetv3_small_100':
|
||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',
|
||||
'tf_mobilenetv3_small_minimal_100':
|
||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
|
||||
}
|
||||
|
||||
|
||||
class MobileNetV3(nn.Module):
|
||||
""" MobileNet-V3
|
||||
|
||||
A this model utilizes the MobileNet-v3 specific 'efficient head', where global pooling is done before the
|
||||
head convolution without a final batch-norm layer before the classifier.
|
||||
|
||||
Paper: https://arxiv.org/abs/1905.02244
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
|
||||
channel_multiplier=1.0, pad_type='', act_layer=HardSwish, drop_rate=0., drop_connect_rate=0.,
|
||||
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'):
|
||||
super(MobileNetV3, self).__init__()
|
||||
self.drop_rate = drop_rate
|
||||
|
||||
stem_size = round_channels(stem_size, channel_multiplier)
|
||||
self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
|
||||
self.bn1 = nn.BatchNorm2d(stem_size, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
in_chs = stem_size
|
||||
|
||||
builder = EfficientNetBuilder(
|
||||
channel_multiplier, pad_type=pad_type, act_layer=act_layer, se_kwargs=se_kwargs,
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, drop_connect_rate=drop_connect_rate)
|
||||
self.blocks = nn.Sequential(*builder(in_chs, block_args))
|
||||
in_chs = builder.in_chs
|
||||
|
||||
self.global_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type, bias=head_bias)
|
||||
self.act2 = act_layer(inplace=True)
|
||||
self.classifier = nn.Linear(num_features, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if weight_init == 'goog':
|
||||
initialize_weight_goog(m)
|
||||
else:
|
||||
initialize_weight_default(m)
|
||||
|
||||
def as_sequential(self):
|
||||
layers = [self.conv_stem, self.bn1, self.act1]
|
||||
layers.extend(self.blocks)
|
||||
layers.extend([
|
||||
self.global_pool, self.conv_head, self.act2,
|
||||
nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def features(self, x):
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
x = self.blocks(x)
|
||||
x = self.global_pool(x)
|
||||
x = self.conv_head(x)
|
||||
x = self.act2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.flatten(1)
|
||||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
return self.classifier(x)
|
||||
|
||||
|
||||
def _create_model(model_kwargs, variant, pretrained=False):
|
||||
as_sequential = model_kwargs.pop('as_sequential', False)
|
||||
model = MobileNetV3(**model_kwargs)
|
||||
if pretrained and model_urls[variant]:
|
||||
load_pretrained(model, model_urls[variant])
|
||||
if as_sequential:
|
||||
model = model.as_sequential()
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
"""Creates a MobileNet-V3 model (RW variant).
|
||||
|
||||
Paper: https://arxiv.org/abs/1905.02244
|
||||
|
||||
This was my first attempt at reproducing the MobileNet-V3 from paper alone. It came close to the
|
||||
eventual Tensorflow reference impl but has a few differences:
|
||||
1. This model has no bias on the head convolution
|
||||
2. This model forces no residual (noskip) on the first DWS block, this is different than MnasNet
|
||||
3. This model always uses ReLU for the SE activation layer, other models in the family inherit their act layer
|
||||
from their parent block
|
||||
4. This model does not enforce divisible by 8 limitation on the SE reduction channel count
|
||||
|
||||
Overall the changes are fairly minor and result in a very small parameter count difference and no
|
||||
top-1/5
|
||||
|
||||
Args:
|
||||
channel_multiplier: multiplier to number of channels per layer.
|
||||
"""
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu
|
||||
# stage 1, 112x112 in
|
||||
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
|
||||
# stage 2, 56x56 in
|
||||
['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
|
||||
# stage 3, 28x28 in
|
||||
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
|
||||
# stage 4, 14x14in
|
||||
['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
|
||||
# stage 5, 14x14in
|
||||
['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
|
||||
# stage 6, 7x7 in
|
||||
['cn_r1_k1_s1_c960'], # hard-swish
|
||||
]
|
||||
with layer_config_kwargs(kwargs):
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
head_bias=False, # one of my mistakes
|
||||
channel_multiplier=channel_multiplier,
|
||||
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
|
||||
se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True),
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_model(model_kwargs, variant, pretrained)
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
"""Creates a MobileNet-V3 large/small/minimal models.
|
||||
|
||||
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v3.py
|
||||
Paper: https://arxiv.org/abs/1905.02244
|
||||
|
||||
Args:
|
||||
channel_multiplier: multiplier to number of channels per layer.
|
||||
"""
|
||||
if 'small' in variant:
|
||||
num_features = 1024
|
||||
if 'minimal' in variant:
|
||||
act_layer = 'relu'
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['ds_r1_k3_s2_e1_c16'],
|
||||
# stage 1, 56x56 in
|
||||
['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'],
|
||||
# stage 2, 28x28 in
|
||||
['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'],
|
||||
# stage 3, 14x14 in
|
||||
['ir_r2_k3_s1_e3_c48'],
|
||||
# stage 4, 14x14in
|
||||
['ir_r3_k3_s2_e6_c96'],
|
||||
# stage 6, 7x7 in
|
||||
['cn_r1_k1_s1_c576'],
|
||||
]
|
||||
else:
|
||||
act_layer = 'hard_swish'
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu
|
||||
# stage 1, 56x56 in
|
||||
['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu
|
||||
# stage 2, 28x28 in
|
||||
['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish
|
||||
# stage 3, 14x14 in
|
||||
['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish
|
||||
# stage 4, 14x14in
|
||||
['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish
|
||||
# stage 6, 7x7 in
|
||||
['cn_r1_k1_s1_c576'], # hard-swish
|
||||
]
|
||||
else:
|
||||
num_features = 1280
|
||||
if 'minimal' in variant:
|
||||
act_layer = 'relu'
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['ds_r1_k3_s1_e1_c16'],
|
||||
# stage 1, 112x112 in
|
||||
['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'],
|
||||
# stage 2, 56x56 in
|
||||
['ir_r3_k3_s2_e3_c40'],
|
||||
# stage 3, 28x28 in
|
||||
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'],
|
||||
# stage 4, 14x14in
|
||||
['ir_r2_k3_s1_e6_c112'],
|
||||
# stage 5, 14x14in
|
||||
['ir_r3_k3_s2_e6_c160'],
|
||||
# stage 6, 7x7 in
|
||||
['cn_r1_k1_s1_c960'],
|
||||
]
|
||||
else:
|
||||
act_layer = 'hard_swish'
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['ds_r1_k3_s1_e1_c16_nre'], # relu
|
||||
# stage 1, 112x112 in
|
||||
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
|
||||
# stage 2, 56x56 in
|
||||
['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
|
||||
# stage 3, 28x28 in
|
||||
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
|
||||
# stage 4, 14x14in
|
||||
['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
|
||||
# stage 5, 14x14in
|
||||
['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
|
||||
# stage 6, 7x7 in
|
||||
['cn_r1_k1_s1_c960'], # hard-swish
|
||||
]
|
||||
with layer_config_kwargs(kwargs):
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
num_features=num_features,
|
||||
stem_size=16,
|
||||
channel_multiplier=channel_multiplier,
|
||||
act_layer=resolve_act_layer(kwargs, act_layer),
|
||||
se_kwargs=dict(
|
||||
act_layer=get_act_layer('relu'), gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=8),
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_model(model_kwargs, variant, pretrained)
|
||||
return model
|
||||
|
||||
|
||||
def mobilenetv3_rw(pretrained=False, **kwargs):
|
||||
""" MobileNet-V3 RW
|
||||
Attn: See note in gen function for this variant.
|
||||
"""
|
||||
# NOTE for train set drop_rate=0.2
|
||||
if pretrained:
|
||||
# pretrained model trained with non-default BN epsilon
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def mobilenetv3_large_075(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Large 0.75"""
|
||||
# NOTE for train set drop_rate=0.2
|
||||
model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def mobilenetv3_large_100(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Large 1.0 """
|
||||
# NOTE for train set drop_rate=0.2
|
||||
model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Large (Minimalistic) 1.0 """
|
||||
# NOTE for train set drop_rate=0.2
|
||||
model = _gen_mobilenet_v3('mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def mobilenetv3_small_075(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Small 0.75 """
|
||||
model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def mobilenetv3_small_100(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Small 1.0 """
|
||||
model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Small (Minimalistic) 1.0 """
|
||||
model = _gen_mobilenet_v3('mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def tf_mobilenetv3_large_075(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Large 0.75. Tensorflow compat variant. """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def tf_mobilenetv3_large_100(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Large 1.0. Tensorflow compat variant. """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Large Minimalistic 1.0. Tensorflow compat variant. """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def tf_mobilenetv3_small_075(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Small 0.75. Tensorflow compat variant. """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def tf_mobilenetv3_small_100(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Small 1.0. Tensorflow compat variant."""
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 Small Minimalistic 1.0. Tensorflow compat variant. """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
@@ -0,0 +1,27 @@
|
||||
from .config import set_layer_config
|
||||
from .helpers import load_checkpoint
|
||||
|
||||
from .gen_efficientnet import *
|
||||
from .mobilenetv3 import *
|
||||
|
||||
|
||||
def create_model(
|
||||
model_name='mnasnet_100',
|
||||
pretrained=None,
|
||||
num_classes=1000,
|
||||
in_chans=3,
|
||||
checkpoint_path='',
|
||||
**kwargs):
|
||||
|
||||
model_kwargs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained, **kwargs)
|
||||
|
||||
if model_name in globals():
|
||||
create_fn = globals()[model_name]
|
||||
model = create_fn(**model_kwargs)
|
||||
else:
|
||||
raise RuntimeError('Unknown model (%s)' % model_name)
|
||||
|
||||
if checkpoint_path and not pretrained:
|
||||
load_checkpoint(model, checkpoint_path)
|
||||
|
||||
return model
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '1.0.2'
|
||||
@@ -0,0 +1,84 @@
|
||||
dependencies = ['torch', 'math']
|
||||
|
||||
from geffnet import efficientnet_b0
|
||||
from geffnet import efficientnet_b1
|
||||
from geffnet import efficientnet_b2
|
||||
from geffnet import efficientnet_b3
|
||||
|
||||
from geffnet import efficientnet_es
|
||||
|
||||
from geffnet import efficientnet_lite0
|
||||
|
||||
from geffnet import mixnet_s
|
||||
from geffnet import mixnet_m
|
||||
from geffnet import mixnet_l
|
||||
from geffnet import mixnet_xl
|
||||
|
||||
from geffnet import mobilenetv2_100
|
||||
from geffnet import mobilenetv2_110d
|
||||
from geffnet import mobilenetv2_120d
|
||||
from geffnet import mobilenetv2_140
|
||||
|
||||
from geffnet import mobilenetv3_large_100
|
||||
from geffnet import mobilenetv3_rw
|
||||
from geffnet import mnasnet_a1
|
||||
from geffnet import mnasnet_b1
|
||||
from geffnet import fbnetc_100
|
||||
from geffnet import spnasnet_100
|
||||
|
||||
from geffnet import tf_efficientnet_b0
|
||||
from geffnet import tf_efficientnet_b1
|
||||
from geffnet import tf_efficientnet_b2
|
||||
from geffnet import tf_efficientnet_b3
|
||||
from geffnet import tf_efficientnet_b4
|
||||
from geffnet import tf_efficientnet_b5
|
||||
from geffnet import tf_efficientnet_b6
|
||||
from geffnet import tf_efficientnet_b7
|
||||
from geffnet import tf_efficientnet_b8
|
||||
|
||||
from geffnet import tf_efficientnet_b0_ap
|
||||
from geffnet import tf_efficientnet_b1_ap
|
||||
from geffnet import tf_efficientnet_b2_ap
|
||||
from geffnet import tf_efficientnet_b3_ap
|
||||
from geffnet import tf_efficientnet_b4_ap
|
||||
from geffnet import tf_efficientnet_b5_ap
|
||||
from geffnet import tf_efficientnet_b6_ap
|
||||
from geffnet import tf_efficientnet_b7_ap
|
||||
from geffnet import tf_efficientnet_b8_ap
|
||||
|
||||
from geffnet import tf_efficientnet_b0_ns
|
||||
from geffnet import tf_efficientnet_b1_ns
|
||||
from geffnet import tf_efficientnet_b2_ns
|
||||
from geffnet import tf_efficientnet_b3_ns
|
||||
from geffnet import tf_efficientnet_b4_ns
|
||||
from geffnet import tf_efficientnet_b5_ns
|
||||
from geffnet import tf_efficientnet_b6_ns
|
||||
from geffnet import tf_efficientnet_b7_ns
|
||||
from geffnet import tf_efficientnet_l2_ns_475
|
||||
from geffnet import tf_efficientnet_l2_ns
|
||||
|
||||
from geffnet import tf_efficientnet_es
|
||||
from geffnet import tf_efficientnet_em
|
||||
from geffnet import tf_efficientnet_el
|
||||
|
||||
from geffnet import tf_efficientnet_cc_b0_4e
|
||||
from geffnet import tf_efficientnet_cc_b0_8e
|
||||
from geffnet import tf_efficientnet_cc_b1_8e
|
||||
|
||||
from geffnet import tf_efficientnet_lite0
|
||||
from geffnet import tf_efficientnet_lite1
|
||||
from geffnet import tf_efficientnet_lite2
|
||||
from geffnet import tf_efficientnet_lite3
|
||||
from geffnet import tf_efficientnet_lite4
|
||||
|
||||
from geffnet import tf_mixnet_s
|
||||
from geffnet import tf_mixnet_m
|
||||
from geffnet import tf_mixnet_l
|
||||
|
||||
from geffnet import tf_mobilenetv3_large_075
|
||||
from geffnet import tf_mobilenetv3_large_100
|
||||
from geffnet import tf_mobilenetv3_large_minimal_100
|
||||
from geffnet import tf_mobilenetv3_small_075
|
||||
from geffnet import tf_mobilenetv3_small_100
|
||||
from geffnet import tf_mobilenetv3_small_minimal_100
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
""" ONNX export script
|
||||
|
||||
Export PyTorch models as ONNX graphs.
|
||||
|
||||
This export script originally started as an adaptation of code snippets found at
|
||||
https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
|
||||
|
||||
The default parameters work with PyTorch 1.6 and ONNX 1.7 and produce an optimal ONNX graph
|
||||
for hosting in the ONNX runtime (see onnx_validate.py). To export an ONNX model compatible
|
||||
with caffe2 (see caffe2_benchmark.py and caffe2_validate.py), the --keep-init and --aten-fallback
|
||||
flags are currently required.
|
||||
|
||||
Older versions of PyTorch/ONNX (tested PyTorch 1.4, ONNX 1.5) do not need extra flags for
|
||||
caffe2 compatibility, but they produce a model that isn't as fast running on ONNX runtime.
|
||||
|
||||
Most new release of PyTorch and ONNX cause some sort of breakage in the export / usage of ONNX models.
|
||||
Please do your research and search ONNX and PyTorch issue tracker before asking me. Thanks.
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
import onnx
|
||||
import geffnet
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
|
||||
parser.add_argument('output', metavar='ONNX_FILE',
|
||||
help='output model filename')
|
||||
parser.add_argument('--model', '-m', metavar='MODEL', default='mobilenetv3_large_100',
|
||||
help='model architecture (default: mobilenetv3_large_100)')
|
||||
parser.add_argument('--opset', type=int, default=10,
|
||||
help='ONNX opset to use (default: 10)')
|
||||
parser.add_argument('--keep-init', action='store_true', default=False,
|
||||
help='Keep initializers as input. Needed for Caffe2 compatible export in newer PyTorch/ONNX.')
|
||||
parser.add_argument('--aten-fallback', action='store_true', default=False,
|
||||
help='Fallback to ATEN ops. Helps fix AdaptiveAvgPool issue with Caffe2 in newer PyTorch/ONNX.')
|
||||
parser.add_argument('--dynamic-size', action='store_true', default=False,
|
||||
help='Export model width dynamic width/height. Not recommended for "tf" models with SAME padding.')
|
||||
parser.add_argument('-b', '--batch-size', default=1, type=int,
|
||||
metavar='N', help='mini-batch size (default: 1)')
|
||||
parser.add_argument('--img-size', default=None, type=int,
|
||||
metavar='N', help='Input image dimension, uses model default if empty')
|
||||
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||
help='Override mean pixel value of dataset')
|
||||
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
||||
help='Override std deviation of of dataset')
|
||||
parser.add_argument('--num-classes', type=int, default=1000,
|
||||
help='Number classes in dataset')
|
||||
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||
help='path to checkpoint (default: none)')
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
args.pretrained = True
|
||||
if args.checkpoint:
|
||||
args.pretrained = False
|
||||
|
||||
print("==> Creating PyTorch {} model".format(args.model))
|
||||
# NOTE exportable=True flag disables autofn/jit scripted activations and uses Conv2dSameExport layers
|
||||
# for models using SAME padding
|
||||
model = geffnet.create_model(
|
||||
args.model,
|
||||
num_classes=args.num_classes,
|
||||
in_chans=3,
|
||||
pretrained=args.pretrained,
|
||||
checkpoint_path=args.checkpoint,
|
||||
exportable=True)
|
||||
|
||||
model.eval()
|
||||
|
||||
example_input = torch.randn((args.batch_size, 3, args.img_size or 224, args.img_size or 224), requires_grad=True)
|
||||
|
||||
# Run model once before export trace, sets padding for models with Conv2dSameExport. This means
|
||||
# that the padding for models with Conv2dSameExport (most models with tf_ prefix) is fixed for
|
||||
# the input img_size specified in this script.
|
||||
# Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to
|
||||
# issues in the tracing of the dynamic padding or errors attempting to export the model after jit
|
||||
# scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions...
|
||||
model(example_input)
|
||||
|
||||
print("==> Exporting model to ONNX format at '{}'".format(args.output))
|
||||
input_names = ["input0"]
|
||||
output_names = ["output0"]
|
||||
dynamic_axes = {'input0': {0: 'batch'}, 'output0': {0: 'batch'}}
|
||||
if args.dynamic_size:
|
||||
dynamic_axes['input0'][2] = 'height'
|
||||
dynamic_axes['input0'][3] = 'width'
|
||||
if args.aten_fallback:
|
||||
export_type = torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
|
||||
else:
|
||||
export_type = torch.onnx.OperatorExportTypes.ONNX
|
||||
|
||||
torch_out = torch.onnx._export(
|
||||
model, example_input, args.output, export_params=True, verbose=True, input_names=input_names,
|
||||
output_names=output_names, keep_initializers_as_inputs=args.keep_init, dynamic_axes=dynamic_axes,
|
||||
opset_version=args.opset, operator_export_type=export_type)
|
||||
|
||||
print("==> Loading and checking exported model from '{}'".format(args.output))
|
||||
onnx_model = onnx.load(args.output)
|
||||
onnx.checker.check_model(onnx_model) # assuming throw on error
|
||||
print("==> Passed")
|
||||
|
||||
if args.keep_init and args.aten_fallback:
|
||||
import caffe2.python.onnx.backend as onnx_caffe2
|
||||
# Caffe2 loading only works properly in newer PyTorch/ONNX combos when
|
||||
# keep_initializers_as_inputs and aten_fallback are set to True.
|
||||
print("==> Loading model into Caffe2 backend and comparing forward pass.".format(args.output))
|
||||
caffe2_backend = onnx_caffe2.prepare(onnx_model)
|
||||
B = {onnx_model.graph.input[0].name: x.data.numpy()}
|
||||
c2_out = caffe2_backend.run(B)[0]
|
||||
np.testing.assert_almost_equal(torch_out.data.numpy(), c2_out, decimal=5)
|
||||
print("==> Passed")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,84 @@
|
||||
""" ONNX optimization script
|
||||
|
||||
Run ONNX models through the optimizer to prune unneeded nodes, fuse batchnorm layers into conv, etc.
|
||||
|
||||
NOTE: This isn't working consistently in recent PyTorch/ONNX combos (ie PyTorch 1.6 and ONNX 1.7),
|
||||
it seems time to switch to using the onnxruntime online optimizer (can also be saved for offline).
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import argparse
|
||||
import warnings
|
||||
|
||||
import onnx
|
||||
from onnx import optimizer
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Optimize ONNX model")
|
||||
|
||||
parser.add_argument("model", help="The ONNX model")
|
||||
parser.add_argument("--output", required=True, help="The optimized model output filename")
|
||||
|
||||
|
||||
def traverse_graph(graph, prefix=''):
|
||||
content = []
|
||||
indent = prefix + ' '
|
||||
graphs = []
|
||||
num_nodes = 0
|
||||
for node in graph.node:
|
||||
pn, gs = onnx.helper.printable_node(node, indent, subgraphs=True)
|
||||
assert isinstance(gs, list)
|
||||
content.append(pn)
|
||||
graphs.extend(gs)
|
||||
num_nodes += 1
|
||||
for g in graphs:
|
||||
g_count, g_str = traverse_graph(g)
|
||||
content.append('\n' + g_str)
|
||||
num_nodes += g_count
|
||||
return num_nodes, '\n'.join(content)
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
onnx_model = onnx.load(args.model)
|
||||
num_original_nodes, original_graph_str = traverse_graph(onnx_model.graph)
|
||||
|
||||
# Optimizer passes to perform
|
||||
passes = [
|
||||
#'eliminate_deadend',
|
||||
'eliminate_identity',
|
||||
'eliminate_nop_dropout',
|
||||
'eliminate_nop_pad',
|
||||
'eliminate_nop_transpose',
|
||||
'eliminate_unused_initializer',
|
||||
'extract_constant_to_initializer',
|
||||
'fuse_add_bias_into_conv',
|
||||
'fuse_bn_into_conv',
|
||||
'fuse_consecutive_concats',
|
||||
'fuse_consecutive_reduce_unsqueeze',
|
||||
'fuse_consecutive_squeezes',
|
||||
'fuse_consecutive_transposes',
|
||||
#'fuse_matmul_add_bias_into_gemm',
|
||||
'fuse_pad_into_conv',
|
||||
#'fuse_transpose_into_gemm',
|
||||
#'lift_lexical_references',
|
||||
]
|
||||
|
||||
# Apply the optimization on the original serialized model
|
||||
# WARNING I've had issues with optimizer in recent versions of PyTorch / ONNX causing
|
||||
# 'duplicate definition of name' errors, see: https://github.com/onnx/onnx/issues/2401
|
||||
# It may be better to rely on onnxruntime optimizations, see onnx_validate.py script.
|
||||
warnings.warn("I've had issues with optimizer in recent versions of PyTorch / ONNX."
|
||||
"Try onnxruntime optimization if this doesn't work.")
|
||||
optimized_model = optimizer.optimize(onnx_model, passes)
|
||||
|
||||
num_optimized_nodes, optimzied_graph_str = traverse_graph(optimized_model.graph)
|
||||
print('==> The model after optimization:\n{}\n'.format(optimzied_graph_str))
|
||||
print('==> The optimized model has {} nodes, the original had {}.'.format(num_optimized_nodes, num_original_nodes))
|
||||
|
||||
# Save the ONNX model
|
||||
onnx.save(optimized_model, args.output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,27 @@
|
||||
import argparse
|
||||
|
||||
import onnx
|
||||
from caffe2.python.onnx.backend import Caffe2Backend
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert ONNX to Caffe2")
|
||||
|
||||
parser.add_argument("model", help="The ONNX model")
|
||||
parser.add_argument("--c2-prefix", required=True,
|
||||
help="The output file prefix for the caffe2 model init and predict file. ")
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
onnx_model = onnx.load(args.model)
|
||||
caffe2_init, caffe2_predict = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model)
|
||||
caffe2_init_str = caffe2_init.SerializeToString()
|
||||
with open(args.c2_prefix + '.init.pb', "wb") as f:
|
||||
f.write(caffe2_init_str)
|
||||
caffe2_predict_str = caffe2_predict.SerializeToString()
|
||||
with open(args.c2_prefix + '.predict.pb', "wb") as f:
|
||||
f.write(caffe2_predict_str)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,112 @@
|
||||
""" ONNX-runtime validation script
|
||||
|
||||
This script was created to verify accuracy and performance of exported ONNX
|
||||
models running with the onnxruntime. It utilizes the PyTorch dataloader/processing
|
||||
pipeline for a fair comparison against the originals.
|
||||
|
||||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
from data import create_loader, resolve_data_config, Dataset
|
||||
from utils import AverageMeter
|
||||
import time
|
||||
|
||||
parser = argparse.ArgumentParser(description='Caffe2 ImageNet Validation')
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('--onnx-input', default='', type=str, metavar='PATH',
|
||||
help='path to onnx model/weights file')
|
||||
parser.add_argument('--onnx-output-opt', default='', type=str, metavar='PATH',
|
||||
help='path to output optimized onnx graph')
|
||||
parser.add_argument('--profile', action='store_true', default=False,
|
||||
help='Enable profiler output.')
|
||||
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||
metavar='N', help='mini-batch size (default: 256)')
|
||||
parser.add_argument('--img-size', default=None, type=int,
|
||||
metavar='N', help='Input image dimension, uses model default if empty')
|
||||
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||
help='Override mean pixel value of dataset')
|
||||
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
||||
help='Override std deviation of of dataset')
|
||||
parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT',
|
||||
help='Override default crop pct of 0.875')
|
||||
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
||||
help='Image resize interpolation type (overrides model)')
|
||||
parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
|
||||
help='use tensorflow mnasnet preporcessing')
|
||||
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
||||
metavar='N', help='print frequency (default: 10)')
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
args.gpu_id = 0
|
||||
|
||||
# Set graph optimization level
|
||||
sess_options = onnxruntime.SessionOptions()
|
||||
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
if args.profile:
|
||||
sess_options.enable_profiling = True
|
||||
if args.onnx_output_opt:
|
||||
sess_options.optimized_model_filepath = args.onnx_output_opt
|
||||
|
||||
session = onnxruntime.InferenceSession(args.onnx_input, sess_options)
|
||||
|
||||
data_config = resolve_data_config(None, args)
|
||||
loader = create_loader(
|
||||
Dataset(args.data, load_bytes=args.tf_preprocessing),
|
||||
input_size=data_config['input_size'],
|
||||
batch_size=args.batch_size,
|
||||
use_prefetcher=False,
|
||||
interpolation=data_config['interpolation'],
|
||||
mean=data_config['mean'],
|
||||
std=data_config['std'],
|
||||
num_workers=args.workers,
|
||||
crop_pct=data_config['crop_pct'],
|
||||
tensorflow_preprocessing=args.tf_preprocessing)
|
||||
|
||||
input_name = session.get_inputs()[0].name
|
||||
|
||||
batch_time = AverageMeter()
|
||||
top1 = AverageMeter()
|
||||
top5 = AverageMeter()
|
||||
end = time.time()
|
||||
for i, (input, target) in enumerate(loader):
|
||||
# run the net and return prediction
|
||||
output = session.run([], {input_name: input.data.numpy()})
|
||||
output = output[0]
|
||||
|
||||
# measure accuracy and record loss
|
||||
prec1, prec5 = accuracy_np(output, target.numpy())
|
||||
top1.update(prec1.item(), input.size(0))
|
||||
top5.update(prec5.item(), input.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
print('Test: [{0}/{1}]\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s, {ms_avg:.3f} ms/sample) \t'
|
||||
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
|
||||
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
|
||||
i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg,
|
||||
ms_avg=100 * batch_time.avg / input.size(0), top1=top1, top5=top5))
|
||||
|
||||
print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format(
|
||||
top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg))
|
||||
|
||||
|
||||
def accuracy_np(output, target):
|
||||
max_indices = np.argsort(output, axis=1)[:, ::-1]
|
||||
top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean()
|
||||
top1 = 100 * np.equal(max_indices[:, 0], target).mean()
|
||||
return top1, top5
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,2 @@
|
||||
torch>=1.2.0
|
||||
torchvision>=0.4.0
|
||||
@@ -0,0 +1,47 @@
|
||||
""" Setup
|
||||
"""
|
||||
from setuptools import setup, find_packages
|
||||
from codecs import open
|
||||
from os import path
|
||||
|
||||
here = path.abspath(path.dirname(__file__))
|
||||
|
||||
# Get the long description from the README file
|
||||
with open(path.join(here, 'README.md'), encoding='utf-8') as f:
|
||||
long_description = f.read()
|
||||
|
||||
exec(open('geffnet/version.py').read())
|
||||
setup(
|
||||
name='geffnet',
|
||||
version=__version__,
|
||||
description='(Generic) EfficientNets for PyTorch',
|
||||
long_description=long_description,
|
||||
long_description_content_type='text/markdown',
|
||||
url='https://github.com/rwightman/gen-efficientnet-pytorch',
|
||||
author='Ross Wightman',
|
||||
author_email='hello@rwightman.com',
|
||||
classifiers=[
|
||||
# How mature is this project? Common values are
|
||||
# 3 - Alpha
|
||||
# 4 - Beta
|
||||
# 5 - Production/Stable
|
||||
'Development Status :: 3 - Alpha',
|
||||
'Intended Audience :: Education',
|
||||
'Intended Audience :: Science/Research',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
'Topic :: Scientific/Engineering',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'Topic :: Software Development',
|
||||
'Topic :: Software Development :: Libraries',
|
||||
'Topic :: Software Development :: Libraries :: Python Modules',
|
||||
],
|
||||
|
||||
# Note that this is a string of words separated by whitespace, not a list.
|
||||
keywords='pytorch pretrained models efficientnet mixnet mobilenetv3 mnasnet',
|
||||
packages=find_packages(exclude=['data']),
|
||||
install_requires=['torch >= 1.4', 'torchvision'],
|
||||
python_requires='>=3.6',
|
||||
)
|
||||
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def get_outdir(path, *paths, inc=False):
|
||||
outdir = os.path.join(path, *paths)
|
||||
if not os.path.exists(outdir):
|
||||
os.makedirs(outdir)
|
||||
elif inc:
|
||||
count = 1
|
||||
outdir_inc = outdir + '-' + str(count)
|
||||
while os.path.exists(outdir_inc):
|
||||
count = count + 1
|
||||
outdir_inc = outdir + '-' + str(count)
|
||||
assert count < 100
|
||||
outdir = outdir_inc
|
||||
os.makedirs(outdir)
|
||||
return outdir
|
||||
|
||||
@@ -0,0 +1,166 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
from contextlib import suppress
|
||||
|
||||
import geffnet
|
||||
from data import Dataset, create_loader, resolve_data_config
|
||||
from utils import accuracy, AverageMeter
|
||||
|
||||
has_native_amp = False
|
||||
try:
|
||||
if getattr(torch.cuda.amp, 'autocast') is not None:
|
||||
has_native_amp = True
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('--model', '-m', metavar='MODEL', default='spnasnet1_00',
|
||||
help='model architecture (default: dpn92)')
|
||||
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||
metavar='N', help='mini-batch size (default: 256)')
|
||||
parser.add_argument('--img-size', default=None, type=int,
|
||||
metavar='N', help='Input image dimension, uses model default if empty')
|
||||
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||
help='Override mean pixel value of dataset')
|
||||
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
||||
help='Override std deviation of of dataset')
|
||||
parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT',
|
||||
help='Override default crop pct of 0.875')
|
||||
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
||||
help='Image resize interpolation type (overrides model)')
|
||||
parser.add_argument('--num-classes', type=int, default=1000,
|
||||
help='Number classes in dataset')
|
||||
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
||||
metavar='N', help='print frequency (default: 10)')
|
||||
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||
help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||||
help='use pre-trained model')
|
||||
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
help='convert model torchscript for inference')
|
||||
parser.add_argument('--num-gpu', type=int, default=1,
|
||||
help='Number of GPUS to use')
|
||||
parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
|
||||
help='use tensorflow mnasnet preporcessing')
|
||||
parser.add_argument('--no-cuda', dest='no_cuda', action='store_true',
|
||||
help='')
|
||||
parser.add_argument('--channels-last', action='store_true', default=False,
|
||||
help='Use channels_last memory layout')
|
||||
parser.add_argument('--amp', action='store_true', default=False,
|
||||
help='Use native Torch AMP mixed precision.')
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.checkpoint and not args.pretrained:
|
||||
args.pretrained = True
|
||||
|
||||
amp_autocast = suppress # do nothing
|
||||
if args.amp:
|
||||
if not has_native_amp:
|
||||
print("Native Torch AMP is not available (requires torch >= 1.6), using FP32.")
|
||||
else:
|
||||
amp_autocast = torch.cuda.amp.autocast
|
||||
|
||||
# create model
|
||||
model = geffnet.create_model(
|
||||
args.model,
|
||||
num_classes=args.num_classes,
|
||||
in_chans=3,
|
||||
pretrained=args.pretrained,
|
||||
checkpoint_path=args.checkpoint,
|
||||
scriptable=args.torchscript)
|
||||
|
||||
if args.channels_last:
|
||||
model = model.to(memory_format=torch.channels_last)
|
||||
|
||||
if args.torchscript:
|
||||
torch.jit.optimized_execution(True)
|
||||
model = torch.jit.script(model)
|
||||
|
||||
print('Model %s created, param count: %d' %
|
||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||
|
||||
data_config = resolve_data_config(model, args)
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
if not args.no_cuda:
|
||||
if args.num_gpu > 1:
|
||||
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
||||
else:
|
||||
model = model.cuda()
|
||||
criterion = criterion.cuda()
|
||||
|
||||
loader = create_loader(
|
||||
Dataset(args.data, load_bytes=args.tf_preprocessing),
|
||||
input_size=data_config['input_size'],
|
||||
batch_size=args.batch_size,
|
||||
use_prefetcher=not args.no_cuda,
|
||||
interpolation=data_config['interpolation'],
|
||||
mean=data_config['mean'],
|
||||
std=data_config['std'],
|
||||
num_workers=args.workers,
|
||||
crop_pct=data_config['crop_pct'],
|
||||
tensorflow_preprocessing=args.tf_preprocessing)
|
||||
|
||||
batch_time = AverageMeter()
|
||||
losses = AverageMeter()
|
||||
top1 = AverageMeter()
|
||||
top5 = AverageMeter()
|
||||
|
||||
model.eval()
|
||||
end = time.time()
|
||||
with torch.no_grad():
|
||||
for i, (input, target) in enumerate(loader):
|
||||
if not args.no_cuda:
|
||||
target = target.cuda()
|
||||
input = input.cuda()
|
||||
if args.channels_last:
|
||||
input = input.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
# compute output
|
||||
with amp_autocast():
|
||||
output = model(input)
|
||||
loss = criterion(output, target)
|
||||
|
||||
# measure accuracy and record loss
|
||||
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
|
||||
losses.update(loss.item(), input.size(0))
|
||||
top1.update(prec1.item(), input.size(0))
|
||||
top5.update(prec5.item(), input.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
print('Test: [{0}/{1}]\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s) \t'
|
||||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
|
||||
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
|
||||
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
|
||||
i, len(loader), batch_time=batch_time,
|
||||
rate_avg=input.size(0) / batch_time.avg,
|
||||
loss=losses, top1=top1, top5=top5))
|
||||
|
||||
print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format(
|
||||
top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,34 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self):
|
||||
super(Encoder, self).__init__()
|
||||
|
||||
basemodel_name = 'tf_efficientnet_b5_ap'
|
||||
print('Loading base model ()...'.format(basemodel_name), end='')
|
||||
repo_path = os.path.join(os.path.dirname(__file__), 'efficientnet_repo')
|
||||
basemodel = torch.hub.load(repo_path, basemodel_name, pretrained=False, source='local')
|
||||
print('Done.')
|
||||
|
||||
# Remove last layer
|
||||
print('Removing last two layers (global_pool & classifier).')
|
||||
basemodel.global_pool = nn.Identity()
|
||||
basemodel.classifier = nn.Identity()
|
||||
|
||||
self.original_model = basemodel
|
||||
|
||||
def forward(self, x):
|
||||
features = [x]
|
||||
for k, v in self.original_model._modules.items():
|
||||
if (k == 'blocks'):
|
||||
for ki, vi in v._modules.items():
|
||||
features.append(vi(features[-1]))
|
||||
else:
|
||||
features.append(v(features[-1]))
|
||||
return features
|
||||
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
########################################################################################################################
|
||||
|
||||
|
||||
# Upsample + BatchNorm
|
||||
class UpSampleBN(nn.Module):
|
||||
def __init__(self, skip_input, output_features):
|
||||
super(UpSampleBN, self).__init__()
|
||||
|
||||
self._net = nn.Sequential(nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(output_features),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(output_features),
|
||||
nn.LeakyReLU())
|
||||
|
||||
def forward(self, x, concat_with):
|
||||
up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
|
||||
f = torch.cat([up_x, concat_with], dim=1)
|
||||
return self._net(f)
|
||||
|
||||
|
||||
# Upsample + GroupNorm + Weight Standardization
|
||||
class UpSampleGN(nn.Module):
|
||||
def __init__(self, skip_input, output_features):
|
||||
super(UpSampleGN, self).__init__()
|
||||
|
||||
self._net = nn.Sequential(Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
|
||||
nn.GroupNorm(8, output_features),
|
||||
nn.LeakyReLU(),
|
||||
Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1),
|
||||
nn.GroupNorm(8, output_features),
|
||||
nn.LeakyReLU())
|
||||
|
||||
def forward(self, x, concat_with):
|
||||
up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
|
||||
f = torch.cat([up_x, concat_with], dim=1)
|
||||
return self._net(f)
|
||||
|
||||
|
||||
# Conv2d with weight standardization
|
||||
class Conv2d(nn.Conv2d):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1, bias=True):
|
||||
super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
|
||||
padding, dilation, groups, bias)
|
||||
|
||||
def forward(self, x):
|
||||
weight = self.weight
|
||||
weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
|
||||
keepdim=True).mean(dim=3, keepdim=True)
|
||||
weight = weight - weight_mean
|
||||
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
|
||||
weight = weight / std.expand_as(weight)
|
||||
return F.conv2d(x, weight, self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
# normalize
|
||||
def norm_normalize(norm_out):
|
||||
min_kappa = 0.01
|
||||
norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1)
|
||||
norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10
|
||||
kappa = F.elu(kappa) + 1.0 + min_kappa
|
||||
final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1)
|
||||
return final_out
|
||||
|
||||
|
||||
# uncertainty-guided sampling (only used during training)
|
||||
@torch.no_grad()
|
||||
def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta):
|
||||
device = init_normal.device
|
||||
B, _, H, W = init_normal.shape
|
||||
N = int(sampling_ratio * H * W)
|
||||
beta = beta
|
||||
|
||||
# uncertainty map
|
||||
uncertainty_map = -1 * init_normal[:, 3, :, :] # B, H, W
|
||||
|
||||
# gt_invalid_mask (B, H, W)
|
||||
if gt_norm_mask is not None:
|
||||
gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest')
|
||||
gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5
|
||||
uncertainty_map[gt_invalid_mask] = -1e4
|
||||
|
||||
# (B, H*W)
|
||||
_, idx = uncertainty_map.view(B, -1).sort(1, descending=True)
|
||||
|
||||
# importance sampling
|
||||
if int(beta * N) > 0:
|
||||
importance = idx[:, :int(beta * N)] # B, beta*N
|
||||
|
||||
# remaining
|
||||
remaining = idx[:, int(beta * N):] # B, H*W - beta*N
|
||||
|
||||
# coverage
|
||||
num_coverage = N - int(beta * N)
|
||||
|
||||
if num_coverage <= 0:
|
||||
samples = importance
|
||||
else:
|
||||
coverage_list = []
|
||||
for i in range(B):
|
||||
idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
|
||||
coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
|
||||
coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
|
||||
samples = torch.cat((importance, coverage), dim=1) # B, N
|
||||
|
||||
else:
|
||||
# remaining
|
||||
remaining = idx[:, :] # B, H*W
|
||||
|
||||
# coverage
|
||||
num_coverage = N
|
||||
|
||||
coverage_list = []
|
||||
for i in range(B):
|
||||
idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
|
||||
coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
|
||||
coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
|
||||
samples = coverage
|
||||
|
||||
# point coordinates
|
||||
rows_int = samples // W # 0 for first row, H-1 for last row
|
||||
rows_float = rows_int / float(H-1) # 0 to 1.0
|
||||
rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0
|
||||
|
||||
cols_int = samples % W # 0 for first column, W-1 for last column
|
||||
cols_float = cols_int / float(W-1) # 0 to 1.0
|
||||
cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0
|
||||
|
||||
point_coords = torch.zeros(B, 1, N, 2)
|
||||
point_coords[:, 0, :, 0] = cols_float # x coord
|
||||
point_coords[:, 0, :, 1] = rows_float # y coord
|
||||
point_coords = point_coords.to(device)
|
||||
return point_coords, rows_int, cols_int
|
||||
Reference in New Issue
Block a user