feat(ui): improved dynamicprompts generator

- Split into two (random and combinatorial) - lots of fiddly logic to do both in one generator.
- Update to support seeds for random.
This commit is contained in:
psychedelicious
2025-01-19 10:15:48 +11:00
parent 7dcc2dafbc
commit ec816d3c04
7 changed files with 239 additions and 138 deletions

View File

@@ -186,7 +186,8 @@
"max": "Max",
"values": "Values",
"resetToDefaults": "Reset to Defaults",
"seed": "Seed"
"seed": "Seed",
"combinatorial": "Combinatorial"
},
"hrf": {
"hrf": "High Resolution Fix",
@@ -868,6 +869,9 @@
"generatorNRandomValues_one": "{{count}} random value",
"generatorNRandomValues_other": "{{count}} random values",
"generatorNoValues": "empty",
"generatorLoading": "loading",
"dynamicPromptsRandom": "Dynamic Prompts (Random)",
"dynamicPromptsCombinatorial": "Dynamic Prompts (Combinatorial)",
"addNode": "Add Node",
"addNodeToolTip": "Add Node (Shift+A, Space)",
"addLinearView": "Add to Linear View",
@@ -1133,7 +1137,8 @@
"perPromptLabel": "Seed per Image",
"perPromptDesc": "Use a different seed for each image"
},
"loading": "Generating Dynamic Prompts..."
"loading": "Generating Dynamic Prompts...",
"promptsToGenerate": "Prompts to Generate"
},
"sdxl": {
"cfgScale": "CFG Scale",

View File

@@ -0,0 +1,70 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel, Textarea } from '@invoke-ai/ui-library';
import type { StringGeneratorDynamicPromptsCombinatorial } from 'features/nodes/types/field';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useDynamicPromptsQuery } from 'services/api/endpoints/utilities';
import { useDebounce } from 'use-debounce';
type StringGeneratorDynamicPromptsCombinatorialSettingsProps = {
state: StringGeneratorDynamicPromptsCombinatorial;
onChange: (state: StringGeneratorDynamicPromptsCombinatorial) => void;
};
export const StringGeneratorDynamicPromptsCombinatorialSettings = memo(
({ state, onChange }: StringGeneratorDynamicPromptsCombinatorialSettingsProps) => {
const { t } = useTranslation();
const loadingValues = useMemo(() => [`<${t('nodes.generatorLoading')}>`], [t]);
const onChangeInput = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
onChange({ ...state, input: e.target.value, values: loadingValues });
},
[onChange, state, loadingValues]
);
const onChangeMaxPrompts = useCallback(
(v: number) => {
onChange({ ...state, maxPrompts: v, values: loadingValues });
},
[onChange, state, loadingValues]
);
const arg = useMemo(() => {
const { input, maxPrompts } = state;
return { prompt: input, max_prompts: maxPrompts, combinatorial: true };
}, [state]);
const [debouncedArg] = useDebounce(arg, 300);
const { data, isLoading } = useDynamicPromptsQuery(debouncedArg);
useEffect(() => {
if (isLoading) {
onChange({ ...state, values: loadingValues });
} else if (data) {
onChange({ ...state, values: data.prompts });
} else {
onChange({ ...state, values: [] });
}
}, [data, isLoading, loadingValues, onChange, state]);
return (
<Flex gap={2} flexDir="column">
<FormControl orientation="vertical">
<FormLabel>{t('dynamicPrompts.maxPrompts')}</FormLabel>
<CompositeNumberInput value={state.maxPrompts} onChange={onChangeMaxPrompts} min={1} max={1000} w="full"/>
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.input')}</FormLabel>
<Textarea
className="nowheel nodrag nopan"
value={state.input}
onChange={onChangeInput}
p={2}
resize="none"
rows={5}
/>
</FormControl>
</Flex>
);
}
);
StringGeneratorDynamicPromptsCombinatorialSettings.displayName = 'StringGeneratorDynamicPromptsCombinatorialSettings';

View File

@@ -0,0 +1,97 @@
import { Checkbox, CompositeNumberInput, Flex, FormControl, FormLabel, Textarea } from '@invoke-ai/ui-library';
import type { StringGeneratorDynamicPromptsRandom } from 'features/nodes/types/field';
import { isNil, random } from 'lodash-es';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useDynamicPromptsQuery } from 'services/api/endpoints/utilities';
import { useDebounce } from 'use-debounce';
type StringGeneratorDynamicPromptsRandomSettingsProps = {
state: StringGeneratorDynamicPromptsRandom;
onChange: (state: StringGeneratorDynamicPromptsRandom) => void;
};
export const StringGeneratorDynamicPromptsRandomSettings = memo(
({ state, onChange }: StringGeneratorDynamicPromptsRandomSettingsProps) => {
const { t } = useTranslation();
const loadingValues = useMemo(() => [`<${t('nodes.generatorLoading')}>`], [t]);
const onChangeInput = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
onChange({ ...state, input: e.target.value, values: loadingValues });
},
[onChange, state, loadingValues]
);
const onChangeCount = useCallback(
(v: number) => {
onChange({ ...state, count: v, values: loadingValues });
},
[onChange, state, loadingValues]
);
const onToggleSeed = useCallback(() => {
onChange({ ...state, seed: isNil(state.seed) ? 0 : null, values: loadingValues });
}, [onChange, state, loadingValues]);
const onChangeSeed = useCallback(
(seed?: number | null) => {
onChange({ ...state, seed, values: loadingValues });
},
[onChange, state, loadingValues]
);
const arg = useMemo(() => {
const { input, count, seed } = state;
return { prompt: input, max_prompts: count, combinatorial: false, seed: seed ?? random() };
}, [state]);
const [debouncedArg] = useDebounce(arg, 300);
const { data, isLoading } = useDynamicPromptsQuery(debouncedArg);
useEffect(() => {
if (isLoading) {
onChange({ ...state, values: loadingValues });
} else if (data) {
onChange({ ...state, values: data.prompts });
} else {
onChange({ ...state, values: [] });
}
}, [data, isLoading, loadingValues, onChange, state]);
return (
<Flex gap={2} flexDir="column">
<Flex gap={2}>
<FormControl orientation="vertical">
<FormLabel alignItems="center" justifyContent="space-between" display="flex" w="full" pe={0.5}>
{t('common.seed')}
<Checkbox onChange={onToggleSeed} isChecked={!isNil(state.seed)} />
</FormLabel>
<CompositeNumberInput
isDisabled={isNil(state.seed)}
// This cast is save only because we disable the element when seed is not a number - the `...` is
// rendered in the input field in this case
value={state.seed ?? ('...' as unknown as number)}
onChange={onChangeSeed}
min={-Infinity}
max={Infinity}
/>
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={1000} />
</FormControl>
</Flex>
<FormControl orientation="vertical">
<FormLabel>{t('common.input')}</FormLabel>
<Textarea
className="nowheel nodrag nopan"
value={state.input}
onChange={onChangeInput}
p={2}
resize="none"
rows={5}
/>
</FormControl>
</Flex>
);
}
);
StringGeneratorDynamicPromptsRandomSettings.displayName = 'StringGeneratorDynamicPromptsRandomSettings';

View File

@@ -1,115 +0,0 @@
import {
Checkbox,
CompositeNumberInput,
Flex,
FormControl,
FormLabel,
IconButton,
Textarea,
} from '@invoke-ai/ui-library';
import { getStore } from 'app/store/nanostores/store';
import type { StringGeneratorDynamicPrompts } from 'features/nodes/types/field';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { PiShuffleSimpleBold } from 'react-icons/pi';
import { utilitiesApi } from 'services/api/endpoints/utilities';
import { useDebounce } from 'use-debounce';
const processDynamicPrompts = async (state: StringGeneratorDynamicPrompts) => {
const { input, maxPrompts, combinatorial } = state;
const { dispatch } = getStore();
const req = dispatch(
utilitiesApi.endpoints.dynamicPrompts.initiate(
{ prompt: input, max_prompts: maxPrompts, combinatorial },
{ subscribe: false }
)
);
try {
const { prompts } = await req.unwrap();
return prompts;
} catch {
return [];
}
};
type StringGeneratorDynamicPromptsSettingsProps = {
state: StringGeneratorDynamicPrompts;
onChange: (state: StringGeneratorDynamicPrompts) => void;
};
export const StringGeneratorDynamicPromptsSettings = memo(
({ state, onChange }: StringGeneratorDynamicPromptsSettingsProps) => {
const { t } = useTranslation();
const onChangeInput = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
onChange({ ...state, input: e.target.value });
},
[onChange, state]
);
const onChangeMaxPrompts = useCallback(
(v: number) => {
onChange({ ...state, maxPrompts: v });
},
[onChange, state]
);
const onChangeCombinatorial = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...state, combinatorial: e.target.checked });
},
[onChange, state]
);
const [debouncedState] = useDebounce(state, 1000);
useEffect(() => {
processDynamicPrompts(debouncedState).then((prompts) => {
onChange({ ...debouncedState, values: prompts });
});
}, [onChange, debouncedState]);
const reroll = useCallback(() => {
processDynamicPrompts(debouncedState).then((prompts) => {
onChange({ ...debouncedState, values: prompts });
});
}, [debouncedState, onChange]);
return (
<Flex gap={2} flexDir="column">
<Flex gap={2}>
<FormControl orientation="vertical">
<FormLabel>Max Prompts</FormLabel>
{/* <FormLabel>{t('nodes.splitOn')}</FormLabel> */}
<CompositeNumberInput value={state.maxPrompts} onChange={onChangeMaxPrompts} min={1} max={1000} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>Combinatorial</FormLabel>
{/* <FormLabel>{t('nodes.splitOn')}</FormLabel> */}
<Checkbox isChecked={state.combinatorial} onChange={onChangeCombinatorial} />
</FormControl>
<IconButton
aria-label="Reroll"
isDisabled={state.combinatorial}
onClick={reroll}
icon={<PiShuffleSimpleBold />}
variant="ghost"
/>
</Flex>
<FormControl orientation="vertical">
<FormLabel>{t('common.input')}</FormLabel>
<Textarea
className="nowheel nodrag nopan"
value={state.input}
onChange={onChangeInput}
p={2}
resize="none"
rows={5}
/>
</FormControl>
</Flex>
);
}
);
StringGeneratorDynamicPromptsSettings.displayName = 'StringGeneratorDynamicPromptsSettings';

View File

@@ -1,7 +1,8 @@
import { Flex, Select, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
import { StringGeneratorDynamicPromptsSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorDynamicPromptsSettings';
import { StringGeneratorDynamicPromptsCombinatorialSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorDynamicPromptsCombinatorialSettings';
import { StringGeneratorDynamicPromptsRandomSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorDynamicPromptsRandomSettings';
import { StringGeneratorParseStringSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorParseStringSettings';
import type { FieldComponentProps } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/types';
import { fieldStringGeneratorValueChanged } from 'features/nodes/store/nodesSlice';
@@ -9,9 +10,11 @@ import type { StringGeneratorFieldInputInstance, StringGeneratorFieldInputTempla
import {
getStringGeneratorDefaults,
resolveStringGeneratorField,
StringGeneratorDynamicPromptsType,
StringGeneratorDynamicPromptsCombinatorialType,
StringGeneratorDynamicPromptsRandomType,
StringGeneratorParseStringType,
} from 'features/nodes/types/field';
import { isNil } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useMemo } from 'react';
@@ -55,6 +58,11 @@ export const StringGeneratorFieldInputComponent = memo(
const [debouncedField] = useDebounce(field, 300);
const resolvedValuesAsString = useMemo(() => {
if (debouncedField.value.type === StringGeneratorDynamicPromptsRandomType && isNil(debouncedField.value.seed)) {
const { count } = debouncedField.value;
return `<${t('nodes.generatorNRandomValues', { count })}>`;
}
const resolvedValues = resolveStringGeneratorField(debouncedField);
if (resolvedValues.length === 0) {
return `<${t('nodes.generatorNoValues')}>`;
@@ -67,13 +75,19 @@ export const StringGeneratorFieldInputComponent = memo(
<Flex flexDir="column" gap={2}>
<Select className="nowheel nodrag" onChange={onChangeGeneratorType} value={field.value.type} size="sm">
<option value={StringGeneratorParseStringType}>{t('nodes.parseString')}</option>
<option value={StringGeneratorDynamicPromptsType}>{t('nodes.dynamicPrompts')}</option>
<option value={StringGeneratorDynamicPromptsRandomType}>{t('nodes.dynamicPromptsRandom')}</option>
<option value={StringGeneratorDynamicPromptsCombinatorialType}>
{t('nodes.dynamicPromptsCombinatorial')}
</option>
</Select>
{field.value.type === StringGeneratorParseStringType && (
<StringGeneratorParseStringSettings state={field.value} onChange={onChange} />
)}
{field.value.type === StringGeneratorDynamicPromptsType && (
<StringGeneratorDynamicPromptsSettings state={field.value} onChange={onChange} />
{field.value.type === StringGeneratorDynamicPromptsRandomType && (
<StringGeneratorDynamicPromptsRandomSettings state={field.value} onChange={onChange} />
)}
{field.value.type === StringGeneratorDynamicPromptsCombinatorialType && (
<StringGeneratorDynamicPromptsCombinatorialSettings state={field.value} onChange={onChange} />
)}
<Flex w="full" h="full" p={2} borderWidth={1} borderRadius="base" maxH={128}>
<Flex w="full" h="auto">

View File

@@ -1352,22 +1352,44 @@ const getStringGeneratorParseStringValues = (generator: StringGeneratorParseStri
const values = splitValues.filter((s) => s.length > 0);
return values;
};
export const StringGeneratorDynamicPromptsType = 'string_generator_dynamic_prompts';
const zStringGeneratorDynamicPrompts = z.object({
type: z.literal(StringGeneratorDynamicPromptsType).default(StringGeneratorDynamicPromptsType),
export const StringGeneratorDynamicPromptsCombinatorialType = 'string_generator_dynamic_prompts_combinatorial';
const zStringGeneratorDynamicPromptsCombinatorial = z.object({
type: z
.literal(StringGeneratorDynamicPromptsCombinatorialType)
.default(StringGeneratorDynamicPromptsCombinatorialType),
input: z.string().default('a super {cute|ferocious} {dog|cat}'),
maxPrompts: z.number().int().gte(1).default(20),
combinatorial: z.boolean().default(true),
maxPrompts: z.number().int().gte(1).default(10),
values: z.array(z.string()).nullish(),
});
export type StringGeneratorDynamicPrompts = z.infer<typeof zStringGeneratorDynamicPrompts>;
export const getStringGeneratorDynamicPromptsDefaults = () => zStringGeneratorDynamicPrompts.parse({});
const getStringGeneratorDynamicPromptsValues = (generator: StringGeneratorDynamicPrompts) => {
export type StringGeneratorDynamicPromptsCombinatorial = z.infer<typeof zStringGeneratorDynamicPromptsCombinatorial>;
export const getStringGeneratorDynamicPromptsCombinatorialDefaults = () =>
zStringGeneratorDynamicPromptsCombinatorial.parse({});
const getStringGeneratorDynamicPromptsCombinatorialValues = (generator: StringGeneratorDynamicPromptsCombinatorial) => {
const { values } = generator;
return values ?? [];
};
export const zStringGeneratorFieldValue = z.union([zStringGeneratorParseString, zStringGeneratorDynamicPrompts]);
export const StringGeneratorDynamicPromptsRandomType = 'string_generator_dynamic_prompts_random';
const zStringGeneratorDynamicPromptsRandom = z.object({
type: z.literal(StringGeneratorDynamicPromptsRandomType).default(StringGeneratorDynamicPromptsRandomType),
input: z.string().default('a super {cute|ferocious} {dog|cat}'),
count: z.number().int().gte(1).default(10),
seed: z.number().int().nullish(),
values: z.array(z.string()).nullish(),
});
export type StringGeneratorDynamicPromptsRandom = z.infer<typeof zStringGeneratorDynamicPromptsRandom>;
export const getStringGeneratorDynamicPromptsRandomDefaults = () => zStringGeneratorDynamicPromptsRandom.parse({});
const getStringGeneratorDynamicPromptsRandomValues = (generator: StringGeneratorDynamicPromptsRandom) => {
const { values } = generator;
return values ?? [];
};
export const zStringGeneratorFieldValue = z.union([
zStringGeneratorParseString,
zStringGeneratorDynamicPromptsCombinatorial,
zStringGeneratorDynamicPromptsRandom,
]);
const zStringGeneratorFieldInputInstance = zFieldInputInstanceBase.extend({
value: zStringGeneratorFieldValue,
});
@@ -1391,8 +1413,11 @@ export const resolveStringGeneratorField = ({ value }: StringGeneratorFieldInput
if (value.type === StringGeneratorParseStringType) {
return getStringGeneratorParseStringValues(value);
}
if (value.type === StringGeneratorDynamicPromptsType) {
return getStringGeneratorDynamicPromptsValues(value);
if (value.type === StringGeneratorDynamicPromptsRandomType) {
return getStringGeneratorDynamicPromptsRandomValues(value);
}
if (value.type === StringGeneratorDynamicPromptsCombinatorialType) {
return getStringGeneratorDynamicPromptsCombinatorialValues(value);
}
assert(false, 'Invalid string generator type');
};
@@ -1400,8 +1425,11 @@ export const getStringGeneratorDefaults = (type: StringGeneratorFieldValue['type
if (type === StringGeneratorParseStringType) {
return getStringGeneratorParseStringDefaults();
}
if (type === StringGeneratorDynamicPromptsType) {
return getStringGeneratorDynamicPromptsDefaults();
if (type === StringGeneratorDynamicPromptsRandomType) {
return getStringGeneratorDynamicPromptsRandomDefaults();
}
if (type === StringGeneratorDynamicPromptsCombinatorialType) {
return getStringGeneratorDynamicPromptsCombinatorialDefaults();
}
assert(false, 'Invalid string generator type');
};

View File

@@ -1,4 +1,4 @@
import type { components } from 'services/api/schema';
import type { paths } from 'services/api/schema';
import { api, buildV1Url } from '..';
@@ -13,8 +13,8 @@ const buildUtilitiesUrl = (path: string = '') => buildV1Url(`utilities/${path}`)
export const utilitiesApi = api.injectEndpoints({
endpoints: (build) => ({
dynamicPrompts: build.query<
components['schemas']['DynamicPromptsResponse'],
{ prompt: string; max_prompts: number; combinatorial?: boolean }
paths['/api/v1/utilities/dynamicprompts']['post']['responses']['200']['content']['application/json'],
paths['/api/v1/utilities/dynamicprompts']['post']['requestBody']['content']['application/json']
>({
query: (arg) => ({
url: buildUtilitiesUrl('dynamicprompts'),
@@ -28,3 +28,5 @@ export const utilitiesApi = api.injectEndpoints({
}),
}),
});
export const { useDynamicPromptsQuery, useLazyDynamicPromptsQuery } = utilitiesApi;