feat(ui): add dynamic prompts string generator (WIP)

This commit is contained in:
psychedelicious
2025-01-18 23:12:16 +11:00
parent 724028d974
commit eb976a2ab0
4 changed files with 143 additions and 2 deletions

View File

@@ -0,0 +1,115 @@
import {
Checkbox,
CompositeNumberInput,
Flex,
FormControl,
FormLabel,
IconButton,
Textarea,
} from '@invoke-ai/ui-library';
import { getStore } from 'app/store/nanostores/store';
import type { StringGeneratorDynamicPrompts } from 'features/nodes/types/field';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { PiShuffleSimpleBold } from 'react-icons/pi';
import { utilitiesApi } from 'services/api/endpoints/utilities';
import { useDebounce } from 'use-debounce';
const processDynamicPrompts = async (state: StringGeneratorDynamicPrompts) => {
const { input, maxPrompts, combinatorial } = state;
const { dispatch } = getStore();
const req = dispatch(
utilitiesApi.endpoints.dynamicPrompts.initiate(
{ prompt: input, max_prompts: maxPrompts, combinatorial },
{ subscribe: false }
)
);
try {
const { prompts } = await req.unwrap();
return prompts;
} catch {
return [];
}
};
type StringGeneratorDynamicPromptsSettingsProps = {
state: StringGeneratorDynamicPrompts;
onChange: (state: StringGeneratorDynamicPrompts) => void;
};
export const StringGeneratorDynamicPromptsSettings = memo(
({ state, onChange }: StringGeneratorDynamicPromptsSettingsProps) => {
const { t } = useTranslation();
const onChangeInput = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
onChange({ ...state, input: e.target.value });
},
[onChange, state]
);
const onChangeMaxPrompts = useCallback(
(v: number) => {
onChange({ ...state, maxPrompts: v });
},
[onChange, state]
);
const onChangeCombinatorial = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...state, combinatorial: e.target.checked });
},
[onChange, state]
);
const [debouncedState] = useDebounce(state, 1000);
useEffect(() => {
processDynamicPrompts(debouncedState).then((prompts) => {
onChange({ ...debouncedState, values: prompts });
});
}, [onChange, debouncedState]);
const reroll = useCallback(() => {
processDynamicPrompts(debouncedState).then((prompts) => {
onChange({ ...debouncedState, values: prompts });
});
}, [debouncedState, onChange]);
return (
<Flex gap={2} flexDir="column">
<Flex gap={2}>
<FormControl orientation="vertical">
<FormLabel>Max Prompts</FormLabel>
{/* <FormLabel>{t('nodes.splitOn')}</FormLabel> */}
<CompositeNumberInput value={state.maxPrompts} onChange={onChangeMaxPrompts} min={1} max={1000} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>Combinatorial</FormLabel>
{/* <FormLabel>{t('nodes.splitOn')}</FormLabel> */}
<Checkbox isChecked={state.combinatorial} onChange={onChangeCombinatorial} />
</FormControl>
<IconButton
aria-label="Reroll"
isDisabled={state.combinatorial}
onClick={reroll}
icon={<PiShuffleSimpleBold />}
variant="ghost"
/>
</Flex>
<FormControl orientation="vertical">
<FormLabel>{t('common.input')}</FormLabel>
<Textarea
className="nowheel nodrag nopan"
value={state.input}
onChange={onChangeInput}
p={2}
resize="none"
rows={5}
/>
</FormControl>
</Flex>
);
}
);
StringGeneratorDynamicPromptsSettings.displayName = 'StringGeneratorDynamicPromptsSettings';

View File

@@ -1,6 +1,7 @@
import { Flex, Select, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
import { StringGeneratorDynamicPromptsSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorDynamicPromptsSettings';
import { StringGeneratorParseStringSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorParseStringSettings';
import type { FieldComponentProps } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/types';
import { fieldStringGeneratorValueChanged } from 'features/nodes/store/nodesSlice';
@@ -8,6 +9,7 @@ import type { StringGeneratorFieldInputInstance, StringGeneratorFieldInputTempla
import {
getStringGeneratorDefaults,
resolveStringGeneratorField,
StringGeneratorDynamicPromptsType,
StringGeneratorParseStringType,
} from 'features/nodes/types/field';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
@@ -65,10 +67,14 @@ export const StringGeneratorFieldInputComponent = memo(
<Flex flexDir="column" gap={2}>
<Select className="nowheel nodrag" onChange={onChangeGeneratorType} value={field.value.type} size="sm">
<option value={StringGeneratorParseStringType}>{t('nodes.parseString')}</option>
<option value={StringGeneratorDynamicPromptsType}>{t('nodes.dynamicPrompts')}</option>
</Select>
{field.value.type === StringGeneratorParseStringType && (
<StringGeneratorParseStringSettings state={field.value} onChange={onChange} />
)}
{field.value.type === StringGeneratorDynamicPromptsType && (
<StringGeneratorDynamicPromptsSettings state={field.value} onChange={onChange} />
)}
<Flex w="full" h="full" p={2} borderWidth={1} borderRadius="base" maxH={128}>
<Flex w="full" h="auto">
<OverlayScrollbarsComponent

View File

@@ -1352,8 +1352,22 @@ const getStringGeneratorParseStringValues = (generator: StringGeneratorParseStri
const values = splitValues.filter((s) => s.length > 0);
return values;
};
export const StringGeneratorDynamicPromptsType = 'string_generator_dynamic_prompts';
const zStringGeneratorDynamicPrompts = z.object({
type: z.literal(StringGeneratorDynamicPromptsType).default(StringGeneratorDynamicPromptsType),
input: z.string().default('a super {cute|ferocious} {dog|cat}'),
maxPrompts: z.number().int().gte(1).default(20),
combinatorial: z.boolean().default(true),
values: z.array(z.string()).nullish(),
});
export type StringGeneratorDynamicPrompts = z.infer<typeof zStringGeneratorDynamicPrompts>;
export const getStringGeneratorDynamicPromptsDefaults = () => zStringGeneratorDynamicPrompts.parse({});
const getStringGeneratorDynamicPromptsValues = (generator: StringGeneratorDynamicPrompts) => {
const { values } = generator;
return values ?? [];
};
export const zStringGeneratorFieldValue = zStringGeneratorParseString;
export const zStringGeneratorFieldValue = z.union([zStringGeneratorParseString, zStringGeneratorDynamicPrompts]);
const zStringGeneratorFieldInputInstance = zFieldInputInstanceBase.extend({
value: zStringGeneratorFieldValue,
});
@@ -1377,12 +1391,18 @@ export const resolveStringGeneratorField = ({ value }: StringGeneratorFieldInput
if (value.type === StringGeneratorParseStringType) {
return getStringGeneratorParseStringValues(value);
}
if (value.type === StringGeneratorDynamicPromptsType) {
return getStringGeneratorDynamicPromptsValues(value);
}
assert(false, 'Invalid string generator type');
};
export const getStringGeneratorDefaults = (type: StringGeneratorFieldValue['type']) => {
if (type === StringGeneratorParseStringType) {
return getStringGeneratorParseStringDefaults();
}
if (type === StringGeneratorDynamicPromptsType) {
return getStringGeneratorDynamicPromptsDefaults();
}
assert(false, 'Invalid string generator type');
};
// #endregion

View File

@@ -14,7 +14,7 @@ export const utilitiesApi = api.injectEndpoints({
endpoints: (build) => ({
dynamicPrompts: build.query<
components['schemas']['DynamicPromptsResponse'],
{ prompt: string; max_prompts: number }
{ prompt: string; max_prompts: number; combinatorial?: boolean }
>({
query: (arg) => ({
url: buildUtilitiesUrl('dynamicprompts'),