feat(ui): fix dynamic prompts generators (but break readiness checks)

This commit is contained in:
psychedelicious
2025-02-26 07:49:49 +10:00
parent d037eea42a
commit 43349cb5ce
5 changed files with 105 additions and 93 deletions

View File

@@ -1,10 +1,8 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { GeneratorTextareaWithFileUpload } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/GeneratorTextareaWithFileUpload';
import type { StringGeneratorDynamicPromptsCombinatorial } from 'features/nodes/types/field';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useDynamicPromptsQuery } from 'services/api/endpoints/utilities';
import { useDebounce } from 'use-debounce';
type StringGeneratorDynamicPromptsCombinatorialSettingsProps = {
state: StringGeneratorDynamicPromptsCombinatorial;
@@ -13,41 +11,20 @@ type StringGeneratorDynamicPromptsCombinatorialSettingsProps = {
export const StringGeneratorDynamicPromptsCombinatorialSettings = memo(
({ state, onChange }: StringGeneratorDynamicPromptsCombinatorialSettingsProps) => {
const { t } = useTranslation();
const loadingValues = useMemo(() => [`<${t('nodes.generatorLoading')}>`], [t]);
const onChangeInput = useCallback(
(input: string) => {
onChange({ ...state, input, values: loadingValues });
onChange({ ...state, input });
},
[onChange, state, loadingValues]
[onChange, state]
);
const onChangeMaxPrompts = useCallback(
(v: number) => {
onChange({ ...state, maxPrompts: v, values: loadingValues });
onChange({ ...state, maxPrompts: v });
},
[onChange, state, loadingValues]
[onChange, state]
);
const arg = useMemo(() => {
return { prompt: state.input, max_prompts: state.maxPrompts, combinatorial: true };
}, [state]);
const [debouncedArg] = useDebounce(arg, 300);
const { data, isLoading } = useDynamicPromptsQuery(debouncedArg);
useEffect(() => {
if (isLoading) {
return;
}
if (!data) {
onChange({ ...state, values: [] });
return;
}
onChange({ ...state, values: data.prompts });
}, [data, isLoading, onChange, state]);
return (
<Flex gap={2} flexDir="column">
<FormControl orientation="vertical">

View File

@@ -1,11 +1,9 @@
import { Checkbox, CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { GeneratorTextareaWithFileUpload } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/GeneratorTextareaWithFileUpload';
import type { StringGeneratorDynamicPromptsRandom } from 'features/nodes/types/field';
import { isNil, random } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { isNil } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useDynamicPromptsQuery } from 'services/api/endpoints/utilities';
import { useDebounce } from 'use-debounce';
type StringGeneratorDynamicPromptsRandomSettingsProps = {
state: StringGeneratorDynamicPromptsRandom;
@@ -14,51 +12,29 @@ type StringGeneratorDynamicPromptsRandomSettingsProps = {
export const StringGeneratorDynamicPromptsRandomSettings = memo(
({ state, onChange }: StringGeneratorDynamicPromptsRandomSettingsProps) => {
const { t } = useTranslation();
const loadingValues = useMemo(() => [`<${t('nodes.generatorLoading')}>`], [t]);
const onChangeInput = useCallback(
(input: string) => {
onChange({ ...state, input, values: loadingValues });
onChange({ ...state, input });
},
[onChange, state, loadingValues]
[onChange, state]
);
const onChangeCount = useCallback(
(v: number) => {
onChange({ ...state, count: v, values: loadingValues });
onChange({ ...state, count: v });
},
[onChange, state, loadingValues]
[onChange, state]
);
const onToggleSeed = useCallback(() => {
onChange({ ...state, seed: isNil(state.seed) ? 0 : null, values: loadingValues });
}, [onChange, state, loadingValues]);
onChange({ ...state, seed: isNil(state.seed) ? 0 : null });
}, [onChange, state]);
const onChangeSeed = useCallback(
(seed?: number | null) => {
onChange({ ...state, seed, values: loadingValues });
onChange({ ...state, seed });
},
[onChange, state, loadingValues]
[onChange, state]
);
const arg = useMemo(() => {
return { prompt: state.input, max_prompts: state.count, combinatorial: false, seed: state.seed ?? random() };
}, [state.count, state.input, state.seed]);
const [debouncedArg] = useDebounce(arg, 300);
const { data, isLoading } = useDynamicPromptsQuery(debouncedArg);
useEffect(() => {
if (isLoading) {
return;
}
if (!data) {
onChange({ ...state, values: [] });
return;
}
onChange({ ...state, values: data.prompts });
}, [data, isLoading, onChange, state]);
return (
<Flex gap={2} flexDir="column">
<Flex gap={2}>

View File

@@ -1,4 +1,5 @@
import { Flex, Select, Text } from '@invoke-ai/ui-library';
import { useAppStore } from 'app/store/nanostores/store';
import { useAppDispatch } from 'app/store/storeHooks';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
import { StringGeneratorDynamicPromptsCombinatorialSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorDynamicPromptsCombinatorialSettings';
@@ -15,12 +16,11 @@ import {
StringGeneratorDynamicPromptsRandomType,
StringGeneratorParseStringType,
} from 'features/nodes/types/field';
import { isNil } from 'lodash-es';
import { debounce } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useDebounce } from 'use-debounce';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
@@ -29,6 +29,7 @@ export const StringGeneratorFieldInputComponent = memo(
const { nodeId, field } = props;
const { t } = useTranslation();
const dispatch = useAppDispatch();
const store = useAppStore();
const onChange = useCallback(
(value: StringGeneratorFieldInputInstance['value']) => {
@@ -57,20 +58,22 @@ export const StringGeneratorFieldInputComponent = memo(
[dispatch, field.name, nodeId]
);
const [debouncedField] = useDebounce(field, 300);
const resolvedValuesAsString = useMemo(() => {
if (debouncedField.value.type === StringGeneratorDynamicPromptsRandomType && isNil(debouncedField.value.seed)) {
const { count } = debouncedField.value;
return `<${t('nodes.generatorNRandomValues', { count })}>`;
}
const resolvedValues = resolveStringGeneratorField(debouncedField);
if (resolvedValues.length === 0) {
return `<${t('nodes.generatorNoValues')}>`;
} else {
return resolvedValues.join(', ');
}
}, [debouncedField, t]);
const [resolvedValuesAsString, setResolvedValuesAsString] = useState<string | null>(null);
const resolveAndSetValuesAsString = useMemo(
() =>
debounce(async (field: StringGeneratorFieldInputInstance) => {
const resolvedValues = await resolveStringGeneratorField(field, store);
if (resolvedValues.length === 0) {
setResolvedValuesAsString(`<${t('nodes.generatorNoValues')}>`);
} else {
setResolvedValuesAsString(resolvedValues.join(', '));
}
}, 300),
[store, t]
);
useEffect(() => {
resolveAndSetValuesAsString(field);
}, [field, resolveAndSetValuesAsString]);
return (
<Flex flexDir="column" gap={2}>
@@ -81,10 +84,10 @@ export const StringGeneratorFieldInputComponent = memo(
size="sm"
>
<option value={StringGeneratorParseStringType}>{t('nodes.parseString')}</option>
{/* <option value={StringGeneratorDynamicPromptsRandomType}>{t('nodes.dynamicPromptsRandom')}</option>
<option value={StringGeneratorDynamicPromptsRandomType}>{t('nodes.dynamicPromptsRandom')}</option>
<option value={StringGeneratorDynamicPromptsCombinatorialType}>
{t('nodes.dynamicPromptsCombinatorial')}
</option> */}
</option>
</Select>
{field.value.type === StringGeneratorParseStringType && (
<StringGeneratorParseStringSettings state={field.value} onChange={onChange} />

View File

@@ -1,5 +1,8 @@
import { isNil, trim } from 'lodash-es';
import { EMPTY_ARRAY } from 'app/store/constants';
import type { AppStore } from 'app/store/store';
import { isNil, random, trim } from 'lodash-es';
import MersenneTwister from 'mtwist';
import { utilitiesApi } from 'services/api/endpoints/utilities';
import { assert } from 'tsafe';
import { z } from 'zod';
@@ -1409,9 +1412,32 @@ const zStringGeneratorDynamicPromptsCombinatorial = z.object({
export type StringGeneratorDynamicPromptsCombinatorial = z.infer<typeof zStringGeneratorDynamicPromptsCombinatorial>;
const getStringGeneratorDynamicPromptsCombinatorialDefaults = () =>
zStringGeneratorDynamicPromptsCombinatorial.parse({});
const getStringGeneratorDynamicPromptsCombinatorialValues = (generator: StringGeneratorDynamicPromptsCombinatorial) => {
const { values } = generator;
return values ?? [];
const getStringGeneratorDynamicPromptsCombinatorialValues = async (
generator: StringGeneratorDynamicPromptsCombinatorial,
store: AppStore
): Promise<string[]> => {
const { input, maxPrompts } = generator;
const req = store.dispatch(
utilitiesApi.endpoints.dynamicPrompts.initiate(
{
prompt: input,
max_prompts: maxPrompts,
combinatorial: true,
},
{
subscribe: false,
}
)
);
try {
const { prompts, error } = await req.unwrap();
if (error) {
return EMPTY_ARRAY;
}
return prompts;
} catch {
return EMPTY_ARRAY;
}
};
export const StringGeneratorDynamicPromptsRandomType = 'string_generator_dynamic_prompts_random';
@@ -1424,9 +1450,33 @@ const zStringGeneratorDynamicPromptsRandom = z.object({
});
export type StringGeneratorDynamicPromptsRandom = z.infer<typeof zStringGeneratorDynamicPromptsRandom>;
const getStringGeneratorDynamicPromptsRandomDefaults = () => zStringGeneratorDynamicPromptsRandom.parse({});
const getStringGeneratorDynamicPromptsRandomValues = (generator: StringGeneratorDynamicPromptsRandom) => {
const { values } = generator;
return values ?? [];
const getStringGeneratorDynamicPromptsRandomValues = async (
generator: StringGeneratorDynamicPromptsRandom,
store: AppStore
): Promise<string[]> => {
const { input, seed, count } = generator;
const req = store.dispatch(
utilitiesApi.endpoints.dynamicPrompts.initiate(
{
prompt: input,
max_prompts: count,
combinatorial: false,
seed: seed ?? random(),
},
{
subscribe: false,
}
)
);
try {
const { prompts, error } = await req.unwrap();
if (error) {
return EMPTY_ARRAY;
}
return prompts;
} catch {
return EMPTY_ARRAY;
}
};
export const zStringGeneratorFieldValue = z.union([
@@ -1453,7 +1503,7 @@ export const isStringGeneratorFieldInputTemplate = buildTemplateTypeGuard<String
zStringGeneratorFieldType.shape.name.value
);
export const resolveStringGeneratorField = ({ value }: StringGeneratorFieldInputInstance) => {
export const resolveStringGeneratorField = async ({ value }: StringGeneratorFieldInputInstance, store: AppStore) => {
if (value.values) {
return value.values;
}
@@ -1461,10 +1511,10 @@ export const resolveStringGeneratorField = ({ value }: StringGeneratorFieldInput
return getStringGeneratorParseStringValues(value);
}
if (value.type === StringGeneratorDynamicPromptsRandomType) {
return getStringGeneratorDynamicPromptsRandomValues(value);
return await getStringGeneratorDynamicPromptsRandomValues(value, store);
}
if (value.type === StringGeneratorDynamicPromptsCombinatorialType) {
return getStringGeneratorDynamicPromptsCombinatorialValues(value);
return await getStringGeneratorDynamicPromptsCombinatorialValues(value, store);
}
assert(false, 'Invalid string generator type');
};

View File

@@ -1,3 +1,4 @@
import type { AppStore } from 'app/store/store';
import {
isFloatFieldCollectionInputInstance,
isFloatGeneratorFieldInputInstance,
@@ -13,7 +14,12 @@ import {
import type { AnyEdge, InvocationNode } from 'features/nodes/types/invocation';
import { assert } from 'tsafe';
export const resolveBatchValue = (batchNode: InvocationNode, nodes: InvocationNode[], edges: AnyEdge[]) => {
export const resolveBatchValue = async (
batchNode: InvocationNode,
nodes: InvocationNode[],
edges: AnyEdge[],
store: AppStore
) => {
if (batchNode.data.type === 'image_batch') {
assert(isImageFieldCollectionInputInstance(batchNode.data.inputs.images));
const ownValue = batchNode.data.inputs.images.value ?? [];
@@ -34,7 +40,7 @@ export const resolveBatchValue = (batchNode: InvocationNode, nodes: InvocationNo
const generatorField = generatorNode.data.inputs['generator'];
assert(isStringGeneratorFieldInputInstance(generatorField), 'Invalid string generator');
const generatorValue = resolveStringGeneratorField(generatorField);
const generatorValue = await resolveStringGeneratorField(generatorField, store);
return generatorValue;
} else if (batchNode.data.type === 'float_batch') {
assert(isFloatFieldCollectionInputInstance(batchNode.data.inputs.floats));