mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): add dynamic prompts string generator (WIP)
This commit is contained in:
@@ -0,0 +1,115 @@
|
||||
import {
|
||||
Checkbox,
|
||||
CompositeNumberInput,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
IconButton,
|
||||
Textarea,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import type { StringGeneratorDynamicPrompts } from 'features/nodes/types/field';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useEffect } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiShuffleSimpleBold } from 'react-icons/pi';
|
||||
import { utilitiesApi } from 'services/api/endpoints/utilities';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
|
||||
const processDynamicPrompts = async (state: StringGeneratorDynamicPrompts) => {
|
||||
const { input, maxPrompts, combinatorial } = state;
|
||||
const { dispatch } = getStore();
|
||||
const req = dispatch(
|
||||
utilitiesApi.endpoints.dynamicPrompts.initiate(
|
||||
{ prompt: input, max_prompts: maxPrompts, combinatorial },
|
||||
{ subscribe: false }
|
||||
)
|
||||
);
|
||||
try {
|
||||
const { prompts } = await req.unwrap();
|
||||
return prompts;
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
type StringGeneratorDynamicPromptsSettingsProps = {
|
||||
state: StringGeneratorDynamicPrompts;
|
||||
onChange: (state: StringGeneratorDynamicPrompts) => void;
|
||||
};
|
||||
export const StringGeneratorDynamicPromptsSettings = memo(
|
||||
({ state, onChange }: StringGeneratorDynamicPromptsSettingsProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onChangeInput = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
onChange({ ...state, input: e.target.value });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
const onChangeMaxPrompts = useCallback(
|
||||
(v: number) => {
|
||||
onChange({ ...state, maxPrompts: v });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
const onChangeCombinatorial = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
onChange({ ...state, combinatorial: e.target.checked });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
const [debouncedState] = useDebounce(state, 1000);
|
||||
|
||||
useEffect(() => {
|
||||
processDynamicPrompts(debouncedState).then((prompts) => {
|
||||
onChange({ ...debouncedState, values: prompts });
|
||||
});
|
||||
}, [onChange, debouncedState]);
|
||||
|
||||
const reroll = useCallback(() => {
|
||||
processDynamicPrompts(debouncedState).then((prompts) => {
|
||||
onChange({ ...debouncedState, values: prompts });
|
||||
});
|
||||
}, [debouncedState, onChange]);
|
||||
|
||||
return (
|
||||
<Flex gap={2} flexDir="column">
|
||||
<Flex gap={2}>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>Max Prompts</FormLabel>
|
||||
{/* <FormLabel>{t('nodes.splitOn')}</FormLabel> */}
|
||||
<CompositeNumberInput value={state.maxPrompts} onChange={onChangeMaxPrompts} min={1} max={1000} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>Combinatorial</FormLabel>
|
||||
{/* <FormLabel>{t('nodes.splitOn')}</FormLabel> */}
|
||||
<Checkbox isChecked={state.combinatorial} onChange={onChangeCombinatorial} />
|
||||
</FormControl>
|
||||
<IconButton
|
||||
aria-label="Reroll"
|
||||
isDisabled={state.combinatorial}
|
||||
onClick={reroll}
|
||||
icon={<PiShuffleSimpleBold />}
|
||||
variant="ghost"
|
||||
/>
|
||||
</Flex>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.input')}</FormLabel>
|
||||
<Textarea
|
||||
className="nowheel nodrag nopan"
|
||||
value={state.input}
|
||||
onChange={onChangeInput}
|
||||
p={2}
|
||||
resize="none"
|
||||
rows={5}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
StringGeneratorDynamicPromptsSettings.displayName = 'StringGeneratorDynamicPromptsSettings';
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Flex, Select, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
|
||||
import { StringGeneratorDynamicPromptsSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorDynamicPromptsSettings';
|
||||
import { StringGeneratorParseStringSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorParseStringSettings';
|
||||
import type { FieldComponentProps } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/types';
|
||||
import { fieldStringGeneratorValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
@@ -8,6 +9,7 @@ import type { StringGeneratorFieldInputInstance, StringGeneratorFieldInputTempla
|
||||
import {
|
||||
getStringGeneratorDefaults,
|
||||
resolveStringGeneratorField,
|
||||
StringGeneratorDynamicPromptsType,
|
||||
StringGeneratorParseStringType,
|
||||
} from 'features/nodes/types/field';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
@@ -65,10 +67,14 @@ export const StringGeneratorFieldInputComponent = memo(
|
||||
<Flex flexDir="column" gap={2}>
|
||||
<Select className="nowheel nodrag" onChange={onChangeGeneratorType} value={field.value.type} size="sm">
|
||||
<option value={StringGeneratorParseStringType}>{t('nodes.parseString')}</option>
|
||||
<option value={StringGeneratorDynamicPromptsType}>{t('nodes.dynamicPrompts')}</option>
|
||||
</Select>
|
||||
{field.value.type === StringGeneratorParseStringType && (
|
||||
<StringGeneratorParseStringSettings state={field.value} onChange={onChange} />
|
||||
)}
|
||||
{field.value.type === StringGeneratorDynamicPromptsType && (
|
||||
<StringGeneratorDynamicPromptsSettings state={field.value} onChange={onChange} />
|
||||
)}
|
||||
<Flex w="full" h="full" p={2} borderWidth={1} borderRadius="base" maxH={128}>
|
||||
<Flex w="full" h="auto">
|
||||
<OverlayScrollbarsComponent
|
||||
|
||||
@@ -1352,8 +1352,22 @@ const getStringGeneratorParseStringValues = (generator: StringGeneratorParseStri
|
||||
const values = splitValues.filter((s) => s.length > 0);
|
||||
return values;
|
||||
};
|
||||
export const StringGeneratorDynamicPromptsType = 'string_generator_dynamic_prompts';
|
||||
const zStringGeneratorDynamicPrompts = z.object({
|
||||
type: z.literal(StringGeneratorDynamicPromptsType).default(StringGeneratorDynamicPromptsType),
|
||||
input: z.string().default('a super {cute|ferocious} {dog|cat}'),
|
||||
maxPrompts: z.number().int().gte(1).default(20),
|
||||
combinatorial: z.boolean().default(true),
|
||||
values: z.array(z.string()).nullish(),
|
||||
});
|
||||
export type StringGeneratorDynamicPrompts = z.infer<typeof zStringGeneratorDynamicPrompts>;
|
||||
export const getStringGeneratorDynamicPromptsDefaults = () => zStringGeneratorDynamicPrompts.parse({});
|
||||
const getStringGeneratorDynamicPromptsValues = (generator: StringGeneratorDynamicPrompts) => {
|
||||
const { values } = generator;
|
||||
return values ?? [];
|
||||
};
|
||||
|
||||
export const zStringGeneratorFieldValue = zStringGeneratorParseString;
|
||||
export const zStringGeneratorFieldValue = z.union([zStringGeneratorParseString, zStringGeneratorDynamicPrompts]);
|
||||
const zStringGeneratorFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zStringGeneratorFieldValue,
|
||||
});
|
||||
@@ -1377,12 +1391,18 @@ export const resolveStringGeneratorField = ({ value }: StringGeneratorFieldInput
|
||||
if (value.type === StringGeneratorParseStringType) {
|
||||
return getStringGeneratorParseStringValues(value);
|
||||
}
|
||||
if (value.type === StringGeneratorDynamicPromptsType) {
|
||||
return getStringGeneratorDynamicPromptsValues(value);
|
||||
}
|
||||
assert(false, 'Invalid string generator type');
|
||||
};
|
||||
export const getStringGeneratorDefaults = (type: StringGeneratorFieldValue['type']) => {
|
||||
if (type === StringGeneratorParseStringType) {
|
||||
return getStringGeneratorParseStringDefaults();
|
||||
}
|
||||
if (type === StringGeneratorDynamicPromptsType) {
|
||||
return getStringGeneratorDynamicPromptsDefaults();
|
||||
}
|
||||
assert(false, 'Invalid string generator type');
|
||||
};
|
||||
// #endregion
|
||||
|
||||
@@ -14,7 +14,7 @@ export const utilitiesApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
dynamicPrompts: build.query<
|
||||
components['schemas']['DynamicPromptsResponse'],
|
||||
{ prompt: string; max_prompts: number }
|
||||
{ prompt: string; max_prompts: number; combinatorial?: boolean }
|
||||
>({
|
||||
query: (arg) => ({
|
||||
url: buildUtilitiesUrl('dynamicprompts'),
|
||||
|
||||
Reference in New Issue
Block a user