mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): add utils for getting images from canvas
This commit is contained in:
@@ -1,96 +1,9 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||
import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer';
|
||||
import { blobToDataURL } from 'features/controlLayers/konva/util';
|
||||
import { baseLayerImageCacheChanged } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import type { LayerEntity } from 'features/controlLayers/store/types';
|
||||
import type Konva from 'konva';
|
||||
import type { IRect } from 'konva/lib/types';
|
||||
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const isValidLayer = (entity: LayerEntity) => {
|
||||
export const isValidLayer = (entity: LayerEntity) => {
|
||||
return (
|
||||
entity.isEnabled &&
|
||||
// Boolean(entity.bbox) && TODO(psyche): Re-enable this check when we have a way to calculate bbox for all layers
|
||||
entity.objects.length > 0
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Get the blobs of all regional prompt layers. Only visible layers are returned.
|
||||
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
|
||||
* @param preview Whether to open a new tab displaying each layer.
|
||||
* @returns A map of layer IDs to blobs.
|
||||
*/
|
||||
|
||||
const getBaseLayer = async (layers: LayerEntity[], bbox: IRect, preview: boolean = false): Promise<Blob> => {
|
||||
const manager = $nodeManager.get();
|
||||
assert(manager, 'Node manager is null');
|
||||
|
||||
const stage = manager.stage.clone();
|
||||
|
||||
stage.scaleX(1);
|
||||
stage.scaleY(1);
|
||||
stage.x(0);
|
||||
stage.y(0);
|
||||
|
||||
const validLayers = layers.filter(isValidLayer);
|
||||
|
||||
// Konva bug (?) - when iterating over the array returned from `stage.getLayers()`, if you destroy a layer, the array
|
||||
// is mutated in-place and the next iteration will skip the next layer. To avoid this, we first collect the layers
|
||||
// to delete in a separate array and then destroy them.
|
||||
// TODO(psyche): Maybe report this?
|
||||
const toDelete: Konva.Layer[] = [];
|
||||
|
||||
for (const konvaLayer of stage.getLayers()) {
|
||||
const layer = validLayers.find((l) => l.id === konvaLayer.id());
|
||||
if (!layer) {
|
||||
toDelete.push(konvaLayer);
|
||||
}
|
||||
}
|
||||
|
||||
for (const konvaLayer of toDelete) {
|
||||
konvaLayer.destroy();
|
||||
}
|
||||
|
||||
const blob = await new Promise<Blob>((resolve) => {
|
||||
stage.toBlob({
|
||||
callback: (blob) => {
|
||||
assert(blob, 'Blob is null');
|
||||
resolve(blob);
|
||||
},
|
||||
...bbox,
|
||||
});
|
||||
});
|
||||
|
||||
if (preview) {
|
||||
const base64 = await blobToDataURL(blob);
|
||||
openBase64ImageInTab([{ base64, caption: 'base layer' }]);
|
||||
}
|
||||
|
||||
stage.destroy();
|
||||
|
||||
return blob;
|
||||
};
|
||||
|
||||
export const getBaseLayerImage = async (): Promise<ImageDTO> => {
|
||||
const { dispatch, getState } = getStore();
|
||||
const state = getState();
|
||||
if (state.canvasV2.layers.baseLayerImageCache) {
|
||||
const imageDTO = await getImageDTO(state.canvasV2.layers.baseLayerImageCache.name);
|
||||
if (imageDTO) {
|
||||
return imageDTO;
|
||||
}
|
||||
}
|
||||
const blob = await getBaseLayer(state.canvasV2.layers.entities, state.canvasV2.bbox, true);
|
||||
const file = new File([blob], 'image.png', { type: 'image/png' });
|
||||
const req = dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'general', is_intermediate: true })
|
||||
);
|
||||
req.reset();
|
||||
const imageDTO = await req.unwrap();
|
||||
dispatch(baseLayerImageCacheChanged(imageDTO));
|
||||
return imageDTO;
|
||||
};
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||
import type { KonvaEntityAdapter } from 'features/controlLayers/konva/nodeManager';
|
||||
import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer';
|
||||
import { blobToDataURL } from 'features/controlLayers/konva/util';
|
||||
import { rgMaskImageUploaded } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager';
|
||||
import type { Dimensions, IPAdapterEntity, RegionEntity } from 'features/controlLayers/store/types';
|
||||
import {
|
||||
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
|
||||
@@ -16,8 +11,7 @@ import {
|
||||
import { addIPAdapterCollectorSafe, isValidIPAdapter } from 'features/nodes/util/graph/generation/addIPAdapters';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { IRect } from 'konva/lib/types';
|
||||
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
|
||||
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
|
||||
import type { BaseModelType, Invocation } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
/**
|
||||
@@ -34,6 +28,7 @@ import { assert } from 'tsafe';
|
||||
*/
|
||||
|
||||
export const addRegions = async (
|
||||
manager: KonvaNodeManager,
|
||||
regions: RegionEntity[],
|
||||
g: Graph,
|
||||
documentSize: Dimensions,
|
||||
@@ -51,7 +46,7 @@ export const addRegions = async (
|
||||
|
||||
for (const region of validRegions) {
|
||||
// Upload the mask image, or get the cached image if it exists
|
||||
const { image_name } = await getRegionMaskImage(region, bbox, true);
|
||||
const { image_name } = await manager.util.getRegionMaskImage({ id: region.id, bbox, preview: true });
|
||||
|
||||
// The main mask-to-tensor node
|
||||
const maskToTensor = g.addNode({
|
||||
@@ -217,90 +212,3 @@ export const isValidRegion = (rg: RegionEntity, base: BaseModelType) => {
|
||||
const hasIPAdapter = rg.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)).length > 0;
|
||||
return hasTextPrompt || hasIPAdapter;
|
||||
};
|
||||
|
||||
export const getMaskImage = async (rg: RegionEntity, blob: Blob): Promise<ImageDTO> => {
|
||||
const { id, imageCache } = rg;
|
||||
if (imageCache) {
|
||||
const imageDTO = await getImageDTO(imageCache.name);
|
||||
if (imageDTO) {
|
||||
return imageDTO;
|
||||
}
|
||||
}
|
||||
const { dispatch } = getStore();
|
||||
// No cached mask, or the cached image no longer exists - we need to upload the mask image
|
||||
const file = new File([blob], `${rg.id}_mask.png`, { type: 'image/png' });
|
||||
const req = dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
|
||||
);
|
||||
req.reset();
|
||||
|
||||
const imageDTO = await req.unwrap();
|
||||
dispatch(rgMaskImageUploaded({ id, imageDTO }));
|
||||
return imageDTO;
|
||||
};
|
||||
|
||||
export const uploadMaskImage = async ({ id }: RegionEntity, blob: Blob): Promise<ImageDTO> => {
|
||||
const { dispatch } = getStore();
|
||||
// No cached mask, or the cached image no longer exists - we need to upload the mask image
|
||||
const file = new File([blob], `${id}_mask.png`, { type: 'image/png' });
|
||||
const req = dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
|
||||
);
|
||||
req.reset();
|
||||
|
||||
const imageDTO = await req.unwrap();
|
||||
dispatch(rgMaskImageUploaded({ id, imageDTO }));
|
||||
return imageDTO;
|
||||
};
|
||||
|
||||
/**
|
||||
* Get the blobs of all regional prompt layers. Only visible layers are returned.
|
||||
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
|
||||
* @param preview Whether to open a new tab displaying each layer.
|
||||
* @returns A map of layer IDs to blobs.
|
||||
*/
|
||||
|
||||
export const getRegionMaskImage = async (
|
||||
region: RegionEntity,
|
||||
bbox: IRect,
|
||||
preview: boolean = false
|
||||
): Promise<ImageDTO> => {
|
||||
const manager = $nodeManager.get();
|
||||
assert(manager, 'Node manager is null');
|
||||
|
||||
// TODO(psyche): Why do I need to annotate this? TS must have some kind of circular ref w/ this type but I can't figure it out...
|
||||
const adapter: KonvaEntityAdapter | undefined = manager.get(region.id);
|
||||
assert(adapter, `Adapter for region ${region.id} not found`);
|
||||
if (region.imageCache) {
|
||||
const imageDTO = await getImageDTO(region.imageCache.name);
|
||||
if (imageDTO) {
|
||||
return imageDTO;
|
||||
}
|
||||
}
|
||||
const layer = adapter.konvaLayer.clone();
|
||||
const objectGroup = adapter.konvaObjectGroup.clone();
|
||||
layer.destroyChildren();
|
||||
layer.add(objectGroup);
|
||||
objectGroup.opacity(1);
|
||||
objectGroup.cache();
|
||||
|
||||
const blob = await new Promise<Blob>((resolve) => {
|
||||
layer.toBlob({
|
||||
callback: (blob) => {
|
||||
assert(blob, 'Blob is null');
|
||||
resolve(blob);
|
||||
},
|
||||
...bbox,
|
||||
});
|
||||
});
|
||||
|
||||
if (preview) {
|
||||
const base64 = await blobToDataURL(blob);
|
||||
const caption = `${region.id}: ${region.positivePrompt} / ${region.negativePrompt}`;
|
||||
openBase64ImageInTab([{ base64, caption }]);
|
||||
}
|
||||
|
||||
layer.destroy();
|
||||
|
||||
return await uploadMaskImage(region, blob);
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user