perf(ui): cache image data & transparency mode during generation mode calculation

Perf boost and reduces the number of images we create on the backend.
This commit is contained in:
psychedelicious
2024-10-30 09:57:28 +10:00
parent f4b7c63002
commit 8f02ce54a0
5 changed files with 120 additions and 28 deletions

View File

@@ -1,15 +1,32 @@
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { Transparency } from 'features/controlLayers/konva/util';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { GenerationMode } from 'features/controlLayers/store/types';
import { LRUCache } from 'lru-cache';
import type { Logger } from 'roarr';
type GetCacheEntryWithFallbackArg<T extends NonNullable<unknown>> = {
cache: LRUCache<string, T>;
key: string;
getValue: () => Promise<T>;
onHit?: (value: T) => void;
onMiss?: () => void;
};
type CanvasCacheModuleConfig = {
/**
* The maximum size of the image name cache.
*/
imageNameCacheSize: number;
/**
* The maximum size of the image data cache.
*/
imageDataCacheSize: number;
/**
* The maximum size of the transparency calculation cache.
*/
transparencyCalculationCacheSize: number;
/**
* The maximum size of the canvas element cache.
*/
@@ -21,7 +38,9 @@ type CanvasCacheModuleConfig = {
};
const DEFAULT_CONFIG: CanvasCacheModuleConfig = {
imageNameCacheSize: 100,
imageNameCacheSize: 1000,
imageDataCacheSize: 32,
transparencyCalculationCacheSize: 1000,
canvasElementCacheSize: 32,
generationModeCacheSize: 100,
};
@@ -41,26 +60,38 @@ export class CanvasCacheModule extends CanvasModuleBase {
config: CanvasCacheModuleConfig = DEFAULT_CONFIG;
/**
* A cache for storing image names. Used as a cache for results of layer/canvas/entity exports. For example, when we
* rasterize a layer and upload it to the server, we store the image name in this cache.
* A cache for storing image names.
*
* The cache key is a hash of the exported entity's state and the export rect.
* For example, the key might be a hash of a composite of entities with the uploaded image name as the value.
*/
imageNameCache = new LRUCache<string, string>({ max: this.config.imageNameCacheSize });
/**
* A cache for storing canvas elements. Similar to the image name cache, but for canvas elements. The primary use is
* for caching composite layers. For example, the canvas compositor module uses this to store the canvas elements for
* individual raster layers when creating a composite of the layers.
* A cache for storing canvas elements.
*
* The cache key is a hash of the exported entity's state and the export rect.
* For example, the key might be a hash of a composite of entities with the canvas element as the value.
*/
canvasElementCache = new LRUCache<string, HTMLCanvasElement>({ max: this.config.canvasElementCacheSize });
/**
* A cache for the generation mode calculation, which is fairly expensive.
* A cache for image data objects.
*
* The cache key is a hash of all the objects that contribute to the generation mode calculation (e.g. the composite
* raster layer, the composite inpaint mask, and bounding box), and the value is the generation mode.
* For example, the key might be a hash of a composite of entities with the image data as the value.
*/
imageDataCache = new LRUCache<string, ImageData>({ max: this.config.imageDataCacheSize });
/**
* A cache for transparency calculation results.
*
* For example, the key might be a hash of a composite of entities with the transparency as the value.
*/
transparencyCalculationCache = new LRUCache<string, Transparency>({ max: this.config.imageDataCacheSize });
/**
* A cache for generation mode calculation results.
*
* For example, the key might be a hash of a composite of raster and inpaint mask entities with the generation mode
* as the value.
*/
generationModeCache = new LRUCache<string, GenerationMode>({ max: this.config.generationModeCacheSize });
@@ -75,6 +106,33 @@ export class CanvasCacheModule extends CanvasModuleBase {
this.log.debug('Creating cache module');
}
/**
* A helper function for getting a cache entry with a fallback.
* @param param0.cache The LRUCache to get the entry from.
* @param param0.key The key to use to retrieve the entry.
* @param param0.getValue An async function to generate the value if the entry is not in the cache.
* @param param0.onHit An optional function to call when the entry is in the cache.
* @param param0.onMiss An optional function to call when the entry is not in the cache.
* @returns
*/
static getWithFallback = async <T extends NonNullable<unknown>>({
cache,
getValue,
key,
onHit,
onMiss,
}: GetCacheEntryWithFallbackArg<T>): Promise<T> => {
let value = cache.get(key);
if (value === undefined) {
onMiss?.();
value = await getValue();
cache.set(key, value);
} else {
onHit?.(value);
}
return value;
};
/**
* Clears all caches.
*/

View File

@@ -1,8 +1,10 @@
import type { SerializableObject } from 'common/types';
import { withResultAsync } from 'common/util/result';
import { CanvasCacheModule } from 'features/controlLayers/konva/CanvasCacheModule';
import type { CanvasEntityAdapter, CanvasEntityAdapterFromType } from 'features/controlLayers/konva/CanvasEntity/types';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { Transparency } from 'features/controlLayers/konva/util';
import {
canvasToBlob,
canvasToImageData,
@@ -415,6 +417,38 @@ export class CanvasCompositorModule extends CanvasModuleBase {
return this.mergeByEntityIdentifiers(entityIdentifiers, false);
};
/**
* Calculates the transparency of the composite of the give adapters.
* @param adapters The adapters to composite
* @param rect The region to include in the composite
* @param hash The hash to use for caching the result
* @returns A promise that resolves to the transparency of the composite
*/
getTransparency = (adapters: CanvasEntityAdapter[], rect: Rect, hash: string): Promise<Transparency> => {
const entityIdentifiers = adapters.map((adapter) => adapter.entityIdentifier);
const logCtx = { entityIdentifiers, rect };
return CanvasCacheModule.getWithFallback({
cache: this.manager.cache.transparencyCalculationCache,
key: hash,
getValue: async () => {
this.$isProcessing.set(true);
const compositeInpaintMaskCanvas = this.getCompositeCanvas(adapters, rect);
const compositeInpaintMaskImageData = await CanvasCacheModule.getWithFallback({
cache: this.manager.cache.imageDataCache,
key: hash,
getValue: () => Promise.resolve(canvasToImageData(compositeInpaintMaskCanvas)),
onHit: () => this.log.trace(logCtx, 'Using cached image data'),
onMiss: () => this.log.trace(logCtx, 'Calculating image data'),
});
return getImageDataTransparency(compositeInpaintMaskImageData);
},
onHit: () => this.log.trace(logCtx, 'Using cached transparency'),
onMiss: () => this.log.trace(logCtx, 'Calculating transparency'),
});
};
/**
* Calculates the generation mode for the current canvas state. This is determined by the transparency of the
* composite raster layer and composite inpaint mask:
@@ -433,11 +467,11 @@ export class CanvasCompositorModule extends CanvasModuleBase {
*
* @returns The generation mode
*/
getGenerationMode(): GenerationMode {
getGenerationMode = async (): Promise<GenerationMode> => {
const { rect } = this.manager.stateApi.getBbox();
const rasterAdapters = this.manager.compositor.getVisibleAdaptersOfType('raster_layer');
const compositeRasterLayerHash = this.getCompositeHash(rasterAdapters, { rect });
const rasterLayerAdapters = this.manager.compositor.getVisibleAdaptersOfType('raster_layer');
const compositeRasterLayerHash = this.getCompositeHash(rasterLayerAdapters, { rect });
const inpaintMaskAdapters = this.manager.compositor.getVisibleAdaptersOfType('inpaint_mask');
const compositeInpaintMaskHash = this.getCompositeHash(inpaintMaskAdapters, { rect });
@@ -452,17 +486,17 @@ export class CanvasCompositorModule extends CanvasModuleBase {
this.log.debug({ rect }, 'Calculating generation mode');
const compositeInpaintMaskCanvas = this.getCompositeCanvas(inpaintMaskAdapters, rect);
this.$isProcessing.set(true);
const compositeInpaintMaskImageData = canvasToImageData(compositeInpaintMaskCanvas);
const compositeInpaintMaskTransparency = getImageDataTransparency(compositeInpaintMaskImageData);
this.$isProcessing.set(false);
const compositeRasterLayerTransparency = await this.getTransparency(
rasterLayerAdapters,
rect,
compositeRasterLayerHash
);
const compositeRasterLayerCanvas = this.getCompositeCanvas(rasterAdapters, rect);
this.$isProcessing.set(true);
const compositeRasterLayerImageData = canvasToImageData(compositeRasterLayerCanvas);
const compositeRasterLayerTransparency = getImageDataTransparency(compositeRasterLayerImageData);
this.$isProcessing.set(false);
const compositeInpaintMaskTransparency = await this.getTransparency(
inpaintMaskAdapters,
rect,
compositeInpaintMaskHash
);
let generationMode: GenerationMode;
if (compositeRasterLayerTransparency === 'FULLY_TRANSPARENT') {
@@ -482,7 +516,7 @@ export class CanvasCompositorModule extends CanvasModuleBase {
this.manager.cache.generationModeCache.set(hash, generationMode);
return generationMode;
}
};
repr = () => {
return {

View File

@@ -34,7 +34,7 @@ export const buildFLUXGraph = async (
state: RootState,
manager: CanvasManager
): Promise<{ g: Graph; noise: Invocation<'noise' | 'flux_denoise'>; posCond: Invocation<'flux_text_encoder'> }> => {
const generationMode = manager.compositor.getGenerationMode();
const generationMode = await manager.compositor.getGenerationMode();
log.debug({ generationMode }, 'Building FLUX graph');
const params = selectParamsSlice(state);

View File

@@ -37,7 +37,7 @@ export const buildSD1Graph = async (
state: RootState,
manager: CanvasManager
): Promise<{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'compel'> }> => {
const generationMode = manager.compositor.getGenerationMode();
const generationMode = await manager.compositor.getGenerationMode();
log.debug({ generationMode }, 'Building SD1/SD2 graph');
const params = selectParamsSlice(state);

View File

@@ -37,7 +37,7 @@ export const buildSDXLGraph = async (
state: RootState,
manager: CanvasManager
): Promise<{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'sdxl_compel_prompt'> }> => {
const generationMode = manager.compositor.getGenerationMode();
const generationMode = await manager.compositor.getGenerationMode();
log.debug({ generationMode }, 'Building SDXL graph');
const params = selectParamsSlice(state);