mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): fix dynamic prompts generators (but break readiness checks)
This commit is contained in:
@@ -1,10 +1,8 @@
|
||||
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { GeneratorTextareaWithFileUpload } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/GeneratorTextareaWithFileUpload';
|
||||
import type { StringGeneratorDynamicPromptsCombinatorial } from 'features/nodes/types/field';
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDynamicPromptsQuery } from 'services/api/endpoints/utilities';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
|
||||
type StringGeneratorDynamicPromptsCombinatorialSettingsProps = {
|
||||
state: StringGeneratorDynamicPromptsCombinatorial;
|
||||
@@ -13,41 +11,20 @@ type StringGeneratorDynamicPromptsCombinatorialSettingsProps = {
|
||||
export const StringGeneratorDynamicPromptsCombinatorialSettings = memo(
|
||||
({ state, onChange }: StringGeneratorDynamicPromptsCombinatorialSettingsProps) => {
|
||||
const { t } = useTranslation();
|
||||
const loadingValues = useMemo(() => [`<${t('nodes.generatorLoading')}>`], [t]);
|
||||
|
||||
const onChangeInput = useCallback(
|
||||
(input: string) => {
|
||||
onChange({ ...state, input, values: loadingValues });
|
||||
onChange({ ...state, input });
|
||||
},
|
||||
[onChange, state, loadingValues]
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeMaxPrompts = useCallback(
|
||||
(v: number) => {
|
||||
onChange({ ...state, maxPrompts: v, values: loadingValues });
|
||||
onChange({ ...state, maxPrompts: v });
|
||||
},
|
||||
[onChange, state, loadingValues]
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
const arg = useMemo(() => {
|
||||
return { prompt: state.input, max_prompts: state.maxPrompts, combinatorial: true };
|
||||
}, [state]);
|
||||
const [debouncedArg] = useDebounce(arg, 300);
|
||||
|
||||
const { data, isLoading } = useDynamicPromptsQuery(debouncedArg);
|
||||
|
||||
useEffect(() => {
|
||||
if (isLoading) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!data) {
|
||||
onChange({ ...state, values: [] });
|
||||
return;
|
||||
}
|
||||
|
||||
onChange({ ...state, values: data.prompts });
|
||||
}, [data, isLoading, onChange, state]);
|
||||
|
||||
return (
|
||||
<Flex gap={2} flexDir="column">
|
||||
<FormControl orientation="vertical">
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import { Checkbox, CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { GeneratorTextareaWithFileUpload } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/GeneratorTextareaWithFileUpload';
|
||||
import type { StringGeneratorDynamicPromptsRandom } from 'features/nodes/types/field';
|
||||
import { isNil, random } from 'lodash-es';
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDynamicPromptsQuery } from 'services/api/endpoints/utilities';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
|
||||
type StringGeneratorDynamicPromptsRandomSettingsProps = {
|
||||
state: StringGeneratorDynamicPromptsRandom;
|
||||
@@ -14,51 +12,29 @@ type StringGeneratorDynamicPromptsRandomSettingsProps = {
|
||||
export const StringGeneratorDynamicPromptsRandomSettings = memo(
|
||||
({ state, onChange }: StringGeneratorDynamicPromptsRandomSettingsProps) => {
|
||||
const { t } = useTranslation();
|
||||
const loadingValues = useMemo(() => [`<${t('nodes.generatorLoading')}>`], [t]);
|
||||
|
||||
const onChangeInput = useCallback(
|
||||
(input: string) => {
|
||||
onChange({ ...state, input, values: loadingValues });
|
||||
onChange({ ...state, input });
|
||||
},
|
||||
[onChange, state, loadingValues]
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeCount = useCallback(
|
||||
(v: number) => {
|
||||
onChange({ ...state, count: v, values: loadingValues });
|
||||
onChange({ ...state, count: v });
|
||||
},
|
||||
[onChange, state, loadingValues]
|
||||
[onChange, state]
|
||||
);
|
||||
const onToggleSeed = useCallback(() => {
|
||||
onChange({ ...state, seed: isNil(state.seed) ? 0 : null, values: loadingValues });
|
||||
}, [onChange, state, loadingValues]);
|
||||
onChange({ ...state, seed: isNil(state.seed) ? 0 : null });
|
||||
}, [onChange, state]);
|
||||
const onChangeSeed = useCallback(
|
||||
(seed?: number | null) => {
|
||||
onChange({ ...state, seed, values: loadingValues });
|
||||
onChange({ ...state, seed });
|
||||
},
|
||||
[onChange, state, loadingValues]
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
const arg = useMemo(() => {
|
||||
return { prompt: state.input, max_prompts: state.count, combinatorial: false, seed: state.seed ?? random() };
|
||||
}, [state.count, state.input, state.seed]);
|
||||
|
||||
const [debouncedArg] = useDebounce(arg, 300);
|
||||
|
||||
const { data, isLoading } = useDynamicPromptsQuery(debouncedArg);
|
||||
|
||||
useEffect(() => {
|
||||
if (isLoading) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!data) {
|
||||
onChange({ ...state, values: [] });
|
||||
return;
|
||||
}
|
||||
|
||||
onChange({ ...state, values: data.prompts });
|
||||
}, [data, isLoading, onChange, state]);
|
||||
|
||||
return (
|
||||
<Flex gap={2} flexDir="column">
|
||||
<Flex gap={2}>
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { Flex, Select, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
|
||||
import { StringGeneratorDynamicPromptsCombinatorialSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorDynamicPromptsCombinatorialSettings';
|
||||
@@ -15,12 +16,11 @@ import {
|
||||
StringGeneratorDynamicPromptsRandomType,
|
||||
StringGeneratorParseStringType,
|
||||
} from 'features/nodes/types/field';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { debounce } from 'lodash-es';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
|
||||
|
||||
@@ -29,6 +29,7 @@ export const StringGeneratorFieldInputComponent = memo(
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const store = useAppStore();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: StringGeneratorFieldInputInstance['value']) => {
|
||||
@@ -57,20 +58,22 @@ export const StringGeneratorFieldInputComponent = memo(
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const [debouncedField] = useDebounce(field, 300);
|
||||
const resolvedValuesAsString = useMemo(() => {
|
||||
if (debouncedField.value.type === StringGeneratorDynamicPromptsRandomType && isNil(debouncedField.value.seed)) {
|
||||
const { count } = debouncedField.value;
|
||||
return `<${t('nodes.generatorNRandomValues', { count })}>`;
|
||||
}
|
||||
|
||||
const resolvedValues = resolveStringGeneratorField(debouncedField);
|
||||
if (resolvedValues.length === 0) {
|
||||
return `<${t('nodes.generatorNoValues')}>`;
|
||||
} else {
|
||||
return resolvedValues.join(', ');
|
||||
}
|
||||
}, [debouncedField, t]);
|
||||
const [resolvedValuesAsString, setResolvedValuesAsString] = useState<string | null>(null);
|
||||
const resolveAndSetValuesAsString = useMemo(
|
||||
() =>
|
||||
debounce(async (field: StringGeneratorFieldInputInstance) => {
|
||||
const resolvedValues = await resolveStringGeneratorField(field, store);
|
||||
if (resolvedValues.length === 0) {
|
||||
setResolvedValuesAsString(`<${t('nodes.generatorNoValues')}>`);
|
||||
} else {
|
||||
setResolvedValuesAsString(resolvedValues.join(', '));
|
||||
}
|
||||
}, 300),
|
||||
[store, t]
|
||||
);
|
||||
useEffect(() => {
|
||||
resolveAndSetValuesAsString(field);
|
||||
}, [field, resolveAndSetValuesAsString]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={2}>
|
||||
@@ -81,10 +84,10 @@ export const StringGeneratorFieldInputComponent = memo(
|
||||
size="sm"
|
||||
>
|
||||
<option value={StringGeneratorParseStringType}>{t('nodes.parseString')}</option>
|
||||
{/* <option value={StringGeneratorDynamicPromptsRandomType}>{t('nodes.dynamicPromptsRandom')}</option>
|
||||
<option value={StringGeneratorDynamicPromptsRandomType}>{t('nodes.dynamicPromptsRandom')}</option>
|
||||
<option value={StringGeneratorDynamicPromptsCombinatorialType}>
|
||||
{t('nodes.dynamicPromptsCombinatorial')}
|
||||
</option> */}
|
||||
</option>
|
||||
</Select>
|
||||
{field.value.type === StringGeneratorParseStringType && (
|
||||
<StringGeneratorParseStringSettings state={field.value} onChange={onChange} />
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import { isNil, trim } from 'lodash-es';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import { isNil, random, trim } from 'lodash-es';
|
||||
import MersenneTwister from 'mtwist';
|
||||
import { utilitiesApi } from 'services/api/endpoints/utilities';
|
||||
import { assert } from 'tsafe';
|
||||
import { z } from 'zod';
|
||||
|
||||
@@ -1409,9 +1412,32 @@ const zStringGeneratorDynamicPromptsCombinatorial = z.object({
|
||||
export type StringGeneratorDynamicPromptsCombinatorial = z.infer<typeof zStringGeneratorDynamicPromptsCombinatorial>;
|
||||
const getStringGeneratorDynamicPromptsCombinatorialDefaults = () =>
|
||||
zStringGeneratorDynamicPromptsCombinatorial.parse({});
|
||||
const getStringGeneratorDynamicPromptsCombinatorialValues = (generator: StringGeneratorDynamicPromptsCombinatorial) => {
|
||||
const { values } = generator;
|
||||
return values ?? [];
|
||||
const getStringGeneratorDynamicPromptsCombinatorialValues = async (
|
||||
generator: StringGeneratorDynamicPromptsCombinatorial,
|
||||
store: AppStore
|
||||
): Promise<string[]> => {
|
||||
const { input, maxPrompts } = generator;
|
||||
const req = store.dispatch(
|
||||
utilitiesApi.endpoints.dynamicPrompts.initiate(
|
||||
{
|
||||
prompt: input,
|
||||
max_prompts: maxPrompts,
|
||||
combinatorial: true,
|
||||
},
|
||||
{
|
||||
subscribe: false,
|
||||
}
|
||||
)
|
||||
);
|
||||
try {
|
||||
const { prompts, error } = await req.unwrap();
|
||||
if (error) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
return prompts;
|
||||
} catch {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
};
|
||||
|
||||
export const StringGeneratorDynamicPromptsRandomType = 'string_generator_dynamic_prompts_random';
|
||||
@@ -1424,9 +1450,33 @@ const zStringGeneratorDynamicPromptsRandom = z.object({
|
||||
});
|
||||
export type StringGeneratorDynamicPromptsRandom = z.infer<typeof zStringGeneratorDynamicPromptsRandom>;
|
||||
const getStringGeneratorDynamicPromptsRandomDefaults = () => zStringGeneratorDynamicPromptsRandom.parse({});
|
||||
const getStringGeneratorDynamicPromptsRandomValues = (generator: StringGeneratorDynamicPromptsRandom) => {
|
||||
const { values } = generator;
|
||||
return values ?? [];
|
||||
const getStringGeneratorDynamicPromptsRandomValues = async (
|
||||
generator: StringGeneratorDynamicPromptsRandom,
|
||||
store: AppStore
|
||||
): Promise<string[]> => {
|
||||
const { input, seed, count } = generator;
|
||||
const req = store.dispatch(
|
||||
utilitiesApi.endpoints.dynamicPrompts.initiate(
|
||||
{
|
||||
prompt: input,
|
||||
max_prompts: count,
|
||||
combinatorial: false,
|
||||
seed: seed ?? random(),
|
||||
},
|
||||
{
|
||||
subscribe: false,
|
||||
}
|
||||
)
|
||||
);
|
||||
try {
|
||||
const { prompts, error } = await req.unwrap();
|
||||
if (error) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
return prompts;
|
||||
} catch {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
};
|
||||
|
||||
export const zStringGeneratorFieldValue = z.union([
|
||||
@@ -1453,7 +1503,7 @@ export const isStringGeneratorFieldInputTemplate = buildTemplateTypeGuard<String
|
||||
zStringGeneratorFieldType.shape.name.value
|
||||
);
|
||||
|
||||
export const resolveStringGeneratorField = ({ value }: StringGeneratorFieldInputInstance) => {
|
||||
export const resolveStringGeneratorField = async ({ value }: StringGeneratorFieldInputInstance, store: AppStore) => {
|
||||
if (value.values) {
|
||||
return value.values;
|
||||
}
|
||||
@@ -1461,10 +1511,10 @@ export const resolveStringGeneratorField = ({ value }: StringGeneratorFieldInput
|
||||
return getStringGeneratorParseStringValues(value);
|
||||
}
|
||||
if (value.type === StringGeneratorDynamicPromptsRandomType) {
|
||||
return getStringGeneratorDynamicPromptsRandomValues(value);
|
||||
return await getStringGeneratorDynamicPromptsRandomValues(value, store);
|
||||
}
|
||||
if (value.type === StringGeneratorDynamicPromptsCombinatorialType) {
|
||||
return getStringGeneratorDynamicPromptsCombinatorialValues(value);
|
||||
return await getStringGeneratorDynamicPromptsCombinatorialValues(value, store);
|
||||
}
|
||||
assert(false, 'Invalid string generator type');
|
||||
};
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import {
|
||||
isFloatFieldCollectionInputInstance,
|
||||
isFloatGeneratorFieldInputInstance,
|
||||
@@ -13,7 +14,12 @@ import {
|
||||
import type { AnyEdge, InvocationNode } from 'features/nodes/types/invocation';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const resolveBatchValue = (batchNode: InvocationNode, nodes: InvocationNode[], edges: AnyEdge[]) => {
|
||||
export const resolveBatchValue = async (
|
||||
batchNode: InvocationNode,
|
||||
nodes: InvocationNode[],
|
||||
edges: AnyEdge[],
|
||||
store: AppStore
|
||||
) => {
|
||||
if (batchNode.data.type === 'image_batch') {
|
||||
assert(isImageFieldCollectionInputInstance(batchNode.data.inputs.images));
|
||||
const ownValue = batchNode.data.inputs.images.value ?? [];
|
||||
@@ -34,7 +40,7 @@ export const resolveBatchValue = (batchNode: InvocationNode, nodes: InvocationNo
|
||||
const generatorField = generatorNode.data.inputs['generator'];
|
||||
assert(isStringGeneratorFieldInputInstance(generatorField), 'Invalid string generator');
|
||||
|
||||
const generatorValue = resolveStringGeneratorField(generatorField);
|
||||
const generatorValue = await resolveStringGeneratorField(generatorField, store);
|
||||
return generatorValue;
|
||||
} else if (batchNode.data.type === 'float_batch') {
|
||||
assert(isFloatFieldCollectionInputInstance(batchNode.data.inputs.floats));
|
||||
|
||||
Reference in New Issue
Block a user