Compare commits

...

2 Commits

Author SHA1 Message Date
psychedelicious
ac012721b0 feat(ui): iterate on simple tab 2025-05-14 17:58:17 +10:00
psychedelicious
9706df02d4 feat(ui): rough out simple generation tab state (wip) 2025-05-14 10:46:38 +10:00
25 changed files with 1106 additions and 44 deletions

View File

@@ -23,6 +23,10 @@ import { workflowLibraryPersistConfig, workflowLibrarySlice } from 'features/nod
import { workflowSettingsPersistConfig, workflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { upscalePersistConfig, upscaleSlice } from 'features/parameters/store/upscaleSlice';
import { queueSlice } from 'features/queue/store/queueSlice';
import {
simpleGenerationPersistConfig,
simpleGenerationSlice,
} from 'features/simpleGeneration/store/slice';
import { stylePresetPersistConfig, stylePresetSlice } from 'features/stylePresets/store/stylePresetSlice';
import { configSlice } from 'features/system/store/configSlice';
import { systemPersistConfig, systemSlice } from 'features/system/store/systemSlice';
@@ -68,6 +72,7 @@ const allReducers = {
[canvasStagingAreaSlice.name]: canvasStagingAreaSlice.reducer,
[lorasSlice.name]: lorasSlice.reducer,
[workflowLibrarySlice.name]: workflowLibrarySlice.reducer,
[simpleGenerationSlice.name]: simpleGenerationSlice.reducer,
};
const rootReducer = combineReducers(allReducers);
@@ -113,6 +118,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[canvasStagingAreaPersistConfig.name]: canvasStagingAreaPersistConfig,
[lorasPersistConfig.name]: lorasPersistConfig,
[workflowLibraryPersistConfig.name]: workflowLibraryPersistConfig,
[simpleGenerationSlice.name]: simpleGenerationPersistConfig,
};
const unserialize: UnserializeFunction = (data, key) => {

View File

@@ -14,10 +14,11 @@ import { useStore } from '@nanostores/react';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { typedMemo } from 'common/util/typedMemo';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import { isEqual } from 'lodash-es';
import type { AnyStore, ReadableAtom, Task, WritableAtom } from 'nanostores';
import { atom, computed } from 'nanostores';
import type { StoreValues } from 'nanostores/computed';
import type { ChangeEvent, MouseEventHandler, PropsWithChildren, RefObject } from 'react';
import type { ChangeEvent, PropsWithChildren, RefObject } from 'react';
import React, {
createContext,
useCallback,
@@ -198,11 +199,17 @@ type PickerProps<T extends object> = {
* Whether the picker should be searchable. If true, renders a search input.
*/
searchable?: boolean;
/**
* The default groups to enable. If omitted, all groups are disabled by default, meaning all groups and options
* are visible by default.
*/
defaultEnabledGroups?: string[];
};
export type PickerContextState<T extends object> = {
$optionsOrGroups: WritableAtom<OptionOrGroup<T>[]>;
$groupStatusMap: WritableAtom<GroupStatusMap>;
defaultStatusMap: GroupStatusMap;
$compactView: WritableAtom<boolean>;
$activeOptionId: WritableAtom<string | undefined>;
$filteredOptions: WritableAtom<OptionOrGroup<T>[]>;
@@ -217,6 +224,7 @@ export type PickerContextState<T extends object> = {
$searchTerm: WritableAtom<string>;
searchPlaceholder?: string;
toggleGroup: (id: string) => void;
resetGroups: () => void;
getOptionId: (option: T) => string;
isMatch: (option: T, searchTerm: string) => boolean;
getIsOptionDisabled?: (option: T) => boolean;
@@ -229,6 +237,7 @@ export type PickerContextState<T extends object> = {
OptionComponent: React.ComponentType<{ option: T } & BoxProps>;
NextToSearchBar?: React.ReactNode;
searchable?: boolean;
defaultEnabledGroups?: string[];
};
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
@@ -312,7 +321,7 @@ const flattenOptions = <T extends object>(options: OptionOrGroup<T>[]): T[] => {
type GroupStatusMap = Record<string, boolean>;
const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[]) => {
const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[], defaultEnabledGroups?: string[]) => {
const groupsWithOptions = useMemo(() => {
const ids: string[] = [];
for (const optionOrGroup of options) {
@@ -323,23 +332,41 @@ const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[]) => {
return ids;
}, [options]);
const [$groupStatusMap] = useState(atom<GroupStatusMap>({}));
const defaultStatusMap = useMemo(() => {
const map: GroupStatusMap = {};
if (!defaultEnabledGroups || defaultEnabledGroups.length === 0) {
for (const id of groupsWithOptions) {
map[id] = false;
}
} else {
for (const id of groupsWithOptions) {
map[id] = defaultEnabledGroups.includes(id);
}
}
return map;
}, [defaultEnabledGroups, groupsWithOptions]);
const [$groupStatusMap] = useState(atom<GroupStatusMap>(defaultStatusMap));
const [$areAllGroupsDisabled] = useState(() =>
computed($groupStatusMap, (groupStatusMap) => Object.values(groupStatusMap).every((status) => status === false))
);
useEffect(() => {
const groupStatusMap = $groupStatusMap.get();
const newMap: GroupStatusMap = {};
for (const id of groupsWithOptions) {
if (newMap[id] === undefined) {
newMap[id] = false;
} else if (groupStatusMap[id] !== undefined) {
newMap[id] = groupStatusMap[id];
}
}
$groupStatusMap.set(newMap);
}, [groupsWithOptions, $groupStatusMap]);
$groupStatusMap.set(defaultStatusMap);
}, [$groupStatusMap, defaultStatusMap]);
// useEffect(() => {
// const groupStatusMap = $groupStatusMap.get();
// const newMap: GroupStatusMap = {};
// for (const id of groupsWithOptions) {
// if (newMap[id] === undefined) {
// newMap[id] = false;
// } else if (groupStatusMap[id] !== undefined) {
// newMap[id] = groupStatusMap[id];
// }
// }
// $groupStatusMap.set(newMap);
// }, [groupsWithOptions, $groupStatusMap]);
const toggleGroup = useCallback(
(idToToggle: string) => {
@@ -354,7 +381,11 @@ const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[]) => {
[$groupStatusMap, groupsWithOptions]
);
return { $groupStatusMap, $areAllGroupsDisabled, toggleGroup } as const;
const resetGroups = useCallback(() => {
$groupStatusMap.set(defaultStatusMap);
}, [$groupStatusMap, defaultStatusMap]);
return { $groupStatusMap, $areAllGroupsDisabled, toggleGroup, resetGroups, defaultStatusMap } as const;
};
const useKeyboardNavigation = <T extends object>() => {
@@ -511,10 +542,14 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
OptionComponent = DefaultOptionComponent,
NextToSearchBar,
searchable,
defaultEnabledGroups,
} = props;
const rootRef = useRef<HTMLDivElement>(null);
const inputRef = useRef<HTMLInputElement>(null);
const { $groupStatusMap, $areAllGroupsDisabled, toggleGroup } = useTogglableGroups(optionsOrGroups);
const { $groupStatusMap, $areAllGroupsDisabled, toggleGroup, resetGroups, defaultStatusMap } = useTogglableGroups(
optionsOrGroups,
defaultEnabledGroups
);
const $activeOptionId = useAtom(getFirstOptionId(optionsOrGroups, getOptionId));
const $compactView = useAtom(true);
const $optionsOrGroups = useAtom(optionsOrGroups);
@@ -576,11 +611,14 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
NextToSearchBar,
onClose,
searchable,
defaultEnabledGroups,
$areAllGroupsDisabled,
$selectedItemId,
$hasOptions,
$hasFilteredOptions,
$filteredOptionsCount,
resetGroups,
defaultStatusMap,
}) satisfies PickerContextState<T>,
[
$optionsOrGroups,
@@ -604,11 +642,14 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
NextToSearchBar,
onClose,
searchable,
defaultEnabledGroups,
$areAllGroupsDisabled,
$selectedItemId,
$hasOptions,
$hasFilteredOptions,
$filteredOptionsCount,
resetGroups,
defaultStatusMap,
]
);
@@ -807,8 +848,10 @@ const SearchInput = typedMemo(<T extends object>() => {
});
SearchInput.displayName = 'SearchInput';
const GroupToggleButtons = typedMemo(<T extends object>() => {
const { $optionsOrGroups, $groupStatusMap, $areAllGroupsDisabled } = usePickerContext<T>();
const { $optionsOrGroups, $groupStatusMap, defaultStatusMap, resetGroups } = usePickerContext<T>();
const { t } = useTranslation();
const groupStatusMap = useStore($groupStatusMap);
const isResetDisabled = useMemo(() => isEqual(defaultStatusMap, groupStatusMap), [defaultStatusMap, groupStatusMap]);
const $groups = useComputed([$optionsOrGroups], (optionsOrGroups) => {
const _groups: Group<T>[] = [];
for (const optionOrGroup of optionsOrGroups) {
@@ -819,15 +862,6 @@ const GroupToggleButtons = typedMemo(<T extends object>() => {
return _groups;
});
const groups = useStore($groups);
const areAllGroupsDisabled = useStore($areAllGroupsDisabled);
const onClick = useCallback<MouseEventHandler>(() => {
const newMap: GroupStatusMap = {};
for (const { id } of groups) {
newMap[id] = false;
}
$groupStatusMap.set(newMap);
}, [$groupStatusMap, groups]);
if (!groups.length) {
return null;
@@ -846,11 +880,11 @@ const GroupToggleButtons = typedMemo(<T extends object>() => {
size="sm"
variant="link"
alignSelf="stretch"
onClick={onClick}
onClick={resetGroups}
// When a focused element is disabled, it blurs. This closes the popover. Fake the disabled state to prevent this.
// See: https://github.com/chakra-ui/chakra-ui/issues/7965
opacity={areAllGroupsDisabled ? 0.5 : undefined}
pointerEvents={areAllGroupsDisabled ? 'none' : undefined}
opacity={isResetDisabled ? 0.5 : undefined}
pointerEvents={isResetDisabled ? 'none' : undefined}
/>
</Flex>
);

View File

@@ -23,7 +23,7 @@ const zServerValidatedModelIdentifierField = zModelIdentifierField.refine(async
}
});
const zImageWithDims = z
export const zImageWithDims = z
.object({
image_name: z.string(),
width: z.number().int().positive(),
@@ -410,6 +410,7 @@ export const isImagen3AspectRatioID = (v: unknown): v is z.infer<typeof zImagen3
zImagen3AspectRatioID.safeParse(v).success;
export const zChatGPT4oAspectRatioID = z.enum(['3:2', '1:1', '2:3']);
export type ChatGPT4oAspectRatioID = z.infer<typeof zChatGPT4oAspectRatioID>;
export const isChatGPT4oAspectRatioID = (v: unknown): v is z.infer<typeof zChatGPT4oAspectRatioID> =>
zChatGPT4oAspectRatioID.safeParse(v).success;

View File

@@ -0,0 +1,361 @@
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from 'app/constants';
import type { RootState } from 'app/store/store';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
import { selectSimpleGenerationSlice } from 'features/simpleGeneration/store/slice';
import { ASPECT_RATIO_MAP } from 'features/simpleGeneration/util/aspectRatioToDimensions';
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
const EMPTY_ENTITY_STATE = {
ids: [],
entities: {},
};
export const getFLUXModels = (modelConfigs: ReturnType<typeof selectModelConfigsQuery>) => {
const allModelConfigs = modelConfigsAdapterSelectors.selectAll(modelConfigs?.data ?? EMPTY_ENTITY_STATE);
/**
* The sources are taken from `invokeai/backend/model_manager/starter_models.py`
*
* Note: The sources for HF Repo models are subtly in the python file vs the actual model config we get from the
* HTTP query response.
*
* In python, a double colon `::` separates the HF Repo ID from the folder path within the repo. But the model
* configs we get from the query use a single colon `:`.
*
* For example:
* Python: InvokeAI/t5-v1_1-xxl::bnb_llm_int8
* Query : InvokeAI/t5-v1_1-xxl:bnb_llm_int8
*
* I'm not sure if it's always been like this, but to be safe, we check for both formats below.
*/
const flux = allModelConfigs.find(
({ type, base, source }) =>
type === 'main' &&
base === 'flux' &&
(source === 'InvokeAI/flux_dev:transformer/bnb_nf4/flux1-dev-bnb_nf4.safetensors' ||
source === 'InvokeAI/flux_dev::transformer/bnb_nf4/flux1-dev-bnb_nf4.safetensors')
);
const t5Encoder = allModelConfigs.find(
({ type, base, source }) =>
type === 't5_encoder' &&
base === 'any' &&
(source === 'InvokeAI/t5-v1_1-xxl:bnb_llm_int8' || source === 'InvokeAI/t5-v1_1-xxl::bnb_llm_int8')
);
const clipEmbed = allModelConfigs.find(
({ type, base, source }) =>
type === 'clip_embed' &&
base === 'any' &&
(source === 'InvokeAI/clip-vit-large-patch14-text-encoder:bfloat16' ||
source === 'InvokeAI/clip-vit-large-patch14-text-encoder::bfloat16')
);
const clipVision = allModelConfigs.find(
({ type, base, source }) => type === 'clip_vision' && base === 'any' && source === 'InvokeAI/clip-vit-large-patch14'
);
const vae = allModelConfigs.find(
({ type, base, source }) =>
type === 'vae' &&
base === 'flux' &&
(source === 'black-forest-labs/FLUX.1-schnell:ae.safetensors' ||
source === 'black-forest-labs/FLUX.1-schnell::ae.safetensors')
);
const ipAdapter = allModelConfigs.find(
({ type, base, source }) =>
type === 'ip_adapter' &&
base === 'flux' &&
source === 'https://huggingface.co/XLabs-AI/flux-ip-adapter-v2/resolve/main/ip_adapter.safetensors'
);
return {
flux,
t5Encoder,
clipEmbed,
clipVision,
vae,
ipAdapter,
};
};
export const getSD1Models = (modelConfigs: ReturnType<typeof selectModelConfigsQuery>) => {
const allModelConfigs = modelConfigsAdapterSelectors.selectAll(modelConfigs?.data ?? EMPTY_ENTITY_STATE);
/**
* The sources are taken from `invokeai/backend/model_manager/starter_models.py`
*
* Note: The sources for HF Repo models are subtly in the python file vs the actual model config we get from the
* HTTP query response.
*
* In python, a double colon `::` separates the HF Repo ID from the folder path within the repo. But the model
* configs we get from the query use a single colon `:`.
*
* For example:
* Python: InvokeAI/t5-v1_1-xxl::bnb_llm_int8
* Query : InvokeAI/t5-v1_1-xxl:bnb_llm_int8
*
* I'm not sure if it's always been like this, but to be safe, we check for both formats below.
*/
const main = allModelConfigs.find(
({ type, base, source }) =>
type === 'main' &&
base === 'sd-1' &&
source === 'https://huggingface.co/XpucT/Deliberate/resolve/main/Deliberate_v5.safetensors'
);
return {
main,
};
};
export const getSDXLModels = (modelConfigs: ReturnType<typeof selectModelConfigsQuery>) => {
const allModelConfigs = modelConfigsAdapterSelectors.selectAll(modelConfigs?.data ?? EMPTY_ENTITY_STATE);
/**
* The sources are taken from `invokeai/backend/model_manager/starter_models.py`
*
* Note: The sources for HF Repo models are subtly in the python file vs the actual model config we get from the
* HTTP query response.
*
* In python, a double colon `::` separates the HF Repo ID from the folder path within the repo. But the model
* configs we get from the query use a single colon `:`.
*
* For example:
* Python: InvokeAI/t5-v1_1-xxl::bnb_llm_int8
* Query : InvokeAI/t5-v1_1-xxl:bnb_llm_int8
*
* I'm not sure if it's always been like this, but to be safe, we check for both formats below.
*/
const main = allModelConfigs.find(
({ type, base, source }) => type === 'main' && base === 'sdxl' && source === 'RunDiffusion/Juggernaut-XL-v9'
);
const vae = allModelConfigs.find(
({ type, base, source }) => type === 'vae' && base === 'sdxl' && source === 'madebyollin/sdxl-vae-fp16-fix'
);
return {
main,
vae,
};
};
const buildSimpleSD1Graph = (state: RootState) => {
const { positivePrompt, aspectRatio } = selectSimpleGenerationSlice(state);
const { main } = getSD1Models(selectModelConfigsQuery(state));
const g = new Graph(getPrefixedId('simple_sd1'));
const dimensions = ASPECT_RATIO_MAP['sd-1'][aspectRatio];
const modelLoader = g.addNode({
type: 'main_model_loader',
id: getPrefixedId('main_model_loader'),
model: zModelIdentifierField.parse(main),
});
const posCond = g.addNode({
type: 'compel',
id: getPrefixedId('compel_prompt_pos'),
prompt: positivePrompt,
});
const negCond = g.addNode({
type: 'compel',
id: getPrefixedId('compel_prompt_neg'),
prompt: '',
});
const seed = g.addNode({
type: 'rand_int',
id: getPrefixedId('rand_int'),
low: NUMPY_RAND_MIN,
high: NUMPY_RAND_MAX,
use_cache: false,
});
const noise = g.addNode({
type: 'noise',
id: getPrefixedId('noise'),
width: dimensions.width,
height: dimensions.height,
});
const denoise = g.addNode({
type: 'denoise_latents',
id: getPrefixedId('denoise_latents'),
steps: 30,
});
const l2i = g.addNode({
type: 'l2i',
id: getPrefixedId('l2i'),
is_intermediate: false,
board: getBoardField(state),
});
g.addEdge(modelLoader, 'clip', posCond, 'clip');
g.addEdge(modelLoader, 'clip', negCond, 'clip');
g.addEdge(seed, 'value', noise, 'seed');
g.addEdge(noise, 'noise', denoise, 'noise');
g.addEdge(modelLoader, 'unet', denoise, 'unet');
g.addEdge(posCond, 'conditioning', denoise, 'positive_conditioning');
g.addEdge(negCond, 'conditioning', denoise, 'negative_conditioning');
g.addEdge(modelLoader, 'vae', l2i, 'vae');
g.addEdge(denoise, 'latents', l2i, 'latents');
return g;
};
const buildSimpleSDXLGraph = (state: RootState) => {
const { positivePrompt, aspectRatio } = selectSimpleGenerationSlice(state);
const { main, vae } = getSDXLModels(selectModelConfigsQuery(state));
const g = new Graph(getPrefixedId('simple_sdxl'));
const dimensions = ASPECT_RATIO_MAP['sdxl'][aspectRatio];
const modelLoader = g.addNode({
type: 'sdxl_model_loader',
id: getPrefixedId('sdxl_model_loader'),
model: zModelIdentifierField.parse(main),
});
const vaeLoader = g.addNode({
type: 'vae_loader',
id: getPrefixedId('vae_loader'),
vae_model: zModelIdentifierField.parse(vae),
});
const posCond = g.addNode({
type: 'sdxl_compel_prompt',
id: getPrefixedId('sdxl_compel_prompt_pos'),
prompt: positivePrompt,
});
const negCond = g.addNode({
type: 'sdxl_compel_prompt',
id: getPrefixedId('sdxl_compel_prompt_neg'),
prompt: '',
});
const seed = g.addNode({
type: 'rand_int',
id: getPrefixedId('rand_int'),
low: NUMPY_RAND_MIN,
high: NUMPY_RAND_MAX,
use_cache: false,
});
const noise = g.addNode({
type: 'noise',
id: getPrefixedId('noise'),
width: dimensions.width,
height: dimensions.height,
});
const denoise = g.addNode({
type: 'denoise_latents',
id: getPrefixedId('denoise_latents'),
steps: 30,
});
const l2i = g.addNode({
type: 'l2i',
id: getPrefixedId('l2i'),
is_intermediate: false,
board: getBoardField(state),
});
g.addEdge(modelLoader, 'clip', posCond, 'clip');
g.addEdge(modelLoader, 'clip2', posCond, 'clip2');
g.addEdge(modelLoader, 'clip', negCond, 'clip');
g.addEdge(modelLoader, 'clip2', negCond, 'clip2');
g.addEdge(seed, 'value', noise, 'seed');
g.addEdge(noise, 'noise', denoise, 'noise');
g.addEdge(modelLoader, 'unet', denoise, 'unet');
g.addEdge(posCond, 'conditioning', denoise, 'positive_conditioning');
g.addEdge(negCond, 'conditioning', denoise, 'negative_conditioning');
g.addEdge(vaeLoader, 'vae', l2i, 'vae');
g.addEdge(denoise, 'latents', l2i, 'latents');
return g;
};
const buildSimpleFLUXGraph = (state: RootState) => {
const { positivePrompt, aspectRatio } = selectSimpleGenerationSlice(state);
const { flux, t5Encoder, clipEmbed, vae } = getFLUXModels(selectModelConfigsQuery(state));
const g = new Graph(getPrefixedId('simple_flux'));
const dimensions = ASPECT_RATIO_MAP['flux'][aspectRatio];
const modelLoader = g.addNode({
type: 'flux_model_loader',
id: getPrefixedId('flux_model_loader'),
model: zModelIdentifierField.parse(flux),
t5_encoder_model: zModelIdentifierField.parse(t5Encoder),
clip_embed_model: zModelIdentifierField.parse(clipEmbed),
vae_model: zModelIdentifierField.parse(vae),
});
const posCond = g.addNode({
type: 'flux_text_encoder',
id: getPrefixedId('flux_text_encoder'),
prompt: positivePrompt,
});
const seed = g.addNode({
type: 'rand_int',
id: getPrefixedId('rand_int'),
low: NUMPY_RAND_MIN,
high: NUMPY_RAND_MAX,
use_cache: false,
});
const denoise = g.addNode({
type: 'flux_denoise',
id: getPrefixedId('flux_denoise'),
guidance: 4.0,
num_steps: 30,
denoising_start: 0,
denoising_end: 1,
width: dimensions.width,
height: dimensions.height,
});
const l2i = g.addNode({
type: 'flux_vae_decode',
id: getPrefixedId('flux_vae_decode'),
is_intermediate: false,
board: getBoardField(state),
});
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len');
g.addEdge(modelLoader, 'clip', posCond, 'clip');
g.addEdge(modelLoader, 'transformer', denoise, 'transformer');
g.addEdge(modelLoader, 'vae', denoise, 'controlnet_vae');
g.addEdge(seed, 'value', denoise, 'seed');
g.addEdge(posCond, 'conditioning', denoise, 'positive_text_conditioning');
g.addEdge(modelLoader, 'vae', l2i, 'vae');
g.addEdge(denoise, 'latents', l2i, 'latents');
return g;
};
export const buildSimpleGraph = (state: RootState): Graph => {
const { model } = selectSimpleGenerationSlice(state);
if (model === 'flux') {
return buildSimpleFLUXGraph(state);
}
if (model === 'sdxl') {
return buildSimpleSDXLGraph(state);
}
if (model === 'sd-1') {
return buildSimpleSD1Graph(state);
}
};

View File

@@ -106,6 +106,7 @@ export const ModelPicker = typedMemo(
isDisabled,
isInvalid,
className,
defaultEnabledGroups,
}: {
modelConfigs: T[];
selectedModelConfig: T | undefined;
@@ -117,6 +118,7 @@ export const ModelPicker = typedMemo(
isDisabled?: boolean;
isInvalid?: boolean;
className?: string;
defaultEnabledGroups?: string[];
}) => {
const { t } = useTranslation();
const options = useMemo<T[] | Group<T>[]>(() => {
@@ -223,6 +225,7 @@ export const ModelPicker = typedMemo(
noMatchesFallback={t('modelManager.noMatchingModels')}
NextToSearchBar={<NavigateToModelManagerButton />}
getIsOptionDisabled={getIsOptionDisabled}
defaultEnabledGroups={defaultEnabledGroups}
searchable
/>
</PopoverBody>

View File

@@ -36,6 +36,10 @@ export const InvokeButtonTooltip = ({ prepend, children, ...rest }: PropsWithChi
const TooltipContent = memo(({ prepend = false }: { prepend?: boolean }) => {
const activeTab = useAppSelector(selectActiveTab);
if (activeTab === 'simple') {
return <SimpleTabTooltipContent prepend={prepend} />;
}
if (activeTab === 'canvas') {
return <CanvasTabTooltipContent prepend={prepend} />;
}
@@ -52,6 +56,27 @@ const TooltipContent = memo(({ prepend = false }: { prepend?: boolean }) => {
});
TooltipContent.displayName = 'TooltipContent';
const SimpleTabTooltipContent = memo(({ prepend = false }: { prepend?: boolean }) => {
const isReady = useStore($isReadyToEnqueue);
const reasons = useStore($reasonsWhyCannotEnqueue);
return (
<Flex flexDir="column" gap={1}>
<IsReadyText isReady={isReady} prepend={prepend} />
{/* <QueueCountPredictionSimpleOrUpscaleTab /> */}
{reasons.length > 0 && (
<>
<StyledDivider />
<ReasonsList reasons={reasons} />
</>
)}
<StyledDivider />
<AddingToText />
</Flex>
);
});
SimpleTabTooltipContent.displayName = 'SimpleTabTooltipContent';
const CanvasTabTooltipContent = memo(({ prepend = false }: { prepend?: boolean }) => {
const isReady = useStore($isReadyToEnqueue);
const reasons = useStore($reasonsWhyCannotEnqueue);

View File

@@ -0,0 +1,39 @@
import { createAction } from '@reduxjs/toolkit';
import { useAppStore } from 'app/store/nanostores/store';
import { buildSimpleGraph } from 'features/nodes/util/graph/buildSimpleGraph';
import { useCallback } from 'react';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import type { EnqueueBatchArg } from 'services/api/types';
const enqueueRequestedSimple = createAction('app/enqueueRequestedSimple');
export const useEnqueueSimple = () => {
const { getState, dispatch } = useAppStore();
const enqueue = useCallback(
async (prepend: boolean) => {
dispatch(enqueueRequestedSimple());
const state = getState();
const g = buildSimpleGraph(state);
const batchConfig: EnqueueBatchArg = {
batch: {
graph: g.getGraph(),
runs: state.params.iterations,
origin: 'simple',
destination: 'gallery',
},
prepend,
};
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, { ...enqueueMutationFixedCacheKeyOptions, track: false })
);
const enqueueResult = await req.unwrap();
return { batchConfig, enqueueResult };
},
[dispatch, getState]
);
return enqueue;
};

View File

@@ -6,6 +6,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { withResultAsync } from 'common/util/result';
import { parseify } from 'common/util/serialize';
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
import { useEnqueueSimple } from 'features/queue/hooks/useEnqueueSimple';
import { useEnqueueWorkflows } from 'features/queue/hooks/useEnqueueWorkflows';
import { $isReadyToEnqueue } from 'features/queue/store/readiness';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
@@ -21,6 +22,7 @@ export const useInvoke = () => {
const isReady = useStore($isReadyToEnqueue);
const isLocked = useIsWorkflowEditorLocked();
const enqueueWorkflows = useEnqueueWorkflows();
const enqueueSimple = useEnqueueSimple();
const [_, { isLoading }] = useEnqueueBatchMutation(enqueueMutationFixedCacheKeyOptions);
@@ -30,6 +32,15 @@ export const useInvoke = () => {
return;
}
if (tabName === 'simple') {
const result = await withResultAsync(() => enqueueSimple(prepend));
if (result.isErr()) {
log.error({ error: serializeError(result.error) }, 'Failed to enqueue batch');
} else {
log.debug(parseify(result.value), 'Enqueued batch');
}
}
if (tabName === 'workflows') {
const result = await withResultAsync(() => enqueueWorkflows(prepend, isApiValidationRun));
if (result.isErr()) {
@@ -51,7 +62,7 @@ export const useInvoke = () => {
// Else we are not on a generation tab and should not queue
},
[dispatch, enqueueWorkflows, isReady, tabName]
[dispatch, enqueueSimple, enqueueWorkflows, isReady, tabName]
);
const enqueueBack = useCallback(() => {

View File

@@ -30,10 +30,13 @@ import { getInvocationNodeErrors } from 'features/nodes/store/util/fieldValidato
import type { WorkflowSettingsState } from 'features/nodes/store/workflowSettingsSlice';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { isBatchNode, isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation';
import { getFLUXModels, getSD1Models, getSDXLModels } from 'features/nodes/util/graph/buildSimpleGraph';
import { resolveBatchValue } from 'features/nodes/util/node/resolveBatchValue';
import type { UpscaleState } from 'features/parameters/store/upscaleSlice';
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
import { getGridSize } from 'features/parameters/util/optimalDimension';
import { selectSimpleGenerationSlice } from 'features/simpleGeneration/store/slice';
import type { SimpleGenerationState } from 'features/simpleGeneration/store/types';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { selectConfigSlice } from 'features/system/store/configSlice';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
@@ -42,7 +45,7 @@ import i18n from 'i18next';
import { debounce, groupBy, upperFirst } from 'lodash-es';
import { atom, computed } from 'nanostores';
import { useEffect } from 'react';
import { selectMainModelConfig } from 'services/api/endpoints/models';
import { selectMainModelConfig, selectModelConfigsQuery } from 'services/api/endpoints/models';
import type { MainModelConfig } from 'services/api/types';
import { $isConnected } from 'services/events/stores';
@@ -89,9 +92,14 @@ const debouncedUpdateReasons = debounce(
config: AppConfig,
store: AppStore,
isInPublishFlow: boolean,
areChatGPT4oModelsEnabled: boolean
areChatGPT4oModelsEnabled: boolean,
modelConfigsQuery: ReturnType<typeof selectModelConfigsQuery>,
simple: SimpleGenerationState
) => {
if (tab === 'canvas') {
if (tab === 'simple') {
const reasons = getReasonsWhyCannotEnqueueSimpleTab({ isConnected, simple, modelConfigsQuery });
$reasonsWhyCannotEnqueue.set(reasons);
} else if (tab === 'canvas') {
const model = selectMainModelConfig(store.getState());
const reasons = await getReasonsWhyCannotEnqueueCanvasTab({
isConnected,
@@ -143,6 +151,7 @@ export const useReadinessWatcher = () => {
const nodes = useAppSelector(selectNodesSlice);
const workflowSettings = useAppSelector(selectWorkflowSettingsSlice);
const upscale = useAppSelector(selectUpscaleSlice);
const simple = useAppSelector(selectSimpleGenerationSlice);
const config = useAppSelector(selectConfigSlice);
const templates = useStore($templates);
const isConnected = useStore($isConnected);
@@ -152,6 +161,7 @@ export const useReadinessWatcher = () => {
const canvasIsSelectingObject = useStore(canvasManager?.stateApi.$isSegmenting ?? $true);
const canvasIsCompositing = useStore(canvasManager?.compositor.$isBusy ?? $true);
const isInPublishFlow = useStore($isInPublishFlow);
const modelConfigsQuery = useAppSelector(selectModelConfigsQuery);
const areChatGPT4oModelsEnabled = useFeatureStatus('chatGPT4oModels');
useEffect(() => {
@@ -173,7 +183,9 @@ export const useReadinessWatcher = () => {
config,
store,
isInPublishFlow,
areChatGPT4oModelsEnabled
areChatGPT4oModelsEnabled,
modelConfigsQuery,
simple
);
}, [
store,
@@ -194,6 +206,8 @@ export const useReadinessWatcher = () => {
workflowSettings,
isInPublishFlow,
areChatGPT4oModelsEnabled,
modelConfigsQuery,
simple,
]);
};
@@ -330,6 +344,61 @@ const getReasonsWhyCannotEnqueueUpscaleTab = (arg: {
return reasons;
};
const getReasonsWhyCannotEnqueueSimpleTab = (arg: {
isConnected: boolean;
simple: SimpleGenerationState;
modelConfigsQuery: ReturnType<typeof selectModelConfigsQuery>;
}) => {
const { isConnected, simple, modelConfigsQuery } = arg;
const { model } = simple;
const reasons: Reason[] = [];
if (!isConnected) {
reasons.push(disconnectedReason(i18n.t));
}
if (model === 'flux') {
const models = getFLUXModels(modelConfigsQuery);
if (!models.flux) {
reasons.push({ content: 'FLUX is not installed' });
}
if (!models.t5Encoder) {
reasons.push({ content: 'T5 Encoder is not installed' });
}
if (!models.clipEmbed) {
reasons.push({ content: 'CLIP Embed is not installed' });
}
if (!models.clipVision) {
reasons.push({ content: 'CLIP Vision is not installed' });
}
if (!models.ipAdapter) {
reasons.push({ content: 'IP Adapter is not installed' });
}
if (!models.vae) {
reasons.push({ content: 'VAE is not installed' });
}
}
if (model === 'sdxl') {
const models = getSDXLModels(modelConfigsQuery);
if (!models.main) {
reasons.push({ content: 'Main SDXL model is not installed' });
}
if (!models.vae) {
reasons.push({ content: 'VAE is not installed' });
}
}
if (model === 'sd-1') {
const models = getSD1Models(modelConfigsQuery);
if (!models.main) {
reasons.push({ content: 'Main SDXL model is not installed' });
}
}
return reasons;
};
const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
isConnected: boolean;
model: MainModelConfig | null | undefined;

View File

@@ -12,7 +12,6 @@ import {
selectIsSD3,
} from 'features/controlLayers/store/paramsSlice';
import { LoRAList } from 'features/lora/components/LoRAList';
import LoRASelect from 'features/lora/components/LoRASelect';
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
import ParamGuidance from 'features/parameters/components/Core/ParamGuidance';
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
@@ -20,6 +19,7 @@ import ParamSteps from 'features/parameters/components/Core/ParamSteps';
import { DisabledModelWarning } from 'features/parameters/components/MainModel/DisabledModelWarning';
import ParamUpscaleCFGScale from 'features/parameters/components/Upscale/ParamUpscaleCFGScale';
import ParamUpscaleScheduler from 'features/parameters/components/Upscale/ParamUpscaleScheduler';
import { LoRAModelPicker } from 'features/settingsAccordions/components/GenerationSettingsAccordion/LoRAModelPicker';
import { MainModelPicker } from 'features/settingsAccordions/components/GenerationSettingsAccordion/MainModelPicker';
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
@@ -86,7 +86,7 @@ export const GenerationSettingsAccordion = memo(() => {
<Flex gap={4} flexDir="column" pb={isApiModel ? 4 : 0}>
<DisabledModelWarning />
<MainModelPicker />
{!isApiModel && <LoRASelect />}
{!isApiModel && <LoRAModelPicker />}
{!isApiModel && <LoRAList />}
</Flex>
{!isApiModel && (

View File

@@ -0,0 +1,67 @@
import { Flex, FormLabel } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { UseDefaultSettingsButton } from 'features/parameters/components/MainModel/UseDefaultSettingsButton';
import { ModelPicker } from 'features/parameters/components/ModelPicker';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useLoRAModels } from 'services/api/hooks/modelsByType';
import type { LoRAModelConfig } from 'services/api/types';
const selectLoRAs = createSelector(selectLoRAsSlice, (loras) => loras.loras);
export const LoRAModelPicker = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const currentBaseModel = useAppSelector(selectBase);
const addedLoRAs = useAppSelector(selectLoRAs);
const getIsOptionDisabled = useCallback(
(model: LoRAModelConfig): boolean => {
const isCompatible = currentBaseModel === model.base;
const isAdded = Boolean(addedLoRAs.find((lora) => lora.model.key === model.key));
return !isCompatible || isAdded;
},
[addedLoRAs, currentBaseModel]
);
const loraFilter = useCallback(
(loraConfig: LoRAModelConfig) => {
if (!currentBaseModel) {
return true;
}
return currentBaseModel === loraConfig.base;
},
[currentBaseModel]
);
const [modelConfigs] = useLoRAModels();
const onChange = useCallback(
(loraConfig: LoRAModelConfig) => {
dispatch(loraAdded({ model: loraConfig }));
},
[dispatch]
);
return (
<Flex alignItems="center" gap={2}>
<InformationalPopover feature="lora">
<FormLabel>{t('models.concepts')} </FormLabel>
</InformationalPopover>
<ModelPicker
modelConfigs={modelConfigs}
selectedModelConfig={undefined}
onChange={onChange}
placeholder={t('models.addLora')}
getIsOptionDisabled={getIsOptionDisabled}
defaultEnabledGroups={currentBaseModel ? [currentBaseModel] : undefined}
allowEmpty
grouped
/>
<UseDefaultSettingsButton />
</Flex>
);
});
LoRAModelPicker.displayName = 'LoRAModelPicker';

View File

@@ -0,0 +1,84 @@
import { Flex, FormLabel, Icon, Select } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { modelChanged, selectModel } from 'features/simpleGeneration/store/slice';
import { isModel } from 'features/simpleGeneration/store/types';
import type { ChangeEventHandler } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { MdMoneyOff } from 'react-icons/md';
export const SimpleTabModel = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const model = useAppSelector(selectModel);
const onChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
(e) => {
if (!isModel(e.target.value)) {
return;
}
dispatch(modelChanged({ model: e.target.value }));
},
[dispatch]
);
return (
<Flex alignItems="center" gap={2}>
<InformationalPopover feature="paramModel">
<FormLabel m={0}>{t('modelManager.model')}</FormLabel>
</InformationalPopover>
{model === 'flux' && (
<InformationalPopover feature="fluxDevLicense" hideDisable={true}>
<Flex justifyContent="flex-start">
<Icon as={MdMoneyOff} />
</Flex>
</InformationalPopover>
)}
<Select value={model} onChange={onChange}>
<option value="chatgpt-4o">ChatGPT 4o</option>
<option value="flux">FLUX</option>
<option value="sdxl">SDXL</option>
<option value="sd-1">SD 1.5</option>
</Select>
</Flex>
);
});
SimpleTabModel.displayName = 'SimpleTabModel';
// export const SimpleTabModel = memo(() => {
// const { t } = useTranslation();
// const dispatch = useAppDispatch();
// const [modelConfigs] = useSimpleTabModels();
// const selectedModelConfig = useSimpleTabModelConfig();
// const onChange = useCallback(
// (modelConfig: AnyModelConfig) => {
// dispatch(modelChanged({ model: zModelIdentifierField.parse(modelConfig) }));
// },
// [dispatch]
// );
// const isFluxDevSelected = useMemo(
// () =>
// selectedModelConfig &&
// isCheckpointMainModelConfig(selectedModelConfig) &&
// selectedModelConfig.config_path === 'flux-dev',
// [selectedModelConfig]
// );
// return (
// <Flex alignItems="center" gap={2}>
// <InformationalPopover feature="paramModel">
// <FormLabel>{t('modelManager.model')}</FormLabel>
// </InformationalPopover>
// {isFluxDevSelected && (
// <InformationalPopover feature="fluxDevLicense" hideDisable={true}>
// <Flex justifyContent="flex-start">
// <Icon as={MdMoneyOff} />
// </Flex>
// </InformationalPopover>
// )}
// <ModelPicker modelConfigs={modelConfigs} selectedModelConfig={selectedModelConfig} onChange={onChange} grouped />
// </Flex>
// );
// });
// SimpleTabModel.displayName = 'SimpleTabModel';

View File

@@ -0,0 +1,52 @@
import { FormControl, FormLabel, Select } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { zChatGPT4oAspectRatioID } from 'features/controlLayers/store/types';
import { aspectRatioChanged, selectAspectRatio, selectModel } from 'features/simpleGeneration/store/slice';
import { isAspectRatio, zAspectRatio } from 'features/simpleGeneration/store/types';
import type { ChangeEventHandler } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
export const SimpleTabAspectRatio = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const id = useAppSelector(selectAspectRatio);
const model = useAppSelector(selectModel);
const options = useMemo(() => {
// ChatGPT4o has different aspect ratio options
if (model === 'chatgpt-4o') {
return zChatGPT4oAspectRatioID.options;
}
// All other models
return zAspectRatio.options;
}, [model]);
const onChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
(e) => {
if (!isAspectRatio(e.target.value)) {
return;
}
dispatch(aspectRatioChanged({ aspectRatio: e.target.value }));
},
[dispatch]
);
return (
<FormControl>
<InformationalPopover feature="paramAspect">
<FormLabel m={0}>{t('parameters.aspect')}</FormLabel>
</InformationalPopover>
<Select value={id} onChange={onChange}>
{options.map((ratio) => (
<option key={ratio} value={ratio}>
{ratio}
</option>
))}
</Select>
</FormControl>
);
});
SimpleTabAspectRatio.displayName = 'SimpleTabAspectRatio';

View File

@@ -0,0 +1,17 @@
import { Flex } from '@invoke-ai/ui-library';
import { SimpleTabAspectRatio } from 'features/simpleGeneration/components/SimpleTabAspectRatio';
import { SimpleTabModel } from 'features/simpleGeneration/components/SImpleTabModel';
import { SimpleTabPositivePrompt } from 'features/simpleGeneration/components/SimpleTabPositivePrompt';
import { memo } from 'react';
export const SimpleTabLeftPanel = memo(() => {
return (
<Flex w="full" h="full" flexDir="column" gap={2}>
<SimpleTabPositivePrompt />
<SimpleTabModel />
<SimpleTabAspectRatio />
</Flex>
);
});
SimpleTabLeftPanel.displayName = 'SimpleTabLeftPanel';

View File

@@ -0,0 +1,50 @@
import { Box, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { PromptLabel } from 'features/parameters/components/Prompts/PromptLabel';
import { positivePromptChanged, selectPositivePrompt } from 'features/simpleGeneration/store/slice';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import type { ChangeEventHandler } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const SimpleTabPositivePrompt = memo(() => {
const dispatch = useAppDispatch();
const prompt = useAppSelector(selectPositivePrompt);
const { t } = useTranslation();
const onChange = useCallback<ChangeEventHandler<HTMLTextAreaElement>>(
(e) => {
dispatch(positivePromptChanged({ positivePrompt: e.target.value }));
},
[dispatch]
);
useRegisteredHotkeys({
id: 'focusPrompt',
category: 'app',
callback: focus,
options: { preventDefault: true, enableOnFormTags: ['INPUT', 'SELECT', 'TEXTAREA'] },
dependencies: [focus],
});
return (
<Box pos="relative">
<Textarea
id="prompt"
name="prompt"
value={prompt}
onChange={onChange}
minH={40}
variant="darkFilled"
borderTopWidth={24} // This prevents the prompt from being hidden behind the header
paddingInlineEnd={10}
paddingInlineStart={3}
paddingTop={0}
paddingBottom={3}
/>
<PromptLabel label={t('parameters.positivePromptPlaceholder')} />
</Box>
);
});
SimpleTabPositivePrompt.displayName = 'SimpleTabPositivePrompt';

View File

@@ -0,0 +1,11 @@
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { selectModelKey } from 'features/simpleGeneration/store/slice';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
export const useSimpleTabModelConfig = () => {
const key = useAppSelector(selectModelKey);
const { data: modelConfig } = useGetModelConfigQuery(key ?? skipToken);
return modelConfig;
};

View File

@@ -0,0 +1,105 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { SimpleGenerationState } from 'features/simpleGeneration/store/types';
import { zSimpleGenerationState } from 'features/simpleGeneration/store/types';
const getInitialState = (): SimpleGenerationState => zSimpleGenerationState.parse({});
export const simpleGenerationSlice = createSlice({
name: 'simpleGeneration',
initialState: getInitialState(),
reducers: {
positivePromptChanged: (
state,
action: PayloadAction<{
positivePrompt: SimpleGenerationState['positivePrompt'];
}>
) => {
const { positivePrompt } = action.payload;
state.positivePrompt = positivePrompt;
},
modelChanged: (
state,
action: PayloadAction<{
model: SimpleGenerationState['model'];
}>
) => {
const { model } = action.payload;
state.model = model;
},
aspectRatioChanged: (
state,
action: PayloadAction<{
aspectRatio: SimpleGenerationState['aspectRatio'];
}>
) => {
const { aspectRatio } = action.payload;
state.aspectRatio = aspectRatio;
},
startingImageChanged: (
state,
action: PayloadAction<{
startingImage: SimpleGenerationState['startingImage'];
}>
) => {
const { startingImage } = action.payload;
state.startingImage = startingImage;
},
referenceImageChanged: (
state,
action: PayloadAction<{
index: number;
referenceImage: SimpleGenerationState['referenceImages'][number];
}>
) => {
const { index, referenceImage } = action.payload;
state.referenceImages[index] = referenceImage;
},
controlImageChanged: (
state,
action: PayloadAction<{
controlImage: SimpleGenerationState['controlImage'];
}>
) => {
const { controlImage } = action.payload;
state.controlImage = controlImage;
},
reset: () => getInitialState(),
},
});
export const {
aspectRatioChanged,
controlImageChanged,
modelChanged,
positivePromptChanged,
referenceImageChanged,
startingImageChanged,
reset,
} = simpleGenerationSlice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateSimpleGenerationState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const simpleGenerationPersistConfig: PersistConfig<SimpleGenerationState> = {
name: simpleGenerationSlice.name,
initialState: getInitialState(),
migrate: migrateSimpleGenerationState,
persistDenylist: [],
};
export const selectSimpleGenerationSlice = (state: RootState) => state.simpleGeneration;
const createSliceSelector = <T>(selector: Selector<SimpleGenerationState, T>) =>
createSelector(selectSimpleGenerationSlice, selector);
export const selectPositivePrompt = createSliceSelector((slice) => slice.positivePrompt);
export const selectModel = createSliceSelector((slice) => slice.model);
// export const selectModelBase = createSliceSelector((slice) => slice.model?.base);
// export const selectModelKey = createSliceSelector((slice) => slice.model?.key);
export const selectAspectRatio = createSliceSelector((slice) => slice.aspectRatio);

View File

@@ -0,0 +1,54 @@
import { zImageWithDims } from 'features/controlLayers/store/types';
import { buildTypeGuard } from 'features/parameters/types/parameterSchemas';
import { nanoid } from 'nanoid';
import { z } from 'zod';
const zLowMedHigh = z.enum(['low', 'med', 'high']);
const zControlType = z.enum(['line', 'depth']);
export const zAspectRatio = z.enum(['16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
export type AspectRatio = z.infer<typeof zAspectRatio>;
export const isAspectRatio = (val: unknown): val is AspectRatio => zAspectRatio.safeParse(val).success;
const STARTING_IMAGE_TYPE = 'starting_image';
const zStartingImage = z.object({
type: z.literal(STARTING_IMAGE_TYPE).default(STARTING_IMAGE_TYPE),
id: z.string().default(nanoid),
image: zImageWithDims.nullable().default(null),
variation: zLowMedHigh.default('med'),
});
export type StartingImage = z.infer<typeof zStartingImage>;
export const getStartingImage = (overrides: Partial<Omit<StartingImage, 'type'>>) => zStartingImage.parse(overrides);
const REFERENCE_IMAGE_TYPE = 'reference_image';
const zReferenceImage = z.object({
type: z.literal(REFERENCE_IMAGE_TYPE).default(REFERENCE_IMAGE_TYPE),
id: z.string().default(nanoid),
image: zImageWithDims.nullable().default(null),
});
export type ReferenceImage = z.infer<typeof zReferenceImage>;
export const getReferenceImage = (overrides: Partial<Omit<ReferenceImage, 'type'>>) => zReferenceImage.parse(overrides);
const CONTROL_IMAGE_TYPE = 'control_image';
const zControlImage = z.object({
type: z.literal(CONTROL_IMAGE_TYPE).default(CONTROL_IMAGE_TYPE),
id: z.string().default(nanoid),
control_type: zControlType.default('line'),
image: zImageWithDims.nullable().default(null),
});
export type ControlImage = z.infer<typeof zControlImage>;
export const getControlImage = (overrides: Partial<Omit<ControlImage, 'type'>>) => zControlImage.parse(overrides);
const zModel = z.enum(['chatgpt-4o', 'flux', 'sdxl', 'sd-1']);
export const isModel = buildTypeGuard(zModel);
export const zSimpleGenerationState = z.object({
_version: z.literal(1).default(1),
positivePrompt: z.string().default(''),
model: zModel.default('flux'),
aspectRatio: zAspectRatio.default('1:1'),
startingImage: zStartingImage.nullable().default(null),
referenceImages: z.array(zReferenceImage).default(() => []),
controlImage: zControlImage.nullable().default(null),
});
export type SimpleGenerationState = z.infer<typeof zSimpleGenerationState>;

View File

@@ -0,0 +1,48 @@
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import type { ChatGPT4oAspectRatioID, Dimensions } from 'features/controlLayers/store/types';
import type { AspectRatio } from 'features/simpleGeneration/store/types';
export const getDimensions = (ratio: number, area: number): Dimensions => {
const exactWidth = Math.sqrt(area * ratio);
const exactHeight = exactWidth / ratio;
return {
width: roundToMultiple(exactWidth, 64),
height: roundToMultiple(exactHeight, 64),
};
};
const FLUX_SDXL_AREA = 1024 * 1024;
export const FLUX_SDXL_ASPECT_RATIO_MAP: Record<AspectRatio, Dimensions> = {
'16:9': getDimensions(16 / 9, FLUX_SDXL_AREA),
'3:2': getDimensions(3 / 2, FLUX_SDXL_AREA),
'4:3': getDimensions(4 / 3, FLUX_SDXL_AREA),
'1:1': getDimensions(1, FLUX_SDXL_AREA),
'3:4': getDimensions(3 / 4, FLUX_SDXL_AREA),
'2:3': getDimensions(2 / 3, FLUX_SDXL_AREA),
'9:16': getDimensions(9 / 16, FLUX_SDXL_AREA),
};
const SD_1_AREA = 768 * 768;
export const SD_1_ASPECT_RATIO_MAP: Record<AspectRatio, Dimensions> = {
'16:9': getDimensions(16 / 9, SD_1_AREA),
'3:2': getDimensions(3 / 2, SD_1_AREA),
'4:3': getDimensions(4 / 3, SD_1_AREA),
'1:1': getDimensions(1, SD_1_AREA),
'3:4': getDimensions(3 / 4, SD_1_AREA),
'2:3': getDimensions(2 / 3, SD_1_AREA),
'9:16': getDimensions(9 / 16, SD_1_AREA),
};
export const CHATGPT_4O_ASPECT_RATIO_MAP: Record<ChatGPT4oAspectRatioID, Dimensions> = {
'1:1': { width: 1024, height: 1024 },
'2:3': { width: 1024, height: 1536 },
'3:2': { width: 1536, height: 1024 },
};
export const ASPECT_RATIO_MAP = {
flux: FLUX_SDXL_ASPECT_RATIO_MAP,
sdxl: FLUX_SDXL_ASPECT_RATIO_MAP,
'chatgpt-4o': CHATGPT_4O_ASPECT_RATIO_MAP,
'sd-1': SD_1_ASPECT_RATIO_MAP,
} as const;

View File

@@ -7,6 +7,7 @@ import GalleryPanelContent from 'features/gallery/components/GalleryPanelContent
import { ImageViewer } from 'features/gallery/components/ImageViewer/ImageViewer';
import WorkflowsTabLeftPanel from 'features/nodes/components/sidePanel/WorkflowsTabLeftPanel';
import QueueControls from 'features/queue/components/QueueControls';
import { SimpleTabLeftPanel } from 'features/simpleGeneration/components/SimpleTabLeftPanel';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton';
import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons';
@@ -165,7 +166,7 @@ const RightPanelContent = memo(() => {
if (tab === 'canvas') {
return <CanvasRightPanel />;
}
if (tab === 'upscaling' || tab === 'workflows') {
if (tab === 'simple' || tab === 'upscaling' || tab === 'workflows') {
return <GalleryPanelContent />;
}
@@ -176,6 +177,9 @@ RightPanelContent.displayName = 'RightPanelContent';
const LeftPanelContent = memo(() => {
const tab = useAppSelector(selectActiveTab);
if (tab === 'simple') {
return <SimpleTabLeftPanel />;
}
if (tab === 'canvas') {
return <ParametersPanelTextToImage />;
}
@@ -199,6 +203,9 @@ const MainPanelContent = memo(() => {
if (tab === 'upscaling') {
return <ImageViewer />;
}
if (tab === 'simple') {
return <ImageViewer />;
}
if (tab === 'workflows') {
return <WorkflowsMainPanel />;
}

View File

@@ -8,7 +8,7 @@ import { VideosModalButton } from 'features/system/components/VideosModal/Videos
import { TabMountGate } from 'features/ui/components/TabMountGate';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold, PiCubeBold, PiFlowArrowBold, PiFrameCornersBold, PiQueueBold } from 'react-icons/pi';
import { PiBoundingBoxBold, PiBowlSteam, PiCubeBold, PiFlowArrowBold, PiFrameCornersBold, PiQueueBold } from 'react-icons/pi';
import { Notifications } from './Notifications';
import { TabButton } from './TabButton';
@@ -21,6 +21,9 @@ export const VerticalNavBar = memo(() => {
<Flex flexDir="column" alignItems="center" py={2} gap={4} minW={0}>
<InvokeAILogoComponent />
<Flex gap={4} pt={6} h="full" flexDir="column">
<TabMountGate tab="simple">
<TabButton tab="simple" icon={<PiBowlSteam />} label={t('ui.tabs.simple')} />
</TabMountGate>
<TabMountGate tab="canvas">
<TabButton tab="canvas" icon={<PiBoundingBoxBold />} label={t('ui.tabs.canvas')} />
</TabMountGate>

View File

@@ -8,7 +8,7 @@ import { atom } from 'nanostores';
import type { CanvasRightPanelTabName, TabName, UIState } from './uiTypes';
const initialUIState: UIState = {
_version: 3,
_version: 4,
activeTab: 'canvas',
activeTabCanvasRightPanel: 'gallery',
shouldShowImageDetails: false,
@@ -81,6 +81,10 @@ const migrateUIState = (state: any): any => {
state.activeTab = 'canvas';
state._version = 3;
}
if (state._version === 3) {
state.activeTab = 'simple';
state._version = 4;
}
return state;
};
@@ -91,12 +95,12 @@ export const uiPersistConfig: PersistConfig<UIState> = {
persistDenylist: ['shouldShowImageDetails'],
};
const TABS_WITH_LEFT_PANEL: TabName[] = ['canvas', 'upscaling', 'workflows'] as const;
const TABS_WITH_LEFT_PANEL: TabName[] = ['simple', 'canvas', 'upscaling', 'workflows'] as const;
export const LEFT_PANEL_MIN_SIZE_PX = 400;
export const $isLeftPanelOpen = atom(true);
export const selectWithLeftPanel = createSelector(selectUiSlice, (ui) => TABS_WITH_LEFT_PANEL.includes(ui.activeTab));
const TABS_WITH_RIGHT_PANEL: TabName[] = ['canvas', 'upscaling', 'workflows'] as const;
const TABS_WITH_RIGHT_PANEL: TabName[] = ['simple', 'canvas', 'upscaling', 'workflows'] as const;
export const RIGHT_PANEL_MIN_SIZE_PX = 390;
export const $isRightPanelOpen = atom(true);
export const selectWithRightPanel = createSelector(selectUiSlice, (ui) => TABS_WITH_RIGHT_PANEL.includes(ui.activeTab));

View File

@@ -1,11 +1,11 @@
export type TabName = 'canvas' | 'upscaling' | 'workflows' | 'models' | 'queue';
export type TabName = 'canvas' | 'upscaling' | 'workflows' | 'models' | 'queue' | 'simple';
export type CanvasRightPanelTabName = 'layers' | 'gallery';
export interface UIState {
/**
* Slice schema version.
*/
_version: 3;
_version: 4;
/**
* The currently active tab.
*/

View File

@@ -29,6 +29,7 @@ import {
isSD3MainModelModelConfig,
isSDXLMainModelModelConfig,
isSigLipModelConfig,
isSimpleTabModelConfig,
isSpandrelImageToImageModelConfig,
isT2IAdapterModelConfig,
isT5EncoderModelConfig,
@@ -59,6 +60,7 @@ const buildModelsHook =
return [modelConfigs, result] as const;
};
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
export const useSimpleTabModels = buildModelsHook(isSimpleTabModelConfig);
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);

View File

@@ -276,6 +276,15 @@ export const isTIModelConfig = (config: AnyModelConfig): config is MainModelConf
return config.type === 'embedding';
};
export const isSimpleTabModelConfig = (
config: AnyModelConfig
): config is Extract<MainModelConfig, { base: 'chatgpt-4o' | 'flux' | 'sdxl' | 'sd-1' }> => {
return (
config.type === 'main' &&
(config.base === 'chatgpt-4o' || config.base === 'flux' || config.base === 'sdxl' || config.base === 'sd-1')
);
};
export type ModelInstallJob = S['ModelInstallJob'];
export type ModelInstallStatus = S['InstallStatus'];