Compare commits

...

14 Commits

Author SHA1 Message Date
blessedcoolant
3fdcdc7247 refactor: Simplify the entire ImageAlphaToOutline node
Remove general line width and line mode entirely. Instead now use independent sliders to control inner bleed and outer bleed allowing users to pick the perfect mask shape they want for inpainting.
2024-10-29 08:25:20 +05:30
blessedcoolant
2a3d819113 feat(nodes)(ui): Add line mode to ImageAlphaToOutline
Allows the user to pick where the outlines are drawn .. both inside and outside of the alpha region
2024-10-27 10:36:45 +05:30
blessedcoolant
50bda4c889 fix(nodes)(ui): Update Alpha to Outline node to work better with any image size 2024-10-27 10:07:12 +05:30
psychedelicious
f3015d22e5 feat(ui): add alpha to outline filter 2024-10-25 21:00:34 +10:00
psychedelicious
dd1936afd4 chore(ui): typegen 2024-10-25 21:00:25 +10:00
psychedelicious
01a220b1c9 feat(nodes): add img_alpha_to_outline node 2024-10-25 20:57:47 +10:00
psychedelicious
1f898430e9 feat(ui): use PiPlayFill for process buttons for filter & select object 2024-10-25 16:08:56 +10:00
psychedelicious
a7f3eda99b feat(ui): use fill style icons for Filter 2024-10-25 16:08:56 +10:00
psychedelicious
5ef1cb0067 feat(ui): use PiShapesFill icon for Select Object 2024-10-25 16:08:56 +10:00
psychedelicious
e7ecc4942e feat(ui): make canvas layer toolbar icons a bit larger 2024-10-25 16:08:56 +10:00
psychedelicious
948b482d1b feat(ui): "Auto-Mask" -> "Select Object" 2024-10-25 16:08:56 +10:00
psychedelicious
863af69138 feat(ui): support inverted selection in auto-mask 2024-10-25 16:08:56 +10:00
psychedelicious
4a8f0b874c chore(ui): typegen 2024-10-25 16:08:56 +10:00
psychedelicious
13d64efcfa feat(nodes): add invert to apply_tensor_mask_to_image 2024-10-25 16:08:56 +10:00
25 changed files with 394 additions and 81 deletions

View File

@@ -6,11 +6,7 @@ import cv2
import numpy
from PIL import Image, ImageChops, ImageFilter, ImageOps
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
Classification,
invocation,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import (
ColorField,
@@ -1055,3 +1051,78 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
image_dto = context.images.save(image=generated_image)
return ImageOutput.build(image_dto)
@invocation(
"img_alpha_to_outline",
title="Image Alpha to Outline",
tags=["image", "mask", "id"],
category="image",
version="1.0.0",
)
class ImageAlphaToOutlineInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Finds the outline of the alpha channel of an image, expands it and returns just the outline."""
image: ImageField = InputField(description="The input image. It should have some transparency.")
inner_line_width_percent: float = InputField(
default=5,
ge=0,
le=100,
description="The width of the inner outline as a percentage of image dimension",
)
outer_line_width_percent: float = InputField(
default=5,
ge=0,
le=100,
description="The width of the outer outline as a percentage of image dimension",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
img_pil = context.images.get_pil(self.image.image_name, mode="RGBA")
# Create a binary mask from the alpha channel
alpha = numpy.array(img_pil.split()[-1], dtype=numpy.uint8)
_, binary_mask = cv2.threshold(alpha, 0, 255, cv2.THRESH_BINARY)
# Find contours in the binary mask - effectively the outline of the alpha channel
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Calculate line widths based on smaller image dimension
smaller_dim = min(img_pil.size)
inner_line_width = int(smaller_dim * (self.inner_line_width_percent / 100))
outer_line_width = int(smaller_dim * (self.outer_line_width_percent / 100))
# Create inner and outer contour masks
contour_mask = numpy.zeros_like(binary_mask)
if self.inner_line_width_percent > 0:
cv2.drawContours(
image=contour_mask,
contours=contours,
contourIdx=-1,
color=(255,),
thickness=inner_line_width,
lineType=cv2.LINE_8,
)
contour_mask = numpy.minimum(contour_mask, binary_mask)
if self.outer_line_width_percent > 0:
outer_contour_mask = numpy.zeros_like(binary_mask)
cv2.drawContours(
image=outer_contour_mask,
contours=contours,
contourIdx=-1,
color=(255,),
thickness=outer_line_width,
lineType=cv2.LINE_8,
)
outer_contour_mask = cv2.bitwise_and(outer_contour_mask, cv2.bitwise_not(binary_mask))
contour_mask = cv2.bitwise_or(contour_mask, outer_contour_mask)
# Create result image with contour mask as alpha channel
result_rgba = numpy.zeros((contour_mask.shape[0], contour_mask.shape[1], 4), dtype=numpy.uint8)
result_rgba[..., 3] = contour_mask
result_img = Image.fromarray(result_rgba, "RGBA")
image_dto = context.images.save(image=result_img)
return ImageOutput.build(image_dto)

View File

@@ -165,6 +165,7 @@ class ApplyMaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
mask: TensorField = InputField(description="The mask tensor to apply.")
image: ImageField = InputField(description="The image to apply the mask to.")
invert: bool = InputField(default=False, description="Whether to invert the mask.")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, mode="RGBA")
@@ -179,6 +180,9 @@ class ApplyMaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
mask = mask > 0.5
mask_np = (mask.float() * 255).byte().cpu().numpy().astype(np.uint8)
if self.invert:
mask_np = 255 - mask_np
# Apply the mask only to the alpha channel where the original alpha is non-zero. This preserves the original
# image's transparency - else the transparent regions would end up as opaque black.

View File

@@ -1812,6 +1812,12 @@
"low_threshold": "Low Threshold",
"high_threshold": "High Threshold"
},
"alpha_to_outline": {
"label": "Alpha to Outline",
"description": "Outlines the opaque regions of the image with a line of the given width.",
"inner_line_width_percent": "Inner Line Width Percent",
"outer_line_width_percent": "Outer Line Width Percent"
},
"color_map": {
"label": "Color Map",
"description": "Create a color map from the selected layer.",
@@ -1886,9 +1892,10 @@
"apply": "Apply",
"cancel": "Cancel"
},
"segment": {
"autoMask": "Auto Mask",
"selectObject": {
"selectObject": "Select Object",
"pointType": "Point Type",
"invertSelection": "Invert Selection",
"include": "Include",
"exclude": "Exclude",
"neutral": "Neutral",
@@ -1896,8 +1903,9 @@
"saveAs": "Save As",
"cancel": "Cancel",
"process": "Process",
"help1": "Auto-mask creates a mask for a single target object. Add <Bold>Include</Bold> and <Bold>Exclude</Bold> points to indicate which parts of the layer are part of the target object.",
"help2": "Start with one <Bold>Include</Bold> point within the target object. Add more points to refine the mask. Fewer points typically produce better results.",
"help1": "Select a single target object. Add <Bold>Include</Bold> and <Bold>Exclude</Bold> points to indicate which parts of the layer are part of the target object.",
"help2": "Start with one <Bold>Include</Bold> point within the target object. Add more points to refine the selection. Fewer points typically produce better results.",
"help3": "Invert the selection to select everything except the target object.",
"clickToAdd": "Click on the layer to add a point",
"dragToMove": "Drag a point to move it",
"clickToRemove": "Click on a point to remove it"

View File

@@ -29,7 +29,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
<Menu>
<MenuButton
as={IconButton}
size="sm"
minW={8}
variant="link"
alignSelf="stretch"
tooltip={t('controlLayers.addLayer')}

View File

@@ -1,10 +1,10 @@
import { Flex, Spacer } from '@invoke-ai/ui-library';
import { EntityListGlobalActionBarAddLayerMenu } from 'features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarAddLayerMenu';
import { EntityListSelectedEntityActionBarAutoMaskButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarAutoMaskButton';
import { EntityListSelectedEntityActionBarDuplicateButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarDuplicateButton';
import { EntityListSelectedEntityActionBarFill } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarFill';
import { EntityListSelectedEntityActionBarFilterButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarFilterButton';
import { EntityListSelectedEntityActionBarOpacity } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarOpacity';
import { EntityListSelectedEntityActionBarSelectObjectButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarSelectObjectButton';
import { EntityListSelectedEntityActionBarTransformButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarTransformButton';
import { memo } from 'react';
@@ -17,7 +17,7 @@ export const EntityListSelectedEntityActionBar = memo(() => {
<Spacer />
<EntityListSelectedEntityActionBarFill />
<Flex h="full">
<EntityListSelectedEntityActionBarAutoMaskButton />
<EntityListSelectedEntityActionBarSelectObjectButton />
<EntityListSelectedEntityActionBarFilterButton />
<EntityListSelectedEntityActionBarTransformButton />
<EntityListSelectedEntityActionBarSaveToAssetsButton />

View File

@@ -23,7 +23,7 @@ export const EntityListSelectedEntityActionBarDuplicateButton = memo(() => {
<IconButton
onClick={onClick}
isDisabled={!selectedEntityIdentifier || isBusy}
size="sm"
minW={8}
variant="link"
alignSelf="stretch"
aria-label={t('controlLayers.duplicate')}

View File

@@ -5,7 +5,7 @@ import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/sel
import { isFilterableEntityIdentifier } from 'features/controlLayers/store/types';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiShootingStarBold } from 'react-icons/pi';
import { PiShootingStarFill } from 'react-icons/pi';
export const EntityListSelectedEntityActionBarFilterButton = memo(() => {
const { t } = useTranslation();
@@ -24,12 +24,12 @@ export const EntityListSelectedEntityActionBarFilterButton = memo(() => {
<IconButton
onClick={filter.start}
isDisabled={filter.isDisabled}
size="sm"
minW={8}
variant="link"
alignSelf="stretch"
aria-label={t('controlLayers.filter.filter')}
tooltip={t('controlLayers.filter.filter')}
icon={<PiShootingStarBold />}
icon={<PiShootingStarFill />}
/>
);
});

View File

@@ -31,7 +31,7 @@ export const EntityListSelectedEntityActionBarSaveToAssetsButton = memo(() => {
<IconButton
onClick={onClick}
isDisabled={!selectedEntityIdentifier || isBusy}
size="sm"
minW={8}
variant="link"
alignSelf="stretch"
aria-label={t('controlLayers.saveLayerToAssets')}

View File

@@ -5,9 +5,9 @@ import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/sel
import { isSegmentableEntityIdentifier } from 'features/controlLayers/store/types';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiMaskHappyBold } from 'react-icons/pi';
import { PiShapesFill } from 'react-icons/pi';
export const EntityListSelectedEntityActionBarAutoMaskButton = memo(() => {
export const EntityListSelectedEntityActionBarSelectObjectButton = memo(() => {
const { t } = useTranslation();
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
const segment = useEntitySegmentAnything(selectedEntityIdentifier);
@@ -24,14 +24,14 @@ export const EntityListSelectedEntityActionBarAutoMaskButton = memo(() => {
<IconButton
onClick={segment.start}
isDisabled={segment.isDisabled}
size="sm"
minW={8}
variant="link"
alignSelf="stretch"
aria-label={t('controlLayers.segment.autoMask')}
tooltip={t('controlLayers.segment.autoMask')}
icon={<PiMaskHappyBold />}
aria-label={t('controlLayers.selectObject.selectObject')}
tooltip={t('controlLayers.selectObject.selectObject')}
icon={<PiShapesFill />}
/>
);
});
EntityListSelectedEntityActionBarAutoMaskButton.displayName = 'EntityListSelectedEntityActionBarAutoMaskButton';
EntityListSelectedEntityActionBarSelectObjectButton.displayName = 'EntityListSelectedEntityActionBarSelectObjectButton';

View File

@@ -24,7 +24,7 @@ export const EntityListSelectedEntityActionBarTransformButton = memo(() => {
<IconButton
onClick={transform.start}
isDisabled={transform.isDisabled}
size="sm"
minW={8}
variant="link"
alignSelf="stretch"
aria-label={t('controlLayers.transform.transform')}

View File

@@ -10,7 +10,7 @@ import { CanvasDropArea } from 'features/controlLayers/components/CanvasDropArea
import { Filter } from 'features/controlLayers/components/Filters/Filter';
import { CanvasHUD } from 'features/controlLayers/components/HUD/CanvasHUD';
import { InvokeCanvasComponent } from 'features/controlLayers/components/InvokeCanvasComponent';
import { SegmentAnything } from 'features/controlLayers/components/SegmentAnything/SegmentAnything';
import { SelectObject } from 'features/controlLayers/components/SelectObject/SelectObject';
import { StagingAreaIsStagingGate } from 'features/controlLayers/components/StagingArea/StagingAreaIsStagingGate';
import { StagingAreaToolbar } from 'features/controlLayers/components/StagingArea/StagingAreaToolbar';
import { CanvasToolbar } from 'features/controlLayers/components/Toolbar/CanvasToolbar';
@@ -102,7 +102,7 @@ export const CanvasMainPanelContent = memo(() => {
<CanvasManagerProviderGate>
<Filter />
<Transform />
<SegmentAnything />
<SelectObject />
</CanvasManagerProviderGate>
</Flex>
<CanvasDropArea />

View File

@@ -21,7 +21,7 @@ import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/s
import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold, PiShootingStarBold, PiUploadBold } from 'react-icons/pi';
import { PiBoundingBoxBold, PiShootingStarFill, PiUploadBold } from 'react-icons/pi';
import type { ControlNetModelConfig, PostUploadAction, T2IAdapterModelConfig } from 'services/api/types';
const useControlLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier<'control_layer'>) => {
@@ -93,7 +93,7 @@ export const ControlLayerControlAdapter = memo(() => {
variant="link"
aria-label={t('controlLayers.filter.filter')}
tooltip={t('controlLayers.filter.filter')}
icon={<PiShootingStarBold />}
icon={<PiShootingStarFill />}
/>
<IconButton
onClick={pullBboxIntoLayer}

View File

@@ -6,7 +6,7 @@ import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/c
import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDuplicate';
import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/common/CanvasEntityMenuItemsFilter';
import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave';
import { CanvasEntityMenuItemsSegment } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSegment';
import { CanvasEntityMenuItemsSelectObject } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSelectObject';
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
import { ControlLayerMenuItemsConvertToSubMenu } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItemsConvertToSubMenu';
import { ControlLayerMenuItemsCopyToSubMenu } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItemsCopyToSubMenu';
@@ -24,7 +24,7 @@ export const ControlLayerMenuItems = memo(() => {
<MenuDivider />
<CanvasEntityMenuItemsTransform />
<CanvasEntityMenuItemsFilter />
<CanvasEntityMenuItemsSegment />
<CanvasEntityMenuItemsSelectObject />
<ControlLayerMenuItemsTransparencyEffect />
<MenuDivider />
<CanvasEntityMenuItemsCropToBbox />

View File

@@ -15,7 +15,7 @@ import { IMAGE_FILTERS } from 'features/controlLayers/store/filters';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { memo, useCallback, useMemo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsCounterClockwiseBold, PiCheckBold, PiShootingStarBold, PiXBold } from 'react-icons/pi';
import { PiArrowsCounterClockwiseBold, PiCheckBold, PiPlayFill, PiXBold } from 'react-icons/pi';
const FilterContent = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
@@ -89,7 +89,7 @@ const FilterContent = memo(
<ButtonGroup isAttached={false} size="sm" w="full">
<Button
variant="ghost"
leftIcon={<PiShootingStarBold />}
leftIcon={<PiPlayFill />}
onClick={adapter.filterer.processImmediate}
isLoading={isProcessing}
loadingText={t('controlLayers.filter.process')}

View File

@@ -0,0 +1,69 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { AlphaToOutlineFilterConfig } from 'features/controlLayers/store/filters';
import { IMAGE_FILTERS } from 'features/controlLayers/store/filters';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<AlphaToOutlineFilterConfig>;
const DEFAULTS = IMAGE_FILTERS.alpha_to_outline.buildDefaults();
export const FilterAlphaToOutline = ({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleInnerLineWidthPercentChange = useCallback(
(v: number) => {
onChange({ ...config, inner_line_width_percent: v });
},
[onChange, config]
);
const handleOuterLineWidthPercentChange = useCallback(
(v: number) => {
onChange({ ...config, outer_line_width_percent: v });
},
[onChange, config]
);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlLayers.filter.alpha_to_outline.inner_line_width_percent')}</FormLabel>
<CompositeSlider
value={config.inner_line_width_percent}
onChange={handleInnerLineWidthPercentChange}
defaultValue={DEFAULTS.inner_line_width_percent}
min={0}
max={100}
/>
<CompositeNumberInput
value={config.inner_line_width_percent}
onChange={handleInnerLineWidthPercentChange}
defaultValue={DEFAULTS.inner_line_width_percent}
min={0}
max={100}
/>
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlLayers.filter.alpha_to_outline.outer_line_width_percent')}</FormLabel>
<CompositeSlider
value={config.outer_line_width_percent}
onChange={handleOuterLineWidthPercentChange}
defaultValue={DEFAULTS.outer_line_width_percent}
min={0}
max={100}
/>
<CompositeNumberInput
value={config.outer_line_width_percent}
onChange={handleOuterLineWidthPercentChange}
defaultValue={DEFAULTS.outer_line_width_percent}
min={0}
max={100}
/>
</FormControl>
</>
);
};
FilterAlphaToOutline.displayName = 'FilterAlphaToOutline';

View File

@@ -1,4 +1,5 @@
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { FilterAlphaToOutline } from 'features/controlLayers/components/Filters/FilterAlphaToOutline';
import { FilterCannyEdgeDetection } from 'features/controlLayers/components/Filters/FilterCannyEdgeDetection';
import { FilterColorMap } from 'features/controlLayers/components/Filters/FilterColorMap';
import { FilterContentShuffle } from 'features/controlLayers/components/Filters/FilterContentShuffle';
@@ -63,6 +64,10 @@ export const FilterSettings = memo(({ filterConfig, onChange }: Props) => {
return <FilterSpandrel config={filterConfig} onChange={onChange} />;
}
if (filterConfig.type === 'alpha_to_outline') {
return <FilterAlphaToOutline config={filterConfig} onChange={onChange} />;
}
return (
<IAINoContentFallback
py={4}

View File

@@ -6,7 +6,7 @@ import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/c
import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDuplicate';
import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/common/CanvasEntityMenuItemsFilter';
import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave';
import { CanvasEntityMenuItemsSegment } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSegment';
import { CanvasEntityMenuItemsSelectObject } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSelectObject';
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
import { RasterLayerMenuItemsConvertToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsConvertToSubMenu';
import { RasterLayerMenuItemsCopyToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsCopyToSubMenu';
@@ -23,7 +23,7 @@ export const RasterLayerMenuItems = memo(() => {
<MenuDivider />
<CanvasEntityMenuItemsTransform />
<CanvasEntityMenuItemsFilter />
<CanvasEntityMenuItemsSegment />
<CanvasEntityMenuItemsSelectObject />
<MenuDivider />
<CanvasEntityMenuItemsCropToBbox />
<CanvasEntityMenuItemsSave />

View File

@@ -19,7 +19,8 @@ import { useAppSelector } from 'app/store/storeHooks';
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
import { CanvasAutoProcessSwitch } from 'features/controlLayers/components/CanvasAutoProcessSwitch';
import { CanvasOperationIsolatedLayerPreviewSwitch } from 'features/controlLayers/components/CanvasOperationIsolatedLayerPreviewSwitch';
import { SegmentAnythingPointType } from 'features/controlLayers/components/SegmentAnything/SegmentAnythingPointType';
import { SelectObjectInvert } from 'features/controlLayers/components/SelectObject/SelectObjectInvert';
import { SelectObjectPointType } from 'features/controlLayers/components/SelectObject/SelectObjectPointType';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
@@ -28,9 +29,9 @@ import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/us
import type { PropsWithChildren } from 'react';
import { memo, useCallback, useRef } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import { PiArrowsCounterClockwiseBold, PiFloppyDiskBold, PiInfoBold, PiStarBold, PiXBold } from 'react-icons/pi';
import { PiArrowsCounterClockwiseBold, PiFloppyDiskBold, PiInfoBold, PiPlayFill, PiXBold } from 'react-icons/pi';
const SegmentAnythingContent = memo(
const SelectObjectContent = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const { t } = useTranslation();
const ref = useRef<HTMLDivElement>(null);
@@ -90,9 +91,9 @@ const SegmentAnythingContent = memo(
<Flex w="full" gap={4} alignItems="center">
<Flex gap={2}>
<Heading size="md" color="base.300" userSelect="none">
{t('controlLayers.segment.autoMask')}
{t('controlLayers.selectObject.selectObject')}
</Heading>
<Tooltip label={<SegmentAnythingHelpTooltipContent />}>
<Tooltip label={<SelectObjectHelpTooltipContent />}>
<Flex alignItems="center">
<Icon as={PiInfoBold} color="base.500" />
</Flex>
@@ -103,39 +104,42 @@ const SegmentAnythingContent = memo(
<CanvasOperationIsolatedLayerPreviewSwitch />
</Flex>
<SegmentAnythingPointType adapter={adapter} />
<Flex w="full" justifyContent="space-between" py={2}>
<SelectObjectPointType adapter={adapter} />
<SelectObjectInvert adapter={adapter} />
</Flex>
<ButtonGroup isAttached={false} size="sm" w="full">
<Button
leftIcon={<PiStarBold />}
leftIcon={<PiPlayFill />}
onClick={adapter.segmentAnything.processImmediate}
isLoading={isProcessing}
loadingText={t('controlLayers.segment.process')}
loadingText={t('controlLayers.selectObject.process')}
variant="ghost"
isDisabled={!hasPoints || autoProcess}
>
{t('controlLayers.segment.process')}
{t('controlLayers.selectObject.process')}
</Button>
<Spacer />
<Button
leftIcon={<PiArrowsCounterClockwiseBold />}
onClick={adapter.segmentAnything.reset}
isLoading={isProcessing}
loadingText={t('controlLayers.segment.reset')}
loadingText={t('controlLayers.selectObject.reset')}
variant="ghost"
>
{t('controlLayers.segment.reset')}
{t('controlLayers.selectObject.reset')}
</Button>
<Menu>
<MenuButton
as={Button}
leftIcon={<PiFloppyDiskBold />}
isLoading={isProcessing}
loadingText={t('controlLayers.segment.saveAs')}
loadingText={t('controlLayers.selectObject.saveAs')}
variant="ghost"
isDisabled={!hasImageState}
>
{t('controlLayers.segment.saveAs')}
{t('controlLayers.selectObject.saveAs')}
</MenuButton>
<MenuList>
<MenuItem isDisabled={!hasImageState} onClick={saveAsInpaintMask}>
@@ -159,7 +163,7 @@ const SegmentAnythingContent = memo(
loadingText={t('common.cancel')}
variant="ghost"
>
{t('controlLayers.segment.cancel')}
{t('controlLayers.selectObject.cancel')}
</Button>
</ButtonGroup>
</Flex>
@@ -167,9 +171,9 @@ const SegmentAnythingContent = memo(
}
);
SegmentAnythingContent.displayName = 'SegmentAnythingContent';
SelectObjectContent.displayName = 'SegmentAnythingContent';
export const SegmentAnything = () => {
export const SelectObject = memo(() => {
const canvasManager = useCanvasManager();
const adapter = useStore(canvasManager.stateApi.$segmentingAdapter);
@@ -177,8 +181,10 @@ export const SegmentAnything = () => {
return null;
}
return <SegmentAnythingContent adapter={adapter} />;
};
return <SelectObjectContent adapter={adapter} />;
});
SelectObject.displayName = 'SelectObject';
const Bold = (props: PropsWithChildren) => (
<Text as="span" fontWeight="semibold">
@@ -186,24 +192,27 @@ const Bold = (props: PropsWithChildren) => (
</Text>
);
const SegmentAnythingHelpTooltipContent = memo(() => {
const SelectObjectHelpTooltipContent = memo(() => {
const { t } = useTranslation();
return (
<Flex gap={3} flexDir="column">
<Text>
<Trans i18nKey="controlLayers.segment.help1" components={{ Bold: <Bold /> }} />
<Trans i18nKey="controlLayers.selectObject.help1" components={{ Bold: <Bold /> }} />
</Text>
<Text>
<Trans i18nKey="controlLayers.segment.help2" components={{ Bold: <Bold /> }} />
<Trans i18nKey="controlLayers.selectObject.help2" components={{ Bold: <Bold /> }} />
</Text>
<Text>
<Trans i18nKey="controlLayers.selectObject.help3" />
</Text>
<UnorderedList>
<ListItem>{t('controlLayers.segment.clickToAdd')}</ListItem>
<ListItem>{t('controlLayers.segment.dragToMove')}</ListItem>
<ListItem>{t('controlLayers.segment.clickToRemove')}</ListItem>
<ListItem>{t('controlLayers.selectObject.clickToAdd')}</ListItem>
<ListItem>{t('controlLayers.selectObject.dragToMove')}</ListItem>
<ListItem>{t('controlLayers.selectObject.clickToRemove')}</ListItem>
</UnorderedList>
</Flex>
);
});
SegmentAnythingHelpTooltipContent.displayName = 'SegmentAnythingHelpTooltipContent';
SelectObjectHelpTooltipContent.displayName = 'SelectObjectHelpTooltipContent';

View File

@@ -0,0 +1,26 @@
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const SelectObjectInvert = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const { t } = useTranslation();
const invert = useStore(adapter.segmentAnything.$invert);
const onChange = useCallback(() => {
adapter.segmentAnything.$invert.set(!adapter.segmentAnything.$invert.get());
}, [adapter.segmentAnything.$invert]);
return (
<FormControl w="min-content">
<FormLabel m={0}>{t('controlLayers.selectObject.invertSelection')}</FormLabel>
<Switch size="sm" isChecked={invert} onChange={onChange} />
</FormControl>
);
}
);
SelectObjectInvert.displayName = 'SelectObjectInvert';

View File

@@ -6,7 +6,7 @@ import { SAM_POINT_LABEL_STRING_TO_NUMBER, zSAMPointLabelString } from 'features
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const SegmentAnythingPointType = memo(
export const SelectObjectPointType = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const { t } = useTranslation();
const pointType = useStore(adapter.segmentAnything.$pointTypeString);
@@ -21,15 +21,15 @@ export const SegmentAnythingPointType = memo(
);
return (
<FormControl w="full">
<FormLabel>{t('controlLayers.segment.pointType')}</FormLabel>
<FormControl w="min-content">
<FormLabel m={0}>{t('controlLayers.selectObject.pointType')}</FormLabel>
<RadioGroup value={pointType} onChange={onChange} w="full" size="md">
<Flex alignItems="center" w="full" gap={4} fontWeight="semibold" color="base.300">
<Radio value="foreground">
<Text>{t('controlLayers.segment.include')}</Text>
<Text>{t('controlLayers.selectObject.include')}</Text>
</Radio>
<Radio value="background">
<Text>{t('controlLayers.segment.exclude')}</Text>
<Text>{t('controlLayers.selectObject.exclude')}</Text>
</Radio>
</Flex>
</RadioGroup>
@@ -38,4 +38,4 @@ export const SegmentAnythingPointType = memo(
}
);
SegmentAnythingPointType.displayName = 'SegmentAnythingPointType';
SelectObjectPointType.displayName = 'SelectObject';

View File

@@ -3,7 +3,7 @@ import { useEntityIdentifierContext } from 'features/controlLayers/contexts/Enti
import { useEntityFilter } from 'features/controlLayers/hooks/useEntityFilter';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiShootingStarBold } from 'react-icons/pi';
import { PiShootingStarFill } from 'react-icons/pi';
export const CanvasEntityMenuItemsFilter = memo(() => {
const { t } = useTranslation();
@@ -11,7 +11,7 @@ export const CanvasEntityMenuItemsFilter = memo(() => {
const filter = useEntityFilter(entityIdentifier);
return (
<MenuItem onClick={filter.start} icon={<PiShootingStarBold />} isDisabled={filter.isDisabled}>
<MenuItem onClick={filter.start} icon={<PiShootingStarFill />} isDisabled={filter.isDisabled}>
{t('controlLayers.filter.filter')}
</MenuItem>
);

View File

@@ -3,18 +3,18 @@ import { useEntityIdentifierContext } from 'features/controlLayers/contexts/Enti
import { useEntitySegmentAnything } from 'features/controlLayers/hooks/useEntitySegmentAnything';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiMaskHappyBold } from 'react-icons/pi';
import { PiShapesFill } from 'react-icons/pi';
export const CanvasEntityMenuItemsSegment = memo(() => {
export const CanvasEntityMenuItemsSelectObject = memo(() => {
const { t } = useTranslation();
const entityIdentifier = useEntityIdentifierContext();
const segmentAnything = useEntitySegmentAnything(entityIdentifier);
return (
<MenuItem onClick={segmentAnything.start} icon={<PiMaskHappyBold />} isDisabled={segmentAnything.isDisabled}>
{t('controlLayers.segment.autoMask')}
<MenuItem onClick={segmentAnything.start} icon={<PiShapesFill />} isDisabled={segmentAnything.isDisabled}>
{t('controlLayers.selectObject.selectObject')}
</MenuItem>
);
});
CanvasEntityMenuItemsSegment.displayName = 'CanvasEntityMenuItemsSegment';
CanvasEntityMenuItemsSelectObject.displayName = 'CanvasEntityMenuItemsSelectObject';

View File

@@ -172,6 +172,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
*/
$hasPoints = computed(this.$points, (points) => points.length > 0);
/**
* Whether the module should invert the mask image.
*/
$invert = atom<boolean>(false);
/**
* The masked image object, if it exists.
*/
@@ -456,6 +461,19 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
})
);
// When the invert flag changes, process if autoProcess is enabled
this.subscriptions.add(
this.$invert.listen(() => {
if (this.$points.get().length === 0) {
return;
}
if (this.manager.stateApi.getSettings().autoProcess) {
this.process();
}
})
);
// When auto-process is enabled, process the points if they have not been processed
this.subscriptions.add(
this.manager.stateApi.createStoreSubscription(selectAutoProcess, (autoProcess) => {
@@ -529,7 +547,9 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
return;
}
const hash = stableHash(points);
const invert = this.$invert.get();
const hash = stableHash({ points, invert });
if (hash === this.$lastProcessedHash.get()) {
this.log.trace('Already processed points');
return;
@@ -556,7 +576,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
this.abortController = controller;
// Build the graph for segmenting the image, using the rasterized image DTO
const { graph, outputNodeId } = this.buildGraph(rasterizeResult.value, points);
const { graph, outputNodeId } = CanvasSegmentAnythingModule.buildGraph(rasterizeResult.value, points, invert);
// Run the graph and get the segmented image output
const segmentResult = await withResultAsync(() =>
@@ -793,6 +813,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
this.$points.set([]);
this.$imageState.set(null);
this.$pointType.set(1);
this.$invert.set(false);
this.$lastProcessedHash.set('');
this.$isProcessing.set(false);
@@ -808,7 +829,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
/**
* Builds a graph for segmenting an image with the given image DTO.
*/
buildGraph = ({ image_name }: ImageDTO, points: SAMPointWithId[]): { graph: Graph; outputNodeId: string } => {
static buildGraph = (
{ image_name }: ImageDTO,
points: SAMPointWithId[],
invert: boolean
): { graph: Graph; outputNodeId: string } => {
const graph = new Graph(getPrefixedId('canvas_segment_anything'));
// TODO(psyche): When SAM2 is available in transformers, use it here
@@ -827,6 +852,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
id: getPrefixedId('apply_tensor_mask_to_image'),
type: 'apply_tensor_mask_to_image',
image: { image_name },
invert,
});
graph.addEdge(segmentAnything, 'mask', applyMask, 'mask');

View File

@@ -95,6 +95,13 @@ const zSpandrelFilterConfig = z.object({
});
export type SpandrelFilterConfig = z.infer<typeof zSpandrelFilterConfig>;
const zAlphaToOutlineFilterConfig = z.object({
type: z.literal('alpha_to_outline'),
inner_line_width_percent: z.number().gte(0).lte(100),
outer_line_width_percent: z.number().gte(0).lte(100),
});
export type AlphaToOutlineFilterConfig = z.infer<typeof zAlphaToOutlineFilterConfig>;
const zFilterConfig = z.discriminatedUnion('type', [
zCannyEdgeDetectionFilterConfig,
zColorMapFilterConfig,
@@ -109,6 +116,7 @@ const zFilterConfig = z.discriminatedUnion('type', [
zPiDiNetEdgeDetectionFilterConfig,
zDWOpenposeDetectionFilterConfig,
zSpandrelFilterConfig,
zAlphaToOutlineFilterConfig,
]);
export type FilterConfig = z.infer<typeof zFilterConfig>;
@@ -126,6 +134,7 @@ const zFilterType = z.enum([
'pidi_edge_detection',
'dw_openpose_detection',
'spandrel_filter',
'alpha_to_outline',
]);
export type FilterType = z.infer<typeof zFilterType>;
export const isFilterType = (v: unknown): v is FilterType => zFilterType.safeParse(v).success;
@@ -163,6 +172,28 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData<key
};
},
},
alpha_to_outline: {
type: 'alpha_to_outline',
buildDefaults: () => ({
type: 'alpha_to_outline',
inner_line_width_percent: 5,
outer_line_width_percent: 5,
}),
buildGraph: ({ image_name }, { inner_line_width_percent, outer_line_width_percent }) => {
const graph = new Graph(getPrefixedId('alpha_to_outline_filter'));
const node = graph.addNode({
id: getPrefixedId('alpha_to_outline'),
type: 'img_alpha_to_outline',
image: { image_name },
inner_line_width_percent,
outer_line_width_percent,
});
return {
graph,
outputNodeId: node.id,
};
},
},
color_map: {
type: 'color_map',
buildDefaults: () => ({

File diff suppressed because one or more lines are too long