diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/MetadataUtil.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/MetadataUtil.test.ts new file mode 100644 index 0000000000..ba76e43632 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/MetadataUtil.test.ts @@ -0,0 +1,88 @@ +import { isModelIdentifier } from 'features/nodes/types/common'; +import { Graph } from 'features/nodes/util/graph/Graph'; +import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil'; +import { pick } from 'lodash-es'; +import type { AnyModelConfig } from 'services/api/types'; +import { AssertionError } from 'tsafe'; +import { describe, expect, it } from 'vitest'; + +describe('MetadataUtil', () => { + describe('getNode', () => { + it('should return the metadata node if one exists', () => { + const g = new Graph(); + const metadataNode = g.addNode({ id: MetadataUtil.metadataNodeId, type: 'core_metadata' }); + expect(MetadataUtil.getNode(g)).toEqual(metadataNode); + }); + it('should raise an error if the metadata node does not exist', () => { + const g = new Graph(); + expect(() => MetadataUtil.getNode(g)).toThrowError(AssertionError); + }); + }); + + describe('add', () => { + const g = new Graph(); + it("should add metadata, creating the node if it doesn't exist", () => { + MetadataUtil.add(g, { foo: 'bar' }); + const metadataNode = MetadataUtil.getNode(g); + expect(metadataNode['type']).toBe('core_metadata'); + expect(metadataNode['foo']).toBe('bar'); + }); + it('should update existing metadata keys', () => { + const updatedMetadataNode = MetadataUtil.add(g, { foo: 'bananas', baz: 'qux' }); + expect(updatedMetadataNode['foo']).toBe('bananas'); + expect(updatedMetadataNode['baz']).toBe('qux'); + }); + }); + + describe('remove', () => { + it('should remove a single key', () => { + const g = new Graph(); + MetadataUtil.add(g, { foo: 'bar', baz: 'qux' }); + const updatedMetadataNode = MetadataUtil.remove(g, 'foo'); + expect(updatedMetadataNode['foo']).toBeUndefined(); + expect(updatedMetadataNode['baz']).toBe('qux'); + }); + it('should remove multiple keys', () => { + const g = new Graph(); + MetadataUtil.add(g, { foo: 'bar', baz: 'qux' }); + const updatedMetadataNode = MetadataUtil.remove(g, ['foo', 'baz']); + expect(updatedMetadataNode['foo']).toBeUndefined(); + expect(updatedMetadataNode['baz']).toBeUndefined(); + }); + }); + + describe('setMetadataReceivingNode', () => { + const g = new Graph(); + it('should add an edge from from metadata to the receiving node', () => { + const n = g.addNode({ id: 'my-node', type: 'img_resize' }); + MetadataUtil.add(g, { foo: 'bar' }); + MetadataUtil.setMetadataReceivingNode(g, n.id); + expect(g.hasEdge(MetadataUtil.metadataNodeId, 'metadata', n.id, 'metadata')).toBe(true); + }); + it('should remove existing metadata edges', () => { + const n2 = g.addNode({ id: 'my-other-node', type: 'img_resize' }); + MetadataUtil.setMetadataReceivingNode(g, n2.id); + expect(g.getIncomers(n2.id).length).toBe(1); + expect(g.hasEdge(MetadataUtil.metadataNodeId, 'metadata', n2.id, 'metadata')).toBe(true); + }); + }); + + describe('getModelMetadataField', () => { + it('should return a ModelIdentifierField', () => { + const model: AnyModelConfig = { + key: 'model_key', + type: 'main', + hash: 'model_hash', + base: 'sd-1', + format: 'diffusers', + name: 'my model', + path: '/some/path', + source: 'www.models.com', + source_type: 'url', + }; + const metadataField = MetadataUtil.getModelMetadataField(model); + expect(isModelIdentifier(metadataField)).toBe(true); + expect(pick(model, ['key', 'hash', 'name', 'base', 'type'])).toEqual(metadataField); + }); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/MetadataUtil.ts b/invokeai/frontend/web/src/features/nodes/util/graph/MetadataUtil.ts new file mode 100644 index 0000000000..a51cebd21e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/MetadataUtil.ts @@ -0,0 +1,57 @@ +import type { ModelIdentifierField } from 'features/nodes/types/common'; +import { METADATA } from 'features/nodes/util/graph/constants'; +import { isString, unset } from 'lodash-es'; +import type { AnyModelConfig, Invocation } from 'services/api/types'; + +import type { Graph } from './Graph'; + +export class MetadataUtil { + static metadataNodeId = METADATA; + + static getNode(graph: Graph): Invocation<'core_metadata'> { + return graph.getNode(this.metadataNodeId, 'core_metadata'); + } + + static add(graph: Graph, metadata: Partial>): Invocation<'core_metadata'> { + const metadataNode = graph.getNodeSafe(this.metadataNodeId, 'core_metadata'); + if (!metadataNode) { + return graph.addNode({ + id: this.metadataNodeId, + type: 'core_metadata', + ...metadata, + }); + } else { + return graph.updateNode(this.metadataNodeId, 'core_metadata', metadata); + } + } + + static remove(graph: Graph, key: string): Invocation<'core_metadata'>; + static remove(graph: Graph, keys: string[]): Invocation<'core_metadata'>; + static remove(graph: Graph, keyOrKeys: string | string[]): Invocation<'core_metadata'> { + const metadataNode = this.getNode(graph); + if (isString(keyOrKeys)) { + unset(metadataNode, keyOrKeys); + } else { + for (const key of keyOrKeys) { + unset(metadataNode, key); + } + } + return metadataNode; + } + + static setMetadataReceivingNode(graph: Graph, nodeId: string): void { + // We need to break the rules to update metadata - `addEdge` doesn't allow `core_metadata` as a node type + graph._graph.edges = graph._graph.edges.filter((edge) => edge.source.node_id !== this.metadataNodeId); + graph.addEdge(this.metadataNodeId, 'metadata', nodeId, 'metadata'); + } + + static getModelMetadataField({ key, hash, name, base, type }: AnyModelConfig): ModelIdentifierField { + return { + key, + hash, + name, + base, + type, + }; + } +}