Compare commits

...

106 Commits

Author SHA1 Message Date
Ryan Dick
3ce02be1ea Add links to test models for loha, lokr, ia3. 2024-09-12 15:43:05 +00:00
Ryan Dick
53da78f1ac Update all lycoris layer types to use the new torch.nn.Module base class. 2024-09-12 15:43:05 +00:00
Ryan Dick
7246e007b7 Assume LoRA alpha=8 for FLUX diffusers PEFT LoRAs. 2024-09-12 15:43:05 +00:00
Ryan Dick
5d081aad54 Get diffusers FLUX LoRA working as sidecar patch on quantized model. 2024-09-12 15:43:05 +00:00
Ryan Dick
5b41d74bce WIP - Implement sidecar LoRA layers using functional API. 2024-09-12 15:43:05 +00:00
Ryan Dick
ef6507d9bb Bug fixes to get LoRA sidecar patching working for the first time. 2024-09-12 15:43:05 +00:00
Ryan Dick
bf9a661303 WIP - LoRA sidecar layers. 2024-09-12 15:43:05 +00:00
Ryan Dick
563b9d7713 WIP - adding LoRA sidecar layers 2024-09-12 15:43:05 +00:00
Ryan Dick
ed08c88f78 Add util functions calc_tensor_size(...) and calc_tensors_size(...). 2024-09-12 15:43:05 +00:00
Ryan Dick
d45396b58f Remove unused layer_key property from LoRALayerBase. 2024-09-12 15:43:05 +00:00
Ryan Dick
35aeb81e08 Consolidate all LoRA patching logic in the LoraPatcher. 2024-09-12 15:43:05 +00:00
Ryan Dick
41fe5da2cc Rename peft -> lora in a bunch of places. 2024-09-12 15:43:05 +00:00
Ryan Dick
e33a66c675 Rename lora.py -> lora_model_raw.py. 2024-09-12 15:43:05 +00:00
Ryan Dick
676baf01a9 Rename peft/ -> lora/ 2024-09-12 15:43:05 +00:00
Ryan Dick
f3aeff58d4 Genera cleanup/documentation. 2024-09-12 15:43:05 +00:00
Ryan Dick
4ea392fe82 Add a check that all keys are handled in the FLUX Diffusers LoRA loading code. 2024-09-12 15:43:05 +00:00
Ryan Dick
5519421760 Add model probe support for FLUX LoRA models in Diffusers format. 2024-09-12 15:43:05 +00:00
Ryan Dick
9ca8cedf62 Add utility test function for creating a dummy state_dict. 2024-09-12 15:43:05 +00:00
Ryan Dick
684e4bfdfb Add is_state_dict_likely_in_flux_diffusers_format(...) function with unit test. 2024-09-12 15:43:05 +00:00
Ryan Dick
fa6f6ef733 Add unit test for lora_model_from_flux_diffusers_state_dict(...). 2024-09-12 15:43:05 +00:00
Ryan Dick
95d9a46860 First draft of lora_model_from_flux_diffusers_state_dict(...). 2024-09-12 15:43:05 +00:00
Ryan Dick
b6c0d1f791 (minor) Rename test file. 2024-09-12 15:43:05 +00:00
Ryan Dick
67240d2cec Add ConcatenateLoRALayer class. 2024-09-12 15:43:05 +00:00
Ryan Dick
f2586fb15b WIP on supporting diffusers format FLUX LoRAs. 2024-09-12 15:43:05 +00:00
Ryan Dick
021b523a96 Rename flux_kohya_lora_conversion_utils.py 2024-09-12 15:43:05 +00:00
Ryan Dick
281971c19e Fixup FLUX LoRA unit tests. 2024-09-12 15:43:05 +00:00
Ryan Dick
d9b5b4907b WIP 2024-09-12 15:43:05 +00:00
Ryan Dick
3fceb81ffc WIP - add invocations to support FLUX LORAs. 2024-09-12 15:43:05 +00:00
Ryan Dick
183bec2b5c Get probing of FLUX LoRA kohya models working. 2024-09-12 15:43:05 +00:00
Ryan Dick
77e6f06124 Add utility function for detecting whether a state_dict is in the FLUX kohya LoRA format. 2024-09-12 15:43:05 +00:00
Ryan Dick
d143fc057a Update convert_flux_kohya_state_dict_to_invoke_format() to raise an exception if an unexpected key is encountered, and add a corresponding unit test. 2024-09-12 15:43:05 +00:00
Ryan Dick
77d9f69264 Move the responsibilities of 1) state_dict loading from file, and 2) SDXL lora key conversions, out of LoRAModelRaw and into LoRALoader. 2024-09-12 15:43:05 +00:00
Ryan Dick
d853fe89d4 Remove unused LoRAModelRaw.name attribute. 2024-09-12 15:43:05 +00:00
Ryan Dick
67b00151ec Fix type errors in sdxl_conversion_utils.py 2024-09-12 15:43:05 +00:00
Ryan Dick
acaec3c242 Start moving SDXL-specific LoRA conversions out of the general-purpose LoRAModelRaw class. 2024-09-12 15:43:05 +00:00
Ryan Dick
8bf2f02e75 Get convert_flux_kohya_state_dict_to_invoke_format(...) working, with unit tests. 2024-09-12 15:43:05 +00:00
Ryan Dick
c1d7f797c0 WIP - FLUX LoRA conversion logic. 2024-09-12 15:43:04 +00:00
Ryan Dick
3f7571c709 Add state_dict keys for two FLUX LoRA formats to be used in unit tests. 2024-09-12 15:43:04 +00:00
Ryan Dick
fb4e0ecaed Move lora.py to peft/ subdir. 2024-09-12 15:43:04 +00:00
Ryan Dick
1105833124 Split PEFT layer implementations into separate files. 2024-09-12 15:43:04 +00:00
psychedelicious
2622f7dc02 chore: release v5.0.0.a4 2024-09-13 00:04:07 +10:00
psychedelicious
1a598873de fix(ui): show send to toggle on canvas only 2024-09-12 23:42:21 +10:00
psychedelicious
12cab9fc31 revert(ui): miniviewer
toodles
2024-09-12 23:42:21 +10:00
psychedelicious
6fb8e45761 feat(ui): do not show canvas progress in viewer 2024-09-12 23:42:21 +10:00
psychedelicious
637960d67e fix(ui): remove unused setting, fix missing translation for alerts 2024-09-12 23:42:21 +10:00
psychedelicious
d2ab668fa0 revert(ui): remove post-generation toasts 2024-09-12 23:42:21 +10:00
psychedelicious
82df16d8ce feat(ui): animations for send to alerts 2024-09-12 23:42:21 +10:00
psychedelicious
dd3013d333 feat(ui): alerts display depending on current generation destination 2024-09-12 23:42:21 +10:00
psychedelicious
269db8ae19 feat(ui): remove toasts when toggling send to 2024-09-12 23:42:21 +10:00
psychedelicious
30ea852761 feat(ui): restore viewer
- Remove gallery tab
- Restore viewer
- Add configurable alerts & toasts when user may be lost
2024-09-12 23:42:21 +10:00
psychedelicious
c03f80b19c feat(ui): use <Alert/> for selected entity alerts 2024-09-12 23:42:21 +10:00
psychedelicious
96930055e2 fix(ui): select first image instead of clearing selection fully
Fixes an issue where you end up w/ the no image fallback after pressing escape.
2024-09-12 23:42:21 +10:00
psychedelicious
5fa7f0154f build(ui): bump @invoke-ai/ui-library
This gets us access to the Alert component.
2024-09-12 23:42:21 +10:00
psychedelicious
ab0e9dfcad chore: release v5.0.0.a3 2024-09-12 08:46:17 +10:00
psychedelicious
88dcb388dc feat(ui): pull bbox into functionality for control/ip adapters 2024-09-11 08:12:48 -04:00
psychedelicious
5a89bf841f feat(ui): drop image on layer to replace it 2024-09-11 08:12:48 -04:00
psychedelicious
5b8707a74f feat(ui): entityRasterized action only needs position, not rect
This makes it a bit easier to call the action
2024-09-11 08:12:48 -04:00
psychedelicious
cfb538bdc2 feat(ui): add filter button next to control adapter model 2024-09-11 08:12:48 -04:00
psychedelicious
9f06a9b03c feat(ui): use revised filters
- Add backcompat for cnet model default settings
- Default filter selection based on model type
- Updated UI components to use new filter nodes
- Added handling for failed filter executions, preventing filter from getting stuck in case it failed for some reason
- New translations for all filters & fields
2024-09-11 08:12:48 -04:00
psychedelicious
561db0751b fix(ui): progress bar/queue count race condition 2024-09-11 08:12:48 -04:00
psychedelicious
248e4a81b2 fix(nodes): handle no detected line segments 2024-09-11 08:12:48 -04:00
psychedelicious
b6aba92426 fix(nodes): MLSD needs inputs to be multiples of 64 2024-09-11 08:12:48 -04:00
psychedelicious
7d15f9381d chore(ui): typegen 2024-09-11 08:12:48 -04:00
psychedelicious
4f2fc65257 tidy(nodes): MLSDEdgeDetection -> MLSDDetection
It's a line segment detector, not general edge detector.
2024-09-11 08:12:48 -04:00
psychedelicious
68237d357a feat(ui): hide deprecated nodes from add node menu
They will still be usable if a workflow uses one. You just cannot add them directly.
2024-09-11 08:12:48 -04:00
psychedelicious
bb2db3d6c3 feat(ui): improve typing on CanvasEntityAdapterBase
Use a generic to narrow the `type` field from `string` to a literal. Now you can do e.g. `adapter.type === 'control_layer_adapter'` and TS narrows the type.
2024-09-11 08:12:48 -04:00
psychedelicious
ff94146ee8 chore(ui): typegen 2024-09-11 08:12:48 -04:00
psychedelicious
1d09091a67 feat(nodes): add Classification.Deprecated, deprecated old cnet processors 2024-09-11 08:12:48 -04:00
psychedelicious
ee4c0efbf7 feat(nodes): update pidinet node
Human-readable field names.
2024-09-11 08:12:48 -04:00
psychedelicious
a4250e3ff2 feat(nodes): update mlsd node
Human-readable field names.
2024-09-11 08:12:48 -04:00
psychedelicious
67a234c1bb feat(nodes): update content shuffle node
- Better field names
2024-09-11 08:12:48 -04:00
psychedelicious
420045cb34 feat(nodes): update color map node
- Changed name
- Better field names
2024-09-11 08:12:48 -04:00
psychedelicious
53792fafb3 feat(nodes): add DWOpenposeDetectionInvocation
Similar to the existing node, but without any resizing. The backend logic was consolidated and modified so that it the model loading can be managed by the model manager.

The ONNX Runtime `InferenceSession` class was added to the `AnyModel` union to satisfy the type checker.
2024-09-11 08:12:48 -04:00
psychedelicious
615eddea6f feat(nodes): add PiDiNetEdgeDetectionInvocation
Similar to the existing node, but without any resizing and with a revised model loading API that uses the model manager.

All code related to the invocation now lives in the Invoke repo.
2024-09-11 08:12:48 -04:00
psychedelicious
b3d60bd56a feat(nodes): add NormalMapInvocation
Similar to the existing node, but without any resizing and with a revised model loading API that uses the model manager.

All code related to the invocation now lives in the Invoke repo. Unfortunately, this includes a whole git repo for EfficientNet. I believe we could use the package `timm` instead of this, but it's beyond me.
2024-09-11 08:12:48 -04:00
psychedelicious
fd42da5a36 feat(nodes): add MLSDEdgeDetectionInvocation
Similar to the existing node, but without any resizing and with a revised model loading API that uses the model manager.

All code related to the invocation now lives in the Invoke repo.
2024-09-11 08:12:48 -04:00
psychedelicious
bc55791db1 feat(nodes): add MediaPipeFaceDetectionInvocation
Similar to the existing node, but without any resizing and with a revised model loading API that uses the model manager.

All code related to the invocation now lives in the Invoke repo.
2024-09-11 08:12:48 -04:00
psychedelicious
c5f3297841 feat(nodes): add LineartEdgeDetectionInvocation
Similar to the existing node, but without any resizing and with a revised model loading API that uses the model manager.
2024-09-11 08:12:48 -04:00
psychedelicious
cd2c2a7fde feat(nodes): add LineartAnimeEdgeDetectionInvocation
Similar to the existing node, but without any resizing and with a revised model loading API that uses the model manager.
2024-09-11 08:12:48 -04:00
psychedelicious
1cffcc02a5 feat(nodes): add HEDEdgeDetectionInvocation
Similar to the existing node, but without any resizing and with a revised model loading API that uses the model manager.
2024-09-11 08:12:48 -04:00
psychedelicious
ac9950bdbb feat(nodes): add DepthAnythingDepthEstimationInvocation
Similar to the existing node, but without any resizing and with a revised model loading API.
2024-09-11 08:12:48 -04:00
psychedelicious
059d57f447 feat(nodes): add ContentShuffleInvocation
Similar to the existing node, but without the resolution fields.
2024-09-11 08:12:48 -04:00
psychedelicious
581008b432 feat(nodes): add ColorMapGeneratorInvocation
Similar to the existing node, but without the resolution fields.
2024-09-11 08:12:48 -04:00
psychedelicious
aeaeec9b9d feat(nodes): add CannyEdgeDetectionInvocation
Similar to the existing node, but without the resolution fields.
2024-09-11 08:12:48 -04:00
psychedelicious
301739c4a8 fix(ui): do not reset board search wehn collapsing boards list 2024-09-11 14:15:16 +10:00
psychedelicious
a2e2a31b95 fix(ui): create new resizeObserver when setting stage container
Hopefully this resolves the issue where sometimes the stage misses a resize event and ends up too small until you resize the window again.
2024-09-11 14:15:16 +10:00
psychedelicious
88c276cd09 fix(ui): use default control adapter when converting raster to control layer 2024-09-11 14:15:16 +10:00
psychedelicious
457871af93 chore(ui): lint 2024-09-11 14:15:16 +10:00
psychedelicious
e88d4aa0e8 fix(ui): working panel size persistence 2024-09-11 14:15:16 +10:00
psychedelicious
c8a74f969b feat(ui): make DeleteBoardModal a singleton 2024-09-11 14:15:16 +10:00
psychedelicious
4240817128 fix(ui): invoke button tooltip indicates sendToCanvas 2024-09-11 14:15:16 +10:00
psychedelicious
80877a1f15 fix(ui): disable filter process button when auto-processing 2024-09-11 14:15:16 +10:00
psychedelicious
7fc25e7e01 feat(ui): do not group brush/eraser/rect actions 2024-09-11 14:15:16 +10:00
psychedelicious
9a355c5585 feat(ui): add ctrl+y redo hotkey 2024-09-11 14:15:16 +10:00
psychedelicious
2975ec5467 fix(ui): Layers tab counter only includes active entities
Empty and disabled layers are skipped.
2024-09-11 14:15:16 +10:00
psychedelicious
8ab3b938c1 fix(ui): reset canvas doesn't reset initial inpaint mask fully 2024-09-11 14:15:16 +10:00
psychedelicious
f82640b5df fix(ui): brush size and layer cycle hotkeys conflict
Closes #6829
2024-09-10 09:20:19 -04:00
psychedelicious
e3e50abc5a fix(ui): do not show count on layers tab when no layers 2024-09-10 09:20:19 -04:00
psychedelicious
061bff2814 chore: release v5.0.0.a2 2024-09-10 09:20:19 -04:00
psychedelicious
e5a53be42b feat(ui): add canvas context menu
So far, this includes:
- Save Canvas to Gallery
- Save Bbox to Gallery
- Send Bbox to Regional IP Adapter
- Send Bbox to Global IP Adapter
- Send Bbox to Control Layer
- Send Bbox to Raster Layer
2024-09-10 09:20:19 -04:00
psychedelicious
54c94bd713 chore(ui): bump @invoke-ai/ui-library
Fixes an issue where modifier keys get stuck on when you change tabs or windows.
2024-09-10 09:20:19 -04:00
psychedelicious
8d56becf04 fix(ui): retain global canvas manager instance
To prevent losing all ephemeral canvas stage when switching tabs, we will refrain from destroying the canvas manager instance when its tab unmounts, and use the existing canvas manager instance on mount, if there is one.

One small change required in `CanvasStageModule` - a `setContainer` method to update the konva stage DOM element.
2024-09-10 09:20:19 -04:00
psychedelicious
dc51ccd9a6 feat(ui): simplify canvas component & hook API 2024-09-10 09:20:19 -04:00
psychedelicious
f5eefedc49 feat(ui): add count to layers tab button 2024-09-10 09:20:19 -04:00
psychedelicious
136891ec3d fix(ui): translation string for gallery tab 2024-09-10 09:20:19 -04:00
psychedelicious
c5543e42c7 fix(ui): drag image over tab switches to wrong tab 2024-09-10 09:20:19 -04:00
230 changed files with 16032 additions and 2550 deletions

View File

@@ -60,11 +60,13 @@ class Classification(str, Enum, metaclass=MetaEnum):
- `Stable`: The invocation, including its inputs/outputs and internal logic, is stable. You may build workflows with it, having confidence that they will not break because of a change in this invocation.
- `Beta`: The invocation is not yet stable, but is planned to be stable in the future. Workflows built around this invocation may break, but we are committed to supporting this invocation long-term.
- `Prototype`: The invocation is not yet stable and may be removed from the application at any time. Workflows built around this invocation may break, and we are *not* committed to supporting this invocation.
- `Deprecated`: The invocation is deprecated and may be removed in a future version.
"""
Stable = "stable"
Beta = "beta"
Prototype = "prototype"
Deprecated = "deprecated"
class UIConfigBase(BaseModel):

View File

@@ -0,0 +1,34 @@
import cv2
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.util import cv2_to_pil, pil_to_cv2
@invocation(
"canny_edge_detection",
title="Canny Edge Detection",
tags=["controlnet", "canny"],
category="controlnet",
version="1.0.0",
)
class CannyEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Geneartes an edge map using a cv2's Canny algorithm."""
image: ImageField = InputField(description="The image to process")
low_threshold: int = InputField(
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
)
high_threshold: int = InputField(
default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
np_img = pil_to_cv2(image)
edge_map = cv2.Canny(np_img, self.low_threshold, self.high_threshold)
edge_map_pil = cv2_to_pil(edge_map)
image_dto = context.images.save(image=edge_map_pil)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,41 @@
import cv2
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
@invocation(
"color_map",
title="Color Map",
tags=["controlnet"],
category="controlnet",
version="1.0.0",
)
class ColorMapInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates a color map from the provided image."""
image: ImageField = InputField(description="The image to process")
tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
np_image = pil_to_np(image)
height, width = np_image.shape[:2]
width_tile_size = min(self.tile_size, width)
height_tile_size = min(self.tile_size, height)
color_map = cv2.resize(
np_image,
(width // width_tile_size, height // height_tile_size),
interpolation=cv2.INTER_CUBIC,
)
color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST)
color_map_pil = np_to_pil(color_map)
image_dto = context.images.save(image=color_map_pil)
return ImageOutput.build(image_dto)

View File

@@ -19,7 +19,8 @@ from invokeai.app.invocations.model import CLIPField
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoraPatcher
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
@@ -82,9 +83,10 @@ class CompelInvocation(BaseInvocation):
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(
text_encoder,
loras=_lora_loader(),
LoraPatcher.apply_lora_patches(
model=text_encoder,
patches=_lora_loader(),
prefix="lora_te_",
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
@@ -177,9 +179,9 @@ class SDXLPromptInvocationBase:
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora(
LoraPatcher.apply_lora_patches(
text_encoder,
loras=_lora_loader(),
patches=_lora_loader(),
prefix=lora_prefix,
cached_weights=cached_weights,
),

View File

@@ -0,0 +1,25 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.content_shuffle import content_shuffle
@invocation(
"content_shuffle",
title="Content Shuffle",
tags=["controlnet", "normal"],
category="controlnet",
version="1.0.0",
)
class ContentShuffleInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Shuffles the image, similar to a 'liquify' filter."""
image: ImageField = InputField(description="The image to process")
scale_factor: int = InputField(default=256, ge=0, description="The scale factor used for the shuffle")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
output_image = content_shuffle(input_image=image, scale_factor=self.scale_factor)
image_dto = context.images.save(image=output_image)
return ImageOutput.build(image_dto)

View File

@@ -174,6 +174,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
tags=["controlnet", "canny"],
category="controlnet",
version="1.3.3",
classification=Classification.Deprecated,
)
class CannyImageProcessorInvocation(ImageProcessorInvocation):
"""Canny edge detection for ControlNet"""
@@ -208,6 +209,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
tags=["controlnet", "hed", "softedge"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class HedImageProcessorInvocation(ImageProcessorInvocation):
"""Applies HED edge detection to image"""
@@ -237,6 +239,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
tags=["controlnet", "lineart"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class LineartImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art processing to image"""
@@ -259,6 +262,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
tags=["controlnet", "lineart", "anime"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art anime processing to image"""
@@ -282,6 +286,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
tags=["controlnet", "midas"],
category="controlnet",
version="1.2.4",
classification=Classification.Deprecated,
)
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Midas depth processing to image"""
@@ -314,6 +319,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
tags=["controlnet"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies NormalBae processing to image"""
@@ -330,7 +336,12 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
@invocation(
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.3"
"mlsd_image_processor",
title="MLSD Processor",
tags=["controlnet", "mlsd"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
"""Applies MLSD processing to image"""
@@ -353,7 +364,12 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
@invocation(
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.3"
"pidi_image_processor",
title="PIDI Processor",
tags=["controlnet", "pidi"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class PidiImageProcessorInvocation(ImageProcessorInvocation):
"""Applies PIDI processing to image"""
@@ -381,6 +397,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
tags=["controlnet", "contentshuffle"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
"""Applies content shuffle processing to image"""
@@ -411,6 +428,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
tags=["controlnet", "zoe", "depth"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image"""
@@ -427,6 +445,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
tags=["controlnet", "mediapipe", "face"],
category="controlnet",
version="1.2.4",
classification=Classification.Deprecated,
)
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
"""Applies mediapipe face processing to image"""
@@ -454,6 +473,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
tags=["controlnet", "leres", "depth"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class LeresImageProcessorInvocation(ImageProcessorInvocation):
"""Applies leres processing to image"""
@@ -483,6 +503,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
tags=["controlnet", "tile"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
"""Tile resampler processor"""
@@ -523,6 +544,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
tags=["controlnet", "segmentanything"],
category="controlnet",
version="1.2.4",
classification=Classification.Deprecated,
)
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
"""Applies segment anything processing to image"""
@@ -570,6 +592,7 @@ class SamDetectorReproducibleColors(SamDetector):
tags=["controlnet"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a color map from the provided image"""
@@ -609,6 +632,7 @@ DEPTH_ANYTHING_MODELS = {
tags=["controlnet", "depth", "depth anything"],
category="controlnet",
version="1.1.3",
classification=Classification.Deprecated,
)
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a depth map based on the Depth Anything algorithm"""
@@ -643,6 +667,7 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
tags=["controlnet", "dwpose", "openpose"],
category="controlnet",
version="1.1.1",
classification=Classification.Deprecated,
)
class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
"""Generates an openpose pose from an image using DWPose"""

View File

@@ -36,7 +36,8 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoraPatcher
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState
@@ -979,9 +980,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(
unet,
loras=_lora_loader(),
LoraPatcher.apply_lora_patches(
model=unet,
patches=_lora_loader(),
prefix="lora_unet_",
cached_weights=cached_weights,
),
):

View File

@@ -0,0 +1,45 @@
from typing import Literal
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small", "small_v2"]
# DepthAnything V2 Small model is licensed under Apache 2.0 but not the base and large models.
DEPTH_ANYTHING_MODELS = {
"large": "LiheYoung/depth-anything-large-hf",
"base": "LiheYoung/depth-anything-base-hf",
"small": "LiheYoung/depth-anything-small-hf",
"small_v2": "depth-anything/Depth-Anything-V2-Small-hf",
}
@invocation(
"depth_anything_depth_estimation",
title="Depth Anything Depth Estimation",
tags=["controlnet", "depth", "depth anything"],
category="controlnet",
version="1.0.0",
)
class DepthAnythingDepthEstimationInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates a depth map using a Depth Anything model."""
image: ImageField = InputField(description="The image to process")
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
default="small_v2", description="The size of the depth model to use"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
model_url = DEPTH_ANYTHING_MODELS[self.model_size]
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(model_url, DepthAnythingPipeline.load_model)
with loaded_model as depth_anything_detector:
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
depth_map = depth_anything_detector.generate_depth(image)
image_dto = context.images.save(image=depth_map)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,50 @@
import onnxruntime as ort
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector2
@invocation(
"dw_openpose_detection",
title="DW Openpose Detection",
tags=["controlnet", "dwpose", "openpose"],
category="controlnet",
version="1.1.1",
)
class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an openpose pose from an image using DWPose"""
image: ImageField = InputField(description="The image to process")
draw_body: bool = InputField(default=True)
draw_face: bool = InputField(default=False)
draw_hands: bool = InputField(default=False)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
onnx_det_path = context.models.download_and_cache_model(DWOpenposeDetector2.get_model_url_det())
onnx_pose_path = context.models.download_and_cache_model(DWOpenposeDetector2.get_model_url_pose())
loaded_session_det = context.models.load_local_model(
onnx_det_path, DWOpenposeDetector2.create_onnx_inference_session
)
loaded_session_pose = context.models.load_local_model(
onnx_pose_path, DWOpenposeDetector2.create_onnx_inference_session
)
with loaded_session_det as session_det, loaded_session_pose as session_pose:
assert isinstance(session_det, ort.InferenceSession)
assert isinstance(session_pose, ort.InferenceSession)
detector = DWOpenposeDetector2(session_det=session_det, session_pose=session_pose)
detected_image = detector.run(
image,
draw_face=self.draw_face,
draw_hands=self.draw_hands,
draw_body=self.draw_body,
)
image_dto = context.images.save(image=detected_image)
return ImageOutput.build(image_dto)

View File

@@ -1,4 +1,5 @@
from typing import Callable, Optional
from contextlib import ExitStack
from typing import Callable, Iterator, Optional, Tuple
import torch
import torchvision.transforms as tv_transforms
@@ -29,6 +30,9 @@ from invokeai.backend.flux.sampling_utils import (
pack,
unpack,
)
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoraPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@@ -187,9 +191,40 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
noise=noise,
)
with transformer_info as transformer:
with (
transformer_info.model_on_device() as (cached_weights, transformer),
ExitStack() as exit_stack,
):
assert isinstance(transformer, Flux)
config = transformer_info.config
assert config is not None
# Apply LoRA models to the transformer.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
if config.format in [ModelFormat.Checkpoint]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoraPatcher.apply_lora_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix="",
cached_weights=cached_weights,
)
)
elif config.format in [ModelFormat.BnbQuantizedLlmInt8b, ModelFormat.BnbQuantizednf4b]:
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
# than directly patching the weights, but is agnostic to the quantization format.
exit_stack.enter_context(
LoraPatcher.apply_lora_sidecar_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix="",
)
)
else:
raise ValueError(f"Unsupported model format: {config.format}")
x = denoise(
model=transformer,
img=x,
@@ -241,6 +276,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# `latents`.
return mask.expand_as(latents)
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()

View File

@@ -0,0 +1,53 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("flux_lora_loader_output")
class FluxLoRALoaderOutput(BaseInvocationOutput):
"""FLUX LoRA Loader Output"""
transformer: TransformerField = OutputField(
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
)
@invocation(
"flux_lora_loader",
title="FLUX LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.0.0",
)
class FluxLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a FLUX transformer."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
transformer: TransformerField = InputField(
description=FieldDescriptions.transformer,
input=Input.Connection,
title="FLUX Transformer",
)
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")
if any(lora.lora.key == lora_key for lora in self.transformer.loras):
raise Exception(f'LoRA "{lora_key}" already applied to transformer.')
transformer = self.transformer.model_copy(deep=True)
transformer.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
return FluxLoRALoaderOutput(transformer=transformer)

View File

@@ -0,0 +1,33 @@
from builtins import bool
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.hed import ControlNetHED_Apache2, HEDEdgeDetector
@invocation(
"hed_edge_detection",
title="HED Edge Detection",
tags=["controlnet", "hed", "softedge"],
category="controlnet",
version="1.0.0",
)
class HEDEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Geneartes an edge map using the HED (softedge) model."""
image: ImageField = InputField(description="The image to process")
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(HEDEdgeDetector.get_model_url(), HEDEdgeDetector.load_model)
with loaded_model as model:
assert isinstance(model, ControlNetHED_Apache2)
hed_processor = HEDEdgeDetector(model)
edge_map = hed_processor.run(image=image, scribble=self.scribble)
image_dto = context.images.save(image=edge_map)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,34 @@
from builtins import bool
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.lineart import Generator, LineartEdgeDetector
@invocation(
"lineart_edge_detection",
title="Lineart Edge Detection",
tags=["controlnet", "lineart"],
category="controlnet",
version="1.0.0",
)
class LineartEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an edge map using the Lineart model."""
image: ImageField = InputField(description="The image to process")
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
model_url = LineartEdgeDetector.get_model_url(self.coarse)
loaded_model = context.models.load_remote_model(model_url, LineartEdgeDetector.load_model)
with loaded_model as model:
assert isinstance(model, Generator)
detector = LineartEdgeDetector(model)
edge_map = detector.run(image=image)
image_dto = context.images.save(image=edge_map)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,31 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.lineart_anime import LineartAnimeEdgeDetector, UnetGenerator
@invocation(
"lineart_anime_edge_detection",
title="Lineart Anime Edge Detection",
tags=["controlnet", "lineart"],
category="controlnet",
version="1.0.0",
)
class LineartAnimeEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Geneartes an edge map using the Lineart model."""
image: ImageField = InputField(description="The image to process")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
model_url = LineartAnimeEdgeDetector.get_model_url()
loaded_model = context.models.load_remote_model(model_url, LineartAnimeEdgeDetector.load_model)
with loaded_model as model:
assert isinstance(model, UnetGenerator)
detector = LineartAnimeEdgeDetector(model)
edge_map = detector.run(image=image)
image_dto = context.images.save(image=edge_map)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,26 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.mediapipe_face import detect_faces
@invocation(
"mediapipe_face_detection",
title="MediaPipe Face Detection",
tags=["controlnet", "face"],
category="controlnet",
version="1.0.0",
)
class MediaPipeFaceDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Detects faces using MediaPipe."""
image: ImageField = InputField(description="The image to process")
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
detected_faces = detect_faces(image=image, max_faces=self.max_faces, min_confidence=self.min_confidence)
image_dto = context.images.save(image=detected_faces)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,39 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.mlsd import MLSDDetector
from invokeai.backend.image_util.mlsd.models.mbv2_mlsd_large import MobileV2_MLSD_Large
@invocation(
"mlsd_detection",
title="MLSD Detection",
tags=["controlnet", "mlsd", "edge"],
category="controlnet",
version="1.0.0",
)
class MLSDDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an line segment map using MLSD."""
image: ImageField = InputField(description="The image to process")
score_threshold: float = InputField(
default=0.1, ge=0, description="The threshold used to score points when determining line segments"
)
distance_threshold: float = InputField(
default=20.0,
ge=0,
description="Threshold for including a line segment - lines shorter than this distance will be discarded",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(MLSDDetector.get_model_url(), MLSDDetector.load_model)
with loaded_model as model:
assert isinstance(model, MobileV2_MLSD_Large)
detector = MLSDDetector(model)
edge_map = detector.run(image, self.score_threshold, self.distance_threshold)
image_dto = context.images.save(image=edge_map)
return ImageOutput.build(image_dto)

View File

@@ -69,6 +69,7 @@ class CLIPField(BaseModel):
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class T5EncoderField(BaseModel):
@@ -202,7 +203,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer),
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),

View File

@@ -0,0 +1,31 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.normal_bae import NormalMapDetector
from invokeai.backend.image_util.normal_bae.nets.NNET import NNET
@invocation(
"normal_map",
title="Normal Map",
tags=["controlnet", "normal"],
category="controlnet",
version="1.0.0",
)
class NormalMapInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates a normal map."""
image: ImageField = InputField(description="The image to process")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(NormalMapDetector.get_model_url(), NormalMapDetector.load_model)
with loaded_model as model:
assert isinstance(model, NNET)
detector = NormalMapDetector(model)
normal_map = detector.run(image=image)
image_dto = context.images.save(image=normal_map)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,33 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.pidi import PIDINetDetector
from invokeai.backend.image_util.pidi.model import PiDiNet
@invocation(
"pidi_edge_detection",
title="PiDiNet Edge Detection",
tags=["controlnet", "edge"],
category="controlnet",
version="1.0.0",
)
class PiDiNetEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an edge map using PiDiNet."""
image: ImageField = InputField(description="The image to process")
quantize_edges: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(PIDINetDetector.get_model_url(), PIDINetDetector.load_model)
with loaded_model as model:
assert isinstance(model, PiDiNet)
detector = PIDINetDetector(model)
edge_map = detector.run(image=image, quantize_edges=self.quantize_edges, scribble=self.scribble)
image_dto = context.images.save(image=edge_map)
return ImageOutput.build(image_dto)

View File

@@ -22,8 +22,8 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoraPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
@@ -204,7 +204,11 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
with (
ExitStack() as exit_stack,
unet_info as unet,
LoraPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:

View File

@@ -0,0 +1,40 @@
# Adapted from https://github.com/huggingface/controlnet_aux
import cv2
import numpy as np
from PIL import Image
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
def make_noise_disk(H, W, C, F):
noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
noise = noise[F : F + H, F : F + W]
noise -= np.min(noise)
noise /= np.max(noise)
if C == 1:
noise = noise[:, :, None]
return noise
def content_shuffle(input_image: Image.Image, scale_factor: int | None = None) -> Image.Image:
"""Shuffles the content of an image using a disk noise pattern, similar to a 'liquify' effect."""
np_img = pil_to_np(input_image)
height, width, _channels = np_img.shape
if scale_factor is None:
scale_factor = 256
x = make_noise_disk(height, width, 1, scale_factor) * float(width - 1)
y = make_noise_disk(height, width, 1, scale_factor) * float(height - 1)
flow = np.concatenate([x, y], axis=2).astype(np.float32)
shuffled_img = cv2.remap(np_img, flow, None, cv2.INTER_LINEAR)
output_img = np_to_pil(shuffled_img)
return output_img

View File

@@ -1,7 +1,9 @@
import pathlib
from typing import Optional
import torch
from PIL import Image
from transformers import pipeline
from transformers.pipelines import DepthEstimationPipeline
from invokeai.backend.raw_model import RawModel
@@ -29,3 +31,11 @@ class DepthAnythingPipeline(RawModel):
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self._pipeline.model)
@classmethod
def load_model(cls, model_path: pathlib.Path):
"""Load the model from the given path and return a DepthAnythingPipeline instance."""
depth_anything_pipeline = pipeline(model=str(model_path), task="depth-estimation", local_files_only=True)
assert isinstance(depth_anything_pipeline, DepthEstimationPipeline)
return cls(depth_anything_pipeline)

View File

@@ -1,13 +1,19 @@
from pathlib import Path
from typing import Dict
import huggingface_hub
import numpy as np
import onnxruntime as ort
import torch
from controlnet_aux.util import resize_image
from PIL import Image
from invokeai.backend.image_util.dw_openpose.onnxdet import inference_detector
from invokeai.backend.image_util.dw_openpose.onnxpose import inference_pose
from invokeai.backend.image_util.dw_openpose.utils import NDArrayInt, draw_bodypose, draw_facepose, draw_handpose
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
from invokeai.backend.image_util.util import np_to_pil
from invokeai.backend.util.devices import TorchDevice
DWPOSE_MODELS = {
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
@@ -109,4 +115,142 @@ class DWOpenposeDetector:
)
__all__ = ["DWPOSE_MODELS", "DWOpenposeDetector"]
class DWOpenposeDetector2:
"""
Code from the original implementation of the DW Openpose Detector.
Credits: https://github.com/IDEA-Research/DWPose
This implementation is similar to DWOpenposeDetector, with some alterations to allow the onnx models to be loaded
and managed by the model manager.
"""
hf_repo_id = "yzd-v/DWPose"
hf_filename_onnx_det = "yolox_l.onnx"
hf_filename_onnx_pose = "dw-ll_ucoco_384.onnx"
@classmethod
def get_model_url_det(cls) -> str:
"""Returns the URL for the detection model."""
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename_onnx_det)
@classmethod
def get_model_url_pose(cls) -> str:
"""Returns the URL for the pose model."""
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename_onnx_pose)
@staticmethod
def create_onnx_inference_session(model_path: Path) -> ort.InferenceSession:
"""Creates an ONNX Inference Session for the given model path, using the appropriate execution provider based on
the device type."""
device = TorchDevice.choose_torch_device()
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
return ort.InferenceSession(path_or_bytes=model_path, providers=providers)
def __init__(self, session_det: ort.InferenceSession, session_pose: ort.InferenceSession):
self.session_det = session_det
self.session_pose = session_pose
def pose_estimation(self, np_image: np.ndarray):
"""Does the pose estimation on the given image and returns the keypoints and scores."""
det_result = inference_detector(self.session_det, np_image)
keypoints, scores = inference_pose(self.session_pose, det_result, np_image)
keypoints_info = np.concatenate((keypoints, scores[..., None]), axis=-1)
# compute neck joint
neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
# neck score when visualizing pred
neck[:, 2:4] = np.logical_and(keypoints_info[:, 5, 2:4] > 0.3, keypoints_info[:, 6, 2:4] > 0.3).astype(int)
new_keypoints_info = np.insert(keypoints_info, 17, neck, axis=1)
mmpose_idx = [17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3]
openpose_idx = [1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17]
new_keypoints_info[:, openpose_idx] = new_keypoints_info[:, mmpose_idx]
keypoints_info = new_keypoints_info
keypoints, scores = keypoints_info[..., :2], keypoints_info[..., 2]
return keypoints, scores
def run(
self,
image: Image.Image,
draw_face: bool = False,
draw_body: bool = True,
draw_hands: bool = False,
) -> Image.Image:
"""Detects the pose in the given image and returns an solid black image with pose drawn on top, suitable for
use with a ControlNet."""
np_image = np.array(image)
H, W, C = np_image.shape
with torch.no_grad():
candidate, subset = self.pose_estimation(np_image)
nums, keys, locs = candidate.shape
candidate[..., 0] /= float(W)
candidate[..., 1] /= float(H)
body = candidate[:, :18].copy()
body = body.reshape(nums * 18, locs)
score = subset[:, :18]
for i in range(len(score)):
for j in range(len(score[i])):
if score[i][j] > 0.3:
score[i][j] = int(18 * i + j)
else:
score[i][j] = -1
un_visible = subset < 0.3
candidate[un_visible] = -1
# foot = candidate[:, 18:24]
faces = candidate[:, 24:92]
hands = candidate[:, 92:113]
hands = np.vstack([hands, candidate[:, 113:]])
bodies = {"candidate": body, "subset": score}
pose = {"bodies": bodies, "hands": hands, "faces": faces}
return DWOpenposeDetector2.draw_pose(
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body
)
@staticmethod
def draw_pose(
pose: Dict[str, NDArrayInt | Dict[str, NDArrayInt]],
H: int,
W: int,
draw_face: bool = True,
draw_body: bool = True,
draw_hands: bool = True,
) -> Image.Image:
"""Draws the pose on a black image and returns it as a PIL Image."""
bodies = pose["bodies"]
faces = pose["faces"]
hands = pose["hands"]
assert isinstance(bodies, dict)
candidate = bodies["candidate"]
assert isinstance(bodies, dict)
subset = bodies["subset"]
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
if draw_body:
canvas = draw_bodypose(canvas, candidate, subset)
if draw_hands:
assert isinstance(hands, np.ndarray)
canvas = draw_handpose(canvas, hands)
if draw_face:
assert isinstance(hands, np.ndarray)
canvas = draw_facepose(canvas, faces) # type: ignore
dwpose_image = np_to_pil(canvas)
return dwpose_image

View File

@@ -1,6 +1,9 @@
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
# Adapted from https://github.com/huggingface/controlnet_aux
import pathlib
import cv2
import huggingface_hub
import numpy as np
import torch
from einops import rearrange
@@ -140,3 +143,74 @@ class HEDProcessor:
detected_map[detected_map < 255] = 0
return np_to_pil(detected_map)
class HEDEdgeDetector:
"""Simple wrapper around the HED model for detecting edges in an image."""
hf_repo_id = "lllyasviel/Annotators"
hf_filename = "ControlNetHED.pth"
def __init__(self, model: ControlNetHED_Apache2):
self.model = model
@classmethod
def get_model_url(cls) -> str:
"""Get the URL to download the model from the Hugging Face Hub."""
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename)
@classmethod
def load_model(cls, model_path: pathlib.Path) -> ControlNetHED_Apache2:
"""Load the model from a file."""
model = ControlNetHED_Apache2()
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.float().eval()
return model
def to(self, device: torch.device):
self.model.to(device)
return self
def run(self, image: Image.Image, safe: bool = False, scribble: bool = False) -> Image.Image:
"""Processes an image and returns the detected edges.
Args:
image: The input image.
safe: Whether to apply safe step to the detected edges.
scribble: Whether to apply non-maximum suppression and Gaussian blur to the detected edges.
Returns:
The detected edges.
"""
device = next(iter(self.model.parameters())).device
np_image = pil_to_np(image)
height, width, _channels = np_image.shape
with torch.no_grad():
image_hed = torch.from_numpy(np_image.copy()).float().to(device)
image_hed = rearrange(image_hed, "h w c -> 1 c h w")
edges = self.model(image_hed)
edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
edges = [cv2.resize(e, (width, height), interpolation=cv2.INTER_LINEAR) for e in edges]
edges = np.stack(edges, axis=2)
edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
if safe:
edge = safe_step(edge)
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
detected_map = edge
detected_map = cv2.resize(detected_map, (width, height), interpolation=cv2.INTER_LINEAR)
if scribble:
detected_map = nms(detected_map, 127, 3.0)
detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
detected_map[detected_map > 4] = 255
detected_map[detected_map < 255] = 0
output = np_to_pil(detected_map)
return output

View File

@@ -1,6 +1,9 @@
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
import pathlib
import cv2
import huggingface_hub
import numpy as np
import torch
import torch.nn as nn
@@ -156,3 +159,63 @@ class LineartProcessor:
detected_map = 255 - detected_map
return np_to_pil(detected_map)
class LineartEdgeDetector:
"""Simple wrapper around the fine and coarse lineart models for detecting edges in an image."""
hf_repo_id = "lllyasviel/Annotators"
hf_filename_fine = "sk_model.pth"
hf_filename_coarse = "sk_model2.pth"
@classmethod
def get_model_url(cls, coarse: bool = False) -> str:
"""Get the URL to download the model from the Hugging Face Hub."""
if coarse:
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename_coarse)
else:
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename_fine)
@classmethod
def load_model(cls, model_path: pathlib.Path) -> Generator:
"""Load the model from a file."""
model = Generator(3, 1, 3)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.float().eval()
return model
def __init__(self, model: Generator) -> None:
self.model = model
def to(self, device: torch.device):
self.model.to(device)
return self
def run(self, image: Image.Image) -> Image.Image:
"""Detects edges in the input image with the selected lineart model.
Args:
input: The input image.
coarse: Whether to use the coarse model.
Returns:
The detected edges.
"""
device = next(iter(self.model.parameters())).device
np_image = pil_to_np(image)
with torch.no_grad():
np_image = torch.from_numpy(np_image).float().to(device)
np_image = np_image / 255.0
np_image = rearrange(np_image, "h w c -> 1 c h w")
line = self.model(np_image)[0][0]
line = line.cpu().numpy()
line = (line * 255.0).clip(0, 255).astype(np.uint8)
detected_map = line
detected_map = 255 - detected_map
return np_to_pil(detected_map)

View File

@@ -1,9 +1,11 @@
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
import functools
import pathlib
from typing import Optional
import cv2
import huggingface_hub
import numpy as np
import torch
import torch.nn as nn
@@ -201,3 +203,65 @@ class LineartAnimeProcessor:
detected_map = 255 - detected_map
return np_to_pil(detected_map)
class LineartAnimeEdgeDetector:
"""Simple wrapper around the Lineart Anime model for detecting edges in an image."""
hf_repo_id = "lllyasviel/Annotators"
hf_filename = "netG.pth"
@classmethod
def get_model_url(cls) -> str:
"""Get the URL to download the model from the Hugging Face Hub."""
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename)
@classmethod
def load_model(cls, model_path: pathlib.Path) -> UnetGenerator:
"""Load the model from a file."""
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
model = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
ckpt = torch.load(model_path)
for key in list(ckpt.keys()):
if "module." in key:
ckpt[key.replace("module.", "")] = ckpt[key]
del ckpt[key]
model.load_state_dict(ckpt)
model.eval()
return model
def __init__(self, model: UnetGenerator) -> None:
self.model = model
def to(self, device: torch.device):
self.model.to(device)
return self
def run(self, image: Image.Image) -> Image.Image:
"""Processes an image and returns the detected edges."""
device = next(iter(self.model.parameters())).device
np_image = pil_to_np(image)
height, width, _channels = np_image.shape
new_height = 256 * int(np.ceil(float(height) / 256.0))
new_width = 256 * int(np.ceil(float(width) / 256.0))
resized_img = cv2.resize(np_image, (new_width, new_height), interpolation=cv2.INTER_CUBIC)
with torch.no_grad():
image_feed = torch.from_numpy(resized_img).float().to(device)
image_feed = image_feed / 127.5 - 1.0
image_feed = rearrange(image_feed, "h w c -> 1 c h w")
line = self.model(image_feed)[0, 0] * 127.5 + 127.5
line = line.cpu().numpy()
line = cv2.resize(line, (width, height), interpolation=cv2.INTER_CUBIC)
line = line.clip(0, 255).astype(np.uint8)
detected_map = line
detected_map = 255 - detected_map
output = np_to_pil(detected_map)
return output

View File

@@ -0,0 +1,15 @@
# Adapted from https://github.com/huggingface/controlnet_aux
from PIL import Image
from invokeai.backend.image_util.mediapipe_face.mediapipe_face_common import generate_annotation
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
def detect_faces(image: Image.Image, max_faces: int = 1, min_confidence: float = 0.5) -> Image.Image:
"""Detects faces in an image using MediaPipe."""
np_img = pil_to_np(image)
detected_map = generate_annotation(np_img, max_faces, min_confidence)
detected_map_pil = np_to_pil(detected_map)
return detected_map_pil

View File

@@ -0,0 +1,149 @@
from typing import Mapping
import mediapipe as mp
import numpy
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
mp_face_detection = mp.solutions.face_detection # Only for counting faces.
mp_face_mesh = mp.solutions.face_mesh
mp_face_connections = mp.solutions.face_mesh_connections.FACEMESH_TESSELATION
mp_hand_connections = mp.solutions.hands_connections.HAND_CONNECTIONS
mp_body_connections = mp.solutions.pose_connections.POSE_CONNECTIONS
DrawingSpec = mp.solutions.drawing_styles.DrawingSpec
PoseLandmark = mp.solutions.drawing_styles.PoseLandmark
min_face_size_pixels: int = 64
f_thick = 2
f_rad = 1
right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad)
right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad)
right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad)
left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad)
left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad)
left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad)
mouth_draw = DrawingSpec(color=(10, 180, 10), thickness=f_thick, circle_radius=f_rad)
head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad)
# mp_face_mesh.FACEMESH_CONTOURS has all the items we care about.
face_connection_spec = {}
for edge in mp_face_mesh.FACEMESH_FACE_OVAL:
face_connection_spec[edge] = head_draw
for edge in mp_face_mesh.FACEMESH_LEFT_EYE:
face_connection_spec[edge] = left_eye_draw
for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW:
face_connection_spec[edge] = left_eyebrow_draw
# for edge in mp_face_mesh.FACEMESH_LEFT_IRIS:
# face_connection_spec[edge] = left_iris_draw
for edge in mp_face_mesh.FACEMESH_RIGHT_EYE:
face_connection_spec[edge] = right_eye_draw
for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW:
face_connection_spec[edge] = right_eyebrow_draw
# for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS:
# face_connection_spec[edge] = right_iris_draw
for edge in mp_face_mesh.FACEMESH_LIPS:
face_connection_spec[edge] = mouth_draw
iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw}
def draw_pupils(image, landmark_list, drawing_spec, halfwidth: int = 2):
"""We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all
landmarks. Until our PR is merged into mediapipe, we need this separate method."""
if len(image.shape) != 3:
raise ValueError("Input image must be H,W,C.")
image_rows, image_cols, image_channels = image.shape
if image_channels != 3: # BGR channels
raise ValueError("Input image must contain three channel bgr data.")
for idx, landmark in enumerate(landmark_list.landmark):
if (landmark.HasField("visibility") and landmark.visibility < 0.9) or (
landmark.HasField("presence") and landmark.presence < 0.5
):
continue
if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0:
continue
image_x = int(image_cols * landmark.x)
image_y = int(image_rows * landmark.y)
draw_color = None
if isinstance(drawing_spec, Mapping):
if drawing_spec.get(idx) is None:
continue
else:
draw_color = drawing_spec[idx].color
elif isinstance(drawing_spec, DrawingSpec):
draw_color = drawing_spec.color
image[image_y - halfwidth : image_y + halfwidth, image_x - halfwidth : image_x + halfwidth, :] = draw_color
def reverse_channels(image):
"""Given a numpy array in RGB form, convert to BGR. Will also convert from BGR to RGB."""
# im[:,:,::-1] is a neat hack to convert BGR to RGB by reversing the indexing order.
# im[:,:,::[2,1,0]] would also work but makes a copy of the data.
return image[:, :, ::-1]
def generate_annotation(img_rgb, max_faces: int, min_confidence: float):
"""
Find up to 'max_faces' inside the provided input image.
If min_face_size_pixels is provided and nonzero it will be used to filter faces that occupy less than this many
pixels in the image.
"""
with mp_face_mesh.FaceMesh(
static_image_mode=True,
max_num_faces=max_faces,
refine_landmarks=True,
min_detection_confidence=min_confidence,
) as facemesh:
img_height, img_width, img_channels = img_rgb.shape
assert img_channels == 3
results = facemesh.process(img_rgb).multi_face_landmarks
if results is None:
print("No faces detected in controlnet image for Mediapipe face annotator.")
return numpy.zeros_like(img_rgb)
# Filter faces that are too small
filtered_landmarks = []
for lm in results:
landmarks = lm.landmark
face_rect = [
landmarks[0].x,
landmarks[0].y,
landmarks[0].x,
landmarks[0].y,
] # Left, up, right, down.
for i in range(len(landmarks)):
face_rect[0] = min(face_rect[0], landmarks[i].x)
face_rect[1] = min(face_rect[1], landmarks[i].y)
face_rect[2] = max(face_rect[2], landmarks[i].x)
face_rect[3] = max(face_rect[3], landmarks[i].y)
if min_face_size_pixels > 0:
face_width = abs(face_rect[2] - face_rect[0])
face_height = abs(face_rect[3] - face_rect[1])
face_width_pixels = face_width * img_width
face_height_pixels = face_height * img_height
face_size = min(face_width_pixels, face_height_pixels)
if face_size >= min_face_size_pixels:
filtered_landmarks.append(lm)
else:
filtered_landmarks.append(lm)
# Annotations are drawn in BGR for some reason, but we don't need to flip a zero-filled image at the start.
empty = numpy.zeros_like(img_rgb)
# Draw detected faces:
for face_landmarks in filtered_landmarks:
mp_drawing.draw_landmarks(
empty,
face_landmarks,
connections=face_connection_spec.keys(),
landmark_drawing_spec=None,
connection_drawing_spec=face_connection_spec,
)
draw_pupils(empty, face_landmarks, iris_landmark_spec, 2)
# Flip BGR back to RGB.
empty = reverse_channels(empty).copy()
return empty

View File

@@ -0,0 +1,66 @@
# Adapted from https://github.com/huggingface/controlnet_aux
import pathlib
import cv2
import huggingface_hub
import numpy as np
import torch
from PIL import Image
from invokeai.backend.image_util.mlsd.models.mbv2_mlsd_large import MobileV2_MLSD_Large
from invokeai.backend.image_util.mlsd.utils import pred_lines
from invokeai.backend.image_util.util import np_to_pil, pil_to_np, resize_to_multiple
class MLSDDetector:
"""Simple wrapper around a MLSD model for detecting edges as line segments in an image."""
hf_repo_id = "lllyasviel/ControlNet"
hf_filename = "annotator/ckpts/mlsd_large_512_fp32.pth"
@classmethod
def get_model_url(cls) -> str:
"""Get the URL to download the model from the Hugging Face Hub."""
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename)
@classmethod
def load_model(cls, model_path: pathlib.Path) -> MobileV2_MLSD_Large:
"""Load the model from a file."""
model = MobileV2_MLSD_Large()
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
return model
def __init__(self, model: MobileV2_MLSD_Large) -> None:
self.model = model
def to(self, device: torch.device):
self.model.to(device)
return self
def run(self, image: Image.Image, score_threshold: float = 0.1, distance_threshold: float = 20.0) -> Image.Image:
"""Processes an image and returns the detected edges."""
np_img = pil_to_np(image)
height, width, _channels = np_img.shape
# This model requires the input image to have a resolution that is a multiple of 64
np_img = resize_to_multiple(np_img, 64)
img_output = np.zeros_like(np_img)
with torch.no_grad():
lines = pred_lines(np_img, self.model, [np_img.shape[0], np_img.shape[1]], score_threshold, distance_threshold)
for line in lines:
x_start, y_start, x_end, y_end = [int(val) for val in line]
cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1)
detected_map = img_output[:, :, 0]
# Back to the original size
output_image = cv2.resize(detected_map, (width, height), interpolation=cv2.INTER_LINEAR)
return np_to_pil(output_image)

View File

@@ -0,0 +1,290 @@
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from torch.nn import functional as F
class BlockTypeA(nn.Module):
def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
super(BlockTypeA, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_c2, out_c2, kernel_size=1),
nn.BatchNorm2d(out_c2),
nn.ReLU(inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_c1, out_c1, kernel_size=1),
nn.BatchNorm2d(out_c1),
nn.ReLU(inplace=True)
)
self.upscale = upscale
def forward(self, a, b):
b = self.conv1(b)
a = self.conv2(a)
if self.upscale:
b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
return torch.cat((a, b), dim=1)
class BlockTypeB(nn.Module):
def __init__(self, in_c, out_c):
super(BlockTypeB, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
nn.BatchNorm2d(in_c),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
nn.BatchNorm2d(out_c),
nn.ReLU()
)
def forward(self, x):
x = self.conv1(x) + x
x = self.conv2(x)
return x
class BlockTypeC(nn.Module):
def __init__(self, in_c, out_c):
super(BlockTypeC, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
nn.BatchNorm2d(in_c),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
nn.BatchNorm2d(in_c),
nn.ReLU()
)
self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
return x
def _make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
self.channel_pad = out_planes - in_planes
self.stride = stride
#padding = (kernel_size - 1) // 2
# TFLite uses slightly different padding than PyTorch
if stride == 2:
padding = 0
else:
padding = (kernel_size - 1) // 2
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU6(inplace=True)
)
self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
def forward(self, x):
# TFLite uses different padding
if self.stride == 2:
x = F.pad(x, (0, 1, 0, 1), "constant", 0)
#print(x.shape)
for module in self:
if not isinstance(module, nn.MaxPool2d):
x = module(x)
return x
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
layers.extend([
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
])
self.conv = nn.Sequential(*layers)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class MobileNetV2(nn.Module):
def __init__(self, pretrained=True):
"""
MobileNet V2 main class
Args:
num_classes (int): Number of classes
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
inverted_residual_setting: Network structure
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
Set to 1 to turn off rounding
block: Module specifying inverted residual building block for mobilenet
"""
super(MobileNetV2, self).__init__()
block = InvertedResidual
input_channel = 32
last_channel = 1280
width_mult = 1.0
round_nearest = 8
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
#[6, 160, 3, 2],
#[6, 320, 1, 1],
]
# only check the first element, assuming user knows t,c,n,s are required
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
raise ValueError("inverted_residual_setting should be non-empty "
"or a 4-element list, got {}".format(inverted_residual_setting))
# building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [ConvBNReLU(4, input_channel, stride=2)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest)
for i in range(n):
stride = s if i == 0 else 1
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
input_channel = output_channel
self.features = nn.Sequential(*features)
self.fpn_selected = [1, 3, 6, 10, 13]
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
if pretrained:
self._load_pretrained_model()
def _forward_impl(self, x):
# This exists since TorchScript doesn't support inheritance, so the superclass method
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
fpn_features = []
for i, f in enumerate(self.features):
if i > self.fpn_selected[-1]:
break
x = f(x)
if i in self.fpn_selected:
fpn_features.append(x)
c1, c2, c3, c4, c5 = fpn_features
return c1, c2, c3, c4, c5
def forward(self, x):
return self._forward_impl(x)
def _load_pretrained_model(self):
pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
model_dict = {}
state_dict = self.state_dict()
for k, v in pretrain_dict.items():
if k in state_dict:
model_dict[k] = v
state_dict.update(model_dict)
self.load_state_dict(state_dict)
class MobileV2_MLSD_Large(nn.Module):
def __init__(self):
super(MobileV2_MLSD_Large, self).__init__()
self.backbone = MobileNetV2(pretrained=False)
## A, B
self.block15 = BlockTypeA(in_c1= 64, in_c2= 96,
out_c1= 64, out_c2=64,
upscale=False)
self.block16 = BlockTypeB(128, 64)
## A, B
self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64,
out_c1= 64, out_c2= 64)
self.block18 = BlockTypeB(128, 64)
## A, B
self.block19 = BlockTypeA(in_c1=24, in_c2=64,
out_c1=64, out_c2=64)
self.block20 = BlockTypeB(128, 64)
## A, B, C
self.block21 = BlockTypeA(in_c1=16, in_c2=64,
out_c1=64, out_c2=64)
self.block22 = BlockTypeB(128, 64)
self.block23 = BlockTypeC(64, 16)
def forward(self, x):
c1, c2, c3, c4, c5 = self.backbone(x)
x = self.block15(c4, c5)
x = self.block16(x)
x = self.block17(c3, x)
x = self.block18(x)
x = self.block19(c2, x)
x = self.block20(x)
x = self.block21(c1, x)
x = self.block22(x)
x = self.block23(x)
x = x[:, 7:, :, :]
return x

View File

@@ -0,0 +1,273 @@
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from torch.nn import functional as F
class BlockTypeA(nn.Module):
def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
super(BlockTypeA, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_c2, out_c2, kernel_size=1),
nn.BatchNorm2d(out_c2),
nn.ReLU(inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_c1, out_c1, kernel_size=1),
nn.BatchNorm2d(out_c1),
nn.ReLU(inplace=True)
)
self.upscale = upscale
def forward(self, a, b):
b = self.conv1(b)
a = self.conv2(a)
b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
return torch.cat((a, b), dim=1)
class BlockTypeB(nn.Module):
def __init__(self, in_c, out_c):
super(BlockTypeB, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
nn.BatchNorm2d(in_c),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
nn.BatchNorm2d(out_c),
nn.ReLU()
)
def forward(self, x):
x = self.conv1(x) + x
x = self.conv2(x)
return x
class BlockTypeC(nn.Module):
def __init__(self, in_c, out_c):
super(BlockTypeC, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
nn.BatchNorm2d(in_c),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
nn.BatchNorm2d(in_c),
nn.ReLU()
)
self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
return x
def _make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
self.channel_pad = out_planes - in_planes
self.stride = stride
#padding = (kernel_size - 1) // 2
# TFLite uses slightly different padding than PyTorch
if stride == 2:
padding = 0
else:
padding = (kernel_size - 1) // 2
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU6(inplace=True)
)
self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
def forward(self, x):
# TFLite uses different padding
if self.stride == 2:
x = F.pad(x, (0, 1, 0, 1), "constant", 0)
#print(x.shape)
for module in self:
if not isinstance(module, nn.MaxPool2d):
x = module(x)
return x
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
layers.extend([
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
])
self.conv = nn.Sequential(*layers)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class MobileNetV2(nn.Module):
def __init__(self, pretrained=True):
"""
MobileNet V2 main class
Args:
num_classes (int): Number of classes
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
inverted_residual_setting: Network structure
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
Set to 1 to turn off rounding
block: Module specifying inverted residual building block for mobilenet
"""
super(MobileNetV2, self).__init__()
block = InvertedResidual
input_channel = 32
last_channel = 1280
width_mult = 1.0
round_nearest = 8
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
#[6, 96, 3, 1],
#[6, 160, 3, 2],
#[6, 320, 1, 1],
]
# only check the first element, assuming user knows t,c,n,s are required
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
raise ValueError("inverted_residual_setting should be non-empty "
"or a 4-element list, got {}".format(inverted_residual_setting))
# building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [ConvBNReLU(4, input_channel, stride=2)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest)
for i in range(n):
stride = s if i == 0 else 1
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
input_channel = output_channel
self.features = nn.Sequential(*features)
self.fpn_selected = [3, 6, 10]
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
#if pretrained:
# self._load_pretrained_model()
def _forward_impl(self, x):
# This exists since TorchScript doesn't support inheritance, so the superclass method
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
fpn_features = []
for i, f in enumerate(self.features):
if i > self.fpn_selected[-1]:
break
x = f(x)
if i in self.fpn_selected:
fpn_features.append(x)
c2, c3, c4 = fpn_features
return c2, c3, c4
def forward(self, x):
return self._forward_impl(x)
def _load_pretrained_model(self):
pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
model_dict = {}
state_dict = self.state_dict()
for k, v in pretrain_dict.items():
if k in state_dict:
model_dict[k] = v
state_dict.update(model_dict)
self.load_state_dict(state_dict)
class MobileV2_MLSD_Tiny(nn.Module):
def __init__(self):
super(MobileV2_MLSD_Tiny, self).__init__()
self.backbone = MobileNetV2(pretrained=True)
self.block12 = BlockTypeA(in_c1= 32, in_c2= 64,
out_c1= 64, out_c2=64)
self.block13 = BlockTypeB(128, 64)
self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64,
out_c1= 32, out_c2= 32)
self.block15 = BlockTypeB(64, 64)
self.block16 = BlockTypeC(64, 16)
def forward(self, x):
c2, c3, c4 = self.backbone(x)
x = self.block12(c3, c4)
x = self.block13(x)
x = self.block14(c2, x)
x = self.block15(x)
x = self.block16(x)
x = x[:, 7:, :, :]
#print(x.shape)
x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True)
return x

View File

@@ -0,0 +1,587 @@
'''
modified by lihaoweicv
pytorch version
'''
'''
M-LSD
Copyright 2021-present NAVER Corp.
Apache License v2.0
'''
import cv2
import numpy as np
import torch
from torch.nn import functional as F
def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5):
'''
tpMap:
center: tpMap[1, 0, :, :]
displacement: tpMap[1, 1:5, :, :]
'''
b, c, h, w = tpMap.shape
assert b==1, 'only support bsize==1'
displacement = tpMap[:, 1:5, :, :][0]
center = tpMap[:, 0, :, :]
heat = torch.sigmoid(center)
hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2)
keep = (hmax == heat).float()
heat = heat * keep
heat = heat.reshape(-1, )
scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True)
yy = torch.floor_divide(indices, w).unsqueeze(-1)
xx = torch.fmod(indices, w).unsqueeze(-1)
ptss = torch.cat((yy, xx),dim=-1)
ptss = ptss.detach().cpu().numpy()
scores = scores.detach().cpu().numpy()
displacement = displacement.detach().cpu().numpy()
displacement = displacement.transpose((1,2,0))
return ptss, scores, displacement
def pred_lines(image, model,
input_shape=[512, 512],
score_thr=0.10,
dist_thr=20.0):
h, w, _ = image.shape
device = next(iter(model.parameters())).device
h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]
resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
resized_image = resized_image.transpose((2,0,1))
batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
batch_image = (batch_image / 127.5) - 1.0
batch_image = torch.from_numpy(batch_image).float()
batch_image = batch_image.to(device)
outputs = model(batch_image)
pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
start = vmap[:, :, :2]
end = vmap[:, :, 2:]
dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
segments_list = []
for center, score in zip(pts, pts_score, strict=False):
y, x = center
distance = dist_map[y, x]
if score > score_thr and distance > dist_thr:
disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
x_start = x + disp_x_start
y_start = y + disp_y_start
x_end = x + disp_x_end
y_end = y + disp_y_end
segments_list.append([x_start, y_start, x_end, y_end])
if segments_list:
lines = 2 * np.array(segments_list) # 256 > 512
lines[:, 0] = lines[:, 0] * w_ratio
lines[:, 1] = lines[:, 1] * h_ratio
lines[:, 2] = lines[:, 2] * w_ratio
lines[:, 3] = lines[:, 3] * h_ratio
else:
# No segments detected - return empty array
lines = np.array([])
return lines
def pred_squares(image,
model,
input_shape=[512, 512],
params={'score': 0.06,
'outside_ratio': 0.28,
'inside_ratio': 0.45,
'w_overlap': 0.0,
'w_degree': 1.95,
'w_length': 0.0,
'w_area': 1.86,
'w_center': 0.14}):
'''
shape = [height, width]
'''
h, w, _ = image.shape
original_shape = [h, w]
device = next(iter(model.parameters())).device
resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA),
np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
resized_image = resized_image.transpose((2, 0, 1))
batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
batch_image = (batch_image / 127.5) - 1.0
batch_image = torch.from_numpy(batch_image).float().to(device)
outputs = model(batch_image)
pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
start = vmap[:, :, :2] # (x, y)
end = vmap[:, :, 2:] # (x, y)
dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
junc_list = []
segments_list = []
for junc, score in zip(pts, pts_score, strict=False):
y, x = junc
distance = dist_map[y, x]
if score > params['score'] and distance > 20.0:
junc_list.append([x, y])
disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
d_arrow = 1.0
x_start = x + d_arrow * disp_x_start
y_start = y + d_arrow * disp_y_start
x_end = x + d_arrow * disp_x_end
y_end = y + d_arrow * disp_y_end
segments_list.append([x_start, y_start, x_end, y_end])
segments = np.array(segments_list)
####### post processing for squares
# 1. get unique lines
point = np.array([[0, 0]])
point = point[0]
start = segments[:, :2]
end = segments[:, 2:]
diff = start - end
a = diff[:, 1]
b = -diff[:, 0]
c = a * start[:, 0] + b * start[:, 1]
d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10)
theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi
theta[theta < 0.0] += 180
hough = np.concatenate([d[:, None], theta[:, None]], axis=-1)
d_quant = 1
theta_quant = 2
hough[:, 0] //= d_quant
hough[:, 1] //= theta_quant
_, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True)
acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32')
idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1
yx_indices = hough[indices, :].astype('int32')
acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts
idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices
acc_map_np = acc_map
# acc_map = acc_map[None, :, :, None]
#
# ### fast suppression using tensorflow op
# acc_map = tf.constant(acc_map, dtype=tf.float32)
# max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map)
# acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32)
# flatten_acc_map = tf.reshape(acc_map, [1, -1])
# topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts))
# _, h, w, _ = acc_map.shape
# y = tf.expand_dims(topk_indices // w, axis=-1)
# x = tf.expand_dims(topk_indices % w, axis=-1)
# yx = tf.concat([y, x], axis=-1)
### fast suppression using pytorch op
acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0)
_,_, h, w = acc_map.shape
max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2)
acc_map = acc_map * ( (acc_map == max_acc_map).float() )
flatten_acc_map = acc_map.reshape([-1, ])
scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True)
yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1)
xx = torch.fmod(indices, w).unsqueeze(-1)
yx = torch.cat((yy, xx), dim=-1)
yx = yx.detach().cpu().numpy()
topk_values = scores.detach().cpu().numpy()
indices = idx_map[yx[:, 0], yx[:, 1]]
basis = 5 // 2
merged_segments = []
for yx_pt, max_indice, value in zip(yx, indices, topk_values, strict=False):
y, x = yx_pt
if max_indice == -1 or value == 0:
continue
segment_list = []
for y_offset in range(-basis, basis + 1):
for x_offset in range(-basis, basis + 1):
indice = idx_map[y + y_offset, x + x_offset]
cnt = int(acc_map_np[y + y_offset, x + x_offset])
if indice != -1:
segment_list.append(segments[indice])
if cnt > 1:
check_cnt = 1
current_hough = hough[indice]
for new_indice, new_hough in enumerate(hough):
if (current_hough == new_hough).all() and indice != new_indice:
segment_list.append(segments[new_indice])
check_cnt += 1
if check_cnt == cnt:
break
group_segments = np.array(segment_list).reshape([-1, 2])
sorted_group_segments = np.sort(group_segments, axis=0)
x_min, y_min = sorted_group_segments[0, :]
x_max, y_max = sorted_group_segments[-1, :]
deg = theta[max_indice]
if deg >= 90:
merged_segments.append([x_min, y_max, x_max, y_min])
else:
merged_segments.append([x_min, y_min, x_max, y_max])
# 2. get intersections
new_segments = np.array(merged_segments) # (x1, y1, x2, y2)
start = new_segments[:, :2] # (x1, y1)
end = new_segments[:, 2:] # (x2, y2)
new_centers = (start + end) / 2.0
diff = start - end
dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1))
# ax + by = c
a = diff[:, 1]
b = -diff[:, 0]
c = a * start[:, 0] + b * start[:, 1]
pre_det = a[:, None] * b[None, :]
det = pre_det - np.transpose(pre_det)
pre_inter_y = a[:, None] * c[None, :]
inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10)
pre_inter_x = c[:, None] * b[None, :]
inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10)
inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32')
# 3. get corner information
# 3.1 get distance
'''
dist_segments:
| dist(0), dist(1), dist(2), ...|
dist_inter_to_segment1:
| dist(inter,0), dist(inter,0), dist(inter,0), ... |
| dist(inter,1), dist(inter,1), dist(inter,1), ... |
...
dist_inter_to_semgnet2:
| dist(inter,0), dist(inter,1), dist(inter,2), ... |
| dist(inter,0), dist(inter,1), dist(inter,2), ... |
...
'''
dist_inter_to_segment1_start = np.sqrt(
np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
dist_inter_to_segment1_end = np.sqrt(
np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
dist_inter_to_segment2_start = np.sqrt(
np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
dist_inter_to_segment2_end = np.sqrt(
np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
# sort ascending
dist_inter_to_segment1 = np.sort(
np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1),
axis=-1) # [n_batch, n_batch, 2]
dist_inter_to_segment2 = np.sort(
np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1),
axis=-1) # [n_batch, n_batch, 2]
# 3.2 get degree
inter_to_start = new_centers[:, None, :] - inter_pts
deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi
deg_inter_to_start[deg_inter_to_start < 0.0] += 360
inter_to_end = new_centers[None, :, :] - inter_pts
deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi
deg_inter_to_end[deg_inter_to_end < 0.0] += 360
'''
B -- G
| |
C -- R
B : blue / G: green / C: cyan / R: red
0 -- 1
| |
3 -- 2
'''
# rename variables
deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end
# sort deg ascending
deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1)
deg_diff_map = np.abs(deg1_map - deg2_map)
# we only consider the smallest degree of intersect
deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180]
# define available degree range
deg_range = [60, 120]
corner_dict = {corner_info: [] for corner_info in range(4)}
inter_points = []
for i in range(inter_pts.shape[0]):
for j in range(i + 1, inter_pts.shape[1]):
# i, j > line index, always i < j
x, y = inter_pts[i, j, :]
deg1, deg2 = deg_sort[i, j, :]
deg_diff = deg_diff_map[i, j]
check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1]
outside_ratio = params['outside_ratio'] # over ratio >>> drop it!
inside_ratio = params['inside_ratio'] # over ratio >>> drop it!
check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \
dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \
(dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \
dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \
((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \
dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \
(dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \
dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio))
if check_degree and check_distance:
corner_info = None
if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \
(deg2 >= 315 and deg1 >= 45 and deg1 <= 120):
corner_info, color_info = 0, 'blue'
elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225):
corner_info, color_info = 1, 'green'
elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315):
corner_info, color_info = 2, 'black'
elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \
(deg2 >= 315 and deg1 >= 225 and deg1 <= 315):
corner_info, color_info = 3, 'cyan'
else:
corner_info, color_info = 4, 'red' # we don't use it
continue
corner_dict[corner_info].append([x, y, i, j])
inter_points.append([x, y])
square_list = []
connect_list = []
segments_list = []
for corner0 in corner_dict[0]:
for corner1 in corner_dict[1]:
connect01 = False
for corner0_line in corner0[2:]:
if corner0_line in corner1[2:]:
connect01 = True
break
if connect01:
for corner2 in corner_dict[2]:
connect12 = False
for corner1_line in corner1[2:]:
if corner1_line in corner2[2:]:
connect12 = True
break
if connect12:
for corner3 in corner_dict[3]:
connect23 = False
for corner2_line in corner2[2:]:
if corner2_line in corner3[2:]:
connect23 = True
break
if connect23:
for corner3_line in corner3[2:]:
if corner3_line in corner0[2:]:
# SQUARE!!!
'''
0 -- 1
| |
3 -- 2
square_list:
order: 0 > 1 > 2 > 3
| x0, y0, x1, y1, x2, y2, x3, y3 |
| x0, y0, x1, y1, x2, y2, x3, y3 |
...
connect_list:
order: 01 > 12 > 23 > 30
| line_idx01, line_idx12, line_idx23, line_idx30 |
| line_idx01, line_idx12, line_idx23, line_idx30 |
...
segments_list:
order: 0 > 1 > 2 > 3
| line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
| line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
...
'''
square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2])
connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line])
segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:])
def check_outside_inside(segments_info, connect_idx):
# return 'outside or inside', min distance, cover_param, peri_param
if connect_idx == segments_info[0]:
check_dist_mat = dist_inter_to_segment1
else:
check_dist_mat = dist_inter_to_segment2
i, j = segments_info
min_dist, max_dist = check_dist_mat[i, j, :]
connect_dist = dist_segments[connect_idx]
if max_dist > connect_dist:
return 'outside', min_dist, 0, 1
else:
return 'inside', min_dist, -1, -1
top_square = None
try:
map_size = input_shape[0] / 2
squares = np.array(square_list).reshape([-1, 4, 2])
score_array = []
connect_array = np.array(connect_list)
segments_array = np.array(segments_list).reshape([-1, 4, 2])
# get degree of corners:
squares_rollup = np.roll(squares, 1, axis=1)
squares_rolldown = np.roll(squares, -1, axis=1)
vec1 = squares_rollup - squares
normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10)
vec2 = squares_rolldown - squares
normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10)
inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4]
squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4]
# get square score
overlap_scores = []
degree_scores = []
length_scores = []
for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree, strict=False):
'''
0 -- 1
| |
3 -- 2
# segments: [4, 2]
# connects: [4]
'''
###################################### OVERLAP SCORES
cover = 0
perimeter = 0
# check 0 > 1 > 2 > 3
square_length = []
for start_idx in range(4):
end_idx = (start_idx + 1) % 4
connect_idx = connects[start_idx] # segment idx of segment01
start_segments = segments[start_idx]
end_segments = segments[end_idx]
start_point = square[start_idx]
end_point = square[end_idx]
# check whether outside or inside
start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments,
connect_idx)
end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx)
cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min
perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min
square_length.append(
dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min)
overlap_scores.append(cover / perimeter)
######################################
###################################### DEGREE SCORES
'''
deg0 vs deg2
deg1 vs deg3
'''
deg0, deg1, deg2, deg3 = degree
deg_ratio1 = deg0 / deg2
if deg_ratio1 > 1.0:
deg_ratio1 = 1 / deg_ratio1
deg_ratio2 = deg1 / deg3
if deg_ratio2 > 1.0:
deg_ratio2 = 1 / deg_ratio2
degree_scores.append((deg_ratio1 + deg_ratio2) / 2)
######################################
###################################### LENGTH SCORES
'''
len0 vs len2
len1 vs len3
'''
len0, len1, len2, len3 = square_length
len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0
len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1
length_scores.append((len_ratio1 + len_ratio2) / 2)
######################################
overlap_scores = np.array(overlap_scores)
overlap_scores /= np.max(overlap_scores)
degree_scores = np.array(degree_scores)
# degree_scores /= np.max(degree_scores)
length_scores = np.array(length_scores)
###################################### AREA SCORES
area_scores = np.reshape(squares, [-1, 4, 2])
area_x = area_scores[:, :, 0]
area_y = area_scores[:, :, 1]
correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0]
area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1)
area_scores = 0.5 * np.abs(area_scores + correction)
area_scores /= (map_size * map_size) # np.max(area_scores)
######################################
###################################### CENTER SCORES
centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2]
# squares: [n, 4, 2]
square_centers = np.mean(squares, axis=1) # [n, 2]
center2center = np.sqrt(np.sum((centers - square_centers) ** 2))
center_scores = center2center / (map_size / np.sqrt(2.0))
'''
score_w = [overlap, degree, area, center, length]
'''
score_w = [0.0, 1.0, 10.0, 0.5, 1.0]
score_array = params['w_overlap'] * overlap_scores \
+ params['w_degree'] * degree_scores \
+ params['w_area'] * area_scores \
- params['w_center'] * center_scores \
+ params['w_length'] * length_scores
best_square = []
sorted_idx = np.argsort(score_array)[::-1]
score_array = score_array[sorted_idx]
squares = squares[sorted_idx]
except Exception:
pass
'''return list
merged_lines, squares, scores
'''
try:
new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1]
new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0]
new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1]
new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0]
except:
new_segments = []
try:
squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1]
squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0]
except:
squares = []
score_array = []
try:
inter_points = np.array(inter_points)
inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1]
inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0]
except:
inter_points = []
return new_segments, squares, score_array, inter_points

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2022 Caroline Chan
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,93 @@
# Adapted from https://github.com/huggingface/controlnet_aux
import pathlib
import types
import cv2
import huggingface_hub
import numpy as np
import torch
import torchvision.transforms as transforms
from einops import rearrange
from PIL import Image
from invokeai.backend.image_util.normal_bae.nets.NNET import NNET
from invokeai.backend.image_util.util import np_to_pil, pil_to_np, resize_to_multiple
class NormalMapDetector:
"""Simple wrapper around the Normal BAE model for normal map generation."""
hf_repo_id = "lllyasviel/Annotators"
hf_filename = "scannet.pt"
@classmethod
def get_model_url(cls) -> str:
"""Get the URL to download the model from the Hugging Face Hub."""
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename)
@classmethod
def load_model(cls, model_path: pathlib.Path) -> NNET:
"""Load the model from a file."""
args = types.SimpleNamespace()
args.mode = "client"
args.architecture = "BN"
args.pretrained = "scannet"
args.sampling_ratio = 0.4
args.importance_ratio = 0.7
model = NNET(args)
ckpt = torch.load(model_path, map_location="cpu")["model"]
load_dict = {}
for k, v in ckpt.items():
if k.startswith("module."):
k_ = k.replace("module.", "")
load_dict[k_] = v
else:
load_dict[k] = v
model.load_state_dict(load_dict)
model.eval()
return model
def __init__(self, model: NNET) -> None:
self.model = model
self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
def to(self, device: torch.device):
self.model.to(device)
return self
def run(self, image: Image.Image):
"""Processes an image and returns the detected normal map."""
device = next(iter(self.model.parameters())).device
np_image = pil_to_np(image)
height, width, _channels = np_image.shape
# The model requires the image to be a multiple of 8
np_image = resize_to_multiple(np_image, 8)
image_normal = np_image
with torch.no_grad():
image_normal = torch.from_numpy(image_normal).float().to(device)
image_normal = image_normal / 255.0
image_normal = rearrange(image_normal, "h w c -> 1 c h w")
image_normal = self.norm(image_normal)
normal = self.model(image_normal)
normal = normal[0][-1][:, :3]
normal = ((normal + 1) * 0.5).clip(0, 1)
normal = rearrange(normal[0], "c h w -> h w c").cpu().numpy()
normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8)
# Back to the original size
output_image = cv2.resize(normal_image, (width, height), interpolation=cv2.INTER_LINEAR)
return np_to_pil(output_image)

View File

@@ -0,0 +1,22 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .submodules.encoder import Encoder
from .submodules.decoder import Decoder
class NNET(nn.Module):
def __init__(self, args):
super(NNET, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder(args)
def get_1x_lr_params(self): # lr/10 learning rate
return self.encoder.parameters()
def get_10x_lr_params(self): # lr learning rate
return self.decoder.parameters()
def forward(self, img, **kwargs):
return self.decoder(self.encoder(img), **kwargs)

View File

@@ -0,0 +1,85 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .submodules.submodules import UpSampleBN, norm_normalize
# This is the baseline encoder-decoder we used in the ablation study
class NNET(nn.Module):
def __init__(self, args=None):
super(NNET, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder(num_classes=4)
def forward(self, x, **kwargs):
out = self.decoder(self.encoder(x), **kwargs)
# Bilinearly upsample the output to match the input resolution
up_out = F.interpolate(out, size=[x.size(2), x.size(3)], mode='bilinear', align_corners=False)
# L2-normalize the first three channels / ensure positive value for concentration parameters (kappa)
up_out = norm_normalize(up_out)
return up_out
def get_1x_lr_params(self): # lr/10 learning rate
return self.encoder.parameters()
def get_10x_lr_params(self): # lr learning rate
modules = [self.decoder]
for m in modules:
yield from m.parameters()
# Encoder
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
basemodel_name = 'tf_efficientnet_b5_ap'
basemodel = torch.hub.load('rwightman/gen-efficientnet-pytorch', basemodel_name, pretrained=True)
# Remove last layer
basemodel.global_pool = nn.Identity()
basemodel.classifier = nn.Identity()
self.original_model = basemodel
def forward(self, x):
features = [x]
for k, v in self.original_model._modules.items():
if (k == 'blocks'):
for ki, vi in v._modules.items():
features.append(vi(features[-1]))
else:
features.append(v(features[-1]))
return features
# Decoder (no pixel-wise MLP, no uncertainty-guided sampling)
class Decoder(nn.Module):
def __init__(self, num_classes=4):
super(Decoder, self).__init__()
self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0)
self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024)
self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512)
self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256)
self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128)
self.conv3 = nn.Conv2d(128, num_classes, kernel_size=3, stride=1, padding=1)
def forward(self, features):
x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11]
x_d0 = self.conv2(x_block4)
x_d1 = self.up1(x_d0, x_block3)
x_d2 = self.up2(x_d1, x_block2)
x_d3 = self.up3(x_d2, x_block1)
x_d4 = self.up4(x_d3, x_block0)
out = self.conv3(x_d4)
return out
if __name__ == '__main__':
model = Baseline()
x = torch.rand(2, 3, 480, 640)
out = model(x)
print(out.shape)

View File

@@ -0,0 +1,202 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .submodules import UpSampleBN, UpSampleGN, norm_normalize, sample_points
class Decoder(nn.Module):
def __init__(self, args):
super(Decoder, self).__init__()
# hyper-parameter for sampling
self.sampling_ratio = args.sampling_ratio
self.importance_ratio = args.importance_ratio
# feature-map
self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0)
if args.architecture == 'BN':
self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024)
self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512)
self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256)
self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128)
elif args.architecture == 'GN':
self.up1 = UpSampleGN(skip_input=2048 + 176, output_features=1024)
self.up2 = UpSampleGN(skip_input=1024 + 64, output_features=512)
self.up3 = UpSampleGN(skip_input=512 + 40, output_features=256)
self.up4 = UpSampleGN(skip_input=256 + 24, output_features=128)
else:
raise Exception('invalid architecture')
# produces 1/8 res output
self.out_conv_res8 = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
# produces 1/4 res output
self.out_conv_res4 = nn.Sequential(
nn.Conv1d(512 + 4, 128, kernel_size=1), nn.ReLU(),
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
nn.Conv1d(128, 4, kernel_size=1),
)
# produces 1/2 res output
self.out_conv_res2 = nn.Sequential(
nn.Conv1d(256 + 4, 128, kernel_size=1), nn.ReLU(),
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
nn.Conv1d(128, 4, kernel_size=1),
)
# produces 1/1 res output
self.out_conv_res1 = nn.Sequential(
nn.Conv1d(128 + 4, 128, kernel_size=1), nn.ReLU(),
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
nn.Conv1d(128, 4, kernel_size=1),
)
def forward(self, features, gt_norm_mask=None, mode='test'):
x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11]
# generate feature-map
x_d0 = self.conv2(x_block4) # x_d0 : [2, 2048, 15, 20] 1/32 res
x_d1 = self.up1(x_d0, x_block3) # x_d1 : [2, 1024, 30, 40] 1/16 res
x_d2 = self.up2(x_d1, x_block2) # x_d2 : [2, 512, 60, 80] 1/8 res
x_d3 = self.up3(x_d2, x_block1) # x_d3: [2, 256, 120, 160] 1/4 res
x_d4 = self.up4(x_d3, x_block0) # x_d4: [2, 128, 240, 320] 1/2 res
# 1/8 res output
out_res8 = self.out_conv_res8(x_d2) # out_res8: [2, 4, 60, 80] 1/8 res output
out_res8 = norm_normalize(out_res8) # out_res8: [2, 4, 60, 80] 1/8 res output
################################################################################################################
# out_res4
################################################################################################################
if mode == 'train':
# upsampling ... out_res8: [2, 4, 60, 80] -> out_res8_res4: [2, 4, 120, 160]
out_res8_res4 = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True)
B, _, H, W = out_res8_res4.shape
# samples: [B, 1, N, 2]
point_coords_res4, rows_int, cols_int = sample_points(out_res8_res4.detach(), gt_norm_mask,
sampling_ratio=self.sampling_ratio,
beta=self.importance_ratio)
# output (needed for evaluation / visualization)
out_res4 = out_res8_res4
# grid_sample feature-map
feat_res4 = F.grid_sample(x_d2, point_coords_res4, mode='bilinear', align_corners=True) # (B, 512, 1, N)
init_pred = F.grid_sample(out_res8, point_coords_res4, mode='bilinear', align_corners=True) # (B, 4, 1, N)
feat_res4 = torch.cat([feat_res4, init_pred], dim=1) # (B, 512+4, 1, N)
# prediction (needed to compute loss)
samples_pred_res4 = self.out_conv_res4(feat_res4[:, :, 0, :]) # (B, 4, N)
samples_pred_res4 = norm_normalize(samples_pred_res4) # (B, 4, N) - normalized
for i in range(B):
out_res4[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res4[i, :, :]
else:
# grid_sample feature-map
feat_map = F.interpolate(x_d2, scale_factor=2, mode='bilinear', align_corners=True)
init_pred = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True)
feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
B, _, H, W = feat_map.shape
# try all pixels
out_res4 = self.out_conv_res4(feat_map.view(B, 512 + 4, -1)) # (B, 4, N)
out_res4 = norm_normalize(out_res4) # (B, 4, N) - normalized
out_res4 = out_res4.view(B, 4, H, W)
samples_pred_res4 = point_coords_res4 = None
################################################################################################################
# out_res2
################################################################################################################
if mode == 'train':
# upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320]
out_res4_res2 = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True)
B, _, H, W = out_res4_res2.shape
# samples: [B, 1, N, 2]
point_coords_res2, rows_int, cols_int = sample_points(out_res4_res2.detach(), gt_norm_mask,
sampling_ratio=self.sampling_ratio,
beta=self.importance_ratio)
# output (needed for evaluation / visualization)
out_res2 = out_res4_res2
# grid_sample feature-map
feat_res2 = F.grid_sample(x_d3, point_coords_res2, mode='bilinear', align_corners=True) # (B, 256, 1, N)
init_pred = F.grid_sample(out_res4, point_coords_res2, mode='bilinear', align_corners=True) # (B, 4, 1, N)
feat_res2 = torch.cat([feat_res2, init_pred], dim=1) # (B, 256+4, 1, N)
# prediction (needed to compute loss)
samples_pred_res2 = self.out_conv_res2(feat_res2[:, :, 0, :]) # (B, 4, N)
samples_pred_res2 = norm_normalize(samples_pred_res2) # (B, 4, N) - normalized
for i in range(B):
out_res2[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res2[i, :, :]
else:
# grid_sample feature-map
feat_map = F.interpolate(x_d3, scale_factor=2, mode='bilinear', align_corners=True)
init_pred = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True)
feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
B, _, H, W = feat_map.shape
out_res2 = self.out_conv_res2(feat_map.view(B, 256 + 4, -1)) # (B, 4, N)
out_res2 = norm_normalize(out_res2) # (B, 4, N) - normalized
out_res2 = out_res2.view(B, 4, H, W)
samples_pred_res2 = point_coords_res2 = None
################################################################################################################
# out_res1
################################################################################################################
if mode == 'train':
# upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320]
out_res2_res1 = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True)
B, _, H, W = out_res2_res1.shape
# samples: [B, 1, N, 2]
point_coords_res1, rows_int, cols_int = sample_points(out_res2_res1.detach(), gt_norm_mask,
sampling_ratio=self.sampling_ratio,
beta=self.importance_ratio)
# output (needed for evaluation / visualization)
out_res1 = out_res2_res1
# grid_sample feature-map
feat_res1 = F.grid_sample(x_d4, point_coords_res1, mode='bilinear', align_corners=True) # (B, 128, 1, N)
init_pred = F.grid_sample(out_res2, point_coords_res1, mode='bilinear', align_corners=True) # (B, 4, 1, N)
feat_res1 = torch.cat([feat_res1, init_pred], dim=1) # (B, 128+4, 1, N)
# prediction (needed to compute loss)
samples_pred_res1 = self.out_conv_res1(feat_res1[:, :, 0, :]) # (B, 4, N)
samples_pred_res1 = norm_normalize(samples_pred_res1) # (B, 4, N) - normalized
for i in range(B):
out_res1[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res1[i, :, :]
else:
# grid_sample feature-map
feat_map = F.interpolate(x_d4, scale_factor=2, mode='bilinear', align_corners=True)
init_pred = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True)
feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
B, _, H, W = feat_map.shape
out_res1 = self.out_conv_res1(feat_map.view(B, 128 + 4, -1)) # (B, 4, N)
out_res1 = norm_normalize(out_res1) # (B, 4, N) - normalized
out_res1 = out_res1.view(B, 4, H, W)
samples_pred_res1 = point_coords_res1 = None
return [out_res8, out_res4, out_res2, out_res1], \
[out_res8, samples_pred_res4, samples_pred_res2, samples_pred_res1], \
[None, point_coords_res4, point_coords_res2, point_coords_res1]

View File

@@ -0,0 +1,109 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# pytorch stuff
*.pth
*.onnx
*.pb
trained_models/
.fuse_hidden*

View File

@@ -0,0 +1,555 @@
# Model Performance Benchmarks
All benchmarks run as per:
```
python onnx_export.py --model mobilenetv3_100 ./mobilenetv3_100.onnx
python onnx_optimize.py ./mobilenetv3_100.onnx --output mobilenetv3_100-opt.onnx
python onnx_to_caffe.py ./mobilenetv3_100.onnx --c2-prefix mobilenetv3
python onnx_to_caffe.py ./mobilenetv3_100-opt.onnx --c2-prefix mobilenetv3-opt
python caffe2_benchmark.py --c2-init ./mobilenetv3.init.pb --c2-predict ./mobilenetv3.predict.pb
python caffe2_benchmark.py --c2-init ./mobilenetv3-opt.init.pb --c2-predict ./mobilenetv3-opt.predict.pb
```
## EfficientNet-B0
### Unoptimized
```
Main run finished. Milliseconds per iter: 49.2862. Iters per second: 20.2897
Time per operator type:
29.7378 ms. 60.5145%. Conv
12.1785 ms. 24.7824%. Sigmoid
3.62811 ms. 7.38297%. SpatialBN
2.98444 ms. 6.07314%. Mul
0.326902 ms. 0.665225%. AveragePool
0.197317 ms. 0.401528%. FC
0.0852877 ms. 0.173555%. Add
0.0032607 ms. 0.00663532%. Squeeze
49.1416 ms in Total
FLOP per operator type:
0.76907 GFLOP. 95.2696%. Conv
0.0269508 GFLOP. 3.33857%. SpatialBN
0.00846444 GFLOP. 1.04855%. Mul
0.002561 GFLOP. 0.317248%. FC
0.000210112 GFLOP. 0.0260279%. Add
0.807256 GFLOP in Total
Feature Memory Read per operator type:
58.5253 MB. 43.0891%. Mul
43.2015 MB. 31.807%. Conv
27.2869 MB. 20.0899%. SpatialBN
5.12912 MB. 3.77631%. FC
1.6809 MB. 1.23756%. Add
135.824 MB in Total
Feature Memory Written per operator type:
33.8578 MB. 38.1965%. Mul
26.9881 MB. 30.4465%. Conv
26.9508 MB. 30.4044%. SpatialBN
0.840448 MB. 0.948147%. Add
0.004 MB. 0.00451258%. FC
88.6412 MB in Total
Parameter Memory per operator type:
15.8248 MB. 74.9391%. Conv
5.124 MB. 24.265%. FC
0.168064 MB. 0.795877%. SpatialBN
0 MB. 0%. Add
0 MB. 0%. Mul
21.1168 MB in Total
```
### Optimized
```
Main run finished. Milliseconds per iter: 46.0838. Iters per second: 21.6996
Time per operator type:
29.776 ms. 65.002%. Conv
12.2803 ms. 26.8084%. Sigmoid
3.15073 ms. 6.87815%. Mul
0.328651 ms. 0.717456%. AveragePool
0.186237 ms. 0.406563%. FC
0.0832429 ms. 0.181722%. Add
0.0026184 ms. 0.00571606%. Squeeze
45.8078 ms in Total
FLOP per operator type:
0.76907 GFLOP. 98.5601%. Conv
0.00846444 GFLOP. 1.08476%. Mul
0.002561 GFLOP. 0.328205%. FC
0.000210112 GFLOP. 0.0269269%. Add
0.780305 GFLOP in Total
Feature Memory Read per operator type:
58.5253 MB. 53.8803%. Mul
43.2855 MB. 39.8501%. Conv
5.12912 MB. 4.72204%. FC
1.6809 MB. 1.54749%. Add
108.621 MB in Total
Feature Memory Written per operator type:
33.8578 MB. 54.8834%. Mul
26.9881 MB. 43.7477%. Conv
0.840448 MB. 1.36237%. Add
0.004 MB. 0.00648399%. FC
61.6904 MB in Total
Parameter Memory per operator type:
15.8248 MB. 75.5403%. Conv
5.124 MB. 24.4597%. FC
0 MB. 0%. Add
0 MB. 0%. Mul
20.9488 MB in Total
```
## EfficientNet-B1
### Optimized
```
Main run finished. Milliseconds per iter: 71.8102. Iters per second: 13.9256
Time per operator type:
45.7915 ms. 66.3206%. Conv
17.8718 ms. 25.8841%. Sigmoid
4.44132 ms. 6.43244%. Mul
0.51001 ms. 0.738658%. AveragePool
0.233283 ms. 0.337868%. Add
0.194986 ms. 0.282402%. FC
0.00268255 ms. 0.00388519%. Squeeze
69.0456 ms in Total
FLOP per operator type:
1.37105 GFLOP. 98.7673%. Conv
0.0138759 GFLOP. 0.99959%. Mul
0.002561 GFLOP. 0.184489%. FC
0.000674432 GFLOP. 0.0485847%. Add
1.38816 GFLOP in Total
Feature Memory Read per operator type:
94.624 MB. 54.0789%. Mul
69.8255 MB. 39.9062%. Conv
5.39546 MB. 3.08357%. Add
5.12912 MB. 2.93136%. FC
174.974 MB in Total
Feature Memory Written per operator type:
55.5035 MB. 54.555%. Mul
43.5333 MB. 42.7894%. Conv
2.69773 MB. 2.65163%. Add
0.004 MB. 0.00393165%. FC
101.739 MB in Total
Parameter Memory per operator type:
25.7479 MB. 83.4024%. Conv
5.124 MB. 16.5976%. FC
0 MB. 0%. Add
0 MB. 0%. Mul
30.8719 MB in Total
```
## EfficientNet-B2
### Optimized
```
Main run finished. Milliseconds per iter: 92.28. Iters per second: 10.8366
Time per operator type:
61.4627 ms. 67.5845%. Conv
22.7458 ms. 25.0113%. Sigmoid
5.59931 ms. 6.15701%. Mul
0.642567 ms. 0.706568%. AveragePool
0.272795 ms. 0.299965%. Add
0.216178 ms. 0.237709%. FC
0.00268895 ms. 0.00295677%. Squeeze
90.942 ms in Total
FLOP per operator type:
1.98431 GFLOP. 98.9343%. Conv
0.0177039 GFLOP. 0.882686%. Mul
0.002817 GFLOP. 0.140451%. FC
0.000853984 GFLOP. 0.0425782%. Add
2.00568 GFLOP in Total
Feature Memory Read per operator type:
120.609 MB. 54.9637%. Mul
86.3512 MB. 39.3519%. Conv
6.83187 MB. 3.11341%. Add
5.64163 MB. 2.571%. FC
219.433 MB in Total
Feature Memory Written per operator type:
70.8155 MB. 54.6573%. Mul
55.3273 MB. 42.7031%. Conv
3.41594 MB. 2.63651%. Add
0.004 MB. 0.00308731%. FC
129.563 MB in Total
Parameter Memory per operator type:
30.4721 MB. 84.3913%. Conv
5.636 MB. 15.6087%. FC
0 MB. 0%. Add
0 MB. 0%. Mul
36.1081 MB in Total
```
## MixNet-M
### Optimized
```
Main run finished. Milliseconds per iter: 63.1122. Iters per second: 15.8448
Time per operator type:
48.1139 ms. 75.2052%. Conv
7.1341 ms. 11.1511%. Sigmoid
2.63706 ms. 4.12189%. SpatialBN
1.73186 ms. 2.70701%. Mul
1.38707 ms. 2.16809%. Split
1.29322 ms. 2.02139%. Concat
1.00093 ms. 1.56452%. Relu
0.235309 ms. 0.367803%. Add
0.221579 ms. 0.346343%. FC
0.219315 ms. 0.342803%. AveragePool
0.00250145 ms. 0.00390993%. Squeeze
63.9768 ms in Total
FLOP per operator type:
0.675273 GFLOP. 95.5827%. Conv
0.0221072 GFLOP. 3.12921%. SpatialBN
0.00538445 GFLOP. 0.762152%. Mul
0.003073 GFLOP. 0.434973%. FC
0.000642488 GFLOP. 0.0909421%. Add
0 GFLOP. 0%. Concat
0 GFLOP. 0%. Relu
0.70648 GFLOP in Total
Feature Memory Read per operator type:
46.8424 MB. 30.502%. Conv
36.8626 MB. 24.0036%. Mul
22.3152 MB. 14.5309%. SpatialBN
22.1074 MB. 14.3955%. Concat
14.1496 MB. 9.21372%. Relu
6.15414 MB. 4.00735%. FC
5.1399 MB. 3.34692%. Add
153.571 MB in Total
Feature Memory Written per operator type:
32.7672 MB. 28.4331%. Conv
22.1072 MB. 19.1831%. Concat
22.1072 MB. 19.1831%. SpatialBN
21.5378 MB. 18.689%. Mul
14.1496 MB. 12.2781%. Relu
2.56995 MB. 2.23003%. Add
0.004 MB. 0.00347092%. FC
115.243 MB in Total
Parameter Memory per operator type:
13.7059 MB. 68.674%. Conv
6.148 MB. 30.8049%. FC
0.104 MB. 0.521097%. SpatialBN
0 MB. 0%. Add
0 MB. 0%. Concat
0 MB. 0%. Mul
0 MB. 0%. Relu
19.9579 MB in Total
```
## TF MobileNet-V3 Large 1.0
### Optimized
```
Main run finished. Milliseconds per iter: 22.0495. Iters per second: 45.3525
Time per operator type:
17.437 ms. 80.0087%. Conv
1.27662 ms. 5.8577%. Add
1.12759 ms. 5.17387%. Div
0.701155 ms. 3.21721%. Mul
0.562654 ms. 2.58171%. Relu
0.431144 ms. 1.97828%. Clip
0.156902 ms. 0.719936%. FC
0.0996858 ms. 0.457402%. AveragePool
0.00112455 ms. 0.00515993%. Flatten
21.7939 ms in Total
FLOP per operator type:
0.43062 GFLOP. 98.1484%. Conv
0.002561 GFLOP. 0.583713%. FC
0.00210867 GFLOP. 0.480616%. Mul
0.00193868 GFLOP. 0.441871%. Add
0.00151532 GFLOP. 0.345377%. Div
0 GFLOP. 0%. Relu
0.438743 GFLOP in Total
Feature Memory Read per operator type:
34.7967 MB. 43.9391%. Conv
14.496 MB. 18.3046%. Mul
9.44828 MB. 11.9307%. Add
9.26157 MB. 11.6949%. Relu
6.0614 MB. 7.65395%. Div
5.12912 MB. 6.47673%. FC
79.193 MB in Total
Feature Memory Written per operator type:
17.6247 MB. 35.8656%. Conv
9.26157 MB. 18.847%. Relu
8.43469 MB. 17.1643%. Mul
7.75472 MB. 15.7806%. Add
6.06128 MB. 12.3345%. Div
0.004 MB. 0.00813985%. FC
49.1409 MB in Total
Parameter Memory per operator type:
16.6851 MB. 76.5052%. Conv
5.124 MB. 23.4948%. FC
0 MB. 0%. Add
0 MB. 0%. Div
0 MB. 0%. Mul
0 MB. 0%. Relu
21.8091 MB in Total
```
## MobileNet-V3 (RW)
### Unoptimized
```
Main run finished. Milliseconds per iter: 24.8316. Iters per second: 40.2712
Time per operator type:
15.9266 ms. 69.2624%. Conv
2.36551 ms. 10.2873%. SpatialBN
1.39102 ms. 6.04936%. Add
1.30327 ms. 5.66773%. Div
0.737014 ms. 3.20517%. Mul
0.639697 ms. 2.78195%. Relu
0.375681 ms. 1.63378%. Clip
0.153126 ms. 0.665921%. FC
0.0993787 ms. 0.432184%. AveragePool
0.0032632 ms. 0.0141912%. Squeeze
22.9946 ms in Total
FLOP per operator type:
0.430616 GFLOP. 94.4041%. Conv
0.0175992 GFLOP. 3.85829%. SpatialBN
0.002561 GFLOP. 0.561449%. FC
0.00210961 GFLOP. 0.46249%. Mul
0.00173891 GFLOP. 0.381223%. Add
0.00151626 GFLOP. 0.33241%. Div
0 GFLOP. 0%. Relu
0.456141 GFLOP in Total
Feature Memory Read per operator type:
34.7354 MB. 36.4363%. Conv
17.7944 MB. 18.6658%. SpatialBN
14.5035 MB. 15.2137%. Mul
9.25778 MB. 9.71113%. Relu
7.84641 MB. 8.23064%. Add
6.06516 MB. 6.36216%. Div
5.12912 MB. 5.38029%. FC
95.3317 MB in Total
Feature Memory Written per operator type:
17.6246 MB. 26.7264%. Conv
17.5992 MB. 26.6878%. SpatialBN
9.25778 MB. 14.0387%. Relu
8.43843 MB. 12.7962%. Mul
6.95565 MB. 10.5477%. Add
6.06502 MB. 9.19713%. Div
0.004 MB. 0.00606568%. FC
65.9447 MB in Total
Parameter Memory per operator type:
16.6778 MB. 76.1564%. Conv
5.124 MB. 23.3979%. FC
0.0976 MB. 0.445674%. SpatialBN
0 MB. 0%. Add
0 MB. 0%. Div
0 MB. 0%. Mul
0 MB. 0%. Relu
21.8994 MB in Total
```
### Optimized
```
Main run finished. Milliseconds per iter: 22.0981. Iters per second: 45.2527
Time per operator type:
17.146 ms. 78.8965%. Conv
1.38453 ms. 6.37084%. Add
1.30991 ms. 6.02749%. Div
0.685417 ms. 3.15391%. Mul
0.532589 ms. 2.45068%. Relu
0.418263 ms. 1.92461%. Clip
0.15128 ms. 0.696106%. FC
0.102065 ms. 0.469648%. AveragePool
0.0022143 ms. 0.010189%. Squeeze
21.7323 ms in Total
FLOP per operator type:
0.430616 GFLOP. 98.1927%. Conv
0.002561 GFLOP. 0.583981%. FC
0.00210961 GFLOP. 0.481051%. Mul
0.00173891 GFLOP. 0.396522%. Add
0.00151626 GFLOP. 0.34575%. Div
0 GFLOP. 0%. Relu
0.438542 GFLOP in Total
Feature Memory Read per operator type:
34.7842 MB. 44.833%. Conv
14.5035 MB. 18.6934%. Mul
9.25778 MB. 11.9323%. Relu
7.84641 MB. 10.1132%. Add
6.06516 MB. 7.81733%. Div
5.12912 MB. 6.61087%. FC
77.5861 MB in Total
Feature Memory Written per operator type:
17.6246 MB. 36.4556%. Conv
9.25778 MB. 19.1492%. Relu
8.43843 MB. 17.4544%. Mul
6.95565 MB. 14.3874%. Add
6.06502 MB. 12.5452%. Div
0.004 MB. 0.00827378%. FC
48.3455 MB in Total
Parameter Memory per operator type:
16.6778 MB. 76.4973%. Conv
5.124 MB. 23.5027%. FC
0 MB. 0%. Add
0 MB. 0%. Div
0 MB. 0%. Mul
0 MB. 0%. Relu
21.8018 MB in Total
```
## MnasNet-A1
### Unoptimized
```
Main run finished. Milliseconds per iter: 30.0892. Iters per second: 33.2345
Time per operator type:
24.4656 ms. 79.0905%. Conv
4.14958 ms. 13.4144%. SpatialBN
1.60598 ms. 5.19169%. Relu
0.295219 ms. 0.95436%. Mul
0.187609 ms. 0.606486%. FC
0.120556 ms. 0.389724%. AveragePool
0.09036 ms. 0.292109%. Add
0.015727 ms. 0.050841%. Sigmoid
0.00306205 ms. 0.00989875%. Squeeze
30.9337 ms in Total
FLOP per operator type:
0.620598 GFLOP. 95.6434%. Conv
0.0248873 GFLOP. 3.8355%. SpatialBN
0.002561 GFLOP. 0.394688%. FC
0.000597408 GFLOP. 0.0920695%. Mul
0.000222656 GFLOP. 0.0343146%. Add
0 GFLOP. 0%. Relu
0.648867 GFLOP in Total
Feature Memory Read per operator type:
35.5457 MB. 38.4109%. Conv
25.1552 MB. 27.1829%. SpatialBN
22.5235 MB. 24.339%. Relu
5.12912 MB. 5.54256%. FC
2.40586 MB. 2.59978%. Mul
1.78125 MB. 1.92483%. Add
92.5406 MB in Total
Feature Memory Written per operator type:
24.9042 MB. 32.9424%. Conv
24.8873 MB. 32.92%. SpatialBN
22.5235 MB. 29.7932%. Relu
2.38963 MB. 3.16092%. Mul
0.890624 MB. 1.17809%. Add
0.004 MB. 0.00529106%. FC
75.5993 MB in Total
Parameter Memory per operator type:
10.2732 MB. 66.1459%. Conv
5.124 MB. 32.9917%. FC
0.133952 MB. 0.86247%. SpatialBN
0 MB. 0%. Add
0 MB. 0%. Mul
0 MB. 0%. Relu
15.5312 MB in Total
```
### Optimized
```
Main run finished. Milliseconds per iter: 24.2367. Iters per second: 41.2597
Time per operator type:
22.0547 ms. 91.1375%. Conv
1.49096 ms. 6.16116%. Relu
0.253417 ms. 1.0472%. Mul
0.18506 ms. 0.76473%. FC
0.112942 ms. 0.466717%. AveragePool
0.086769 ms. 0.358559%. Add
0.0127889 ms. 0.0528479%. Sigmoid
0.0027346 ms. 0.0113003%. Squeeze
24.1994 ms in Total
FLOP per operator type:
0.620598 GFLOP. 99.4581%. Conv
0.002561 GFLOP. 0.41043%. FC
0.000597408 GFLOP. 0.0957417%. Mul
0.000222656 GFLOP. 0.0356832%. Add
0 GFLOP. 0%. Relu
0.623979 GFLOP in Total
Feature Memory Read per operator type:
35.6127 MB. 52.7968%. Conv
22.5235 MB. 33.3917%. Relu
5.12912 MB. 7.60406%. FC
2.40586 MB. 3.56675%. Mul
1.78125 MB. 2.64075%. Add
67.4524 MB in Total
Feature Memory Written per operator type:
24.9042 MB. 49.1092%. Conv
22.5235 MB. 44.4145%. Relu
2.38963 MB. 4.71216%. Mul
0.890624 MB. 1.75624%. Add
0.004 MB. 0.00788768%. FC
50.712 MB in Total
Parameter Memory per operator type:
10.2732 MB. 66.7213%. Conv
5.124 MB. 33.2787%. FC
0 MB. 0%. Add
0 MB. 0%. Mul
0 MB. 0%. Relu
15.3972 MB in Total
```
## MnasNet-B1
### Unoptimized
```
Main run finished. Milliseconds per iter: 28.3109. Iters per second: 35.322
Time per operator type:
29.1121 ms. 83.3081%. Conv
4.14959 ms. 11.8746%. SpatialBN
1.35823 ms. 3.88675%. Relu
0.186188 ms. 0.532802%. FC
0.116244 ms. 0.332647%. Add
0.018641 ms. 0.0533437%. AveragePool
0.0040904 ms. 0.0117052%. Squeeze
34.9451 ms in Total
FLOP per operator type:
0.626272 GFLOP. 96.2088%. Conv
0.0218266 GFLOP. 3.35303%. SpatialBN
0.002561 GFLOP. 0.393424%. FC
0.000291648 GFLOP. 0.0448034%. Add
0 GFLOP. 0%. Relu
0.650951 GFLOP in Total
Feature Memory Read per operator type:
34.4354 MB. 41.3788%. Conv
22.1299 MB. 26.5921%. SpatialBN
19.1923 MB. 23.0622%. Relu
5.12912 MB. 6.16333%. FC
2.33318 MB. 2.80364%. Add
83.2199 MB in Total
Feature Memory Written per operator type:
21.8266 MB. 34.0955%. Conv
21.8266 MB. 34.0955%. SpatialBN
19.1923 MB. 29.9805%. Relu
1.16659 MB. 1.82234%. Add
0.004 MB. 0.00624844%. FC
64.016 MB in Total
Parameter Memory per operator type:
12.2576 MB. 69.9104%. Conv
5.124 MB. 29.2245%. FC
0.15168 MB. 0.865099%. SpatialBN
0 MB. 0%. Add
0 MB. 0%. Relu
17.5332 MB in Total
```
### Optimized
```
Main run finished. Milliseconds per iter: 26.6364. Iters per second: 37.5426
Time per operator type:
24.9888 ms. 94.0962%. Conv
1.26147 ms. 4.75011%. Relu
0.176234 ms. 0.663619%. FC
0.113309 ms. 0.426672%. Add
0.0138708 ms. 0.0522311%. AveragePool
0.00295685 ms. 0.0111341%. Squeeze
26.5566 ms in Total
FLOP per operator type:
0.626272 GFLOP. 99.5466%. Conv
0.002561 GFLOP. 0.407074%. FC
0.000291648 GFLOP. 0.0463578%. Add
0 GFLOP. 0%. Relu
0.629124 GFLOP in Total
Feature Memory Read per operator type:
34.5112 MB. 56.4224%. Conv
19.1923 MB. 31.3775%. Relu
5.12912 MB. 8.3856%. FC
2.33318 MB. 3.81452%. Add
61.1658 MB in Total
Feature Memory Written per operator type:
21.8266 MB. 51.7346%. Conv
19.1923 MB. 45.4908%. Relu
1.16659 MB. 2.76513%. Add
0.004 MB. 0.00948104%. FC
42.1895 MB in Total
Parameter Memory per operator type:
12.2576 MB. 70.5205%. Conv
5.124 MB. 29.4795%. FC
0 MB. 0%. Add
0 MB. 0%. Relu
17.3816 MB in Total
```

View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2020 Ross Wightman
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -0,0 +1,323 @@
# (Generic) EfficientNets for PyTorch
A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter efficient architectures derived from the MobileNet V1/V2 block sequence, including those found via automated neural architecture search.
All models are implemented by GenEfficientNet or MobileNetV3 classes, with string based architecture definitions to configure the block layouts (idea from [here](https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py))
## What's New
### Aug 19, 2020
* Add updated PyTorch trained EfficientNet-B3 weights trained by myself with `timm` (82.1 top-1)
* Add PyTorch trained EfficientNet-Lite0 contributed by [@hal-314](https://github.com/hal-314) (75.5 top-1)
* Update ONNX and Caffe2 export / utility scripts to work with latest PyTorch / ONNX
* ONNX runtime based validation script added
* activations (mostly) brought in sync with `timm` equivalents
### April 5, 2020
* Add some newly trained MobileNet-V2 models trained with latest h-params, rand augment. They compare quite favourably to EfficientNet-Lite
* 3.5M param MobileNet-V2 100 @ 73%
* 4.5M param MobileNet-V2 110d @ 75%
* 6.1M param MobileNet-V2 140 @ 76.5%
* 5.8M param MobileNet-V2 120d @ 77.3%
### March 23, 2020
* Add EfficientNet-Lite models w/ weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite)
* Add PyTorch trained MobileNet-V3 Large weights with 75.77% top-1
* IMPORTANT CHANGE (if training from scratch) - weight init changed to better match Tensorflow impl, set `fix_group_fanout=False` in `initialize_weight_goog` for old behavior
### Feb 12, 2020
* Add EfficientNet-L2 and B0-B7 NoisyStudent weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet)
* Port new EfficientNet-B8 (RandAugment) weights from TF TPU, these are different than the B8 AdvProp, different input normalization.
* Add RandAugment PyTorch trained EfficientNet-ES (EdgeTPU-Small) weights with 78.1 top-1. Trained by [Andrew Lavin](https://github.com/andravin)
### Jan 22, 2020
* Update weights for EfficientNet B0, B2, B3 and MixNet-XL with latest RandAugment trained weights. Trained with (https://github.com/rwightman/pytorch-image-models)
* Fix torchscript compatibility for PyTorch 1.4, add torchscript support for MixedConv2d using ModuleDict
* Test models, torchscript, onnx export with PyTorch 1.4 -- no issues
### Nov 22, 2019
* New top-1 high! Ported official TF EfficientNet AdvProp (https://arxiv.org/abs/1911.09665) weights and B8 model spec. Created a new set of `ap` models since they use a different
preprocessing (Inception mean/std) from the original EfficientNet base/AA/RA weights.
### Nov 15, 2019
* Ported official TF MobileNet-V3 float32 large/small/minimalistic weights
* Modifications to MobileNet-V3 model and components to support some additional config needed for differences between TF MobileNet-V3 and mine
### Oct 30, 2019
* Many of the models will now work with torch.jit.script, MixNet being the biggest exception
* Improved interface for enabling torchscript or ONNX export compatible modes (via config)
* Add JIT optimized mem-efficient Swish/Mish autograd.fn in addition to memory-efficient autgrad.fn
* Activation factory to select best version of activation by name or override one globally
* Add pretrained checkpoint load helper that handles input conv and classifier changes
### Oct 27, 2019
* Add CondConv EfficientNet variants ported from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv
* Add RandAug weights for TF EfficientNet B5 and B7 from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
* Bring over MixNet-XL model and depth scaling algo from my pytorch-image-models code base
* Switch activations and global pooling to modules
* Add memory-efficient Swish/Mish impl
* Add as_sequential() method to all models and allow as an argument in entrypoint fns
* Move MobileNetV3 into own file since it has a different head
* Remove ChamNet, MobileNet V2/V1 since they will likely never be used here
## Models
Implemented models include:
* EfficientNet NoisyStudent (B0-B7, L2) (https://arxiv.org/abs/1911.04252)
* EfficientNet AdvProp (B0-B8) (https://arxiv.org/abs/1911.09665)
* EfficientNet (B0-B8) (https://arxiv.org/abs/1905.11946)
* EfficientNet-EdgeTPU (S, M, L) (https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html)
* EfficientNet-CondConv (https://arxiv.org/abs/1904.04971)
* EfficientNet-Lite (https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite)
* MixNet (https://arxiv.org/abs/1907.09595)
* MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626)
* MobileNet-V3 (https://arxiv.org/abs/1905.02244)
* FBNet-C (https://arxiv.org/abs/1812.03443)
* Single-Path NAS (https://arxiv.org/abs/1904.02877)
I originally implemented and trained some these models with code [here](https://github.com/rwightman/pytorch-image-models), this repository contains just the GenEfficientNet models, validation, and associated ONNX/Caffe2 export code.
## Pretrained
I've managed to train several of the models to accuracies close to or above the originating papers and official impl. My training code is here: https://github.com/rwightman/pytorch-image-models
|Model | Prec@1 (Err) | Prec@5 (Err) | Param#(M) | MAdds(M) | Image Scaling | Resolution | Crop |
|---|---|---|---|---|---|---|---|
| efficientnet_b3 | 82.240 (17.760) | 96.116 (3.884) | 12.23 | TBD | bicubic | 320 | 1.0 |
| efficientnet_b3 | 82.076 (17.924) | 96.020 (3.980) | 12.23 | TBD | bicubic | 300 | 0.904 |
| mixnet_xl | 81.074 (18.926) | 95.282 (4.718) | 11.90 | TBD | bicubic | 256 | 1.0 |
| efficientnet_b2 | 80.612 (19.388) | 95.318 (4.682) | 9.1 | TBD | bicubic | 288 | 1.0 |
| mixnet_xl | 80.476 (19.524) | 94.936 (5.064) | 11.90 | TBD | bicubic | 224 | 0.875 |
| efficientnet_b2 | 80.288 (19.712) | 95.166 (4.834) | 9.1 | 1003 | bicubic | 260 | 0.890 |
| mixnet_l | 78.976 (21.024 | 94.184 (5.816) | 7.33 | TBD | bicubic | 224 | 0.875 |
| efficientnet_b1 | 78.692 (21.308) | 94.086 (5.914) | 7.8 | 694 | bicubic | 240 | 0.882 |
| efficientnet_es | 78.066 (21.934) | 93.926 (6.074) | 5.44 | TBD | bicubic | 224 | 0.875 |
| efficientnet_b0 | 77.698 (22.302) | 93.532 (6.468) | 5.3 | 390 | bicubic | 224 | 0.875 |
| mobilenetv2_120d | 77.294 (22.706 | 93.502 (6.498) | 5.8 | TBD | bicubic | 224 | 0.875 |
| mixnet_m | 77.256 (22.744) | 93.418 (6.582) | 5.01 | 353 | bicubic | 224 | 0.875 |
| mobilenetv2_140 | 76.524 (23.476) | 92.990 (7.010) | 6.1 | TBD | bicubic | 224 | 0.875 |
| mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13 | TBD | bicubic | 224 | 0.875 |
| mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5 | TBD | bicubic | 224 | 0.875 |
| mobilenetv3_rw | 75.634 (24.366) | 92.708 (7.292) | 5.5 | 219 | bicubic | 224 | 0.875 |
| efficientnet_lite0 | 75.472 (24.528) | 92.520 (7.480) | 4.65 | TBD | bicubic | 224 | 0.875 |
| mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.9 | 312 | bicubic | 224 | 0.875 |
| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6 | 385 | bilinear | 224 | 0.875 |
| mobilenetv2_110d | 75.052 (24.948) | 92.180 (7.820) | 4.5 | TBD | bicubic | 224 | 0.875 |
| mnasnet_b1 | 74.658 (25.342) | 92.114 (7.886) | 4.4 | 315 | bicubic | 224 | 0.875 |
| spnasnet_100 | 74.084 (25.916) | 91.818 (8.182) | 4.4 | TBD | bilinear | 224 | 0.875 |
| mobilenetv2_100 | 72.978 (27.022) | 91.016 (8.984) | 3.5 | TBD | bicubic | 224 | 0.875 |
More pretrained models to come...
## Ported Weights
The weights ported from Tensorflow checkpoints for the EfficientNet models do pretty much match accuracy in Tensorflow once a SAME convolution padding equivalent is added, and the same crop factors, image scaling, etc (see table) are used via cmd line args.
**IMPORTANT:**
* Tensorflow ported weights for EfficientNet AdvProp (AP), EfficientNet EdgeTPU, EfficientNet-CondConv, EfficientNet-Lite, and MobileNet-V3 models use Inception style (0.5, 0.5, 0.5) for mean and std.
* Enabling the Tensorflow preprocessing pipeline with `--tf-preprocessing` at validation time will improve scores by 0.1-0.5%, very close to original TF impl.
To run validation for tf_efficientnet_b5:
`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --crop-pct 0.934 --interpolation bicubic`
To run validation w/ TF preprocessing for tf_efficientnet_b5:
`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --tf-preprocessing`
To run validation for a model with Inception preprocessing, ie EfficientNet-B8 AdvProp:
`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b8_ap -b 48 --num-gpu 2 --img-size 672 --crop-pct 0.954 --mean 0.5 --std 0.5`
|Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Image Size | Crop |
|---|---|---|---|---|---|---|
| tf_efficientnet_l2_ns *tfp | 88.352 (11.648) | 98.652 (1.348) | 480 | bicubic | 800 | N/A |
| tf_efficientnet_l2_ns | TBD | TBD | 480 | bicubic | 800 | 0.961 |
| tf_efficientnet_l2_ns_475 | 88.234 (11.766) | 98.546 (1.454) | 480 | bicubic | 475 | 0.936 |
| tf_efficientnet_l2_ns_475 *tfp | 88.172 (11.828) | 98.566 (1.434) | 480 | bicubic | 475 | N/A |
| tf_efficientnet_b7_ns *tfp | 86.844 (13.156) | 98.084 (1.916) | 66.35 | bicubic | 600 | N/A |
| tf_efficientnet_b7_ns | 86.840 (13.160) | 98.094 (1.906) | 66.35 | bicubic | 600 | N/A |
| tf_efficientnet_b6_ns | 86.452 (13.548) | 97.882 (2.118) | 43.04 | bicubic | 528 | N/A |
| tf_efficientnet_b6_ns *tfp | 86.444 (13.556) | 97.880 (2.120) | 43.04 | bicubic | 528 | N/A |
| tf_efficientnet_b5_ns *tfp | 86.064 (13.936) | 97.746 (2.254) | 30.39 | bicubic | 456 | N/A |
| tf_efficientnet_b5_ns | 86.088 (13.912) | 97.752 (2.248) | 30.39 | bicubic | 456 | N/A |
| tf_efficientnet_b8_ap *tfp | 85.436 (14.564) | 97.272 (2.728) | 87.4 | bicubic | 672 | N/A |
| tf_efficientnet_b8 *tfp | 85.384 (14.616) | 97.394 (2.606) | 87.4 | bicubic | 672 | N/A |
| tf_efficientnet_b8 | 85.370 (14.630) | 97.390 (2.610) | 87.4 | bicubic | 672 | 0.954 |
| tf_efficientnet_b8_ap | 85.368 (14.632) | 97.294 (2.706) | 87.4 | bicubic | 672 | 0.954 |
| tf_efficientnet_b4_ns *tfp | 85.298 (14.702) | 97.504 (2.496) | 19.34 | bicubic | 380 | N/A |
| tf_efficientnet_b4_ns | 85.162 (14.838) | 97.470 (2.530) | 19.34 | bicubic | 380 | 0.922 |
| tf_efficientnet_b7_ap *tfp | 85.154 (14.846) | 97.244 (2.756) | 66.35 | bicubic | 600 | N/A |
| tf_efficientnet_b7_ap | 85.118 (14.882) | 97.252 (2.748) | 66.35 | bicubic | 600 | 0.949 |
| tf_efficientnet_b7 *tfp | 84.940 (15.060) | 97.214 (2.786) | 66.35 | bicubic | 600 | N/A |
| tf_efficientnet_b7 | 84.932 (15.068) | 97.208 (2.792) | 66.35 | bicubic | 600 | 0.949 |
| tf_efficientnet_b6_ap | 84.786 (15.214) | 97.138 (2.862) | 43.04 | bicubic | 528 | 0.942 |
| tf_efficientnet_b6_ap *tfp | 84.760 (15.240) | 97.124 (2.876) | 43.04 | bicubic | 528 | N/A |
| tf_efficientnet_b5_ap *tfp | 84.276 (15.724) | 96.932 (3.068) | 30.39 | bicubic | 456 | N/A |
| tf_efficientnet_b5_ap | 84.254 (15.746) | 96.976 (3.024) | 30.39 | bicubic | 456 | 0.934 |
| tf_efficientnet_b6 *tfp | 84.140 (15.860) | 96.852 (3.148) | 43.04 | bicubic | 528 | N/A |
| tf_efficientnet_b6 | 84.110 (15.890) | 96.886 (3.114) | 43.04 | bicubic | 528 | 0.942 |
| tf_efficientnet_b3_ns *tfp | 84.054 (15.946) | 96.918 (3.082) | 12.23 | bicubic | 300 | N/A |
| tf_efficientnet_b3_ns | 84.048 (15.952) | 96.910 (3.090) | 12.23 | bicubic | 300 | .904 |
| tf_efficientnet_b5 *tfp | 83.822 (16.178) | 96.756 (3.244) | 30.39 | bicubic | 456 | N/A |
| tf_efficientnet_b5 | 83.812 (16.188) | 96.748 (3.252) | 30.39 | bicubic | 456 | 0.934 |
| tf_efficientnet_b4_ap *tfp | 83.278 (16.722) | 96.376 (3.624) | 19.34 | bicubic | 380 | N/A |
| tf_efficientnet_b4_ap | 83.248 (16.752) | 96.388 (3.612) | 19.34 | bicubic | 380 | 0.922 |
| tf_efficientnet_b4 | 83.022 (16.978) | 96.300 (3.700) | 19.34 | bicubic | 380 | 0.922 |
| tf_efficientnet_b4 *tfp | 82.948 (17.052) | 96.308 (3.692) | 19.34 | bicubic | 380 | N/A |
| tf_efficientnet_b2_ns *tfp | 82.436 (17.564) | 96.268 (3.732) | 9.11 | bicubic | 260 | N/A |
| tf_efficientnet_b2_ns | 82.380 (17.620) | 96.248 (3.752) | 9.11 | bicubic | 260 | 0.89 |
| tf_efficientnet_b3_ap *tfp | 81.882 (18.118) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A |
| tf_efficientnet_b3_ap | 81.828 (18.172) | 95.624 (4.376) | 12.23 | bicubic | 300 | 0.904 |
| tf_efficientnet_b3 | 81.636 (18.364) | 95.718 (4.282) | 12.23 | bicubic | 300 | 0.904 |
| tf_efficientnet_b3 *tfp | 81.576 (18.424) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A |
| tf_efficientnet_lite4 | 81.528 (18.472) | 95.668 (4.332) | 13.00 | bilinear | 380 | 0.92 |
| tf_efficientnet_b1_ns *tfp | 81.514 (18.486) | 95.776 (4.224) | 7.79 | bicubic | 240 | N/A |
| tf_efficientnet_lite4 *tfp | 81.502 (18.498) | 95.676 (4.324) | 13.00 | bilinear | 380 | N/A |
| tf_efficientnet_b1_ns | 81.388 (18.612) | 95.738 (4.262) | 7.79 | bicubic | 240 | 0.88 |
| tf_efficientnet_el | 80.534 (19.466) | 95.190 (4.810) | 10.59 | bicubic | 300 | 0.904 |
| tf_efficientnet_el *tfp | 80.476 (19.524) | 95.200 (4.800) | 10.59 | bicubic | 300 | N/A |
| tf_efficientnet_b2_ap *tfp | 80.420 (19.580) | 95.040 (4.960) | 9.11 | bicubic | 260 | N/A |
| tf_efficientnet_b2_ap | 80.306 (19.694) | 95.028 (4.972) | 9.11 | bicubic | 260 | 0.890 |
| tf_efficientnet_b2 *tfp | 80.188 (19.812) | 94.974 (5.026) | 9.11 | bicubic | 260 | N/A |
| tf_efficientnet_b2 | 80.086 (19.914) | 94.908 (5.092) | 9.11 | bicubic | 260 | 0.890 |
| tf_efficientnet_lite3 | 79.812 (20.188) | 94.914 (5.086) | 8.20 | bilinear | 300 | 0.904 |
| tf_efficientnet_lite3 *tfp | 79.734 (20.266) | 94.838 (5.162) | 8.20 | bilinear | 300 | N/A |
| tf_efficientnet_b1_ap *tfp | 79.532 (20.468) | 94.378 (5.622) | 7.79 | bicubic | 240 | N/A |
| tf_efficientnet_cc_b1_8e *tfp | 79.464 (20.536)| 94.492 (5.508) | 39.7 | bicubic | 240 | 0.88 |
| tf_efficientnet_cc_b1_8e | 79.298 (20.702) | 94.364 (5.636) | 39.7 | bicubic | 240 | 0.88 |
| tf_efficientnet_b1_ap | 79.278 (20.722) | 94.308 (5.692) | 7.79 | bicubic | 240 | 0.88 |
| tf_efficientnet_b1 *tfp | 79.172 (20.828) | 94.450 (5.550) | 7.79 | bicubic | 240 | N/A |
| tf_efficientnet_em *tfp | 78.958 (21.042) | 94.458 (5.542) | 6.90 | bicubic | 240 | N/A |
| tf_efficientnet_b0_ns *tfp | 78.806 (21.194) | 94.496 (5.504) | 5.29 | bicubic | 224 | N/A |
| tf_mixnet_l *tfp | 78.846 (21.154) | 94.212 (5.788) | 7.33 | bilinear | 224 | N/A |
| tf_efficientnet_b1 | 78.826 (21.174) | 94.198 (5.802) | 7.79 | bicubic | 240 | 0.88 |
| tf_mixnet_l | 78.770 (21.230) | 94.004 (5.996) | 7.33 | bicubic | 224 | 0.875 |
| tf_efficientnet_em | 78.742 (21.258) | 94.332 (5.668) | 6.90 | bicubic | 240 | 0.875 |
| tf_efficientnet_b0_ns | 78.658 (21.342) | 94.376 (5.624) | 5.29 | bicubic | 224 | 0.875 |
| tf_efficientnet_cc_b0_8e *tfp | 78.314 (21.686) | 93.790 (6.210) | 24.0 | bicubic | 224 | 0.875 |
| tf_efficientnet_cc_b0_8e | 77.908 (22.092) | 93.656 (6.344) | 24.0 | bicubic | 224 | 0.875 |
| tf_efficientnet_cc_b0_4e *tfp | 77.746 (22.254) | 93.552 (6.448) | 13.3 | bicubic | 224 | 0.875 |
| tf_efficientnet_cc_b0_4e | 77.304 (22.696) | 93.332 (6.668) | 13.3 | bicubic | 224 | 0.875 |
| tf_efficientnet_es *tfp | 77.616 (22.384) | 93.750 (6.250) | 5.44 | bicubic | 224 | N/A |
| tf_efficientnet_lite2 *tfp | 77.544 (22.456) | 93.800 (6.200) | 6.09 | bilinear | 260 | N/A |
| tf_efficientnet_lite2 | 77.460 (22.540) | 93.746 (6.254) | 6.09 | bicubic | 260 | 0.89 |
| tf_efficientnet_b0_ap *tfp | 77.514 (22.486) | 93.576 (6.424) | 5.29 | bicubic | 224 | N/A |
| tf_efficientnet_es | 77.264 (22.736) | 93.600 (6.400) | 5.44 | bicubic | 224 | N/A |
| tf_efficientnet_b0 *tfp | 77.258 (22.742) | 93.478 (6.522) | 5.29 | bicubic | 224 | N/A |
| tf_efficientnet_b0_ap | 77.084 (22.916) | 93.254 (6.746) | 5.29 | bicubic | 224 | 0.875 |
| tf_mixnet_m *tfp | 77.072 (22.928) | 93.368 (6.632) | 5.01 | bilinear | 224 | N/A |
| tf_mixnet_m | 76.950 (23.050) | 93.156 (6.844) | 5.01 | bicubic | 224 | 0.875 |
| tf_efficientnet_b0 | 76.848 (23.152) | 93.228 (6.772) | 5.29 | bicubic | 224 | 0.875 |
| tf_efficientnet_lite1 *tfp | 76.764 (23.236) | 93.326 (6.674) | 5.42 | bilinear | 240 | N/A |
| tf_efficientnet_lite1 | 76.638 (23.362) | 93.232 (6.768) | 5.42 | bicubic | 240 | 0.882 |
| tf_mixnet_s *tfp | 75.800 (24.200) | 92.788 (7.212) | 4.13 | bilinear | 224 | N/A |
| tf_mobilenetv3_large_100 *tfp | 75.768 (24.232) | 92.710 (7.290) | 5.48 | bilinear | 224 | N/A |
| tf_mixnet_s | 75.648 (24.352) | 92.636 (7.364) | 4.13 | bicubic | 224 | 0.875 |
| tf_mobilenetv3_large_100 | 75.516 (24.484) | 92.600 (7.400) | 5.48 | bilinear | 224 | 0.875 |
| tf_efficientnet_lite0 *tfp | 75.074 (24.926) | 92.314 (7.686) | 4.65 | bilinear | 224 | N/A |
| tf_efficientnet_lite0 | 74.842 (25.158) | 92.170 (7.830) | 4.65 | bicubic | 224 | 0.875 |
| tf_mobilenetv3_large_075 *tfp | 73.730 (26.270) | 91.616 (8.384) | 3.99 | bilinear | 224 |N/A |
| tf_mobilenetv3_large_075 | 73.442 (26.558) | 91.352 (8.648) | 3.99 | bilinear | 224 | 0.875 |
| tf_mobilenetv3_large_minimal_100 *tfp | 72.678 (27.322) | 90.860 (9.140) | 3.92 | bilinear | 224 | N/A |
| tf_mobilenetv3_large_minimal_100 | 72.244 (27.756) | 90.636 (9.364) | 3.92 | bilinear | 224 | 0.875 |
| tf_mobilenetv3_small_100 *tfp | 67.918 (32.082) | 87.958 (12.042 | 2.54 | bilinear | 224 | N/A |
| tf_mobilenetv3_small_100 | 67.918 (32.082) | 87.662 (12.338) | 2.54 | bilinear | 224 | 0.875 |
| tf_mobilenetv3_small_075 *tfp | 66.142 (33.858) | 86.498 (13.502) | 2.04 | bilinear | 224 | N/A |
| tf_mobilenetv3_small_075 | 65.718 (34.282) | 86.136 (13.864) | 2.04 | bilinear | 224 | 0.875 |
| tf_mobilenetv3_small_minimal_100 *tfp | 63.378 (36.622) | 84.802 (15.198) | 2.04 | bilinear | 224 | N/A |
| tf_mobilenetv3_small_minimal_100 | 62.898 (37.102) | 84.230 (15.770) | 2.04 | bilinear | 224 | 0.875 |
*tfp models validated with `tf-preprocessing` pipeline
Google tf and tflite weights ported from official Tensorflow repositories
* https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
* https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
* https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet
## Usage
### Environment
All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically Python 3.6.x, 3.7.x, 3.8.x.
Users have reported that a Python 3 Anaconda install in Windows works. I have not verified this myself.
PyTorch versions 1.4, 1.5, 1.6 have been tested with this code.
I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda:
```
conda create -n torch-env
conda activate torch-env
conda install -c pytorch pytorch torchvision cudatoolkit=10.2
```
### PyTorch Hub
Models can be accessed via the PyTorch Hub API
```
>>> torch.hub.list('rwightman/gen-efficientnet-pytorch')
['efficientnet_b0', ...]
>>> model = torch.hub.load('rwightman/gen-efficientnet-pytorch', 'efficientnet_b0', pretrained=True)
>>> model.eval()
>>> output = model(torch.randn(1,3,224,224))
```
### Pip
This package can be installed via pip.
Install (after conda env/install):
```
pip install geffnet
```
Eval use:
```
>>> import geffnet
>>> m = geffnet.create_model('mobilenetv3_large_100', pretrained=True)
>>> m.eval()
```
Train use:
```
>>> import geffnet
>>> # models can also be created by using the entrypoint directly
>>> m = geffnet.efficientnet_b2(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2)
>>> m.train()
```
Create in a nn.Sequential container, for fast.ai, etc:
```
>>> import geffnet
>>> m = geffnet.mixnet_l(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2, as_sequential=True)
```
### Exporting
Scripts are included to
* export models to ONNX (`onnx_export.py`)
* optimized ONNX graph (`onnx_optimize.py` or `onnx_validate.py` w/ `--onnx-output-opt` arg)
* validate with ONNX runtime (`onnx_validate.py`)
* convert ONNX model to Caffe2 (`onnx_to_caffe.py`)
* validate in Caffe2 (`caffe2_validate.py`)
* benchmark in Caffe2 w/ FLOPs, parameters output (`caffe2_benchmark.py`)
As an example, to export the MobileNet-V3 pretrained model and then run an Imagenet validation:
```
python onnx_export.py --model mobilenetv3_large_100 ./mobilenetv3_100.onnx
python onnx_validate.py /imagenet/validation/ --onnx-input ./mobilenetv3_100.onnx
```
These scripts were tested to be working as of PyTorch 1.6 and ONNX 1.7 w/ ONNX runtime 1.4. Caffe2 compatible
export now requires additional args mentioned in the export script (not needed in earlier versions).
#### Export Notes
1. The TF ported weights with the 'SAME' conv padding activated cannot be exported to ONNX unless `_EXPORTABLE` flag in `config.py` is set to True. Use `config.set_exportable(True)` as in the `onnx_export.py` script.
2. TF ported models with 'SAME' padding will have the padding fixed at export time to the resolution used for export. Even though dynamic padding is supported in opset >= 11, I can't get it working.
3. ONNX optimize facility doesn't work reliably in PyTorch 1.6 / ONNX 1.7. Fortunately, the onnxruntime based inference is working very well now and includes on the fly optimization.
3. ONNX / Caffe2 export/import frequently breaks with different PyTorch and ONNX version releases. Please check their respective issue trackers before filing issues here.

View File

@@ -0,0 +1,65 @@
""" Caffe2 validation script
This script runs Caffe2 benchmark on exported ONNX model.
It is a useful tool for reporting model FLOPS.
Copyright 2020 Ross Wightman
"""
import argparse
from caffe2.python import core, workspace, model_helper
from caffe2.proto import caffe2_pb2
parser = argparse.ArgumentParser(description='Caffe2 Model Benchmark')
parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME',
help='caffe2 model pb name prefix')
parser.add_argument('--c2-init', default='', type=str, metavar='PATH',
help='caffe2 model init .pb')
parser.add_argument('--c2-predict', default='', type=str, metavar='PATH',
help='caffe2 model predict .pb')
parser.add_argument('-b', '--batch-size', default=1, type=int,
metavar='N', help='mini-batch size (default: 1)')
parser.add_argument('--img-size', default=224, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
def main():
args = parser.parse_args()
args.gpu_id = 0
if args.c2_prefix:
args.c2_init = args.c2_prefix + '.init.pb'
args.c2_predict = args.c2_prefix + '.predict.pb'
model = model_helper.ModelHelper(name="le_net", init_params=False)
# Bring in the init net from init_net.pb
init_net_proto = caffe2_pb2.NetDef()
with open(args.c2_init, "rb") as f:
init_net_proto.ParseFromString(f.read())
model.param_init_net = core.Net(init_net_proto)
# bring in the predict net from predict_net.pb
predict_net_proto = caffe2_pb2.NetDef()
with open(args.c2_predict, "rb") as f:
predict_net_proto.ParseFromString(f.read())
model.net = core.Net(predict_net_proto)
# CUDA performance not impressive
#device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id)
#model.net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)
#model.param_init_net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)
input_blob = model.net.external_inputs[0]
model.param_init_net.GaussianFill(
[],
input_blob.GetUnscopedName(),
shape=(args.batch_size, 3, args.img_size, args.img_size),
mean=0.0,
std=1.0)
workspace.RunNetOnce(model.param_init_net)
workspace.CreateNet(model.net, overwrite=True)
workspace.BenchmarkNet(model.net.Proto().name, 5, 20, True)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,138 @@
""" Caffe2 validation script
This script is created to verify exported ONNX models running in Caffe2
It utilizes the same PyTorch dataloader/processing pipeline for a
fair comparison against the originals.
Copyright 2020 Ross Wightman
"""
import argparse
import numpy as np
from caffe2.python import core, workspace, model_helper
from caffe2.proto import caffe2_pb2
from data import create_loader, resolve_data_config, Dataset
from utils import AverageMeter
import time
parser = argparse.ArgumentParser(description='Caffe2 ImageNet Validation')
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME',
help='caffe2 model pb name prefix')
parser.add_argument('--c2-init', default='', type=str, metavar='PATH',
help='caffe2 model init .pb')
parser.add_argument('--c2-predict', default='', type=str, metavar='PATH',
help='caffe2 model predict .pb')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT',
help='Override default crop pct of 0.875')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
help='use tensorflow mnasnet preporcessing')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
def main():
args = parser.parse_args()
args.gpu_id = 0
if args.c2_prefix:
args.c2_init = args.c2_prefix + '.init.pb'
args.c2_predict = args.c2_prefix + '.predict.pb'
model = model_helper.ModelHelper(name="validation_net", init_params=False)
# Bring in the init net from init_net.pb
init_net_proto = caffe2_pb2.NetDef()
with open(args.c2_init, "rb") as f:
init_net_proto.ParseFromString(f.read())
model.param_init_net = core.Net(init_net_proto)
# bring in the predict net from predict_net.pb
predict_net_proto = caffe2_pb2.NetDef()
with open(args.c2_predict, "rb") as f:
predict_net_proto.ParseFromString(f.read())
model.net = core.Net(predict_net_proto)
data_config = resolve_data_config(None, args)
loader = create_loader(
Dataset(args.data, load_bytes=args.tf_preprocessing),
input_size=data_config['input_size'],
batch_size=args.batch_size,
use_prefetcher=False,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
crop_pct=data_config['crop_pct'],
tensorflow_preprocessing=args.tf_preprocessing)
# this is so obvious, wonderful interface </sarcasm>
input_blob = model.net.external_inputs[0]
output_blob = model.net.external_outputs[0]
if True:
device_opts = None
else:
# CUDA is crashing, no idea why, awesome error message, give it a try for kicks
device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id)
model.net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)
model.param_init_net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)
model.param_init_net.GaussianFill(
[], input_blob.GetUnscopedName(),
shape=(1,) + data_config['input_size'], mean=0.0, std=1.0)
workspace.RunNetOnce(model.param_init_net)
workspace.CreateNet(model.net, overwrite=True)
batch_time = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
end = time.time()
for i, (input, target) in enumerate(loader):
# run the net and return prediction
caffe2_in = input.data.numpy()
workspace.FeedBlob(input_blob, caffe2_in, device_opts)
workspace.RunNet(model.net, num_iter=1)
output = workspace.FetchBlob(output_blob)
# measure accuracy and record loss
prec1, prec5 = accuracy_np(output.data, target.numpy())
top1.update(prec1.item(), input.size(0))
top5.update(prec5.item(), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s, {ms_avg:.3f} ms/sample) \t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg,
ms_avg=100 * batch_time.avg / input.size(0), top1=top1, top5=top5))
print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format(
top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg))
def accuracy_np(output, target):
max_indices = np.argsort(output, axis=1)[:, ::-1]
top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean()
top1 = 100 * np.equal(max_indices[:, 0], target).mean()
return top1, top5
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,5 @@
from .gen_efficientnet import *
from .mobilenetv3 import *
from .model_factory import create_model
from .config import is_exportable, is_scriptable, set_exportable, set_scriptable
from .activations import *

View File

@@ -0,0 +1,137 @@
from geffnet import config
from geffnet.activations.activations_me import *
from geffnet.activations.activations_jit import *
from geffnet.activations.activations import *
import torch
_has_silu = 'silu' in dir(torch.nn.functional)
_ACT_FN_DEFAULT = dict(
silu=F.silu if _has_silu else swish,
swish=F.silu if _has_silu else swish,
mish=mish,
relu=F.relu,
relu6=F.relu6,
sigmoid=sigmoid,
tanh=tanh,
hard_sigmoid=hard_sigmoid,
hard_swish=hard_swish,
)
_ACT_FN_JIT = dict(
silu=F.silu if _has_silu else swish_jit,
swish=F.silu if _has_silu else swish_jit,
mish=mish_jit,
)
_ACT_FN_ME = dict(
silu=F.silu if _has_silu else swish_me,
swish=F.silu if _has_silu else swish_me,
mish=mish_me,
hard_swish=hard_swish_me,
hard_sigmoid_jit=hard_sigmoid_me,
)
_ACT_LAYER_DEFAULT = dict(
silu=nn.SiLU if _has_silu else Swish,
swish=nn.SiLU if _has_silu else Swish,
mish=Mish,
relu=nn.ReLU,
relu6=nn.ReLU6,
sigmoid=Sigmoid,
tanh=Tanh,
hard_sigmoid=HardSigmoid,
hard_swish=HardSwish,
)
_ACT_LAYER_JIT = dict(
silu=nn.SiLU if _has_silu else SwishJit,
swish=nn.SiLU if _has_silu else SwishJit,
mish=MishJit,
)
_ACT_LAYER_ME = dict(
silu=nn.SiLU if _has_silu else SwishMe,
swish=nn.SiLU if _has_silu else SwishMe,
mish=MishMe,
hard_swish=HardSwishMe,
hard_sigmoid=HardSigmoidMe
)
_OVERRIDE_FN = dict()
_OVERRIDE_LAYER = dict()
def add_override_act_fn(name, fn):
global _OVERRIDE_FN
_OVERRIDE_FN[name] = fn
def update_override_act_fn(overrides):
assert isinstance(overrides, dict)
global _OVERRIDE_FN
_OVERRIDE_FN.update(overrides)
def clear_override_act_fn():
global _OVERRIDE_FN
_OVERRIDE_FN = dict()
def add_override_act_layer(name, fn):
_OVERRIDE_LAYER[name] = fn
def update_override_act_layer(overrides):
assert isinstance(overrides, dict)
global _OVERRIDE_LAYER
_OVERRIDE_LAYER.update(overrides)
def clear_override_act_layer():
global _OVERRIDE_LAYER
_OVERRIDE_LAYER = dict()
def get_act_fn(name='relu'):
""" Activation Function Factory
Fetching activation fns by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
"""
if name in _OVERRIDE_FN:
return _OVERRIDE_FN[name]
use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit())
if use_me and name in _ACT_FN_ME:
# If not exporting or scripting the model, first look for a memory optimized version
# activation with custom autograd, then fallback to jit scripted, then a Python or Torch builtin
return _ACT_FN_ME[name]
if config.is_exportable() and name in ('silu', 'swish'):
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
return swish
use_jit = not (config.is_exportable() or config.is_no_jit())
# NOTE: export tracing should work with jit scripted components, but I keep running into issues
if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting
return _ACT_FN_JIT[name]
return _ACT_FN_DEFAULT[name]
def get_act_layer(name='relu'):
""" Activation Layer Factory
Fetching activation layers by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
"""
if name in _OVERRIDE_LAYER:
return _OVERRIDE_LAYER[name]
use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit())
if use_me and name in _ACT_LAYER_ME:
return _ACT_LAYER_ME[name]
if config.is_exportable() and name in ('silu', 'swish'):
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
return Swish
use_jit = not (config.is_exportable() or config.is_no_jit())
# NOTE: export tracing should work with jit scripted components, but I keep running into issues
if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting
return _ACT_LAYER_JIT[name]
return _ACT_LAYER_DEFAULT[name]

View File

@@ -0,0 +1,102 @@
""" Activations
A collection of activations fn and modules with a common interface so that they can
easily be swapped. All have an `inplace` arg even if not used.
Copyright 2020 Ross Wightman
"""
from torch import nn as nn
from torch.nn import functional as F
def swish(x, inplace: bool = False):
"""Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
and also as Swish (https://arxiv.org/abs/1710.05941).
TODO Rename to SiLU with addition to PyTorch
"""
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
class Swish(nn.Module):
def __init__(self, inplace: bool = False):
super(Swish, self).__init__()
self.inplace = inplace
def forward(self, x):
return swish(x, self.inplace)
def mish(x, inplace: bool = False):
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
"""
return x.mul(F.softplus(x).tanh())
class Mish(nn.Module):
def __init__(self, inplace: bool = False):
super(Mish, self).__init__()
self.inplace = inplace
def forward(self, x):
return mish(x, self.inplace)
def sigmoid(x, inplace: bool = False):
return x.sigmoid_() if inplace else x.sigmoid()
# PyTorch has this, but not with a consistent inplace argmument interface
class Sigmoid(nn.Module):
def __init__(self, inplace: bool = False):
super(Sigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
return x.sigmoid_() if self.inplace else x.sigmoid()
def tanh(x, inplace: bool = False):
return x.tanh_() if inplace else x.tanh()
# PyTorch has this, but not with a consistent inplace argmument interface
class Tanh(nn.Module):
def __init__(self, inplace: bool = False):
super(Tanh, self).__init__()
self.inplace = inplace
def forward(self, x):
return x.tanh_() if self.inplace else x.tanh()
def hard_swish(x, inplace: bool = False):
inner = F.relu6(x + 3.).div_(6.)
return x.mul_(inner) if inplace else x.mul(inner)
class HardSwish(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSwish, self).__init__()
self.inplace = inplace
def forward(self, x):
return hard_swish(x, self.inplace)
def hard_sigmoid(x, inplace: bool = False):
if inplace:
return x.add_(3.).clamp_(0., 6.).div_(6.)
else:
return F.relu6(x + 3.) / 6.
class HardSigmoid(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
return hard_sigmoid(x, self.inplace)

View File

@@ -0,0 +1,79 @@
""" Activations (jit)
A collection of jit-scripted activations fn and modules with a common interface so that they can
easily be swapped. All have an `inplace` arg even if not used.
All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
versions if they contain in-place ops.
Copyright 2020 Ross Wightman
"""
import torch
from torch import nn as nn
from torch.nn import functional as F
__all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit',
'hard_sigmoid_jit', 'HardSigmoidJit', 'hard_swish_jit', 'HardSwishJit']
@torch.jit.script
def swish_jit(x, inplace: bool = False):
"""Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
and also as Swish (https://arxiv.org/abs/1710.05941).
TODO Rename to SiLU with addition to PyTorch
"""
return x.mul(x.sigmoid())
@torch.jit.script
def mish_jit(x, _inplace: bool = False):
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
"""
return x.mul(F.softplus(x).tanh())
class SwishJit(nn.Module):
def __init__(self, inplace: bool = False):
super(SwishJit, self).__init__()
def forward(self, x):
return swish_jit(x)
class MishJit(nn.Module):
def __init__(self, inplace: bool = False):
super(MishJit, self).__init__()
def forward(self, x):
return mish_jit(x)
@torch.jit.script
def hard_sigmoid_jit(x, inplace: bool = False):
# return F.relu6(x + 3.) / 6.
return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
class HardSigmoidJit(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSigmoidJit, self).__init__()
def forward(self, x):
return hard_sigmoid_jit(x)
@torch.jit.script
def hard_swish_jit(x, inplace: bool = False):
# return x * (F.relu6(x + 3.) / 6)
return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
class HardSwishJit(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSwishJit, self).__init__()
def forward(self, x):
return hard_swish_jit(x)

View File

@@ -0,0 +1,174 @@
""" Activations (memory-efficient w/ custom autograd)
A collection of activations fn and modules with a common interface so that they can
easily be swapped. All have an `inplace` arg even if not used.
These activations are not compatible with jit scripting or ONNX export of the model, please use either
the JIT or basic versions of the activations.
Copyright 2020 Ross Wightman
"""
import torch
from torch import nn as nn
from torch.nn import functional as F
__all__ = ['swish_me', 'SwishMe', 'mish_me', 'MishMe',
'hard_sigmoid_me', 'HardSigmoidMe', 'hard_swish_me', 'HardSwishMe']
@torch.jit.script
def swish_jit_fwd(x):
return x.mul(torch.sigmoid(x))
@torch.jit.script
def swish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
class SwishJitAutoFn(torch.autograd.Function):
""" torch.jit.script optimised Swish w/ memory-efficient checkpoint
Inspired by conversation btw Jeremy Howard & Adam Pazske
https://twitter.com/jeremyphoward/status/1188251041835315200
Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
and also as Swish (https://arxiv.org/abs/1710.05941).
TODO Rename to SiLU with addition to PyTorch
"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return swish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return swish_jit_bwd(x, grad_output)
def swish_me(x, inplace=False):
return SwishJitAutoFn.apply(x)
class SwishMe(nn.Module):
def __init__(self, inplace: bool = False):
super(SwishMe, self).__init__()
def forward(self, x):
return SwishJitAutoFn.apply(x)
@torch.jit.script
def mish_jit_fwd(x):
return x.mul(torch.tanh(F.softplus(x)))
@torch.jit.script
def mish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
x_tanh_sp = F.softplus(x).tanh()
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
class MishJitAutoFn(torch.autograd.Function):
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
A memory efficient, jit scripted variant of Mish
"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return mish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return mish_jit_bwd(x, grad_output)
def mish_me(x, inplace=False):
return MishJitAutoFn.apply(x)
class MishMe(nn.Module):
def __init__(self, inplace: bool = False):
super(MishMe, self).__init__()
def forward(self, x):
return MishJitAutoFn.apply(x)
@torch.jit.script
def hard_sigmoid_jit_fwd(x, inplace: bool = False):
return (x + 3).clamp(min=0, max=6).div(6.)
@torch.jit.script
def hard_sigmoid_jit_bwd(x, grad_output):
m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
return grad_output * m
class HardSigmoidJitAutoFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return hard_sigmoid_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return hard_sigmoid_jit_bwd(x, grad_output)
def hard_sigmoid_me(x, inplace: bool = False):
return HardSigmoidJitAutoFn.apply(x)
class HardSigmoidMe(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSigmoidMe, self).__init__()
def forward(self, x):
return HardSigmoidJitAutoFn.apply(x)
@torch.jit.script
def hard_swish_jit_fwd(x):
return x * (x + 3).clamp(min=0, max=6).div(6.)
@torch.jit.script
def hard_swish_jit_bwd(x, grad_output):
m = torch.ones_like(x) * (x >= 3.)
m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
return grad_output * m
class HardSwishJitAutoFn(torch.autograd.Function):
"""A memory efficient, jit-scripted HardSwish activation"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return hard_swish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return hard_swish_jit_bwd(x, grad_output)
def hard_swish_me(x, inplace=False):
return HardSwishJitAutoFn.apply(x)
class HardSwishMe(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSwishMe, self).__init__()
def forward(self, x):
return HardSwishJitAutoFn.apply(x)

View File

@@ -0,0 +1,123 @@
""" Global layer config state
"""
from typing import Any, Optional
__all__ = [
'is_exportable', 'is_scriptable', 'is_no_jit', 'layer_config_kwargs',
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config'
]
# Set to True if prefer to have layers with no jit optimization (includes activations)
_NO_JIT = False
# Set to True if prefer to have activation layers with no jit optimization
# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
# the jit flags so far are activations. This will change as more layers are updated and/or added.
_NO_ACTIVATION_JIT = False
# Set to True if exporting a model with Same padding via ONNX
_EXPORTABLE = False
# Set to True if wanting to use torch.jit.script on a model
_SCRIPTABLE = False
def is_no_jit():
return _NO_JIT
class set_no_jit:
def __init__(self, mode: bool) -> None:
global _NO_JIT
self.prev = _NO_JIT
_NO_JIT = mode
def __enter__(self) -> None:
pass
def __exit__(self, *args: Any) -> bool:
global _NO_JIT
_NO_JIT = self.prev
return False
def is_exportable():
return _EXPORTABLE
class set_exportable:
def __init__(self, mode: bool) -> None:
global _EXPORTABLE
self.prev = _EXPORTABLE
_EXPORTABLE = mode
def __enter__(self) -> None:
pass
def __exit__(self, *args: Any) -> bool:
global _EXPORTABLE
_EXPORTABLE = self.prev
return False
def is_scriptable():
return _SCRIPTABLE
class set_scriptable:
def __init__(self, mode: bool) -> None:
global _SCRIPTABLE
self.prev = _SCRIPTABLE
_SCRIPTABLE = mode
def __enter__(self) -> None:
pass
def __exit__(self, *args: Any) -> bool:
global _SCRIPTABLE
_SCRIPTABLE = self.prev
return False
class set_layer_config:
""" Layer config context manager that allows setting all layer config flags at once.
If a flag arg is None, it will not change the current value.
"""
def __init__(
self,
scriptable: Optional[bool] = None,
exportable: Optional[bool] = None,
no_jit: Optional[bool] = None,
no_activation_jit: Optional[bool] = None):
global _SCRIPTABLE
global _EXPORTABLE
global _NO_JIT
global _NO_ACTIVATION_JIT
self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT
if scriptable is not None:
_SCRIPTABLE = scriptable
if exportable is not None:
_EXPORTABLE = exportable
if no_jit is not None:
_NO_JIT = no_jit
if no_activation_jit is not None:
_NO_ACTIVATION_JIT = no_activation_jit
def __enter__(self) -> None:
pass
def __exit__(self, *args: Any) -> bool:
global _SCRIPTABLE
global _EXPORTABLE
global _NO_JIT
global _NO_ACTIVATION_JIT
_SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
return False
def layer_config_kwargs(kwargs):
""" Consume config kwargs and return contextmgr obj """
return set_layer_config(
scriptable=kwargs.pop('scriptable', None),
exportable=kwargs.pop('exportable', None),
no_jit=kwargs.pop('no_jit', None))

View File

@@ -0,0 +1,304 @@
""" Conv2D w/ SAME padding, CondConv, MixedConv
A collection of conv layers and padding helpers needed by EfficientNet, MixNet, and
MobileNetV3 models that maintain weight compatibility with original Tensorflow models.
Copyright 2020 Ross Wightman
"""
import collections.abc
import math
from functools import partial
from itertools import repeat
from typing import Tuple, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .config import *
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
_single = _ntuple(1)
_pair = _ntuple(2)
_triple = _ntuple(3)
_quadruple = _ntuple(4)
def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
def _get_padding(kernel_size, stride=1, dilation=1, **_):
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
def _calc_same_pad(i: int, k: int, s: int, d: int):
return max((-(i // -s) - 1) * s + (k - 1) * d + 1 - i, 0)
def _same_pad_arg(input_size, kernel_size, stride, dilation):
ih, iw = input_size
kh, kw = kernel_size
pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
def _split_channels(num_chan, num_groups):
split = [num_chan // num_groups for _ in range(num_groups)]
split[0] += num_chan - sum(split)
return split
def conv2d_same(
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
ih, iw = x.size()[-2:]
kh, kw = weight.size()[-2:]
pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
class Conv2dSame(nn.Conv2d):
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
"""
# pylint: disable=unused-argument
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2dSame, self).__init__(
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
def forward(self, x):
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
class Conv2dSameExport(nn.Conv2d):
""" ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions
NOTE: This does not currently work with torch.jit.script
"""
# pylint: disable=unused-argument
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2dSameExport, self).__init__(
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
self.pad = None
self.pad_input_size = (0, 0)
def forward(self, x):
input_size = x.size()[-2:]
if self.pad is None:
pad_arg = _same_pad_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation)
self.pad = nn.ZeroPad2d(pad_arg)
self.pad_input_size = input_size
if self.pad is not None:
x = self.pad(x)
return F.conv2d(
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def get_padding_value(padding, kernel_size, **kwargs):
dynamic = False
if isinstance(padding, str):
# for any string padding, the padding will be calculated for you, one of three ways
padding = padding.lower()
if padding == 'same':
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
if _is_static_pad(kernel_size, **kwargs):
# static case, no extra overhead
padding = _get_padding(kernel_size, **kwargs)
else:
# dynamic padding
padding = 0
dynamic = True
elif padding == 'valid':
# 'VALID' padding, same as padding=0
padding = 0
else:
# Default to PyTorch style 'same'-ish symmetric padding
padding = _get_padding(kernel_size, **kwargs)
return padding, dynamic
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
padding = kwargs.pop('padding', '')
kwargs.setdefault('bias', False)
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
if is_dynamic:
if is_exportable():
assert not is_scriptable()
return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs)
else:
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
else:
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
class MixedConv2d(nn.ModuleDict):
""" Mixed Grouped Convolution
Based on MDConv and GroupedConv in MixNet impl:
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
"""
def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding='', dilation=1, depthwise=False, **kwargs):
super(MixedConv2d, self).__init__()
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
num_groups = len(kernel_size)
in_splits = _split_channels(in_channels, num_groups)
out_splits = _split_channels(out_channels, num_groups)
self.in_channels = sum(in_splits)
self.out_channels = sum(out_splits)
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
conv_groups = out_ch if depthwise else 1
self.add_module(
str(idx),
create_conv2d_pad(
in_ch, out_ch, k, stride=stride,
padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
)
self.splits = in_splits
def forward(self, x):
x_split = torch.split(x, self.splits, 1)
x_out = [conv(x_split[i]) for i, conv in enumerate(self.values())]
x = torch.cat(x_out, 1)
return x
def get_condconv_initializer(initializer, num_experts, expert_shape):
def condconv_initializer(weight):
"""CondConv initializer function."""
num_params = np.prod(expert_shape)
if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
weight.shape[1] != num_params):
raise (ValueError(
'CondConv variables must have shape [num_experts, num_params]'))
for i in range(num_experts):
initializer(weight[i].view(expert_shape))
return condconv_initializer
class CondConv2d(nn.Module):
""" Conditional Convolution
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
https://github.com/pytorch/pytorch/issues/17983
"""
__constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']
def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
super(CondConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
padding_val, is_padding_dynamic = get_padding_value(
padding, kernel_size, stride=stride, dilation=dilation)
self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
self.padding = _pair(padding_val)
self.dilation = _pair(dilation)
self.groups = groups
self.num_experts = num_experts
self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
weight_num_param = 1
for wd in self.weight_shape:
weight_num_param *= wd
self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
if bias:
self.bias_shape = (self.out_channels,)
self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
init_weight = get_condconv_initializer(
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
init_weight(self.weight)
if self.bias is not None:
fan_in = np.prod(self.weight_shape[1:])
bound = 1 / math.sqrt(fan_in)
init_bias = get_condconv_initializer(
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
init_bias(self.bias)
def forward(self, x, routing_weights):
B, C, H, W = x.shape
weight = torch.matmul(routing_weights, self.weight)
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
weight = weight.view(new_weight_shape)
bias = None
if self.bias is not None:
bias = torch.matmul(routing_weights, self.bias)
bias = bias.view(B * self.out_channels)
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel
x = x.view(1, B * C, H, W)
if self.dynamic_padding:
out = conv2d_same(
x, weight, bias, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups * B)
else:
out = F.conv2d(
x, weight, bias, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups * B)
out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
# Literal port (from TF definition)
# x = torch.split(x, 1, 0)
# weight = torch.split(weight, 1, 0)
# if self.bias is not None:
# bias = torch.matmul(routing_weights, self.bias)
# bias = torch.split(bias, 1, 0)
# else:
# bias = [None] * B
# out = []
# for xi, wi, bi in zip(x, weight, bias):
# wi = wi.view(*self.weight_shape)
# if bi is not None:
# bi = bi.view(*self.bias_shape)
# out.append(self.conv_fn(
# xi, wi, bi, stride=self.stride, padding=self.padding,
# dilation=self.dilation, groups=self.groups))
# out = torch.cat(out, 0)
return out
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
assert 'groups' not in kwargs # only use 'depthwise' bool arg
if isinstance(kernel_size, list):
assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
# We're going to use only lists for defining the MixedConv2d kernel groups,
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
else:
depthwise = kwargs.pop('depthwise', False)
groups = out_chs if depthwise else 1
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
else:
m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
return m

View File

@@ -0,0 +1,683 @@
""" EfficientNet / MobileNetV3 Blocks and Builder
Copyright 2020 Ross Wightman
"""
import re
from copy import deepcopy
from .conv2d_layers import *
from geffnet.activations import *
__all__ = ['get_bn_args_tf', 'resolve_bn_args', 'resolve_se_args', 'resolve_act_layer', 'make_divisible',
'round_channels', 'drop_connect', 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv',
'InvertedResidual', 'CondConvResidual', 'EdgeResidual', 'EfficientNetBuilder', 'decode_arch_def',
'initialize_weight_default', 'initialize_weight_goog', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT'
]
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
# NOTE: momentum varies btw .99 and .9997 depending on source
# .99 in official TF TPU impl
# .9997 (/w .999 in search space) for paper
#
# PyTorch defaults are momentum = .1, eps = 1e-5
#
BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
BN_EPS_TF_DEFAULT = 1e-3
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
def get_bn_args_tf():
return _BN_ARGS_TF.copy()
def resolve_bn_args(kwargs):
bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
bn_momentum = kwargs.pop('bn_momentum', None)
if bn_momentum is not None:
bn_args['momentum'] = bn_momentum
bn_eps = kwargs.pop('bn_eps', None)
if bn_eps is not None:
bn_args['eps'] = bn_eps
return bn_args
_SE_ARGS_DEFAULT = dict(
gate_fn=sigmoid,
act_layer=None, # None == use containing block's activation layer
reduce_mid=False,
divisor=1)
def resolve_se_args(kwargs, in_chs, act_layer=None):
se_kwargs = kwargs.copy() if kwargs is not None else {}
# fill in args that aren't specified with the defaults
for k, v in _SE_ARGS_DEFAULT.items():
se_kwargs.setdefault(k, v)
# some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch
if not se_kwargs.pop('reduce_mid'):
se_kwargs['reduced_base_chs'] = in_chs
# act_layer override, if it remains None, the containing block's act_layer will be used
if se_kwargs['act_layer'] is None:
assert act_layer is not None
se_kwargs['act_layer'] = act_layer
return se_kwargs
def resolve_act_layer(kwargs, default='relu'):
act_layer = kwargs.pop('act_layer', default)
if isinstance(act_layer, str):
act_layer = get_act_layer(act_layer)
return act_layer
def make_divisible(v: int, divisor: int = 8, min_value: int = None):
min_value = min_value or divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v: # ensure round down does not go down by more than 10%.
new_v += divisor
return new_v
def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
"""Round number of filters based on depth multiplier."""
if not multiplier:
return channels
channels *= multiplier
return make_divisible(channels, divisor, channel_min)
def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.):
"""Apply drop connect."""
if not training:
return inputs
keep_prob = 1 - drop_connect_rate
random_tensor = keep_prob + torch.rand(
(inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device)
random_tensor.floor_() # binarize
output = inputs.div(keep_prob) * random_tensor
return output
class SqueezeExcite(nn.Module):
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1):
super(SqueezeExcite, self).__init__()
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
self.act1 = act_layer(inplace=True)
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
self.gate_fn = gate_fn
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
x_se = self.conv_reduce(x_se)
x_se = self.act1(x_se)
x_se = self.conv_expand(x_se)
x = x * self.gate_fn(x_se)
return x
class ConvBnAct(nn.Module):
def __init__(self, in_chs, out_chs, kernel_size,
stride=1, pad_type='', act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
super(ConvBnAct, self).__init__()
assert stride in [1, 2]
norm_kwargs = norm_kwargs or {}
self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type)
self.bn1 = norm_layer(out_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn1(x)
x = self.act1(x)
return x
class DepthwiseSeparableConv(nn.Module):
""" DepthwiseSeparable block
Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion
factor of 1.0. This is an alternative to having a IR with optional first pw conv.
"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
super(DepthwiseSeparableConv, self).__init__()
assert stride in [1, 2]
norm_kwargs = norm_kwargs or {}
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
self.drop_connect_rate = drop_connect_rate
self.conv_dw = select_conv2d(
in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True)
self.bn1 = norm_layer(in_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation
if se_ratio is not None and se_ratio > 0.:
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs)
else:
self.se = nn.Identity()
self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
self.bn2 = norm_layer(out_chs, **norm_kwargs)
self.act2 = act_layer(inplace=True) if pw_act else nn.Identity()
def forward(self, x):
residual = x
x = self.conv_dw(x)
x = self.bn1(x)
x = self.act1(x)
x = self.se(x)
x = self.conv_pw(x)
x = self.bn2(x)
x = self.act2(x)
if self.has_residual:
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
x += residual
return x
class InvertedResidual(nn.Module):
""" Inverted residual block w/ optional SE"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
conv_kwargs=None, drop_connect_rate=0.):
super(InvertedResidual, self).__init__()
norm_kwargs = norm_kwargs or {}
conv_kwargs = conv_kwargs or {}
mid_chs: int = make_divisible(in_chs * exp_ratio)
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_connect_rate = drop_connect_rate
# Point-wise expansion
self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Depth-wise convolution
self.conv_dw = select_conv2d(
mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True, **conv_kwargs)
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
self.act2 = act_layer(inplace=True)
# Squeeze-and-excitation
if se_ratio is not None and se_ratio > 0.:
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
else:
self.se = nn.Identity() # for jit.script compat
# Point-wise linear projection
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
self.bn3 = norm_layer(out_chs, **norm_kwargs)
def forward(self, x):
residual = x
# Point-wise expansion
x = self.conv_pw(x)
x = self.bn1(x)
x = self.act1(x)
# Depth-wise convolution
x = self.conv_dw(x)
x = self.bn2(x)
x = self.act2(x)
# Squeeze-and-excitation
x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x)
x = self.bn3(x)
if self.has_residual:
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
x += residual
return x
class CondConvResidual(InvertedResidual):
""" Inverted residual block w/ CondConv routing"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
num_experts=0, drop_connect_rate=0.):
self.num_experts = num_experts
conv_kwargs = dict(num_experts=self.num_experts)
super(CondConvResidual, self).__init__(
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, pad_type=pad_type,
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs,
norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs,
drop_connect_rate=drop_connect_rate)
self.routing_fn = nn.Linear(in_chs, self.num_experts)
def forward(self, x):
residual = x
# CondConv routing
pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1)
routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs))
# Point-wise expansion
x = self.conv_pw(x, routing_weights)
x = self.bn1(x)
x = self.act1(x)
# Depth-wise convolution
x = self.conv_dw(x, routing_weights)
x = self.bn2(x)
x = self.act2(x)
# Squeeze-and-excitation
x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x, routing_weights)
x = self.bn3(x)
if self.has_residual:
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
x += residual
return x
class EdgeResidual(nn.Module):
""" EdgeTPU Residual block with expansion convolution followed by pointwise-linear w/ stride"""
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1,
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
super(EdgeResidual, self).__init__()
norm_kwargs = norm_kwargs or {}
mid_chs = make_divisible(fake_in_chs * exp_ratio) if fake_in_chs > 0 else make_divisible(in_chs * exp_ratio)
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_connect_rate = drop_connect_rate
# Expansion convolution
self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation
if se_ratio is not None and se_ratio > 0.:
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
else:
self.se = nn.Identity()
# Point-wise linear projection
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type)
self.bn2 = nn.BatchNorm2d(out_chs, **norm_kwargs)
def forward(self, x):
residual = x
# Expansion convolution
x = self.conv_exp(x)
x = self.bn1(x)
x = self.act1(x)
# Squeeze-and-excitation
x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x)
x = self.bn2(x)
if self.has_residual:
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
x += residual
return x
class EfficientNetBuilder:
""" Build Trunk Blocks for Efficient/Mobile Networks
This ended up being somewhat of a cross between
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
and
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
"""
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
pad_type='', act_layer=None, se_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
self.channel_multiplier = channel_multiplier
self.channel_divisor = channel_divisor
self.channel_min = channel_min
self.pad_type = pad_type
self.act_layer = act_layer
self.se_kwargs = se_kwargs
self.norm_layer = norm_layer
self.norm_kwargs = norm_kwargs
self.drop_connect_rate = drop_connect_rate
# updated during build
self.in_chs = None
self.block_idx = 0
self.block_count = 0
def _round_channels(self, chs):
return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
def _make_block(self, ba):
bt = ba.pop('block_type')
ba['in_chs'] = self.in_chs
ba['out_chs'] = self._round_channels(ba['out_chs'])
if 'fake_in_chs' in ba and ba['fake_in_chs']:
# FIXME this is a hack to work around mismatch in origin impl input filters for EdgeTPU
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
ba['norm_layer'] = self.norm_layer
ba['norm_kwargs'] = self.norm_kwargs
ba['pad_type'] = self.pad_type
# block act fn overrides the model default
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
assert ba['act_layer'] is not None
if bt == 'ir':
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
ba['se_kwargs'] = self.se_kwargs
if ba.get('num_experts', 0) > 0:
block = CondConvResidual(**ba)
else:
block = InvertedResidual(**ba)
elif bt == 'ds' or bt == 'dsa':
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
ba['se_kwargs'] = self.se_kwargs
block = DepthwiseSeparableConv(**ba)
elif bt == 'er':
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
ba['se_kwargs'] = self.se_kwargs
block = EdgeResidual(**ba)
elif bt == 'cn':
block = ConvBnAct(**ba)
else:
assert False, 'Uknkown block type (%s) while building model.' % bt
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
return block
def _make_stack(self, stack_args):
blocks = []
# each stack (stage) contains a list of block arguments
for i, ba in enumerate(stack_args):
if i >= 1:
# only the first block in any stack can have a stride > 1
ba['stride'] = 1
block = self._make_block(ba)
blocks.append(block)
self.block_idx += 1 # incr global idx (across all stacks)
return nn.Sequential(*blocks)
def __call__(self, in_chs, block_args):
""" Build the blocks
Args:
in_chs: Number of input-channels passed to first block
block_args: A list of lists, outer list defines stages, inner
list contains strings defining block configuration(s)
Return:
List of block stacks (each stack wrapped in nn.Sequential)
"""
self.in_chs = in_chs
self.block_count = sum([len(x) for x in block_args])
self.block_idx = 0
blocks = []
# outer list of block_args defines the stacks ('stages' by some conventions)
for stack_idx, stack in enumerate(block_args):
assert isinstance(stack, list)
stack = self._make_stack(stack)
blocks.append(stack)
return blocks
def _parse_ksize(ss):
if ss.isdigit():
return int(ss)
else:
return [int(k) for k in ss.split('.')]
def _decode_block_str(block_str):
""" Decode block definition string
Gets a list of block arg (dicts) through a string notation of arguments.
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
All args can exist in any order with the exception of the leading string which
is assumed to indicate the block type.
leading string - block type (
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
r - number of repeat blocks,
k - kernel size,
s - strides (1-9),
e - expansion ratio,
c - output channels,
se - squeeze/excitation ratio
n - activation fn ('re', 'r6', 'hs', or 'sw')
Args:
block_str: a string representation of block arguments.
Returns:
A list of block args (dicts)
Raises:
ValueError: if the string def not properly specified (TODO)
"""
assert isinstance(block_str, str)
ops = block_str.split('_')
block_type = ops[0] # take the block type off the front
ops = ops[1:]
options = {}
noskip = False
for op in ops:
# string options being checked on individual basis, combine if they grow
if op == 'noskip':
noskip = True
elif op.startswith('n'):
# activation fn
key = op[0]
v = op[1:]
if v == 're':
value = get_act_layer('relu')
elif v == 'r6':
value = get_act_layer('relu6')
elif v == 'hs':
value = get_act_layer('hard_swish')
elif v == 'sw':
value = get_act_layer('swish')
else:
continue
options[key] = value
else:
# all numeric options
splits = re.split(r'(\d.*)', op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
# if act_layer is None, the model default (passed to model init) will be used
act_layer = options['n'] if 'n' in options else None
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
num_repeat = int(options['r'])
# each type of block has different valid arguments, fill accordingly
if block_type == 'ir':
block_args = dict(
block_type=block_type,
dw_kernel_size=_parse_ksize(options['k']),
exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_layer=act_layer,
noskip=noskip,
)
if 'cc' in options:
block_args['num_experts'] = int(options['cc'])
elif block_type == 'ds' or block_type == 'dsa':
block_args = dict(
block_type=block_type,
dw_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_layer=act_layer,
pw_act=block_type == 'dsa',
noskip=block_type == 'dsa' or noskip,
)
elif block_type == 'er':
block_args = dict(
block_type=block_type,
exp_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
fake_in_chs=fake_in_chs,
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_layer=act_layer,
noskip=noskip,
)
elif block_type == 'cn':
block_args = dict(
block_type=block_type,
kernel_size=int(options['k']),
out_chs=int(options['c']),
stride=int(options['s']),
act_layer=act_layer,
)
else:
assert False, 'Unknown block type (%s)' % block_type
return block_args, num_repeat
def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
""" Per-stage depth scaling
Scales the block repeats in each stage. This depth scaling impl maintains
compatibility with the EfficientNet scaling method, while allowing sensible
scaling for other models that may have multiple block arg definitions in each stage.
"""
# We scale the total repeat count for each stage, there may be multiple
# block arg defs per stage so we need to sum.
num_repeat = sum(repeats)
if depth_trunc == 'round':
# Truncating to int by rounding allows stages with few repeats to remain
# proportionally smaller for longer. This is a good choice when stage definitions
# include single repeat stages that we'd prefer to keep that way as long as possible
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
else:
# The default for EfficientNet truncates repeats to int via 'ceil'.
# Any multiplier > 1.0 will result in an increased depth for every stage.
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
# Proportionally distribute repeat count scaling to each block definition in the stage.
# Allocation is done in reverse as it results in the first block being less likely to be scaled.
# The first block makes less sense to repeat in most of the arch definitions.
repeats_scaled = []
for r in repeats[::-1]:
rs = max(1, round((r / num_repeat * num_repeat_scaled)))
repeats_scaled.append(rs)
num_repeat -= r
num_repeat_scaled -= rs
repeats_scaled = repeats_scaled[::-1]
# Apply the calculated scaling to each block arg in the stage
sa_scaled = []
for ba, rep in zip(stack_args, repeats_scaled):
sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
return sa_scaled
def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False):
arch_args = []
for stack_idx, block_strings in enumerate(arch_def):
assert isinstance(block_strings, list)
stack_args = []
repeats = []
for block_str in block_strings:
assert isinstance(block_str, str)
ba, rep = _decode_block_str(block_str)
if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
ba['num_experts'] *= experts_multiplier
stack_args.append(ba)
repeats.append(rep)
if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
else:
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
return arch_args
def initialize_weight_goog(m, n='', fix_group_fanout=True):
# weight init as per Tensorflow Official impl
# https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
if isinstance(m, CondConv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
if fix_group_fanout:
fan_out //= m.groups
init_weight_fn = get_condconv_initializer(
lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
init_weight_fn(m.weight)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
if fix_group_fanout:
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
fan_out = m.weight.size(0) # fan-out
fan_in = 0
if 'routing_fn' in n:
fan_in = m.weight.size(1)
init_range = 1.0 / math.sqrt(fan_in + fan_out)
m.weight.data.uniform_(-init_range, init_range)
m.bias.data.zero_()
def initialize_weight_default(m, n=''):
if isinstance(m, CondConv2d):
init_fn = get_condconv_initializer(partial(
nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape)
init_fn(m.weight)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')

View File

@@ -0,0 +1,71 @@
""" Checkpoint loading / state_dict helpers
Copyright 2020 Ross Wightman
"""
import torch
import os
from collections import OrderedDict
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
def load_checkpoint(model, checkpoint_path):
if checkpoint_path and os.path.isfile(checkpoint_path):
print("=> Loading checkpoint '{}'".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
if k.startswith('module'):
name = k[7:] # remove `module.`
else:
name = k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(checkpoint)
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
else:
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
def load_pretrained(model, url, filter_fn=None, strict=True):
if not url:
print("=> Warning: Pretrained model URL is empty, using random initialization.")
return
state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu')
input_conv = 'conv_stem'
classifier = 'classifier'
in_chans = getattr(model, input_conv).weight.shape[1]
num_classes = getattr(model, classifier).weight.shape[0]
input_conv_weight = input_conv + '.weight'
pretrained_in_chans = state_dict[input_conv_weight].shape[1]
if in_chans != pretrained_in_chans:
if in_chans == 1:
print('=> Converting pretrained input conv {} from {} to 1 channel'.format(
input_conv_weight, pretrained_in_chans))
conv1_weight = state_dict[input_conv_weight]
state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True)
else:
print('=> Discarding pretrained input conv {} since input channel count != {}'.format(
input_conv_weight, pretrained_in_chans))
del state_dict[input_conv_weight]
strict = False
classifier_weight = classifier + '.weight'
pretrained_num_classes = state_dict[classifier_weight].shape[0]
if num_classes != pretrained_num_classes:
print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes))
del state_dict[classifier_weight]
del state_dict[classifier + '.bias']
strict = False
if filter_fn is not None:
state_dict = filter_fn(state_dict)
model.load_state_dict(state_dict, strict=strict)

View File

@@ -0,0 +1,364 @@
""" MobileNet-V3
A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl.
Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch.nn as nn
import torch.nn.functional as F
from .activations import get_act_fn, get_act_layer, HardSwish
from .config import layer_config_kwargs
from .conv2d_layers import select_conv2d
from .helpers import load_pretrained
from .efficientnet_builder import *
__all__ = ['mobilenetv3_rw', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_minimal_100',
'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilenetv3_small_minimal_100',
'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100',
'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100']
model_urls = {
'mobilenetv3_rw':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
'mobilenetv3_large_075': None,
'mobilenetv3_large_100':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth',
'mobilenetv3_large_minimal_100': None,
'mobilenetv3_small_075': None,
'mobilenetv3_small_100': None,
'mobilenetv3_small_minimal_100': None,
'tf_mobilenetv3_large_075':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
'tf_mobilenetv3_large_100':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',
'tf_mobilenetv3_large_minimal_100':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',
'tf_mobilenetv3_small_075':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',
'tf_mobilenetv3_small_100':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',
'tf_mobilenetv3_small_minimal_100':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
}
class MobileNetV3(nn.Module):
""" MobileNet-V3
A this model utilizes the MobileNet-v3 specific 'efficient head', where global pooling is done before the
head convolution without a final batch-norm layer before the classifier.
Paper: https://arxiv.org/abs/1905.02244
"""
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
channel_multiplier=1.0, pad_type='', act_layer=HardSwish, drop_rate=0., drop_connect_rate=0.,
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'):
super(MobileNetV3, self).__init__()
self.drop_rate = drop_rate
stem_size = round_channels(stem_size, channel_multiplier)
self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = nn.BatchNorm2d(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
in_chs = stem_size
builder = EfficientNetBuilder(
channel_multiplier, pad_type=pad_type, act_layer=act_layer, se_kwargs=se_kwargs,
norm_layer=norm_layer, norm_kwargs=norm_kwargs, drop_connect_rate=drop_connect_rate)
self.blocks = nn.Sequential(*builder(in_chs, block_args))
in_chs = builder.in_chs
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type, bias=head_bias)
self.act2 = act_layer(inplace=True)
self.classifier = nn.Linear(num_features, num_classes)
for m in self.modules():
if weight_init == 'goog':
initialize_weight_goog(m)
else:
initialize_weight_default(m)
def as_sequential(self):
layers = [self.conv_stem, self.bn1, self.act1]
layers.extend(self.blocks)
layers.extend([
self.global_pool, self.conv_head, self.act2,
nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
return nn.Sequential(*layers)
def features(self, x):
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
x = self.blocks(x)
x = self.global_pool(x)
x = self.conv_head(x)
x = self.act2(x)
return x
def forward(self, x):
x = self.features(x)
x = x.flatten(1)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.classifier(x)
def _create_model(model_kwargs, variant, pretrained=False):
as_sequential = model_kwargs.pop('as_sequential', False)
model = MobileNetV3(**model_kwargs)
if pretrained and model_urls[variant]:
load_pretrained(model, model_urls[variant])
if as_sequential:
model = model.as_sequential()
return model
def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a MobileNet-V3 model (RW variant).
Paper: https://arxiv.org/abs/1905.02244
This was my first attempt at reproducing the MobileNet-V3 from paper alone. It came close to the
eventual Tensorflow reference impl but has a few differences:
1. This model has no bias on the head convolution
2. This model forces no residual (noskip) on the first DWS block, this is different than MnasNet
3. This model always uses ReLU for the SE activation layer, other models in the family inherit their act layer
from their parent block
4. This model does not enforce divisible by 8 limitation on the SE reduction channel count
Overall the changes are fairly minor and result in a very small parameter count difference and no
top-1/5
Args:
channel_multiplier: multiplier to number of channels per layer.
"""
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu
# stage 1, 112x112 in
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
# stage 2, 56x56 in
['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
# stage 3, 28x28 in
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
# stage 4, 14x14in
['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
# stage 5, 14x14in
['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
# stage 6, 7x7 in
['cn_r1_k1_s1_c960'], # hard-swish
]
with layer_config_kwargs(kwargs):
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
head_bias=False, # one of my mistakes
channel_multiplier=channel_multiplier,
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True),
norm_kwargs=resolve_bn_args(kwargs),
**kwargs,
)
model = _create_model(model_kwargs, variant, pretrained)
return model
def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a MobileNet-V3 large/small/minimal models.
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v3.py
Paper: https://arxiv.org/abs/1905.02244
Args:
channel_multiplier: multiplier to number of channels per layer.
"""
if 'small' in variant:
num_features = 1024
if 'minimal' in variant:
act_layer = 'relu'
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s2_e1_c16'],
# stage 1, 56x56 in
['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'],
# stage 2, 28x28 in
['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'],
# stage 3, 14x14 in
['ir_r2_k3_s1_e3_c48'],
# stage 4, 14x14in
['ir_r3_k3_s2_e6_c96'],
# stage 6, 7x7 in
['cn_r1_k1_s1_c576'],
]
else:
act_layer = 'hard_swish'
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu
# stage 1, 56x56 in
['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu
# stage 2, 28x28 in
['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish
# stage 3, 14x14 in
['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish
# stage 4, 14x14in
['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish
# stage 6, 7x7 in
['cn_r1_k1_s1_c576'], # hard-swish
]
else:
num_features = 1280
if 'minimal' in variant:
act_layer = 'relu'
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16'],
# stage 1, 112x112 in
['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'],
# stage 2, 56x56 in
['ir_r3_k3_s2_e3_c40'],
# stage 3, 28x28 in
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'],
# stage 4, 14x14in
['ir_r2_k3_s1_e6_c112'],
# stage 5, 14x14in
['ir_r3_k3_s2_e6_c160'],
# stage 6, 7x7 in
['cn_r1_k1_s1_c960'],
]
else:
act_layer = 'hard_swish'
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16_nre'], # relu
# stage 1, 112x112 in
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
# stage 2, 56x56 in
['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
# stage 3, 28x28 in
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
# stage 4, 14x14in
['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
# stage 5, 14x14in
['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
# stage 6, 7x7 in
['cn_r1_k1_s1_c960'], # hard-swish
]
with layer_config_kwargs(kwargs):
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
num_features=num_features,
stem_size=16,
channel_multiplier=channel_multiplier,
act_layer=resolve_act_layer(kwargs, act_layer),
se_kwargs=dict(
act_layer=get_act_layer('relu'), gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=8),
norm_kwargs=resolve_bn_args(kwargs),
**kwargs,
)
model = _create_model(model_kwargs, variant, pretrained)
return model
def mobilenetv3_rw(pretrained=False, **kwargs):
""" MobileNet-V3 RW
Attn: See note in gen function for this variant.
"""
# NOTE for train set drop_rate=0.2
if pretrained:
# pretrained model trained with non-default BN epsilon
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs)
return model
def mobilenetv3_large_075(pretrained=False, **kwargs):
""" MobileNet V3 Large 0.75"""
# NOTE for train set drop_rate=0.2
model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
return model
def mobilenetv3_large_100(pretrained=False, **kwargs):
""" MobileNet V3 Large 1.0 """
# NOTE for train set drop_rate=0.2
model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
return model
def mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
""" MobileNet V3 Large (Minimalistic) 1.0 """
# NOTE for train set drop_rate=0.2
model = _gen_mobilenet_v3('mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)
return model
def mobilenetv3_small_075(pretrained=False, **kwargs):
""" MobileNet V3 Small 0.75 """
model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
return model
def mobilenetv3_small_100(pretrained=False, **kwargs):
""" MobileNet V3 Small 1.0 """
model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
return model
def mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
""" MobileNet V3 Small (Minimalistic) 1.0 """
model = _gen_mobilenet_v3('mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
return model
def tf_mobilenetv3_large_075(pretrained=False, **kwargs):
""" MobileNet V3 Large 0.75. Tensorflow compat variant. """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
return model
def tf_mobilenetv3_large_100(pretrained=False, **kwargs):
""" MobileNet V3 Large 1.0. Tensorflow compat variant. """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
return model
def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
""" MobileNet V3 Large Minimalistic 1.0. Tensorflow compat variant. """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)
return model
def tf_mobilenetv3_small_075(pretrained=False, **kwargs):
""" MobileNet V3 Small 0.75. Tensorflow compat variant. """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
return model
def tf_mobilenetv3_small_100(pretrained=False, **kwargs):
""" MobileNet V3 Small 1.0. Tensorflow compat variant."""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
return model
def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
""" MobileNet V3 Small Minimalistic 1.0. Tensorflow compat variant. """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
return model

View File

@@ -0,0 +1,27 @@
from .config import set_layer_config
from .helpers import load_checkpoint
from .gen_efficientnet import *
from .mobilenetv3 import *
def create_model(
model_name='mnasnet_100',
pretrained=None,
num_classes=1000,
in_chans=3,
checkpoint_path='',
**kwargs):
model_kwargs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained, **kwargs)
if model_name in globals():
create_fn = globals()[model_name]
model = create_fn(**model_kwargs)
else:
raise RuntimeError('Unknown model (%s)' % model_name)
if checkpoint_path and not pretrained:
load_checkpoint(model, checkpoint_path)
return model

View File

@@ -0,0 +1 @@
__version__ = '1.0.2'

View File

@@ -0,0 +1,84 @@
dependencies = ['torch', 'math']
from geffnet import efficientnet_b0
from geffnet import efficientnet_b1
from geffnet import efficientnet_b2
from geffnet import efficientnet_b3
from geffnet import efficientnet_es
from geffnet import efficientnet_lite0
from geffnet import mixnet_s
from geffnet import mixnet_m
from geffnet import mixnet_l
from geffnet import mixnet_xl
from geffnet import mobilenetv2_100
from geffnet import mobilenetv2_110d
from geffnet import mobilenetv2_120d
from geffnet import mobilenetv2_140
from geffnet import mobilenetv3_large_100
from geffnet import mobilenetv3_rw
from geffnet import mnasnet_a1
from geffnet import mnasnet_b1
from geffnet import fbnetc_100
from geffnet import spnasnet_100
from geffnet import tf_efficientnet_b0
from geffnet import tf_efficientnet_b1
from geffnet import tf_efficientnet_b2
from geffnet import tf_efficientnet_b3
from geffnet import tf_efficientnet_b4
from geffnet import tf_efficientnet_b5
from geffnet import tf_efficientnet_b6
from geffnet import tf_efficientnet_b7
from geffnet import tf_efficientnet_b8
from geffnet import tf_efficientnet_b0_ap
from geffnet import tf_efficientnet_b1_ap
from geffnet import tf_efficientnet_b2_ap
from geffnet import tf_efficientnet_b3_ap
from geffnet import tf_efficientnet_b4_ap
from geffnet import tf_efficientnet_b5_ap
from geffnet import tf_efficientnet_b6_ap
from geffnet import tf_efficientnet_b7_ap
from geffnet import tf_efficientnet_b8_ap
from geffnet import tf_efficientnet_b0_ns
from geffnet import tf_efficientnet_b1_ns
from geffnet import tf_efficientnet_b2_ns
from geffnet import tf_efficientnet_b3_ns
from geffnet import tf_efficientnet_b4_ns
from geffnet import tf_efficientnet_b5_ns
from geffnet import tf_efficientnet_b6_ns
from geffnet import tf_efficientnet_b7_ns
from geffnet import tf_efficientnet_l2_ns_475
from geffnet import tf_efficientnet_l2_ns
from geffnet import tf_efficientnet_es
from geffnet import tf_efficientnet_em
from geffnet import tf_efficientnet_el
from geffnet import tf_efficientnet_cc_b0_4e
from geffnet import tf_efficientnet_cc_b0_8e
from geffnet import tf_efficientnet_cc_b1_8e
from geffnet import tf_efficientnet_lite0
from geffnet import tf_efficientnet_lite1
from geffnet import tf_efficientnet_lite2
from geffnet import tf_efficientnet_lite3
from geffnet import tf_efficientnet_lite4
from geffnet import tf_mixnet_s
from geffnet import tf_mixnet_m
from geffnet import tf_mixnet_l
from geffnet import tf_mobilenetv3_large_075
from geffnet import tf_mobilenetv3_large_100
from geffnet import tf_mobilenetv3_large_minimal_100
from geffnet import tf_mobilenetv3_small_075
from geffnet import tf_mobilenetv3_small_100
from geffnet import tf_mobilenetv3_small_minimal_100

View File

@@ -0,0 +1,120 @@
""" ONNX export script
Export PyTorch models as ONNX graphs.
This export script originally started as an adaptation of code snippets found at
https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
The default parameters work with PyTorch 1.6 and ONNX 1.7 and produce an optimal ONNX graph
for hosting in the ONNX runtime (see onnx_validate.py). To export an ONNX model compatible
with caffe2 (see caffe2_benchmark.py and caffe2_validate.py), the --keep-init and --aten-fallback
flags are currently required.
Older versions of PyTorch/ONNX (tested PyTorch 1.4, ONNX 1.5) do not need extra flags for
caffe2 compatibility, but they produce a model that isn't as fast running on ONNX runtime.
Most new release of PyTorch and ONNX cause some sort of breakage in the export / usage of ONNX models.
Please do your research and search ONNX and PyTorch issue tracker before asking me. Thanks.
Copyright 2020 Ross Wightman
"""
import argparse
import torch
import numpy as np
import onnx
import geffnet
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
parser.add_argument('output', metavar='ONNX_FILE',
help='output model filename')
parser.add_argument('--model', '-m', metavar='MODEL', default='mobilenetv3_large_100',
help='model architecture (default: mobilenetv3_large_100)')
parser.add_argument('--opset', type=int, default=10,
help='ONNX opset to use (default: 10)')
parser.add_argument('--keep-init', action='store_true', default=False,
help='Keep initializers as input. Needed for Caffe2 compatible export in newer PyTorch/ONNX.')
parser.add_argument('--aten-fallback', action='store_true', default=False,
help='Fallback to ATEN ops. Helps fix AdaptiveAvgPool issue with Caffe2 in newer PyTorch/ONNX.')
parser.add_argument('--dynamic-size', action='store_true', default=False,
help='Export model width dynamic width/height. Not recommended for "tf" models with SAME padding.')
parser.add_argument('-b', '--batch-size', default=1, type=int,
metavar='N', help='mini-batch size (default: 1)')
parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--num-classes', type=int, default=1000,
help='Number classes in dataset')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to checkpoint (default: none)')
def main():
args = parser.parse_args()
args.pretrained = True
if args.checkpoint:
args.pretrained = False
print("==> Creating PyTorch {} model".format(args.model))
# NOTE exportable=True flag disables autofn/jit scripted activations and uses Conv2dSameExport layers
# for models using SAME padding
model = geffnet.create_model(
args.model,
num_classes=args.num_classes,
in_chans=3,
pretrained=args.pretrained,
checkpoint_path=args.checkpoint,
exportable=True)
model.eval()
example_input = torch.randn((args.batch_size, 3, args.img_size or 224, args.img_size or 224), requires_grad=True)
# Run model once before export trace, sets padding for models with Conv2dSameExport. This means
# that the padding for models with Conv2dSameExport (most models with tf_ prefix) is fixed for
# the input img_size specified in this script.
# Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to
# issues in the tracing of the dynamic padding or errors attempting to export the model after jit
# scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions...
model(example_input)
print("==> Exporting model to ONNX format at '{}'".format(args.output))
input_names = ["input0"]
output_names = ["output0"]
dynamic_axes = {'input0': {0: 'batch'}, 'output0': {0: 'batch'}}
if args.dynamic_size:
dynamic_axes['input0'][2] = 'height'
dynamic_axes['input0'][3] = 'width'
if args.aten_fallback:
export_type = torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
else:
export_type = torch.onnx.OperatorExportTypes.ONNX
torch_out = torch.onnx._export(
model, example_input, args.output, export_params=True, verbose=True, input_names=input_names,
output_names=output_names, keep_initializers_as_inputs=args.keep_init, dynamic_axes=dynamic_axes,
opset_version=args.opset, operator_export_type=export_type)
print("==> Loading and checking exported model from '{}'".format(args.output))
onnx_model = onnx.load(args.output)
onnx.checker.check_model(onnx_model) # assuming throw on error
print("==> Passed")
if args.keep_init and args.aten_fallback:
import caffe2.python.onnx.backend as onnx_caffe2
# Caffe2 loading only works properly in newer PyTorch/ONNX combos when
# keep_initializers_as_inputs and aten_fallback are set to True.
print("==> Loading model into Caffe2 backend and comparing forward pass.".format(args.output))
caffe2_backend = onnx_caffe2.prepare(onnx_model)
B = {onnx_model.graph.input[0].name: x.data.numpy()}
c2_out = caffe2_backend.run(B)[0]
np.testing.assert_almost_equal(torch_out.data.numpy(), c2_out, decimal=5)
print("==> Passed")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,84 @@
""" ONNX optimization script
Run ONNX models through the optimizer to prune unneeded nodes, fuse batchnorm layers into conv, etc.
NOTE: This isn't working consistently in recent PyTorch/ONNX combos (ie PyTorch 1.6 and ONNX 1.7),
it seems time to switch to using the onnxruntime online optimizer (can also be saved for offline).
Copyright 2020 Ross Wightman
"""
import argparse
import warnings
import onnx
from onnx import optimizer
parser = argparse.ArgumentParser(description="Optimize ONNX model")
parser.add_argument("model", help="The ONNX model")
parser.add_argument("--output", required=True, help="The optimized model output filename")
def traverse_graph(graph, prefix=''):
content = []
indent = prefix + ' '
graphs = []
num_nodes = 0
for node in graph.node:
pn, gs = onnx.helper.printable_node(node, indent, subgraphs=True)
assert isinstance(gs, list)
content.append(pn)
graphs.extend(gs)
num_nodes += 1
for g in graphs:
g_count, g_str = traverse_graph(g)
content.append('\n' + g_str)
num_nodes += g_count
return num_nodes, '\n'.join(content)
def main():
args = parser.parse_args()
onnx_model = onnx.load(args.model)
num_original_nodes, original_graph_str = traverse_graph(onnx_model.graph)
# Optimizer passes to perform
passes = [
#'eliminate_deadend',
'eliminate_identity',
'eliminate_nop_dropout',
'eliminate_nop_pad',
'eliminate_nop_transpose',
'eliminate_unused_initializer',
'extract_constant_to_initializer',
'fuse_add_bias_into_conv',
'fuse_bn_into_conv',
'fuse_consecutive_concats',
'fuse_consecutive_reduce_unsqueeze',
'fuse_consecutive_squeezes',
'fuse_consecutive_transposes',
#'fuse_matmul_add_bias_into_gemm',
'fuse_pad_into_conv',
#'fuse_transpose_into_gemm',
#'lift_lexical_references',
]
# Apply the optimization on the original serialized model
# WARNING I've had issues with optimizer in recent versions of PyTorch / ONNX causing
# 'duplicate definition of name' errors, see: https://github.com/onnx/onnx/issues/2401
# It may be better to rely on onnxruntime optimizations, see onnx_validate.py script.
warnings.warn("I've had issues with optimizer in recent versions of PyTorch / ONNX."
"Try onnxruntime optimization if this doesn't work.")
optimized_model = optimizer.optimize(onnx_model, passes)
num_optimized_nodes, optimzied_graph_str = traverse_graph(optimized_model.graph)
print('==> The model after optimization:\n{}\n'.format(optimzied_graph_str))
print('==> The optimized model has {} nodes, the original had {}.'.format(num_optimized_nodes, num_original_nodes))
# Save the ONNX model
onnx.save(optimized_model, args.output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,27 @@
import argparse
import onnx
from caffe2.python.onnx.backend import Caffe2Backend
parser = argparse.ArgumentParser(description="Convert ONNX to Caffe2")
parser.add_argument("model", help="The ONNX model")
parser.add_argument("--c2-prefix", required=True,
help="The output file prefix for the caffe2 model init and predict file. ")
def main():
args = parser.parse_args()
onnx_model = onnx.load(args.model)
caffe2_init, caffe2_predict = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model)
caffe2_init_str = caffe2_init.SerializeToString()
with open(args.c2_prefix + '.init.pb', "wb") as f:
f.write(caffe2_init_str)
caffe2_predict_str = caffe2_predict.SerializeToString()
with open(args.c2_prefix + '.predict.pb', "wb") as f:
f.write(caffe2_predict_str)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,112 @@
""" ONNX-runtime validation script
This script was created to verify accuracy and performance of exported ONNX
models running with the onnxruntime. It utilizes the PyTorch dataloader/processing
pipeline for a fair comparison against the originals.
Copyright 2020 Ross Wightman
"""
import argparse
import numpy as np
import onnxruntime
from data import create_loader, resolve_data_config, Dataset
from utils import AverageMeter
import time
parser = argparse.ArgumentParser(description='Caffe2 ImageNet Validation')
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--onnx-input', default='', type=str, metavar='PATH',
help='path to onnx model/weights file')
parser.add_argument('--onnx-output-opt', default='', type=str, metavar='PATH',
help='path to output optimized onnx graph')
parser.add_argument('--profile', action='store_true', default=False,
help='Enable profiler output.')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT',
help='Override default crop pct of 0.875')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
help='use tensorflow mnasnet preporcessing')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
def main():
args = parser.parse_args()
args.gpu_id = 0
# Set graph optimization level
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
if args.profile:
sess_options.enable_profiling = True
if args.onnx_output_opt:
sess_options.optimized_model_filepath = args.onnx_output_opt
session = onnxruntime.InferenceSession(args.onnx_input, sess_options)
data_config = resolve_data_config(None, args)
loader = create_loader(
Dataset(args.data, load_bytes=args.tf_preprocessing),
input_size=data_config['input_size'],
batch_size=args.batch_size,
use_prefetcher=False,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
crop_pct=data_config['crop_pct'],
tensorflow_preprocessing=args.tf_preprocessing)
input_name = session.get_inputs()[0].name
batch_time = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
end = time.time()
for i, (input, target) in enumerate(loader):
# run the net and return prediction
output = session.run([], {input_name: input.data.numpy()})
output = output[0]
# measure accuracy and record loss
prec1, prec5 = accuracy_np(output, target.numpy())
top1.update(prec1.item(), input.size(0))
top5.update(prec5.item(), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s, {ms_avg:.3f} ms/sample) \t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg,
ms_avg=100 * batch_time.avg / input.size(0), top1=top1, top5=top5))
print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format(
top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg))
def accuracy_np(output, target):
max_indices = np.argsort(output, axis=1)[:, ::-1]
top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean()
top1 = 100 * np.equal(max_indices[:, 0], target).mean()
return top1, top5
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,2 @@
torch>=1.2.0
torchvision>=0.4.0

View File

@@ -0,0 +1,47 @@
""" Setup
"""
from setuptools import setup, find_packages
from codecs import open
from os import path
here = path.abspath(path.dirname(__file__))
# Get the long description from the README file
with open(path.join(here, 'README.md'), encoding='utf-8') as f:
long_description = f.read()
exec(open('geffnet/version.py').read())
setup(
name='geffnet',
version=__version__,
description='(Generic) EfficientNets for PyTorch',
long_description=long_description,
long_description_content_type='text/markdown',
url='https://github.com/rwightman/gen-efficientnet-pytorch',
author='Ross Wightman',
author_email='hello@rwightman.com',
classifiers=[
# How mature is this project? Common values are
# 3 - Alpha
# 4 - Beta
# 5 - Production/Stable
'Development Status :: 3 - Alpha',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development',
'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules',
],
# Note that this is a string of words separated by whitespace, not a list.
keywords='pytorch pretrained models efficientnet mixnet mobilenetv3 mnasnet',
packages=find_packages(exclude=['data']),
install_requires=['torch >= 1.4', 'torchvision'],
python_requires='>=3.6',
)

View File

@@ -0,0 +1,52 @@
import os
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def get_outdir(path, *paths, inc=False):
outdir = os.path.join(path, *paths)
if not os.path.exists(outdir):
os.makedirs(outdir)
elif inc:
count = 1
outdir_inc = outdir + '-' + str(count)
while os.path.exists(outdir_inc):
count = count + 1
outdir_inc = outdir + '-' + str(count)
assert count < 100
outdir = outdir_inc
os.makedirs(outdir)
return outdir

View File

@@ -0,0 +1,166 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import time
import torch
import torch.nn as nn
import torch.nn.parallel
from contextlib import suppress
import geffnet
from data import Dataset, create_loader, resolve_data_config
from utils import accuracy, AverageMeter
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--model', '-m', metavar='MODEL', default='spnasnet1_00',
help='model architecture (default: dpn92)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT',
help='Override default crop pct of 0.875')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--num-classes', type=int, default=1000,
help='Number classes in dataset')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
help='use tensorflow mnasnet preporcessing')
parser.add_argument('--no-cuda', dest='no_cuda', action='store_true',
help='')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--amp', action='store_true', default=False,
help='Use native Torch AMP mixed precision.')
def main():
args = parser.parse_args()
if not args.checkpoint and not args.pretrained:
args.pretrained = True
amp_autocast = suppress # do nothing
if args.amp:
if not has_native_amp:
print("Native Torch AMP is not available (requires torch >= 1.6), using FP32.")
else:
amp_autocast = torch.cuda.amp.autocast
# create model
model = geffnet.create_model(
args.model,
num_classes=args.num_classes,
in_chans=3,
pretrained=args.pretrained,
checkpoint_path=args.checkpoint,
scriptable=args.torchscript)
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
if args.torchscript:
torch.jit.optimized_execution(True)
model = torch.jit.script(model)
print('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
data_config = resolve_data_config(model, args)
criterion = nn.CrossEntropyLoss()
if not args.no_cuda:
if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
else:
model = model.cuda()
criterion = criterion.cuda()
loader = create_loader(
Dataset(args.data, load_bytes=args.tf_preprocessing),
input_size=data_config['input_size'],
batch_size=args.batch_size,
use_prefetcher=not args.no_cuda,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
crop_pct=data_config['crop_pct'],
tensorflow_preprocessing=args.tf_preprocessing)
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
model.eval()
end = time.time()
with torch.no_grad():
for i, (input, target) in enumerate(loader):
if not args.no_cuda:
target = target.cuda()
input = input.cuda()
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
# compute output
with amp_autocast():
output = model(input)
loss = criterion(output, target)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(prec1.item(), input.size(0))
top5.update(prec5.item(), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s) \t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(loader), batch_time=batch_time,
rate_avg=input.size(0) / batch_time.avg,
loss=losses, top1=top1, top5=top5))
print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format(
top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg))
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,34 @@
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
basemodel_name = 'tf_efficientnet_b5_ap'
print('Loading base model ()...'.format(basemodel_name), end='')
repo_path = os.path.join(os.path.dirname(__file__), 'efficientnet_repo')
basemodel = torch.hub.load(repo_path, basemodel_name, pretrained=False, source='local')
print('Done.')
# Remove last layer
print('Removing last two layers (global_pool & classifier).')
basemodel.global_pool = nn.Identity()
basemodel.classifier = nn.Identity()
self.original_model = basemodel
def forward(self, x):
features = [x]
for k, v in self.original_model._modules.items():
if (k == 'blocks'):
for ki, vi in v._modules.items():
features.append(vi(features[-1]))
else:
features.append(v(features[-1]))
return features

View File

@@ -0,0 +1,140 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
########################################################################################################################
# Upsample + BatchNorm
class UpSampleBN(nn.Module):
def __init__(self, skip_input, output_features):
super(UpSampleBN, self).__init__()
self._net = nn.Sequential(nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(output_features),
nn.LeakyReLU(),
nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(output_features),
nn.LeakyReLU())
def forward(self, x, concat_with):
up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
f = torch.cat([up_x, concat_with], dim=1)
return self._net(f)
# Upsample + GroupNorm + Weight Standardization
class UpSampleGN(nn.Module):
def __init__(self, skip_input, output_features):
super(UpSampleGN, self).__init__()
self._net = nn.Sequential(Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
nn.GroupNorm(8, output_features),
nn.LeakyReLU(),
Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1),
nn.GroupNorm(8, output_features),
nn.LeakyReLU())
def forward(self, x, concat_with):
up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
f = torch.cat([up_x, concat_with], dim=1)
return self._net(f)
# Conv2d with weight standardization
class Conv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
def forward(self, x):
weight = self.weight
weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
keepdim=True).mean(dim=3, keepdim=True)
weight = weight - weight_mean
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
weight = weight / std.expand_as(weight)
return F.conv2d(x, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
# normalize
def norm_normalize(norm_out):
min_kappa = 0.01
norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1)
norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10
kappa = F.elu(kappa) + 1.0 + min_kappa
final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1)
return final_out
# uncertainty-guided sampling (only used during training)
@torch.no_grad()
def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta):
device = init_normal.device
B, _, H, W = init_normal.shape
N = int(sampling_ratio * H * W)
beta = beta
# uncertainty map
uncertainty_map = -1 * init_normal[:, 3, :, :] # B, H, W
# gt_invalid_mask (B, H, W)
if gt_norm_mask is not None:
gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest')
gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5
uncertainty_map[gt_invalid_mask] = -1e4
# (B, H*W)
_, idx = uncertainty_map.view(B, -1).sort(1, descending=True)
# importance sampling
if int(beta * N) > 0:
importance = idx[:, :int(beta * N)] # B, beta*N
# remaining
remaining = idx[:, int(beta * N):] # B, H*W - beta*N
# coverage
num_coverage = N - int(beta * N)
if num_coverage <= 0:
samples = importance
else:
coverage_list = []
for i in range(B):
idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
samples = torch.cat((importance, coverage), dim=1) # B, N
else:
# remaining
remaining = idx[:, :] # B, H*W
# coverage
num_coverage = N
coverage_list = []
for i in range(B):
idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
samples = coverage
# point coordinates
rows_int = samples // W # 0 for first row, H-1 for last row
rows_float = rows_int / float(H-1) # 0 to 1.0
rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0
cols_int = samples % W # 0 for first column, W-1 for last column
cols_float = cols_int / float(W-1) # 0 to 1.0
cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0
point_coords = torch.zeros(B, 1, N, 2)
point_coords[:, 0, :, 0] = cols_float # x coord
point_coords[:, 0, :, 1] = rows_float # y coord
point_coords = point_coords.to(device)
return point_coords, rows_int, cols_int

View File

@@ -0,0 +1,79 @@
# Adapted from https://github.com/huggingface/controlnet_aux
import pathlib
import cv2
import huggingface_hub
import numpy as np
import torch
from einops import rearrange
from PIL import Image
from invokeai.backend.image_util.pidi.model import PiDiNet, pidinet
from invokeai.backend.image_util.util import nms, normalize_image_channel_count, np_to_pil, pil_to_np, safe_step
class PIDINetDetector:
"""Simple wrapper around a PiDiNet model for edge detection."""
hf_repo_id = "lllyasviel/Annotators"
hf_filename = "table5_pidinet.pth"
@classmethod
def get_model_url(cls) -> str:
"""Get the URL to download the model from the Hugging Face Hub."""
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename)
@classmethod
def load_model(cls, model_path: pathlib.Path) -> PiDiNet:
"""Load the model from a file."""
model = pidinet()
model.load_state_dict({k.replace("module.", ""): v for k, v in torch.load(model_path)["state_dict"].items()})
model.eval()
return model
def __init__(self, model: PiDiNet) -> None:
self.model = model
def to(self, device: torch.device):
self.model.to(device)
return self
def run(
self, image: Image.Image, quantize_edges: bool = False, scribble: bool = False, apply_filter: bool = False
) -> Image.Image:
"""Processes an image and returns the detected edges."""
device = next(iter(self.model.parameters())).device
np_img = pil_to_np(image)
np_img = normalize_image_channel_count(np_img)
assert np_img.ndim == 3
bgr_img = np_img[:, :, ::-1].copy()
with torch.no_grad():
image_pidi = torch.from_numpy(bgr_img).float().to(device)
image_pidi = image_pidi / 255.0
image_pidi = rearrange(image_pidi, "h w c -> 1 c h w")
edge = self.model(image_pidi)[-1]
edge = edge.cpu().numpy()
if apply_filter:
edge = edge > 0.5
if quantize_edges:
edge = safe_step(edge)
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
detected_map = edge[0, 0]
if scribble:
detected_map = nms(detected_map, 127, 3.0)
detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
detected_map[detected_map > 4] = 255
detected_map[detected_map < 255] = 0
output_img = np_to_pil(detected_map)
return output_img

View File

@@ -0,0 +1,681 @@
"""
Author: Zhuo Su, Wenzhe Liu
Date: Feb 18, 2021
"""
import math
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""
def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
if img.dtype == 'float64':
img = img.astype('float32')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img
if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)
nets = {
'baseline': {
'layer0': 'cv',
'layer1': 'cv',
'layer2': 'cv',
'layer3': 'cv',
'layer4': 'cv',
'layer5': 'cv',
'layer6': 'cv',
'layer7': 'cv',
'layer8': 'cv',
'layer9': 'cv',
'layer10': 'cv',
'layer11': 'cv',
'layer12': 'cv',
'layer13': 'cv',
'layer14': 'cv',
'layer15': 'cv',
},
'c-v15': {
'layer0': 'cd',
'layer1': 'cv',
'layer2': 'cv',
'layer3': 'cv',
'layer4': 'cv',
'layer5': 'cv',
'layer6': 'cv',
'layer7': 'cv',
'layer8': 'cv',
'layer9': 'cv',
'layer10': 'cv',
'layer11': 'cv',
'layer12': 'cv',
'layer13': 'cv',
'layer14': 'cv',
'layer15': 'cv',
},
'a-v15': {
'layer0': 'ad',
'layer1': 'cv',
'layer2': 'cv',
'layer3': 'cv',
'layer4': 'cv',
'layer5': 'cv',
'layer6': 'cv',
'layer7': 'cv',
'layer8': 'cv',
'layer9': 'cv',
'layer10': 'cv',
'layer11': 'cv',
'layer12': 'cv',
'layer13': 'cv',
'layer14': 'cv',
'layer15': 'cv',
},
'r-v15': {
'layer0': 'rd',
'layer1': 'cv',
'layer2': 'cv',
'layer3': 'cv',
'layer4': 'cv',
'layer5': 'cv',
'layer6': 'cv',
'layer7': 'cv',
'layer8': 'cv',
'layer9': 'cv',
'layer10': 'cv',
'layer11': 'cv',
'layer12': 'cv',
'layer13': 'cv',
'layer14': 'cv',
'layer15': 'cv',
},
'cvvv4': {
'layer0': 'cd',
'layer1': 'cv',
'layer2': 'cv',
'layer3': 'cv',
'layer4': 'cd',
'layer5': 'cv',
'layer6': 'cv',
'layer7': 'cv',
'layer8': 'cd',
'layer9': 'cv',
'layer10': 'cv',
'layer11': 'cv',
'layer12': 'cd',
'layer13': 'cv',
'layer14': 'cv',
'layer15': 'cv',
},
'avvv4': {
'layer0': 'ad',
'layer1': 'cv',
'layer2': 'cv',
'layer3': 'cv',
'layer4': 'ad',
'layer5': 'cv',
'layer6': 'cv',
'layer7': 'cv',
'layer8': 'ad',
'layer9': 'cv',
'layer10': 'cv',
'layer11': 'cv',
'layer12': 'ad',
'layer13': 'cv',
'layer14': 'cv',
'layer15': 'cv',
},
'rvvv4': {
'layer0': 'rd',
'layer1': 'cv',
'layer2': 'cv',
'layer3': 'cv',
'layer4': 'rd',
'layer5': 'cv',
'layer6': 'cv',
'layer7': 'cv',
'layer8': 'rd',
'layer9': 'cv',
'layer10': 'cv',
'layer11': 'cv',
'layer12': 'rd',
'layer13': 'cv',
'layer14': 'cv',
'layer15': 'cv',
},
'cccv4': {
'layer0': 'cd',
'layer1': 'cd',
'layer2': 'cd',
'layer3': 'cv',
'layer4': 'cd',
'layer5': 'cd',
'layer6': 'cd',
'layer7': 'cv',
'layer8': 'cd',
'layer9': 'cd',
'layer10': 'cd',
'layer11': 'cv',
'layer12': 'cd',
'layer13': 'cd',
'layer14': 'cd',
'layer15': 'cv',
},
'aaav4': {
'layer0': 'ad',
'layer1': 'ad',
'layer2': 'ad',
'layer3': 'cv',
'layer4': 'ad',
'layer5': 'ad',
'layer6': 'ad',
'layer7': 'cv',
'layer8': 'ad',
'layer9': 'ad',
'layer10': 'ad',
'layer11': 'cv',
'layer12': 'ad',
'layer13': 'ad',
'layer14': 'ad',
'layer15': 'cv',
},
'rrrv4': {
'layer0': 'rd',
'layer1': 'rd',
'layer2': 'rd',
'layer3': 'cv',
'layer4': 'rd',
'layer5': 'rd',
'layer6': 'rd',
'layer7': 'cv',
'layer8': 'rd',
'layer9': 'rd',
'layer10': 'rd',
'layer11': 'cv',
'layer12': 'rd',
'layer13': 'rd',
'layer14': 'rd',
'layer15': 'cv',
},
'c16': {
'layer0': 'cd',
'layer1': 'cd',
'layer2': 'cd',
'layer3': 'cd',
'layer4': 'cd',
'layer5': 'cd',
'layer6': 'cd',
'layer7': 'cd',
'layer8': 'cd',
'layer9': 'cd',
'layer10': 'cd',
'layer11': 'cd',
'layer12': 'cd',
'layer13': 'cd',
'layer14': 'cd',
'layer15': 'cd',
},
'a16': {
'layer0': 'ad',
'layer1': 'ad',
'layer2': 'ad',
'layer3': 'ad',
'layer4': 'ad',
'layer5': 'ad',
'layer6': 'ad',
'layer7': 'ad',
'layer8': 'ad',
'layer9': 'ad',
'layer10': 'ad',
'layer11': 'ad',
'layer12': 'ad',
'layer13': 'ad',
'layer14': 'ad',
'layer15': 'ad',
},
'r16': {
'layer0': 'rd',
'layer1': 'rd',
'layer2': 'rd',
'layer3': 'rd',
'layer4': 'rd',
'layer5': 'rd',
'layer6': 'rd',
'layer7': 'rd',
'layer8': 'rd',
'layer9': 'rd',
'layer10': 'rd',
'layer11': 'rd',
'layer12': 'rd',
'layer13': 'rd',
'layer14': 'rd',
'layer15': 'rd',
},
'carv4': {
'layer0': 'cd',
'layer1': 'ad',
'layer2': 'rd',
'layer3': 'cv',
'layer4': 'cd',
'layer5': 'ad',
'layer6': 'rd',
'layer7': 'cv',
'layer8': 'cd',
'layer9': 'ad',
'layer10': 'rd',
'layer11': 'cv',
'layer12': 'cd',
'layer13': 'ad',
'layer14': 'rd',
'layer15': 'cv',
},
}
def createConvFunc(op_type):
assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type)
if op_type == 'cv':
return F.conv2d
if op_type == 'cd':
def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2'
assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3'
assert padding == dilation, 'padding for cd_conv set wrong'
weights_c = weights.sum(dim=[2, 3], keepdim=True)
yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups)
y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
return y - yc
return func
elif op_type == 'ad':
def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2'
assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3'
assert padding == dilation, 'padding for ad_conv set wrong'
shape = weights.shape
weights = weights.view(shape[0], shape[1], -1)
weights_conv = (weights - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise
y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
return y
return func
elif op_type == 'rd':
def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2'
assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3'
padding = 2 * dilation
shape = weights.shape
if weights.is_cuda:
buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0)
else:
buffer = torch.zeros(shape[0], shape[1], 5 * 5).to(weights.device)
weights = weights.view(shape[0], shape[1], -1)
buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:]
buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:]
buffer[:, :, 12] = 0
buffer = buffer.view(shape[0], shape[1], 5, 5)
y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
return y
return func
else:
print('impossible to be here unless you force that')
return None
class Conv2d(nn.Module):
def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
super(Conv2d, self).__init__()
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups')
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
self.pdc = pdc
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, input):
return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
class CSAM(nn.Module):
"""
Compact Spatial Attention Module
"""
def __init__(self, channels):
super(CSAM, self).__init__()
mid_channels = 4
self.relu1 = nn.ReLU()
self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False)
self.sigmoid = nn.Sigmoid()
nn.init.constant_(self.conv1.bias, 0)
def forward(self, x):
y = self.relu1(x)
y = self.conv1(y)
y = self.conv2(y)
y = self.sigmoid(y)
return x * y
class CDCM(nn.Module):
"""
Compact Dilation Convolution based Module
"""
def __init__(self, in_channels, out_channels):
super(CDCM, self).__init__()
self.relu1 = nn.ReLU()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False)
self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False)
self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False)
self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False)
nn.init.constant_(self.conv1.bias, 0)
def forward(self, x):
x = self.relu1(x)
x = self.conv1(x)
x1 = self.conv2_1(x)
x2 = self.conv2_2(x)
x3 = self.conv2_3(x)
x4 = self.conv2_4(x)
return x1 + x2 + x3 + x4
class MapReduce(nn.Module):
"""
Reduce feature maps into a single edge map
"""
def __init__(self, channels):
super(MapReduce, self).__init__()
self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0)
nn.init.constant_(self.conv.bias, 0)
def forward(self, x):
return self.conv(x)
class PDCBlock(nn.Module):
def __init__(self, pdc, inplane, ouplane, stride=1):
super(PDCBlock, self).__init__()
self.stride=stride
self.stride=stride
if self.stride > 1:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
self.conv1 = Conv2d(pdc, inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
self.relu2 = nn.ReLU()
self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
def forward(self, x):
if self.stride > 1:
x = self.pool(x)
y = self.conv1(x)
y = self.relu2(y)
y = self.conv2(y)
if self.stride > 1:
x = self.shortcut(x)
y = y + x
return y
class PDCBlock_converted(nn.Module):
"""
CPDC, APDC can be converted to vanilla 3x3 convolution
RPDC can be converted to vanilla 5x5 convolution
"""
def __init__(self, pdc, inplane, ouplane, stride=1):
super(PDCBlock_converted, self).__init__()
self.stride=stride
if self.stride > 1:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
if pdc == 'rd':
self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=5, padding=2, groups=inplane, bias=False)
else:
self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
self.relu2 = nn.ReLU()
self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
def forward(self, x):
if self.stride > 1:
x = self.pool(x)
y = self.conv1(x)
y = self.relu2(y)
y = self.conv2(y)
if self.stride > 1:
x = self.shortcut(x)
y = y + x
return y
class PiDiNet(nn.Module):
def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False):
super(PiDiNet, self).__init__()
self.sa = sa
if dil is not None:
assert isinstance(dil, int), 'dil should be an int'
self.dil = dil
self.fuseplanes = []
self.inplane = inplane
if convert:
if pdcs[0] == 'rd':
init_kernel_size = 5
init_padding = 2
else:
init_kernel_size = 3
init_padding = 1
self.init_block = nn.Conv2d(3, self.inplane,
kernel_size=init_kernel_size, padding=init_padding, bias=False)
block_class = PDCBlock_converted
else:
self.init_block = Conv2d(pdcs[0], 3, self.inplane, kernel_size=3, padding=1)
block_class = PDCBlock
self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane)
self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane)
self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane)
self.fuseplanes.append(self.inplane) # C
inplane = self.inplane
self.inplane = self.inplane * 2
self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2)
self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane)
self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane)
self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane)
self.fuseplanes.append(self.inplane) # 2C
inplane = self.inplane
self.inplane = self.inplane * 2
self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2)
self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane)
self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane)
self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane)
self.fuseplanes.append(self.inplane) # 4C
self.block4_1 = block_class(pdcs[12], self.inplane, self.inplane, stride=2)
self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane)
self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane)
self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane)
self.fuseplanes.append(self.inplane) # 4C
self.conv_reduces = nn.ModuleList()
if self.sa and self.dil is not None:
self.attentions = nn.ModuleList()
self.dilations = nn.ModuleList()
for i in range(4):
self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
self.attentions.append(CSAM(self.dil))
self.conv_reduces.append(MapReduce(self.dil))
elif self.sa:
self.attentions = nn.ModuleList()
for i in range(4):
self.attentions.append(CSAM(self.fuseplanes[i]))
self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
elif self.dil is not None:
self.dilations = nn.ModuleList()
for i in range(4):
self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
self.conv_reduces.append(MapReduce(self.dil))
else:
for i in range(4):
self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
self.classifier = nn.Conv2d(4, 1, kernel_size=1) # has bias
nn.init.constant_(self.classifier.weight, 0.25)
nn.init.constant_(self.classifier.bias, 0)
# print('initialization done')
def get_weights(self):
conv_weights = []
bn_weights = []
relu_weights = []
for pname, p in self.named_parameters():
if 'bn' in pname:
bn_weights.append(p)
elif 'relu' in pname:
relu_weights.append(p)
else:
conv_weights.append(p)
return conv_weights, bn_weights, relu_weights
def forward(self, x):
H, W = x.size()[2:]
x = self.init_block(x)
x1 = self.block1_1(x)
x1 = self.block1_2(x1)
x1 = self.block1_3(x1)
x2 = self.block2_1(x1)
x2 = self.block2_2(x2)
x2 = self.block2_3(x2)
x2 = self.block2_4(x2)
x3 = self.block3_1(x2)
x3 = self.block3_2(x3)
x3 = self.block3_3(x3)
x3 = self.block3_4(x3)
x4 = self.block4_1(x3)
x4 = self.block4_2(x4)
x4 = self.block4_3(x4)
x4 = self.block4_4(x4)
x_fuses = []
if self.sa and self.dil is not None:
for i, xi in enumerate([x1, x2, x3, x4]):
x_fuses.append(self.attentions[i](self.dilations[i](xi)))
elif self.sa:
for i, xi in enumerate([x1, x2, x3, x4]):
x_fuses.append(self.attentions[i](xi))
elif self.dil is not None:
for i, xi in enumerate([x1, x2, x3, x4]):
x_fuses.append(self.dilations[i](xi))
else:
x_fuses = [x1, x2, x3, x4]
e1 = self.conv_reduces[0](x_fuses[0])
e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False)
e2 = self.conv_reduces[1](x_fuses[1])
e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False)
e3 = self.conv_reduces[2](x_fuses[2])
e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False)
e4 = self.conv_reduces[3](x_fuses[3])
e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False)
outputs = [e1, e2, e3, e4]
output = self.classifier(torch.cat(outputs, dim=1))
#if not self.training:
# return torch.sigmoid(output)
outputs.append(output)
outputs = [torch.sigmoid(r) for r in outputs]
return outputs
def config_model(model):
model_options = list(nets.keys())
assert model in model_options, \
'unrecognized model, please choose from %s' % str(model_options)
# print(str(nets[model]))
pdcs = []
for i in range(16):
layer_name = 'layer%d' % i
op = nets[model][layer_name]
pdcs.append(createConvFunc(op))
return pdcs
def pidinet():
pdcs = config_model('carv4')
dil = 24 #if args.dil else None
return PiDiNet(60, pdcs, dil=dil, sa=True)
if __name__ == '__main__':
model = pidinet()
ckp = torch.load('table5_pidinet.pth')['state_dict']
model.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
im = cv2.imread('examples/test_my/cat_v4.png')
im = img2tensor(im).unsqueeze(0)/255.
res = model(im)[-1]
res = res>0.5
res = res.float()
res = (res[0,0].cpu().data.numpy()*255.).astype(np.uint8)
print(res.shape)
cv2.imwrite('edge.png', res)

View File

@@ -86,12 +86,20 @@ def np_to_pil(image: np.ndarray) -> Image.Image:
def pil_to_cv2(image: Image.Image) -> np.ndarray:
"""Converts a PIL image to a CV2 image."""
return cv2.cvtColor(np.array(image, dtype=np.uint8), cv2.COLOR_RGB2BGR)
if image.mode == "RGBA":
return cv2.cvtColor(np.array(image, dtype=np.uint8), cv2.COLOR_RGBA2BGRA)
else:
return cv2.cvtColor(np.array(image, dtype=np.uint8), cv2.COLOR_RGB2BGR)
def cv2_to_pil(image: np.ndarray) -> Image.Image:
"""Converts a CV2 image to a PIL image."""
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
if image.ndim == 3 and image.shape[2] == 4:
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))
else:
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def normalize_image_channel_count(image: np.ndarray) -> np.ndarray:
@@ -217,3 +225,23 @@ def safe_step(x: np.ndarray, step: int = 2) -> np.ndarray:
y = x.astype(np.float32) * float(step + 1)
y = y.astype(np.int32).astype(np.float32) / float(step)
return y
def resize_to_multiple(image: np.ndarray, multiple: int) -> np.ndarray:
"""Resize an image to make its dimensions multiples of the given number."""
# Get the original dimensions
height, width = image.shape[:2]
# Calculate the scaling factor to make the dimensions multiples of the given number
new_width = (width // multiple) * multiple
new_height = int((new_width / width) * height)
# If new_height is not a multiple, adjust it
if new_height % multiple != 0:
new_height = (new_height // multiple) * multiple
# Resize the image
resized_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
return resized_image

View File

@@ -1,672 +0,0 @@
# Copyright (c) 2024 The InvokeAI Development team
"""LoRA model support."""
import bisect
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
from safetensors.torch import load_file
from typing_extensions import Self
import invokeai.backend.util.logging as logger
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.raw_model import RawModel
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
# layer_key: str
# @property
# def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
if "alpha" in values:
self.alpha = values["alpha"].item()
else:
self.alpha = None
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
)
else:
self.bias = None
self.rank = None # set in layer implementation
self.layer_key = layer_key
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
return self.bias
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
params = {"weight": self.get_weight(orig_module.weight)}
bias = self.get_bias(orig_module.bias)
if bias is not None:
params["bias"] = bias
return params
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
"""Log a warning if values contains unhandled keys."""
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
unknown_keys = set(values.keys()) - all_known_keys
if unknown_keys:
logger.warning(
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
)
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
# up: torch.Tensor
# mid: Optional[torch.Tensor]
# down: torch.Tensor
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
self.mid = values.get("lora_mid.weight", None)
self.rank = self.down.shape[0]
self.check_keys(
values,
{
"lora_up.weight",
"lora_down.weight",
"lora_mid.weight",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
class LoHALayer(LoRALayerBase):
# w1_a: torch.Tensor
# w1_b: torch.Tensor
# w2_a: torch.Tensor
# w2_b: torch.Tensor
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
super().__init__(layer_key, values)
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
self.t1 = values.get("hada_t1", None)
self.t2 = values.get("hada_t2", None)
self.rank = self.w1_b.shape[0]
self.check_keys(
values,
{
"hada_w1_a",
"hada_w1_b",
"hada_w2_a",
"hada_w2_b",
"hada_t1",
"hada_t2",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.t1 is None:
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoKRLayer(LoRALayerBase):
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
# w1_b: Optional[torch.Tensor] = None
# w2: Optional[torch.Tensor] = None
# w2_a: Optional[torch.Tensor] = None
# w2_b: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.w1 = values.get("lokr_w1", None)
if self.w1 is None:
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
else:
self.w1_b = None
self.w1_a = None
self.w2 = values.get("lokr_w2", None)
if self.w2 is None:
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
else:
self.w2_a = None
self.w2_b = None
self.t2 = values.get("lokr_t2", None)
if self.w1_b is not None:
self.rank = self.w1_b.shape[0]
elif self.w2_b is not None:
self.rank = self.w2_b.shape[0]
else:
self.rank = None # unscaled
self.check_keys(
values,
{
"lokr_w1",
"lokr_w1_a",
"lokr_w1_b",
"lokr_w2",
"lokr_w2_a",
"lokr_w2_b",
"lokr_t2",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
w1: Optional[torch.Tensor] = self.w1
if w1 is None:
assert self.w1_a is not None
assert self.w1_b is not None
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
assert self.w2_a is not None
assert self.w2_b is not None
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
assert w1 is not None
assert w2 is not None
weight = torch.kron(w1, w2)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
assert self.w1_a is not None
assert self.w1_b is not None
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
assert self.w2_a is not None
assert self.w2_b is not None
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class FullLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["diff"]
self.bias = values.get("diff_b", None)
self.rank = None # unscaled
self.check_keys(values, {"diff", "diff_b"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
class IA3Layer(LoRALayerBase):
# weight: torch.Tensor
# on_input: torch.Tensor
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["weight"]
self.on_input = values["on_input"]
self.rank = None # unscaled
self.check_keys(values, {"weight", "on_input"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
assert orig_weight is not None
return orig_weight * weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
model_size += self.on_input.nelement() * self.on_input.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)
class NormLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["w_norm"]
self.bias = values.get("b_norm", None)
self.rank = None # unscaled
self.check_keys(values, {"w_norm", "b_norm"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
class LoRAModelRaw(RawModel): # (torch.nn.Module):
_name: str
layers: Dict[str, AnyLoRALayer]
def __init__(
self,
name: str,
layers: Dict[str, AnyLoRALayer],
):
self._name = name
self.layers = layers
@property
def name(self) -> str:
return self._name
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
# TODO: try revert if exception?
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
model_size = 0
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size
@classmethod
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
diffusers format, then this function will have no effect.
This function is adapted from:
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
Args:
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
Raises:
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
Returns:
Dict[str, Tensor]: The diffusers-format state_dict.
"""
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
not_converted_count = 0 # The number of keys that were not converted.
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
# `input_blocks_4_1_proj_in`.
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
position = bisect.bisect_right(stability_unet_keys, search_key)
map_key = stability_unet_keys[position - 1]
# Now, check if the map_key *actually* matches the search_key.
if search_key.startswith(map_key):
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
new_state_dict[new_key] = value
converted_count += 1
else:
new_state_dict[full_key] = value
not_converted_count += 1
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
new_state_dict[full_key] = value
continue
else:
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
if converted_count > 0 and not_converted_count > 0:
raise ValueError(
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
f" not_converted={not_converted_count}"
)
return new_state_dict
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
base_model: Optional[BaseModelType] = None,
) -> Self:
device = device or torch.device("cpu")
dtype = dtype or torch.float32
if isinstance(file_path, str):
file_path = Path(file_path)
model = cls(
name=file_path.stem,
layers={},
)
if file_path.suffix == ".safetensors":
sd = load_file(file_path.absolute().as_posix(), device="cpu")
else:
sd = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(sd)
if base_model == BaseModelType.StableDiffusionXL:
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
for layer_key, values in state_dict.items():
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
# lora and locon
if "lora_up.weight" in values:
layer: AnyLoRALayer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_a" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1" in values or "lokr_w1_a" in values:
layer = LoKRLayer(layer_key, values)
# diff
elif "diff" in values:
layer = FullLayer(layer_key, values)
# ia3
elif "on_input" in values:
layer = IA3Layer(layer_key, values)
# norms
elif "w_norm" in values:
layer = NormLayer(layer_key, values)
else:
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
return model
@staticmethod
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = {}
state_dict_groupped[stem][leaf] = value
return state_dict_groupped
# code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
unet_conversion_map_layer = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
}

View File

View File

@@ -0,0 +1,209 @@
from typing import Dict
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
"""Checks if the provided state dict is likely in the Diffusers FLUX LoRA format.
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
# First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format).
all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys())
# Next, check that this is likely a FLUX model by spot-checking a few keys.
expected_keys = [
"transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
]
all_expected_keys_present = all(k in state_dict for k in expected_keys)
return all_keys_in_peft_format and all_expected_keys_present
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float) -> LoRAModelRaw: # pyright: ignore[reportRedeclaration] (state_dict is intentionally re-declared)
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
This function is based on:
https://github.com/huggingface/diffusers/blob/55ac421f7bb12fd00ccbef727be4dc2f3f920abb/scripts/convert_flux_to_diffusers.py
"""
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_by_layer(state_dict)
# Remove the "transformer." prefix from all keys.
grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
# Constants for FLUX.1
num_double_layers = 19
num_single_layers = 38
# inner_dim = 3072
# mlp_ratio = 4.0
layers: dict[str, AnyLoRALayer] = {}
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
if src_key in grouped_state_dict:
src_layer_dict = grouped_state_dict.pop(src_key)
layers[dst_key] = LoRALayer.from_state_dict_values(
values={
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
"alpha": torch.tensor(alpha),
},
)
assert len(src_layer_dict) == 0
def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None:
"""Handle the Q, K, V matrices for a transformer block. We need special handling because the diffusers format
stores them in separate matrices, whereas the BFL format used internally by InvokeAI concatenates them.
"""
# We expect that either all src keys are present or none of them are. Verify this.
keys_present = [key in grouped_state_dict for key in src_keys]
assert all(keys_present) or not any(keys_present)
# If none of the keys are present, return early.
if not any(keys_present):
return
src_layer_dicts = [grouped_state_dict.pop(key) for key in src_keys]
sub_layers: list[LoRALayerBase] = []
for src_layer_dict in src_layer_dicts:
sub_layers.append(
LoRALayer.from_state_dict_values(
values={
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
"alpha": torch.tensor(alpha),
},
)
)
assert len(src_layer_dict) == 0
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers, concat_axis=0)
# time_text_embed.timestep_embedder -> time_in.
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_1", "time_in.in_layer")
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_2", "time_in.out_layer")
# time_text_embed.text_embedder -> vector_in.
add_lora_layer_if_present("time_text_embed.text_embedder.linear_1", "vector_in.in_layer")
add_lora_layer_if_present("time_text_embed.text_embedder.linear_2", "vector_in.out_layer")
# time_text_embed.guidance_embedder -> guidance_in.
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_1", "guidance_in")
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_2", "guidance_in")
# context_embedder -> txt_in.
add_lora_layer_if_present("context_embedder", "txt_in")
# x_embedder -> img_in.
add_lora_layer_if_present("x_embedder", "img_in")
# Double transformer blocks.
for i in range(num_double_layers):
# norms.
add_lora_layer_if_present(f"transformer_blocks.{i}.norm1.linear", f"double_blocks.{i}.img_mod.lin")
add_lora_layer_if_present(f"transformer_blocks.{i}.norm1_context.linear", f"double_blocks.{i}.txt_mod.lin")
# Q, K, V
add_qkv_lora_layer_if_present(
[
f"transformer_blocks.{i}.attn.to_q",
f"transformer_blocks.{i}.attn.to_k",
f"transformer_blocks.{i}.attn.to_v",
],
f"double_blocks.{i}.img_attn.qkv",
)
add_qkv_lora_layer_if_present(
[
f"transformer_blocks.{i}.attn.add_q_proj",
f"transformer_blocks.{i}.attn.add_k_proj",
f"transformer_blocks.{i}.attn.add_v_proj",
],
f"double_blocks.{i}.txt_attn.qkv",
)
# ff img_mlp
add_lora_layer_if_present(
f"transformer_blocks.{i}.ff.net.0.proj",
f"double_blocks.{i}.img_mlp.0",
)
add_lora_layer_if_present(
f"transformer_blocks.{i}.ff.net.2",
f"double_blocks.{i}.img_mlp.2",
)
# ff txt_mlp
add_lora_layer_if_present(
f"transformer_blocks.{i}.ff_context.net.0.proj",
f"double_blocks.{i}.txt_mlp.0",
)
add_lora_layer_if_present(
f"transformer_blocks.{i}.ff_context.net.2",
f"double_blocks.{i}.txt_mlp.2",
)
# output projections.
add_lora_layer_if_present(
f"transformer_blocks.{i}.attn.to_out.0",
f"double_blocks.{i}.img_attn.proj",
)
add_lora_layer_if_present(
f"transformer_blocks.{i}.attn.to_add_out",
f"double_blocks.{i}.txt_attn.proj",
)
# Single transformer blocks.
for i in range(num_single_layers):
# norms
add_lora_layer_if_present(
f"single_transformer_blocks.{i}.norm.linear",
f"single_blocks.{i}.modulation.lin",
)
# Q, K, V, mlp
add_qkv_lora_layer_if_present(
[
f"single_transformer_blocks.{i}.attn.to_q",
f"single_transformer_blocks.{i}.attn.to_k",
f"single_transformer_blocks.{i}.attn.to_v",
f"single_transformer_blocks.{i}.proj_mlp",
],
f"single_blocks.{i}.linear1",
)
# Output projections.
add_lora_layer_if_present(
f"single_transformer_blocks.{i}.proj_out",
f"single_blocks.{i}.linear2",
)
# Final layer.
add_lora_layer_if_present("proj_out", "final_layer.linear")
# Assert that all keys were processed.
assert len(grouped_state_dict) == 0
return LoRAModelRaw(layers=layers)
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
"""Groups the keys in the state dict by layer."""
layer_dict: dict[str, dict[str, torch.Tensor]] = {}
for key in state_dict:
# Split the 'lora_A.weight' or 'lora_B.weight' suffix from the layer name.
parts = key.rsplit(".", maxsplit=2)
layer_name = parts[0]
key_name = ".".join(parts[1:])
if layer_name not in layer_dict:
layer_dict[layer_name] = {}
layer_dict[layer_name][key_name] = state_dict[key]
return layer_dict

View File

@@ -0,0 +1,80 @@
import re
from typing import Any, Dict, TypeVar
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
# A regex pattern that matches all of the keys in the Kohya FLUX LoRA format.
# Example keys:
# lora_unet_double_blocks_0_img_attn_proj.alpha
# lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
# lora_unet_double_blocks_0_img_attn_proj.lora_up.weight
FLUX_KOHYA_KEY_REGEX = (
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
)
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
return all(re.match(FLUX_KOHYA_KEY_REGEX, k) for k in state_dict.keys())
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
layer_name, param_name = key.split(".", 1)
if layer_name not in grouped_state_dict:
grouped_state_dict[layer_name] = {}
grouped_state_dict[layer_name][param_name] = value
# Convert the state dict to the InvokeAI format.
grouped_state_dict = convert_flux_kohya_state_dict_to_invoke_format(grouped_state_dict)
# Create LoRA layers.
layers: dict[str, AnyLoRALayer] = {}
for layer_key, layer_state_dict in grouped_state_dict.items():
layers[layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
# Create and return the LoRAModelRaw.
return LoRAModelRaw(layers=layers)
T = TypeVar("T")
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
Example key conversions:
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
"""
def replace_func(match: re.Match[str]) -> str:
s = f"{match.group(1)}.{match.group(2)}.{match.group(3)}"
if match.group(4):
s += f".{match.group(4)}"
return s
converted_dict: dict[str, T] = {}
for k, v in state_dict.items():
match = re.match(FLUX_KOHYA_KEY_REGEX, k)
if match:
new_key = re.sub(FLUX_KOHYA_KEY_REGEX, replace_func, k)
converted_dict[new_key] = v
else:
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
return converted_dict

View File

@@ -0,0 +1,29 @@
from typing import Dict
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_state(state_dict)
layers: dict[str, AnyLoRALayer] = {}
for layer_key, values in grouped_state_dict.items():
layers[layer_key] = any_lora_layer_from_state_dict(values)
return LoRAModelRaw(layers=layers)
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = {}
state_dict_groupped[stem][leaf] = value
return state_dict_groupped

View File

@@ -0,0 +1,154 @@
import bisect
from typing import Dict, List, Tuple, TypeVar
T = TypeVar("T")
def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, T]) -> dict[str, T]:
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
diffusers format, then this function will have no effect.
This function is adapted from:
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
Args:
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
Raises:
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
Returns:
Dict[str, Tensor]: The diffusers-format state_dict.
"""
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
not_converted_count = 0 # The number of keys that were not converted.
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
# `input_blocks_4_1_proj_in`.
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict: dict[str, T] = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
position = bisect.bisect_right(stability_unet_keys, search_key)
map_key = stability_unet_keys[position - 1]
# Now, check if the map_key *actually* matches the search_key.
if search_key.startswith(map_key):
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
new_state_dict[new_key] = value
converted_count += 1
else:
new_state_dict[full_key] = value
not_converted_count += 1
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
new_state_dict[full_key] = value
continue
else:
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
if converted_count > 0 and not_converted_count > 0:
raise ValueError(
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
f" not_converted={not_converted_count}"
)
return new_state_dict
# code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def _make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
unet_conversion_map_layer: list[tuple[str, str]] = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map: list[tuple[str, str]] = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in _make_sdxl_unet_conversion_map()
}

View File

View File

@@ -0,0 +1,11 @@
from typing import Union
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.full_layer import FullLayer
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
from invokeai.backend.lora.layers.loha_layer import LoHALayer
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer]

View File

@@ -0,0 +1,46 @@
from typing import List, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
class ConcatenatedLoRALayer(LoRALayerBase):
"""A LoRA layer that is composed of multiple LoRA layers concatenated along a specified axis.
This class was created to handle a special case with FLUX LoRA models. In the BFL FLUX model format, the attention
Q, K, V matrices are concatenated along the first dimension. In the diffusers LoRA format, the Q, K, V matrices are
stored as separate tensors. This class enables diffusers LoRA layers to be used in BFL FLUX models.
"""
def __init__(self, lora_layers: List[LoRALayerBase], concat_axis: int = 0):
super().__init__(alpha=None, bias=None)
self.lora_layers = torch.nn.ModuleList(lora_layers)
self.concat_axis = concat_axis
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
# TODO(ryand): Currently, we pass orig_weight=None to the sub-layers. If we want to support sub-layers that
# require this value, we will need to implement chunking of the original weight tensor here.
# Note that we must apply the sub-layer scales here.
layer_weights = [lora_layer.get_weight(None) * lora_layer.scale() for lora_layer in self.lora_layers] # pyright: ignore[reportArgumentType]
return torch.cat(layer_weights, dim=self.concat_axis)
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
# TODO(ryand): Currently, we pass orig_bias=None to the sub-layers. If we want to support sub-layers that
# require this value, we will need to implement chunking of the original bias tensor here.
# Note that we must apply the sub-layer scales here.
layer_biases: list[torch.Tensor] = []
for lora_layer in self.lora_layers:
layer_bias = lora_layer.get_bias(None)
if layer_bias is not None:
layer_biases.append(layer_bias * lora_layer.scale())
if len(layer_biases) == 0:
return None
assert len(layer_biases) == len(self.lora_layers)
return torch.cat(layer_biases, dim=self.concat_axis)
def calc_size(self) -> int:
return sum(lora_layer.calc_size() for lora_layer in self.lora_layers)

View File

@@ -0,0 +1,26 @@
from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
class FullLayer(LoRALayerBase):
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor]):
super().__init__(alpha=None, bias=bias)
self.weight = torch.nn.Parameter(weight)
@classmethod
def from_state_dict_values(
cls,
values: Dict[str, torch.Tensor],
):
layer = cls(weight=values["diff"], bias=values.get("diff_b", None))
cls.warn_on_unhandled_keys(values=values, handled_keys={"diff", "diff_b"})
return layer
def rank(self) -> int | None:
return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight

View File

@@ -0,0 +1,53 @@
from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
class IA3Layer(LoRALayerBase):
"""IA3 Layer
Example model for testing this layer type: https://civitai.com/models/123930/gwendolyn-tennyson-ben-10-ia3
"""
def __init__(self, weight: torch.Tensor, on_input: torch.Tensor, bias: Optional[torch.Tensor]):
super().__init__(alpha=None, bias=bias)
self.weight = torch.nn.Parameter(weight)
self.on_input = torch.nn.Parameter(on_input)
def rank(self) -> int | None:
return None
@classmethod
def from_state_dict_values(
cls,
values: Dict[str, torch.Tensor],
):
bias = cls._parse_bias(
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
)
layer = cls(
weight=values["weight"],
on_input=values["on_input"],
bias=bias,
)
cls.warn_on_unhandled_keys(
values=values,
handled_keys={
# Default keys.
"bias_indices",
"bias_values",
"bias_size",
# Layer-specific keys.
"weight",
"on_input",
},
)
return layer
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
return orig_weight * weight

View File

@@ -0,0 +1,85 @@
from typing import Dict
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
class LoHALayer(LoRALayerBase):
"""LoHA LyCoris layer.
Example model for testing this layer type: https://civitai.com/models/27397/loha-renoir-the-dappled-light-style
"""
def __init__(
self,
w1_a: torch.Tensor,
w1_b: torch.Tensor,
w2_a: torch.Tensor,
w2_b: torch.Tensor,
t1: torch.Tensor | None,
t2: torch.Tensor | None,
alpha: float | None,
bias: torch.Tensor | None,
):
super().__init__(alpha=alpha, bias=bias)
self.w1_a = torch.nn.Parameter(w1_a)
self.w1_b = torch.nn.Parameter(w1_b)
self.w2_a = torch.nn.Parameter(w2_a)
self.w2_b = torch.nn.Parameter(w2_b)
self.t1 = torch.nn.Parameter(t1) if t1 is not None else None
self.t2 = torch.nn.Parameter(t2) if t2 is not None else None
assert (self.t1 is None) == (self.t2 is None)
def rank(self) -> int | None:
return self.w1_b.shape[0]
@classmethod
def from_state_dict_values(
cls,
values: Dict[str, torch.Tensor],
):
alpha = cls._parse_alpha(values.get("alpha", None))
bias = cls._parse_bias(
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
)
layer = cls(
w1_a=values["hada_w1_a"],
w1_b=values["hada_w1_b"],
w2_a=values["hada_w2_a"],
w2_b=values["hada_w2_b"],
t1=values.get("hada_t1", None),
t2=values.get("hada_t2", None),
alpha=alpha,
bias=bias,
)
cls.warn_on_unhandled_keys(
values=values,
handled_keys={
# Default keys.
"alpha",
"bias_indices",
"bias_values",
"bias_size",
# Layer-specific keys.
"hada_w1_a",
"hada_w1_b",
"hada_w2_a",
"hada_w2_b",
"hada_t1",
"hada_t2",
},
)
return layer
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.t1 is None:
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
return weight

View File

@@ -0,0 +1,110 @@
from typing import Dict
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
class LoKRLayer(LoRALayerBase):
"""LoKR LyCoris layer.
Example model for testing this layer type: https://civitai.com/models/346747/lokrnekopara-allgirl-for-jru2
"""
def __init__(
self,
w1: torch.Tensor | None,
w1_a: torch.Tensor | None,
w1_b: torch.Tensor | None,
w2: torch.Tensor | None,
w2_a: torch.Tensor | None,
w2_b: torch.Tensor | None,
t2: torch.Tensor | None,
alpha: float | None,
bias: torch.Tensor | None,
):
super().__init__(alpha=alpha, bias=bias)
self.w1 = torch.nn.Parameter(w1) if w1 is not None else None
self.w1_a = torch.nn.Parameter(w1_a) if w1_a is not None else None
self.w1_b = torch.nn.Parameter(w1_b) if w1_b is not None else None
self.w2 = torch.nn.Parameter(w2) if w2 is not None else None
self.w2_a = torch.nn.Parameter(w2_a) if w2_a is not None else None
self.w2_b = torch.nn.Parameter(w2_b) if w2_b is not None else None
self.t2 = torch.nn.Parameter(t2) if t2 is not None else None
# Validate parameters.
assert (self.w1 is None) != (self.w1_a is None)
assert (self.w1_a is None) == (self.w1_b is None)
assert (self.w2 is None) != (self.w2_a is None)
def rank(self) -> int | None:
if self.w1_b is not None:
return self.w1_b.shape[0]
elif self.w2_b is not None:
return self.w2_b.shape[0]
else:
return None
@classmethod
def from_state_dict_values(
cls,
values: Dict[str, torch.Tensor],
):
alpha = cls._parse_alpha(values.get("alpha", None))
bias = cls._parse_bias(
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
)
layer = cls(
w1=values.get("lokr_w1", None),
w1_a=values.get("lokr_w1_a", None),
w1_b=values.get("lokr_w1_b", None),
w2=values.get("lokr_w2", None),
w2_a=values.get("lokr_w2_a", None),
w2_b=values.get("lokr_w2_b", None),
t2=values.get("lokr_t2", None),
alpha=alpha,
bias=bias,
)
cls.warn_on_unhandled_keys(
values,
{
# Default keys.
"alpha",
"bias_indices",
"bias_values",
"bias_size",
# Layer-specific keys.
"lokr_w1",
"lokr_w1_a",
"lokr_w1_b",
"lokr_w2",
"lokr_w2_a",
"lokr_w2_b",
"lokr_t2",
},
)
return layer
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
w1 = self.w1
if w1 is None:
assert self.w1_a is not None
assert self.w1_b is not None
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
assert self.w2_a is not None
assert self.w2_b is not None
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
weight = torch.kron(w1, w2)
return weight

View File

@@ -0,0 +1,69 @@
from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
class LoRALayer(LoRALayerBase):
def __init__(
self,
up: torch.Tensor,
mid: Optional[torch.Tensor],
down: torch.Tensor,
alpha: float | None,
bias: Optional[torch.Tensor],
):
super().__init__(alpha, bias)
self.up = torch.nn.Parameter(up)
self.mid = torch.nn.Parameter(mid) if mid is not None else None
self.down = torch.nn.Parameter(down)
self.bias = torch.nn.Parameter(bias) if bias is not None else None
@classmethod
def from_state_dict_values(
cls,
values: Dict[str, torch.Tensor],
):
alpha = cls._parse_alpha(values.get("alpha", None))
bias = cls._parse_bias(
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
)
layer = cls(
up=values["lora_up.weight"],
down=values["lora_down.weight"],
mid=values.get("lora_mid.weight", None),
alpha=alpha,
bias=bias,
)
cls.warn_on_unhandled_keys(
values=values,
handled_keys={
# Default keys.
"alpha",
"bias_indices",
"bias_values",
"bias_size",
# Layer-specific keys.
"lora_up.weight",
"lora_down.weight",
"lora_mid.weight",
},
)
return layer
def rank(self) -> int:
return self.down.shape[0]
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight

View File

@@ -0,0 +1,68 @@
from typing import Dict, Optional, Set
import torch
import invokeai.backend.util.logging as logger
class LoRALayerBase(torch.nn.Module):
"""Base class for all LoRA-like patching layers."""
def __init__(self, alpha: float | None, bias: torch.Tensor | None):
super().__init__()
self._alpha = alpha
self.bias = torch.nn.Parameter(bias) if bias is not None else None
@classmethod
def _parse_bias(
cls, bias_indices: torch.Tensor | None, bias_values: torch.Tensor | None, bias_size: torch.Tensor | None
) -> torch.Tensor | None:
assert (bias_indices is None) == (bias_values is None) == (bias_size is None)
bias = None
if bias_indices is not None:
bias = torch.sparse_coo_tensor(bias_indices, bias_values, tuple(bias_size))
return bias
@classmethod
def _parse_alpha(
cls,
alpha: torch.Tensor | None,
) -> float | None:
return alpha.item() if alpha is not None else None
def rank(self) -> int | None:
raise NotImplementedError()
def scale(self) -> float:
if self._alpha is None or self.rank() is None:
return 1.0
return self._alpha / self.rank()
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
return self.bias
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
params = {"weight": self.get_weight(orig_module.weight)}
bias = self.get_bias(orig_module.bias)
if bias is not None:
params["bias"] = bias
return params
@classmethod
def warn_on_unhandled_keys(cls, values: Dict[str, torch.Tensor], handled_keys: Set[str]):
"""Log a warning if values contains unhandled keys."""
unknown_keys = set(values.keys()) - handled_keys
if unknown_keys:
logger.warning(
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Unexpected keys: {unknown_keys}"
)
def calc_size(self) -> int:
# HACK(ryand): Fix this issue with circular imports.
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self)

View File

@@ -0,0 +1,26 @@
from typing import Dict
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
class NormLayer(LoRALayerBase):
def __init__(self, weight: torch.Tensor, bias: torch.Tensor | None):
super().__init__(alpha=None, bias=bias)
self.weight = torch.nn.Parameter(weight)
@classmethod
def from_state_dict_values(
cls,
values: Dict[str, torch.Tensor],
):
layer = cls(weight=values["w_norm"], bias=values.get("b_norm", None))
cls.warn_on_unhandled_keys(values, {"w_norm", "b_norm"})
return layer
def rank(self) -> int | None:
return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight

View File

@@ -0,0 +1,33 @@
from typing import Dict
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.full_layer import FullLayer
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
from invokeai.backend.lora.layers.loha_layer import LoHALayer
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
if "lora_up.weight" in state_dict:
# LoRA a.k.a LoCon
return LoRALayer.from_state_dict_values(state_dict)
elif "hada_w1_a" in state_dict:
return LoHALayer.from_state_dict_values(state_dict)
elif "lokr_w1" in state_dict or "lokr_w1_a" in state_dict:
return LoKRLayer.from_state_dict_values(state_dict)
elif "diff" in state_dict:
# Full a.k.a Diff
return FullLayer.from_state_dict_values(state_dict)
elif "on_input" in state_dict:
return IA3Layer.from_state_dict_values(state_dict)
elif "w_norm" in state_dict:
return NormLayer.from_state_dict_values(state_dict)
else:
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")

View File

@@ -0,0 +1,22 @@
# Copyright (c) 2024 The InvokeAI Development team
from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.raw_model import RawModel
class LoRAModelRaw(RawModel): # (torch.nn.Module):
def __init__(self, layers: Dict[str, AnyLoRALayer]):
self.layers = layers
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
model_size = 0
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size

View File

@@ -0,0 +1,274 @@
from contextlib import contextmanager
from typing import Dict, Iterable, Optional, Tuple
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
ConcatenatedLoRALinearSidecarLayer,
)
from invokeai.backend.lora.sidecar_layers.lora.lora_conv_sidecar_layer import (
LoRAConv1dSidecarLayer,
LoRAConv2dSidecarLayer,
LoRAConv3dSidecarLayer,
)
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
class LoraPatcher:
@staticmethod
@torch.no_grad()
@contextmanager
def apply_lora_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
):
"""Apply one or more LoRA patches to a model within a context manager.
:param model: The model to patch.
:param loras: An iterator that returns tuples of LoRA patches and associated weights. An iterator is used so
that the LoRA patches do not need to be loaded into memory all at once.
:param prefix: The keys in the patches will be filtered to only include weights with this prefix.
:cached_weights: Read-only copy of the model's state dict in CPU, for efficient unpatching purposes.
"""
original_weights = OriginalWeightsStorage(cached_weights)
try:
for patch, patch_weight in patches:
LoraPatcher.apply_lora_patch(
model=model,
prefix=prefix,
patch=patch,
patch_weight=patch_weight,
original_weights=original_weights,
)
del patch
yield
finally:
for param_key, weight in original_weights.get_changed_weights():
model.get_parameter(param_key).copy_(weight)
@staticmethod
@torch.no_grad()
def apply_lora_patch(
model: torch.nn.Module,
prefix: str,
patch: LoRAModelRaw,
patch_weight: float,
original_weights: OriginalWeightsStorage,
):
"""
Apply a single LoRA patch to a model.
:param model: The model to patch.
:param patch: LoRA model to patch in.
:param patch_weight: LoRA patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
"""
if patch_weight == 0:
return
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
# without searching, but some legacy code still uses flattened keys.
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
prefix_len = len(prefix)
for layer_key, layer in patch.layers.items():
if not layer_key.startswith(prefix):
continue
module_key, module = LoraPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
layer_scale = layer.scale()
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, lora_param_weight in layer.get_parameters(module).items():
param_key = module_key + "." + param_name
module_param = module.get_parameter(param_name)
# Save original weight
original_weights.save(param_key, module_param)
if module_param.shape != lora_param_weight.shape:
lora_param_weight = lora_param_weight.reshape(module_param.shape)
lora_param_weight *= patch_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)
layer.to(device=TorchDevice.CPU_DEVICE)
@staticmethod
@torch.no_grad()
@contextmanager
def apply_lora_sidecar_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
):
original_modules: dict[str, torch.nn.Module] = {}
try:
for patch, patch_weight in patches:
LoraPatcher._apply_lora_sidecar_patch(
model=model,
prefix=prefix,
patch=patch,
patch_weight=patch_weight,
original_modules=original_modules,
)
yield
finally:
# Restore original modules.
# Note: This logic assumes no nested modules in original_modules.
for module_key, orig_module in original_modules.items():
module_parent_key, module_name = module_key.rsplit(".", 1)
parent_module = model.get_submodule(module_parent_key)
LoraPatcher._set_submodule(parent_module, module_name, orig_module)
@staticmethod
def _apply_lora_sidecar_patch(
model: torch.nn.Module,
patch: LoRAModelRaw,
patch_weight: float,
prefix: str,
original_modules: dict[str, torch.nn.Module],
):
if patch_weight == 0:
return
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
# without searching, but some legacy code still uses flattened keys.
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
prefix_len = len(prefix)
for layer_key, layer in patch.layers.items():
if not layer_key.startswith(prefix):
continue
module_key, module = LoraPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
# Initialize the LoRA sidecar layer.
lora_sidecar_layer = LoraPatcher._initialize_lora_sidecar_layer(module, layer, patch_weight)
# Move the LoRA sidecar layer to the same device/dtype as the orig module.
# TODO(ryand): Experiment with moving to the device first, then casting. This could be faster.
# HACK(ryand): Set the dtype properly here. We want to set it to the *compute* dtype of the original module.
# In the case of quantized layers, this may be different than the weight dtype.
lora_sidecar_layer.to(device=module.weight.device, dtype=torch.bfloat16)
if module_key in original_modules:
# The module has already been patched with a LoRASidecarModule. Append to it.
assert isinstance(module, LoRASidecarModule)
module.add_lora_layer(lora_sidecar_layer)
else:
# The module has not yet been patched with a LoRASidecarModule. Create one.
lora_sidecar_module = LoRASidecarModule(module, [lora_sidecar_layer])
original_modules[module_key] = module
module_parent_key, module_name = module_key.rsplit(".", 1)
module_parent = model.get_submodule(module_parent_key)
LoraPatcher._set_submodule(module_parent, module_name, lora_sidecar_module)
@staticmethod
def _initialize_lora_sidecar_layer(orig_layer: torch.nn.Module, lora_layer: AnyLoRALayer, patch_weight: float):
if isinstance(orig_layer, torch.nn.Linear):
if isinstance(lora_layer, LoRALayer):
return LoRALinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight)
elif isinstance(lora_layer, ConcatenatedLoRALayer):
return ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer=lora_layer, weight=patch_weight)
else:
raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}")
elif isinstance(orig_layer, torch.nn.Conv1d):
if isinstance(lora_layer, LoRALayer):
return LoRAConv1dSidecarLayer.from_layers(orig_layer, lora_layer, patch_weight)
else:
raise ValueError(f"Unsupported Conv1D LoRA layer type: {type(lora_layer)}")
elif isinstance(orig_layer, torch.nn.Conv2d):
if isinstance(lora_layer, LoRALayer):
return LoRAConv2dSidecarLayer.from_layers(orig_layer, lora_layer, patch_weight)
else:
raise ValueError(f"Unsupported Conv2D LoRA layer type: {type(lora_layer)}")
elif isinstance(orig_layer, torch.nn.Conv3d):
if isinstance(lora_layer, LoRALayer):
return LoRAConv3dSidecarLayer.from_layers(orig_layer, lora_layer, patch_weight)
else:
raise ValueError(f"Unsupported Conv3D LoRA layer type: {type(lora_layer)}")
else:
raise ValueError(f"Unsupported layer type: {type(orig_layer)}")
@staticmethod
def _set_submodule(parent_module: torch.nn.Module, module_name: str, submodule: torch.nn.Module):
try:
submodule_index = int(module_name)
# If the module name is an integer, then we use the __setitem__ method to set the submodule.
parent_module[submodule_index] = submodule
except ValueError:
# If the module name is not an integer, then we use the setattr method to set the submodule.
setattr(parent_module, module_name, submodule)
@staticmethod
def _get_submodule(
model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool
) -> tuple[str, torch.nn.Module]:
"""Get the submodule corresponding to the given layer key.
:param model: The model to search.
:param layer_key: The layer key to search for.
:param layer_key_is_flattened: Whether the layer key is flattened. If flattened, then all '.' have been replaced
with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly without
searching, but some legacy code still uses flattened keys.
:return: A tuple containing the module key and the submodule.
"""
if not layer_key_is_flattened:
return layer_key, model.get_submodule(layer_key)
# Handle flattened keys.
assert "." not in layer_key
module = model
module_key = ""
key_parts = layer_key.split("_")
submodule_name = key_parts.pop(0)
while len(key_parts) > 0:
try:
module = module.get_submodule(submodule_name)
module_key += "." + submodule_name
submodule_name = key_parts.pop(0)
except Exception:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = (module_key + "." + submodule_name).lstrip(".")
return module_key, module

View File

@@ -0,0 +1,32 @@
import torch
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
class ConcatenatedLoRALinearSidecarLayer(torch.nn.Module):
def __init__(
self,
concatenated_lora_layer: ConcatenatedLoRALayer,
weight: float,
):
super().__init__()
self._concatenated_lora_layer = concatenated_lora_layer
self._weight = weight
def forward(self, input: torch.Tensor) -> torch.Tensor:
assert input.ndim == 3
x_chunks: list[torch.Tensor] = []
for lora_layer in self._concatenated_lora_layer.lora_layers:
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
x_chunk *= self._weight * lora_layer.scale()
x_chunks.append(x_chunk)
# TODO(ryand): Generalize to support concat_axis != 0.
assert self._concatenated_lora_layer.concat_axis == 0
x = torch.cat(x_chunks, dim=-1)
return x

View File

@@ -0,0 +1,140 @@
import typing
import torch
from invokeai.backend.lora.layers.lora_layer import LoRALayer
class LoRAConvSidecarLayer(torch.nn.Module):
"""An implementation of a conv LoRA layer based on the paper 'LoRA: Low-Rank Adaptation of Large Language Models'.
(https://arxiv.org/pdf/2106.09685.pdf)
"""
@property
def conv_module(self) -> type[torch.nn.Conv1d | torch.nn.Conv2d | torch.nn.Conv3d]:
"""The conv module to be set by child classes. One of torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d."""
raise NotImplementedError(
"LoRAConvLayer cannot be used directly. Use LoRAConv1dLayer, LoRAConv2dLayer, or LoRAConv3dLayer instead."
)
def __init__(
self,
in_channels: int,
out_channels: int,
include_mid: bool,
rank: int,
alpha: float,
weight: float,
kernel_size: typing.Union[int, tuple[int]] = 1,
stride: typing.Union[int, tuple[int]] = 1,
padding: typing.Union[str, int, tuple[int]] = 0,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__()
if rank > min(in_channels, out_channels):
raise ValueError(f"LoRA rank {rank} must be less than or equal to {min(in_channels, out_channels)}")
self._down = self.conv_module(
in_channels,
rank,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False,
device=device,
dtype=dtype,
)
self._up = self.conv_module(rank, out_channels, kernel_size=1, stride=1, bias=False, device=device, dtype=dtype)
self._mid = None
if include_mid:
self._mid = self.conv_module(rank, rank, kernel_size=1, stride=1, bias=False, device=device, dtype=dtype)
# Register alpha as a buffer so that it is not trained, but still gets saved to the state_dict.
self.register_buffer("alpha", torch.tensor(alpha, device=device, dtype=dtype))
self._weight = weight
self._rank = rank
@classmethod
def from_layers(cls, orig_layer: torch.nn.Module, lora_layer: LoRALayer, weight: float):
# Initialize the LoRA layer.
with torch.device("meta"):
model = cls.from_orig_layer(
orig_layer,
include_mid=lora_layer.mid is not None,
rank=lora_layer.rank,
# TODO(ryand): Is this the right default in case of missing alpha?
alpha=lora_layer.alpha if lora_layer.alpha is not None else lora_layer.rank,
weight=weight,
)
# TODO(ryand): Are there cases where we need to reshape the weight matrices to match the conv layers?
# Inject weight into the LoRA layer.
assert model._up.weight.shape == lora_layer.up.shape
assert model._down.weight.shape == lora_layer.down.shape
model._up.weight.data = lora_layer.up
model._down.weight.data = lora_layer.down
if lora_layer.mid is not None:
assert model._mid is not None
assert model._mid.weight.shape == lora_layer.mid.shape
model._mid.weight.data = lora_layer.mid
return model
@classmethod
def from_orig_layer(
cls,
layer: torch.nn.Module,
include_mid: bool,
rank: int,
alpha: float,
weight: float,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
if not isinstance(layer, cls.conv_module):
raise TypeError(f"'{__class__.__name__}' cannot be initialized from a layer of type '{type(layer)}'.")
return cls(
in_channels=layer.in_channels,
out_channels=layer.out_channels,
include_mid=include_mid,
weight=weight,
kernel_size=layer.kernel_size,
stride=layer.stride,
padding=layer.padding,
rank=rank,
alpha=alpha,
device=layer.weight.device if device is None else device,
dtype=layer.weight.dtype if dtype is None else dtype,
)
def forward(self, x: torch.Tensor):
x = self._down(x)
if self._mid is not None:
x = self._mid(x)
x = self._up(x)
x *= self._weight * self.alpha / self._rank
return x
class LoRAConv1dSidecarLayer(LoRAConvSidecarLayer):
@property
def conv_module(self):
return torch.nn.Conv1d
class LoRAConv2dSidecarLayer(LoRAConvSidecarLayer):
@property
def conv_module(self):
return torch.nn.Conv2d
class LoRAConv3dSidecarLayer(LoRAConvSidecarLayer):
@property
def conv_module(self):
return torch.nn.Conv3d

Some files were not shown because too many files have changed in this diff Show More