mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 06:08:02 -05:00
Compare commits
2 Commits
main
...
psyche/fea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ac012721b0 | ||
|
|
9706df02d4 |
@@ -23,6 +23,10 @@ import { workflowLibraryPersistConfig, workflowLibrarySlice } from 'features/nod
|
||||
import { workflowSettingsPersistConfig, workflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { upscalePersistConfig, upscaleSlice } from 'features/parameters/store/upscaleSlice';
|
||||
import { queueSlice } from 'features/queue/store/queueSlice';
|
||||
import {
|
||||
simpleGenerationPersistConfig,
|
||||
simpleGenerationSlice,
|
||||
} from 'features/simpleGeneration/store/slice';
|
||||
import { stylePresetPersistConfig, stylePresetSlice } from 'features/stylePresets/store/stylePresetSlice';
|
||||
import { configSlice } from 'features/system/store/configSlice';
|
||||
import { systemPersistConfig, systemSlice } from 'features/system/store/systemSlice';
|
||||
@@ -68,6 +72,7 @@ const allReducers = {
|
||||
[canvasStagingAreaSlice.name]: canvasStagingAreaSlice.reducer,
|
||||
[lorasSlice.name]: lorasSlice.reducer,
|
||||
[workflowLibrarySlice.name]: workflowLibrarySlice.reducer,
|
||||
[simpleGenerationSlice.name]: simpleGenerationSlice.reducer,
|
||||
};
|
||||
|
||||
const rootReducer = combineReducers(allReducers);
|
||||
@@ -113,6 +118,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
|
||||
[canvasStagingAreaPersistConfig.name]: canvasStagingAreaPersistConfig,
|
||||
[lorasPersistConfig.name]: lorasPersistConfig,
|
||||
[workflowLibraryPersistConfig.name]: workflowLibraryPersistConfig,
|
||||
[simpleGenerationSlice.name]: simpleGenerationPersistConfig,
|
||||
};
|
||||
|
||||
const unserialize: UnserializeFunction = (data, key) => {
|
||||
|
||||
@@ -14,10 +14,11 @@ import { useStore } from '@nanostores/react';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import type { AnyStore, ReadableAtom, Task, WritableAtom } from 'nanostores';
|
||||
import { atom, computed } from 'nanostores';
|
||||
import type { StoreValues } from 'nanostores/computed';
|
||||
import type { ChangeEvent, MouseEventHandler, PropsWithChildren, RefObject } from 'react';
|
||||
import type { ChangeEvent, PropsWithChildren, RefObject } from 'react';
|
||||
import React, {
|
||||
createContext,
|
||||
useCallback,
|
||||
@@ -198,11 +199,17 @@ type PickerProps<T extends object> = {
|
||||
* Whether the picker should be searchable. If true, renders a search input.
|
||||
*/
|
||||
searchable?: boolean;
|
||||
/**
|
||||
* The default groups to enable. If omitted, all groups are disabled by default, meaning all groups and options
|
||||
* are visible by default.
|
||||
*/
|
||||
defaultEnabledGroups?: string[];
|
||||
};
|
||||
|
||||
export type PickerContextState<T extends object> = {
|
||||
$optionsOrGroups: WritableAtom<OptionOrGroup<T>[]>;
|
||||
$groupStatusMap: WritableAtom<GroupStatusMap>;
|
||||
defaultStatusMap: GroupStatusMap;
|
||||
$compactView: WritableAtom<boolean>;
|
||||
$activeOptionId: WritableAtom<string | undefined>;
|
||||
$filteredOptions: WritableAtom<OptionOrGroup<T>[]>;
|
||||
@@ -217,6 +224,7 @@ export type PickerContextState<T extends object> = {
|
||||
$searchTerm: WritableAtom<string>;
|
||||
searchPlaceholder?: string;
|
||||
toggleGroup: (id: string) => void;
|
||||
resetGroups: () => void;
|
||||
getOptionId: (option: T) => string;
|
||||
isMatch: (option: T, searchTerm: string) => boolean;
|
||||
getIsOptionDisabled?: (option: T) => boolean;
|
||||
@@ -229,6 +237,7 @@ export type PickerContextState<T extends object> = {
|
||||
OptionComponent: React.ComponentType<{ option: T } & BoxProps>;
|
||||
NextToSearchBar?: React.ReactNode;
|
||||
searchable?: boolean;
|
||||
defaultEnabledGroups?: string[];
|
||||
};
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
@@ -312,7 +321,7 @@ const flattenOptions = <T extends object>(options: OptionOrGroup<T>[]): T[] => {
|
||||
|
||||
type GroupStatusMap = Record<string, boolean>;
|
||||
|
||||
const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[]) => {
|
||||
const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[], defaultEnabledGroups?: string[]) => {
|
||||
const groupsWithOptions = useMemo(() => {
|
||||
const ids: string[] = [];
|
||||
for (const optionOrGroup of options) {
|
||||
@@ -323,23 +332,41 @@ const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[]) => {
|
||||
return ids;
|
||||
}, [options]);
|
||||
|
||||
const [$groupStatusMap] = useState(atom<GroupStatusMap>({}));
|
||||
const defaultStatusMap = useMemo(() => {
|
||||
const map: GroupStatusMap = {};
|
||||
if (!defaultEnabledGroups || defaultEnabledGroups.length === 0) {
|
||||
for (const id of groupsWithOptions) {
|
||||
map[id] = false;
|
||||
}
|
||||
} else {
|
||||
for (const id of groupsWithOptions) {
|
||||
map[id] = defaultEnabledGroups.includes(id);
|
||||
}
|
||||
}
|
||||
return map;
|
||||
}, [defaultEnabledGroups, groupsWithOptions]);
|
||||
|
||||
const [$groupStatusMap] = useState(atom<GroupStatusMap>(defaultStatusMap));
|
||||
const [$areAllGroupsDisabled] = useState(() =>
|
||||
computed($groupStatusMap, (groupStatusMap) => Object.values(groupStatusMap).every((status) => status === false))
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const groupStatusMap = $groupStatusMap.get();
|
||||
const newMap: GroupStatusMap = {};
|
||||
for (const id of groupsWithOptions) {
|
||||
if (newMap[id] === undefined) {
|
||||
newMap[id] = false;
|
||||
} else if (groupStatusMap[id] !== undefined) {
|
||||
newMap[id] = groupStatusMap[id];
|
||||
}
|
||||
}
|
||||
$groupStatusMap.set(newMap);
|
||||
}, [groupsWithOptions, $groupStatusMap]);
|
||||
$groupStatusMap.set(defaultStatusMap);
|
||||
}, [$groupStatusMap, defaultStatusMap]);
|
||||
|
||||
// useEffect(() => {
|
||||
// const groupStatusMap = $groupStatusMap.get();
|
||||
// const newMap: GroupStatusMap = {};
|
||||
// for (const id of groupsWithOptions) {
|
||||
// if (newMap[id] === undefined) {
|
||||
// newMap[id] = false;
|
||||
// } else if (groupStatusMap[id] !== undefined) {
|
||||
// newMap[id] = groupStatusMap[id];
|
||||
// }
|
||||
// }
|
||||
// $groupStatusMap.set(newMap);
|
||||
// }, [groupsWithOptions, $groupStatusMap]);
|
||||
|
||||
const toggleGroup = useCallback(
|
||||
(idToToggle: string) => {
|
||||
@@ -354,7 +381,11 @@ const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[]) => {
|
||||
[$groupStatusMap, groupsWithOptions]
|
||||
);
|
||||
|
||||
return { $groupStatusMap, $areAllGroupsDisabled, toggleGroup } as const;
|
||||
const resetGroups = useCallback(() => {
|
||||
$groupStatusMap.set(defaultStatusMap);
|
||||
}, [$groupStatusMap, defaultStatusMap]);
|
||||
|
||||
return { $groupStatusMap, $areAllGroupsDisabled, toggleGroup, resetGroups, defaultStatusMap } as const;
|
||||
};
|
||||
|
||||
const useKeyboardNavigation = <T extends object>() => {
|
||||
@@ -511,10 +542,14 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
OptionComponent = DefaultOptionComponent,
|
||||
NextToSearchBar,
|
||||
searchable,
|
||||
defaultEnabledGroups,
|
||||
} = props;
|
||||
const rootRef = useRef<HTMLDivElement>(null);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const { $groupStatusMap, $areAllGroupsDisabled, toggleGroup } = useTogglableGroups(optionsOrGroups);
|
||||
const { $groupStatusMap, $areAllGroupsDisabled, toggleGroup, resetGroups, defaultStatusMap } = useTogglableGroups(
|
||||
optionsOrGroups,
|
||||
defaultEnabledGroups
|
||||
);
|
||||
const $activeOptionId = useAtom(getFirstOptionId(optionsOrGroups, getOptionId));
|
||||
const $compactView = useAtom(true);
|
||||
const $optionsOrGroups = useAtom(optionsOrGroups);
|
||||
@@ -576,11 +611,14 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
NextToSearchBar,
|
||||
onClose,
|
||||
searchable,
|
||||
defaultEnabledGroups,
|
||||
$areAllGroupsDisabled,
|
||||
$selectedItemId,
|
||||
$hasOptions,
|
||||
$hasFilteredOptions,
|
||||
$filteredOptionsCount,
|
||||
resetGroups,
|
||||
defaultStatusMap,
|
||||
}) satisfies PickerContextState<T>,
|
||||
[
|
||||
$optionsOrGroups,
|
||||
@@ -604,11 +642,14 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
NextToSearchBar,
|
||||
onClose,
|
||||
searchable,
|
||||
defaultEnabledGroups,
|
||||
$areAllGroupsDisabled,
|
||||
$selectedItemId,
|
||||
$hasOptions,
|
||||
$hasFilteredOptions,
|
||||
$filteredOptionsCount,
|
||||
resetGroups,
|
||||
defaultStatusMap,
|
||||
]
|
||||
);
|
||||
|
||||
@@ -807,8 +848,10 @@ const SearchInput = typedMemo(<T extends object>() => {
|
||||
});
|
||||
SearchInput.displayName = 'SearchInput';
|
||||
const GroupToggleButtons = typedMemo(<T extends object>() => {
|
||||
const { $optionsOrGroups, $groupStatusMap, $areAllGroupsDisabled } = usePickerContext<T>();
|
||||
const { $optionsOrGroups, $groupStatusMap, defaultStatusMap, resetGroups } = usePickerContext<T>();
|
||||
const { t } = useTranslation();
|
||||
const groupStatusMap = useStore($groupStatusMap);
|
||||
const isResetDisabled = useMemo(() => isEqual(defaultStatusMap, groupStatusMap), [defaultStatusMap, groupStatusMap]);
|
||||
const $groups = useComputed([$optionsOrGroups], (optionsOrGroups) => {
|
||||
const _groups: Group<T>[] = [];
|
||||
for (const optionOrGroup of optionsOrGroups) {
|
||||
@@ -819,15 +862,6 @@ const GroupToggleButtons = typedMemo(<T extends object>() => {
|
||||
return _groups;
|
||||
});
|
||||
const groups = useStore($groups);
|
||||
const areAllGroupsDisabled = useStore($areAllGroupsDisabled);
|
||||
|
||||
const onClick = useCallback<MouseEventHandler>(() => {
|
||||
const newMap: GroupStatusMap = {};
|
||||
for (const { id } of groups) {
|
||||
newMap[id] = false;
|
||||
}
|
||||
$groupStatusMap.set(newMap);
|
||||
}, [$groupStatusMap, groups]);
|
||||
|
||||
if (!groups.length) {
|
||||
return null;
|
||||
@@ -846,11 +880,11 @@ const GroupToggleButtons = typedMemo(<T extends object>() => {
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={onClick}
|
||||
onClick={resetGroups}
|
||||
// When a focused element is disabled, it blurs. This closes the popover. Fake the disabled state to prevent this.
|
||||
// See: https://github.com/chakra-ui/chakra-ui/issues/7965
|
||||
opacity={areAllGroupsDisabled ? 0.5 : undefined}
|
||||
pointerEvents={areAllGroupsDisabled ? 'none' : undefined}
|
||||
opacity={isResetDisabled ? 0.5 : undefined}
|
||||
pointerEvents={isResetDisabled ? 'none' : undefined}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
|
||||
@@ -23,7 +23,7 @@ const zServerValidatedModelIdentifierField = zModelIdentifierField.refine(async
|
||||
}
|
||||
});
|
||||
|
||||
const zImageWithDims = z
|
||||
export const zImageWithDims = z
|
||||
.object({
|
||||
image_name: z.string(),
|
||||
width: z.number().int().positive(),
|
||||
@@ -410,6 +410,7 @@ export const isImagen3AspectRatioID = (v: unknown): v is z.infer<typeof zImagen3
|
||||
zImagen3AspectRatioID.safeParse(v).success;
|
||||
|
||||
export const zChatGPT4oAspectRatioID = z.enum(['3:2', '1:1', '2:3']);
|
||||
export type ChatGPT4oAspectRatioID = z.infer<typeof zChatGPT4oAspectRatioID>;
|
||||
export const isChatGPT4oAspectRatioID = (v: unknown): v is z.infer<typeof zChatGPT4oAspectRatioID> =>
|
||||
zChatGPT4oAspectRatioID.safeParse(v).success;
|
||||
|
||||
|
||||
@@ -0,0 +1,361 @@
|
||||
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from 'app/constants';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { selectSimpleGenerationSlice } from 'features/simpleGeneration/store/slice';
|
||||
import { ASPECT_RATIO_MAP } from 'features/simpleGeneration/util/aspectRatioToDimensions';
|
||||
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const EMPTY_ENTITY_STATE = {
|
||||
ids: [],
|
||||
entities: {},
|
||||
};
|
||||
|
||||
export const getFLUXModels = (modelConfigs: ReturnType<typeof selectModelConfigsQuery>) => {
|
||||
const allModelConfigs = modelConfigsAdapterSelectors.selectAll(modelConfigs?.data ?? EMPTY_ENTITY_STATE);
|
||||
|
||||
/**
|
||||
* The sources are taken from `invokeai/backend/model_manager/starter_models.py`
|
||||
*
|
||||
* Note: The sources for HF Repo models are subtly in the python file vs the actual model config we get from the
|
||||
* HTTP query response.
|
||||
*
|
||||
* In python, a double colon `::` separates the HF Repo ID from the folder path within the repo. But the model
|
||||
* configs we get from the query use a single colon `:`.
|
||||
*
|
||||
* For example:
|
||||
* Python: InvokeAI/t5-v1_1-xxl::bnb_llm_int8
|
||||
* Query : InvokeAI/t5-v1_1-xxl:bnb_llm_int8
|
||||
*
|
||||
* I'm not sure if it's always been like this, but to be safe, we check for both formats below.
|
||||
*/
|
||||
|
||||
const flux = allModelConfigs.find(
|
||||
({ type, base, source }) =>
|
||||
type === 'main' &&
|
||||
base === 'flux' &&
|
||||
(source === 'InvokeAI/flux_dev:transformer/bnb_nf4/flux1-dev-bnb_nf4.safetensors' ||
|
||||
source === 'InvokeAI/flux_dev::transformer/bnb_nf4/flux1-dev-bnb_nf4.safetensors')
|
||||
);
|
||||
|
||||
const t5Encoder = allModelConfigs.find(
|
||||
({ type, base, source }) =>
|
||||
type === 't5_encoder' &&
|
||||
base === 'any' &&
|
||||
(source === 'InvokeAI/t5-v1_1-xxl:bnb_llm_int8' || source === 'InvokeAI/t5-v1_1-xxl::bnb_llm_int8')
|
||||
);
|
||||
|
||||
const clipEmbed = allModelConfigs.find(
|
||||
({ type, base, source }) =>
|
||||
type === 'clip_embed' &&
|
||||
base === 'any' &&
|
||||
(source === 'InvokeAI/clip-vit-large-patch14-text-encoder:bfloat16' ||
|
||||
source === 'InvokeAI/clip-vit-large-patch14-text-encoder::bfloat16')
|
||||
);
|
||||
|
||||
const clipVision = allModelConfigs.find(
|
||||
({ type, base, source }) => type === 'clip_vision' && base === 'any' && source === 'InvokeAI/clip-vit-large-patch14'
|
||||
);
|
||||
|
||||
const vae = allModelConfigs.find(
|
||||
({ type, base, source }) =>
|
||||
type === 'vae' &&
|
||||
base === 'flux' &&
|
||||
(source === 'black-forest-labs/FLUX.1-schnell:ae.safetensors' ||
|
||||
source === 'black-forest-labs/FLUX.1-schnell::ae.safetensors')
|
||||
);
|
||||
|
||||
const ipAdapter = allModelConfigs.find(
|
||||
({ type, base, source }) =>
|
||||
type === 'ip_adapter' &&
|
||||
base === 'flux' &&
|
||||
source === 'https://huggingface.co/XLabs-AI/flux-ip-adapter-v2/resolve/main/ip_adapter.safetensors'
|
||||
);
|
||||
|
||||
return {
|
||||
flux,
|
||||
t5Encoder,
|
||||
clipEmbed,
|
||||
clipVision,
|
||||
vae,
|
||||
ipAdapter,
|
||||
};
|
||||
};
|
||||
|
||||
export const getSD1Models = (modelConfigs: ReturnType<typeof selectModelConfigsQuery>) => {
|
||||
const allModelConfigs = modelConfigsAdapterSelectors.selectAll(modelConfigs?.data ?? EMPTY_ENTITY_STATE);
|
||||
|
||||
/**
|
||||
* The sources are taken from `invokeai/backend/model_manager/starter_models.py`
|
||||
*
|
||||
* Note: The sources for HF Repo models are subtly in the python file vs the actual model config we get from the
|
||||
* HTTP query response.
|
||||
*
|
||||
* In python, a double colon `::` separates the HF Repo ID from the folder path within the repo. But the model
|
||||
* configs we get from the query use a single colon `:`.
|
||||
*
|
||||
* For example:
|
||||
* Python: InvokeAI/t5-v1_1-xxl::bnb_llm_int8
|
||||
* Query : InvokeAI/t5-v1_1-xxl:bnb_llm_int8
|
||||
*
|
||||
* I'm not sure if it's always been like this, but to be safe, we check for both formats below.
|
||||
*/
|
||||
|
||||
const main = allModelConfigs.find(
|
||||
({ type, base, source }) =>
|
||||
type === 'main' &&
|
||||
base === 'sd-1' &&
|
||||
source === 'https://huggingface.co/XpucT/Deliberate/resolve/main/Deliberate_v5.safetensors'
|
||||
);
|
||||
|
||||
return {
|
||||
main,
|
||||
};
|
||||
};
|
||||
|
||||
export const getSDXLModels = (modelConfigs: ReturnType<typeof selectModelConfigsQuery>) => {
|
||||
const allModelConfigs = modelConfigsAdapterSelectors.selectAll(modelConfigs?.data ?? EMPTY_ENTITY_STATE);
|
||||
|
||||
/**
|
||||
* The sources are taken from `invokeai/backend/model_manager/starter_models.py`
|
||||
*
|
||||
* Note: The sources for HF Repo models are subtly in the python file vs the actual model config we get from the
|
||||
* HTTP query response.
|
||||
*
|
||||
* In python, a double colon `::` separates the HF Repo ID from the folder path within the repo. But the model
|
||||
* configs we get from the query use a single colon `:`.
|
||||
*
|
||||
* For example:
|
||||
* Python: InvokeAI/t5-v1_1-xxl::bnb_llm_int8
|
||||
* Query : InvokeAI/t5-v1_1-xxl:bnb_llm_int8
|
||||
*
|
||||
* I'm not sure if it's always been like this, but to be safe, we check for both formats below.
|
||||
*/
|
||||
|
||||
const main = allModelConfigs.find(
|
||||
({ type, base, source }) => type === 'main' && base === 'sdxl' && source === 'RunDiffusion/Juggernaut-XL-v9'
|
||||
);
|
||||
|
||||
const vae = allModelConfigs.find(
|
||||
({ type, base, source }) => type === 'vae' && base === 'sdxl' && source === 'madebyollin/sdxl-vae-fp16-fix'
|
||||
);
|
||||
|
||||
return {
|
||||
main,
|
||||
vae,
|
||||
};
|
||||
};
|
||||
|
||||
const buildSimpleSD1Graph = (state: RootState) => {
|
||||
const { positivePrompt, aspectRatio } = selectSimpleGenerationSlice(state);
|
||||
|
||||
const { main } = getSD1Models(selectModelConfigsQuery(state));
|
||||
const g = new Graph(getPrefixedId('simple_sd1'));
|
||||
|
||||
const dimensions = ASPECT_RATIO_MAP['sd-1'][aspectRatio];
|
||||
|
||||
const modelLoader = g.addNode({
|
||||
type: 'main_model_loader',
|
||||
id: getPrefixedId('main_model_loader'),
|
||||
model: zModelIdentifierField.parse(main),
|
||||
});
|
||||
const posCond = g.addNode({
|
||||
type: 'compel',
|
||||
id: getPrefixedId('compel_prompt_pos'),
|
||||
prompt: positivePrompt,
|
||||
});
|
||||
const negCond = g.addNode({
|
||||
type: 'compel',
|
||||
id: getPrefixedId('compel_prompt_neg'),
|
||||
prompt: '',
|
||||
});
|
||||
const seed = g.addNode({
|
||||
type: 'rand_int',
|
||||
id: getPrefixedId('rand_int'),
|
||||
low: NUMPY_RAND_MIN,
|
||||
high: NUMPY_RAND_MAX,
|
||||
use_cache: false,
|
||||
});
|
||||
const noise = g.addNode({
|
||||
type: 'noise',
|
||||
id: getPrefixedId('noise'),
|
||||
width: dimensions.width,
|
||||
height: dimensions.height,
|
||||
});
|
||||
const denoise = g.addNode({
|
||||
type: 'denoise_latents',
|
||||
id: getPrefixedId('denoise_latents'),
|
||||
steps: 30,
|
||||
});
|
||||
const l2i = g.addNode({
|
||||
type: 'l2i',
|
||||
id: getPrefixedId('l2i'),
|
||||
is_intermediate: false,
|
||||
board: getBoardField(state),
|
||||
});
|
||||
|
||||
g.addEdge(modelLoader, 'clip', posCond, 'clip');
|
||||
g.addEdge(modelLoader, 'clip', negCond, 'clip');
|
||||
|
||||
g.addEdge(seed, 'value', noise, 'seed');
|
||||
|
||||
g.addEdge(noise, 'noise', denoise, 'noise');
|
||||
g.addEdge(modelLoader, 'unet', denoise, 'unet');
|
||||
g.addEdge(posCond, 'conditioning', denoise, 'positive_conditioning');
|
||||
g.addEdge(negCond, 'conditioning', denoise, 'negative_conditioning');
|
||||
|
||||
g.addEdge(modelLoader, 'vae', l2i, 'vae');
|
||||
g.addEdge(denoise, 'latents', l2i, 'latents');
|
||||
|
||||
return g;
|
||||
};
|
||||
|
||||
const buildSimpleSDXLGraph = (state: RootState) => {
|
||||
const { positivePrompt, aspectRatio } = selectSimpleGenerationSlice(state);
|
||||
|
||||
const { main, vae } = getSDXLModels(selectModelConfigsQuery(state));
|
||||
const g = new Graph(getPrefixedId('simple_sdxl'));
|
||||
|
||||
const dimensions = ASPECT_RATIO_MAP['sdxl'][aspectRatio];
|
||||
|
||||
const modelLoader = g.addNode({
|
||||
type: 'sdxl_model_loader',
|
||||
id: getPrefixedId('sdxl_model_loader'),
|
||||
model: zModelIdentifierField.parse(main),
|
||||
});
|
||||
const vaeLoader = g.addNode({
|
||||
type: 'vae_loader',
|
||||
id: getPrefixedId('vae_loader'),
|
||||
vae_model: zModelIdentifierField.parse(vae),
|
||||
});
|
||||
const posCond = g.addNode({
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('sdxl_compel_prompt_pos'),
|
||||
prompt: positivePrompt,
|
||||
});
|
||||
const negCond = g.addNode({
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('sdxl_compel_prompt_neg'),
|
||||
prompt: '',
|
||||
});
|
||||
const seed = g.addNode({
|
||||
type: 'rand_int',
|
||||
id: getPrefixedId('rand_int'),
|
||||
low: NUMPY_RAND_MIN,
|
||||
high: NUMPY_RAND_MAX,
|
||||
use_cache: false,
|
||||
});
|
||||
const noise = g.addNode({
|
||||
type: 'noise',
|
||||
id: getPrefixedId('noise'),
|
||||
width: dimensions.width,
|
||||
height: dimensions.height,
|
||||
});
|
||||
const denoise = g.addNode({
|
||||
type: 'denoise_latents',
|
||||
id: getPrefixedId('denoise_latents'),
|
||||
steps: 30,
|
||||
});
|
||||
const l2i = g.addNode({
|
||||
type: 'l2i',
|
||||
id: getPrefixedId('l2i'),
|
||||
is_intermediate: false,
|
||||
board: getBoardField(state),
|
||||
});
|
||||
|
||||
g.addEdge(modelLoader, 'clip', posCond, 'clip');
|
||||
g.addEdge(modelLoader, 'clip2', posCond, 'clip2');
|
||||
|
||||
g.addEdge(modelLoader, 'clip', negCond, 'clip');
|
||||
g.addEdge(modelLoader, 'clip2', negCond, 'clip2');
|
||||
|
||||
g.addEdge(seed, 'value', noise, 'seed');
|
||||
|
||||
g.addEdge(noise, 'noise', denoise, 'noise');
|
||||
g.addEdge(modelLoader, 'unet', denoise, 'unet');
|
||||
g.addEdge(posCond, 'conditioning', denoise, 'positive_conditioning');
|
||||
g.addEdge(negCond, 'conditioning', denoise, 'negative_conditioning');
|
||||
|
||||
g.addEdge(vaeLoader, 'vae', l2i, 'vae');
|
||||
g.addEdge(denoise, 'latents', l2i, 'latents');
|
||||
|
||||
return g;
|
||||
};
|
||||
|
||||
const buildSimpleFLUXGraph = (state: RootState) => {
|
||||
const { positivePrompt, aspectRatio } = selectSimpleGenerationSlice(state);
|
||||
|
||||
const { flux, t5Encoder, clipEmbed, vae } = getFLUXModels(selectModelConfigsQuery(state));
|
||||
const g = new Graph(getPrefixedId('simple_flux'));
|
||||
|
||||
const dimensions = ASPECT_RATIO_MAP['flux'][aspectRatio];
|
||||
|
||||
const modelLoader = g.addNode({
|
||||
type: 'flux_model_loader',
|
||||
id: getPrefixedId('flux_model_loader'),
|
||||
model: zModelIdentifierField.parse(flux),
|
||||
t5_encoder_model: zModelIdentifierField.parse(t5Encoder),
|
||||
clip_embed_model: zModelIdentifierField.parse(clipEmbed),
|
||||
vae_model: zModelIdentifierField.parse(vae),
|
||||
});
|
||||
const posCond = g.addNode({
|
||||
type: 'flux_text_encoder',
|
||||
id: getPrefixedId('flux_text_encoder'),
|
||||
prompt: positivePrompt,
|
||||
});
|
||||
const seed = g.addNode({
|
||||
type: 'rand_int',
|
||||
id: getPrefixedId('rand_int'),
|
||||
low: NUMPY_RAND_MIN,
|
||||
high: NUMPY_RAND_MAX,
|
||||
use_cache: false,
|
||||
});
|
||||
const denoise = g.addNode({
|
||||
type: 'flux_denoise',
|
||||
id: getPrefixedId('flux_denoise'),
|
||||
guidance: 4.0,
|
||||
num_steps: 30,
|
||||
denoising_start: 0,
|
||||
denoising_end: 1,
|
||||
width: dimensions.width,
|
||||
height: dimensions.height,
|
||||
});
|
||||
const l2i = g.addNode({
|
||||
type: 'flux_vae_decode',
|
||||
id: getPrefixedId('flux_vae_decode'),
|
||||
is_intermediate: false,
|
||||
board: getBoardField(state),
|
||||
});
|
||||
|
||||
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
|
||||
g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len');
|
||||
g.addEdge(modelLoader, 'clip', posCond, 'clip');
|
||||
|
||||
g.addEdge(modelLoader, 'transformer', denoise, 'transformer');
|
||||
g.addEdge(modelLoader, 'vae', denoise, 'controlnet_vae');
|
||||
g.addEdge(seed, 'value', denoise, 'seed');
|
||||
g.addEdge(posCond, 'conditioning', denoise, 'positive_text_conditioning');
|
||||
|
||||
g.addEdge(modelLoader, 'vae', l2i, 'vae');
|
||||
g.addEdge(denoise, 'latents', l2i, 'latents');
|
||||
|
||||
return g;
|
||||
};
|
||||
|
||||
export const buildSimpleGraph = (state: RootState): Graph => {
|
||||
const { model } = selectSimpleGenerationSlice(state);
|
||||
|
||||
if (model === 'flux') {
|
||||
return buildSimpleFLUXGraph(state);
|
||||
}
|
||||
|
||||
if (model === 'sdxl') {
|
||||
return buildSimpleSDXLGraph(state);
|
||||
}
|
||||
|
||||
if (model === 'sd-1') {
|
||||
return buildSimpleSD1Graph(state);
|
||||
}
|
||||
};
|
||||
@@ -106,6 +106,7 @@ export const ModelPicker = typedMemo(
|
||||
isDisabled,
|
||||
isInvalid,
|
||||
className,
|
||||
defaultEnabledGroups,
|
||||
}: {
|
||||
modelConfigs: T[];
|
||||
selectedModelConfig: T | undefined;
|
||||
@@ -117,6 +118,7 @@ export const ModelPicker = typedMemo(
|
||||
isDisabled?: boolean;
|
||||
isInvalid?: boolean;
|
||||
className?: string;
|
||||
defaultEnabledGroups?: string[];
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
const options = useMemo<T[] | Group<T>[]>(() => {
|
||||
@@ -223,6 +225,7 @@ export const ModelPicker = typedMemo(
|
||||
noMatchesFallback={t('modelManager.noMatchingModels')}
|
||||
NextToSearchBar={<NavigateToModelManagerButton />}
|
||||
getIsOptionDisabled={getIsOptionDisabled}
|
||||
defaultEnabledGroups={defaultEnabledGroups}
|
||||
searchable
|
||||
/>
|
||||
</PopoverBody>
|
||||
|
||||
@@ -36,6 +36,10 @@ export const InvokeButtonTooltip = ({ prepend, children, ...rest }: PropsWithChi
|
||||
const TooltipContent = memo(({ prepend = false }: { prepend?: boolean }) => {
|
||||
const activeTab = useAppSelector(selectActiveTab);
|
||||
|
||||
if (activeTab === 'simple') {
|
||||
return <SimpleTabTooltipContent prepend={prepend} />;
|
||||
}
|
||||
|
||||
if (activeTab === 'canvas') {
|
||||
return <CanvasTabTooltipContent prepend={prepend} />;
|
||||
}
|
||||
@@ -52,6 +56,27 @@ const TooltipContent = memo(({ prepend = false }: { prepend?: boolean }) => {
|
||||
});
|
||||
TooltipContent.displayName = 'TooltipContent';
|
||||
|
||||
const SimpleTabTooltipContent = memo(({ prepend = false }: { prepend?: boolean }) => {
|
||||
const isReady = useStore($isReadyToEnqueue);
|
||||
const reasons = useStore($reasonsWhyCannotEnqueue);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={1}>
|
||||
<IsReadyText isReady={isReady} prepend={prepend} />
|
||||
{/* <QueueCountPredictionSimpleOrUpscaleTab /> */}
|
||||
{reasons.length > 0 && (
|
||||
<>
|
||||
<StyledDivider />
|
||||
<ReasonsList reasons={reasons} />
|
||||
</>
|
||||
)}
|
||||
<StyledDivider />
|
||||
<AddingToText />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
SimpleTabTooltipContent.displayName = 'SimpleTabTooltipContent';
|
||||
|
||||
const CanvasTabTooltipContent = memo(({ prepend = false }: { prepend?: boolean }) => {
|
||||
const isReady = useStore($isReadyToEnqueue);
|
||||
const reasons = useStore($reasonsWhyCannotEnqueue);
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { buildSimpleGraph } from 'features/nodes/util/graph/buildSimpleGraph';
|
||||
import { useCallback } from 'react';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import type { EnqueueBatchArg } from 'services/api/types';
|
||||
|
||||
const enqueueRequestedSimple = createAction('app/enqueueRequestedSimple');
|
||||
|
||||
export const useEnqueueSimple = () => {
|
||||
const { getState, dispatch } = useAppStore();
|
||||
const enqueue = useCallback(
|
||||
async (prepend: boolean) => {
|
||||
dispatch(enqueueRequestedSimple());
|
||||
const state = getState();
|
||||
const g = buildSimpleGraph(state);
|
||||
|
||||
const batchConfig: EnqueueBatchArg = {
|
||||
batch: {
|
||||
graph: g.getGraph(),
|
||||
runs: state.params.iterations,
|
||||
origin: 'simple',
|
||||
destination: 'gallery',
|
||||
},
|
||||
prepend,
|
||||
};
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, { ...enqueueMutationFixedCacheKeyOptions, track: false })
|
||||
);
|
||||
|
||||
const enqueueResult = await req.unwrap();
|
||||
return { batchConfig, enqueueResult };
|
||||
},
|
||||
[dispatch, getState]
|
||||
);
|
||||
|
||||
return enqueue;
|
||||
};
|
||||
@@ -6,6 +6,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import { useEnqueueSimple } from 'features/queue/hooks/useEnqueueSimple';
|
||||
import { useEnqueueWorkflows } from 'features/queue/hooks/useEnqueueWorkflows';
|
||||
import { $isReadyToEnqueue } from 'features/queue/store/readiness';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
@@ -21,6 +22,7 @@ export const useInvoke = () => {
|
||||
const isReady = useStore($isReadyToEnqueue);
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
const enqueueWorkflows = useEnqueueWorkflows();
|
||||
const enqueueSimple = useEnqueueSimple();
|
||||
|
||||
const [_, { isLoading }] = useEnqueueBatchMutation(enqueueMutationFixedCacheKeyOptions);
|
||||
|
||||
@@ -30,6 +32,15 @@ export const useInvoke = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
if (tabName === 'simple') {
|
||||
const result = await withResultAsync(() => enqueueSimple(prepend));
|
||||
if (result.isErr()) {
|
||||
log.error({ error: serializeError(result.error) }, 'Failed to enqueue batch');
|
||||
} else {
|
||||
log.debug(parseify(result.value), 'Enqueued batch');
|
||||
}
|
||||
}
|
||||
|
||||
if (tabName === 'workflows') {
|
||||
const result = await withResultAsync(() => enqueueWorkflows(prepend, isApiValidationRun));
|
||||
if (result.isErr()) {
|
||||
@@ -51,7 +62,7 @@ export const useInvoke = () => {
|
||||
|
||||
// Else we are not on a generation tab and should not queue
|
||||
},
|
||||
[dispatch, enqueueWorkflows, isReady, tabName]
|
||||
[dispatch, enqueueSimple, enqueueWorkflows, isReady, tabName]
|
||||
);
|
||||
|
||||
const enqueueBack = useCallback(() => {
|
||||
|
||||
@@ -30,10 +30,13 @@ import { getInvocationNodeErrors } from 'features/nodes/store/util/fieldValidato
|
||||
import type { WorkflowSettingsState } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { isBatchNode, isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { getFLUXModels, getSD1Models, getSDXLModels } from 'features/nodes/util/graph/buildSimpleGraph';
|
||||
import { resolveBatchValue } from 'features/nodes/util/node/resolveBatchValue';
|
||||
import type { UpscaleState } from 'features/parameters/store/upscaleSlice';
|
||||
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
|
||||
import { getGridSize } from 'features/parameters/util/optimalDimension';
|
||||
import { selectSimpleGenerationSlice } from 'features/simpleGeneration/store/slice';
|
||||
import type { SimpleGenerationState } from 'features/simpleGeneration/store/types';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
@@ -42,7 +45,7 @@ import i18n from 'i18next';
|
||||
import { debounce, groupBy, upperFirst } from 'lodash-es';
|
||||
import { atom, computed } from 'nanostores';
|
||||
import { useEffect } from 'react';
|
||||
import { selectMainModelConfig } from 'services/api/endpoints/models';
|
||||
import { selectMainModelConfig, selectModelConfigsQuery } from 'services/api/endpoints/models';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
@@ -89,9 +92,14 @@ const debouncedUpdateReasons = debounce(
|
||||
config: AppConfig,
|
||||
store: AppStore,
|
||||
isInPublishFlow: boolean,
|
||||
areChatGPT4oModelsEnabled: boolean
|
||||
areChatGPT4oModelsEnabled: boolean,
|
||||
modelConfigsQuery: ReturnType<typeof selectModelConfigsQuery>,
|
||||
simple: SimpleGenerationState
|
||||
) => {
|
||||
if (tab === 'canvas') {
|
||||
if (tab === 'simple') {
|
||||
const reasons = getReasonsWhyCannotEnqueueSimpleTab({ isConnected, simple, modelConfigsQuery });
|
||||
$reasonsWhyCannotEnqueue.set(reasons);
|
||||
} else if (tab === 'canvas') {
|
||||
const model = selectMainModelConfig(store.getState());
|
||||
const reasons = await getReasonsWhyCannotEnqueueCanvasTab({
|
||||
isConnected,
|
||||
@@ -143,6 +151,7 @@ export const useReadinessWatcher = () => {
|
||||
const nodes = useAppSelector(selectNodesSlice);
|
||||
const workflowSettings = useAppSelector(selectWorkflowSettingsSlice);
|
||||
const upscale = useAppSelector(selectUpscaleSlice);
|
||||
const simple = useAppSelector(selectSimpleGenerationSlice);
|
||||
const config = useAppSelector(selectConfigSlice);
|
||||
const templates = useStore($templates);
|
||||
const isConnected = useStore($isConnected);
|
||||
@@ -152,6 +161,7 @@ export const useReadinessWatcher = () => {
|
||||
const canvasIsSelectingObject = useStore(canvasManager?.stateApi.$isSegmenting ?? $true);
|
||||
const canvasIsCompositing = useStore(canvasManager?.compositor.$isBusy ?? $true);
|
||||
const isInPublishFlow = useStore($isInPublishFlow);
|
||||
const modelConfigsQuery = useAppSelector(selectModelConfigsQuery);
|
||||
const areChatGPT4oModelsEnabled = useFeatureStatus('chatGPT4oModels');
|
||||
|
||||
useEffect(() => {
|
||||
@@ -173,7 +183,9 @@ export const useReadinessWatcher = () => {
|
||||
config,
|
||||
store,
|
||||
isInPublishFlow,
|
||||
areChatGPT4oModelsEnabled
|
||||
areChatGPT4oModelsEnabled,
|
||||
modelConfigsQuery,
|
||||
simple
|
||||
);
|
||||
}, [
|
||||
store,
|
||||
@@ -194,6 +206,8 @@ export const useReadinessWatcher = () => {
|
||||
workflowSettings,
|
||||
isInPublishFlow,
|
||||
areChatGPT4oModelsEnabled,
|
||||
modelConfigsQuery,
|
||||
simple,
|
||||
]);
|
||||
};
|
||||
|
||||
@@ -330,6 +344,61 @@ const getReasonsWhyCannotEnqueueUpscaleTab = (arg: {
|
||||
return reasons;
|
||||
};
|
||||
|
||||
const getReasonsWhyCannotEnqueueSimpleTab = (arg: {
|
||||
isConnected: boolean;
|
||||
simple: SimpleGenerationState;
|
||||
modelConfigsQuery: ReturnType<typeof selectModelConfigsQuery>;
|
||||
}) => {
|
||||
const { isConnected, simple, modelConfigsQuery } = arg;
|
||||
const { model } = simple;
|
||||
const reasons: Reason[] = [];
|
||||
|
||||
if (!isConnected) {
|
||||
reasons.push(disconnectedReason(i18n.t));
|
||||
}
|
||||
|
||||
if (model === 'flux') {
|
||||
const models = getFLUXModels(modelConfigsQuery);
|
||||
if (!models.flux) {
|
||||
reasons.push({ content: 'FLUX is not installed' });
|
||||
}
|
||||
if (!models.t5Encoder) {
|
||||
reasons.push({ content: 'T5 Encoder is not installed' });
|
||||
}
|
||||
if (!models.clipEmbed) {
|
||||
reasons.push({ content: 'CLIP Embed is not installed' });
|
||||
}
|
||||
if (!models.clipVision) {
|
||||
reasons.push({ content: 'CLIP Vision is not installed' });
|
||||
}
|
||||
if (!models.ipAdapter) {
|
||||
reasons.push({ content: 'IP Adapter is not installed' });
|
||||
}
|
||||
if (!models.vae) {
|
||||
reasons.push({ content: 'VAE is not installed' });
|
||||
}
|
||||
}
|
||||
|
||||
if (model === 'sdxl') {
|
||||
const models = getSDXLModels(modelConfigsQuery);
|
||||
if (!models.main) {
|
||||
reasons.push({ content: 'Main SDXL model is not installed' });
|
||||
}
|
||||
if (!models.vae) {
|
||||
reasons.push({ content: 'VAE is not installed' });
|
||||
}
|
||||
}
|
||||
|
||||
if (model === 'sd-1') {
|
||||
const models = getSD1Models(modelConfigsQuery);
|
||||
if (!models.main) {
|
||||
reasons.push({ content: 'Main SDXL model is not installed' });
|
||||
}
|
||||
}
|
||||
|
||||
return reasons;
|
||||
};
|
||||
|
||||
const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
|
||||
isConnected: boolean;
|
||||
model: MainModelConfig | null | undefined;
|
||||
|
||||
@@ -12,7 +12,6 @@ import {
|
||||
selectIsSD3,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import { LoRAList } from 'features/lora/components/LoRAList';
|
||||
import LoRASelect from 'features/lora/components/LoRASelect';
|
||||
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
|
||||
import ParamGuidance from 'features/parameters/components/Core/ParamGuidance';
|
||||
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
|
||||
@@ -20,6 +19,7 @@ import ParamSteps from 'features/parameters/components/Core/ParamSteps';
|
||||
import { DisabledModelWarning } from 'features/parameters/components/MainModel/DisabledModelWarning';
|
||||
import ParamUpscaleCFGScale from 'features/parameters/components/Upscale/ParamUpscaleCFGScale';
|
||||
import ParamUpscaleScheduler from 'features/parameters/components/Upscale/ParamUpscaleScheduler';
|
||||
import { LoRAModelPicker } from 'features/settingsAccordions/components/GenerationSettingsAccordion/LoRAModelPicker';
|
||||
import { MainModelPicker } from 'features/settingsAccordions/components/GenerationSettingsAccordion/MainModelPicker';
|
||||
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
|
||||
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
||||
@@ -86,7 +86,7 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
<Flex gap={4} flexDir="column" pb={isApiModel ? 4 : 0}>
|
||||
<DisabledModelWarning />
|
||||
<MainModelPicker />
|
||||
{!isApiModel && <LoRASelect />}
|
||||
{!isApiModel && <LoRAModelPicker />}
|
||||
{!isApiModel && <LoRAList />}
|
||||
</Flex>
|
||||
{!isApiModel && (
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
import { Flex, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
|
||||
import { selectBase } from 'features/controlLayers/store/paramsSlice';
|
||||
import { UseDefaultSettingsButton } from 'features/parameters/components/MainModel/UseDefaultSettingsButton';
|
||||
import { ModelPicker } from 'features/parameters/components/ModelPicker';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useLoRAModels } from 'services/api/hooks/modelsByType';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
const selectLoRAs = createSelector(selectLoRAsSlice, (loras) => loras.loras);
|
||||
|
||||
export const LoRAModelPicker = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const currentBaseModel = useAppSelector(selectBase);
|
||||
const addedLoRAs = useAppSelector(selectLoRAs);
|
||||
|
||||
const getIsOptionDisabled = useCallback(
|
||||
(model: LoRAModelConfig): boolean => {
|
||||
const isCompatible = currentBaseModel === model.base;
|
||||
const isAdded = Boolean(addedLoRAs.find((lora) => lora.model.key === model.key));
|
||||
return !isCompatible || isAdded;
|
||||
},
|
||||
[addedLoRAs, currentBaseModel]
|
||||
);
|
||||
|
||||
const loraFilter = useCallback(
|
||||
(loraConfig: LoRAModelConfig) => {
|
||||
if (!currentBaseModel) {
|
||||
return true;
|
||||
}
|
||||
return currentBaseModel === loraConfig.base;
|
||||
},
|
||||
[currentBaseModel]
|
||||
);
|
||||
const [modelConfigs] = useLoRAModels();
|
||||
const onChange = useCallback(
|
||||
(loraConfig: LoRAModelConfig) => {
|
||||
dispatch(loraAdded({ model: loraConfig }));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex alignItems="center" gap={2}>
|
||||
<InformationalPopover feature="lora">
|
||||
<FormLabel>{t('models.concepts')} </FormLabel>
|
||||
</InformationalPopover>
|
||||
<ModelPicker
|
||||
modelConfigs={modelConfigs}
|
||||
selectedModelConfig={undefined}
|
||||
onChange={onChange}
|
||||
placeholder={t('models.addLora')}
|
||||
getIsOptionDisabled={getIsOptionDisabled}
|
||||
defaultEnabledGroups={currentBaseModel ? [currentBaseModel] : undefined}
|
||||
allowEmpty
|
||||
grouped
|
||||
/>
|
||||
<UseDefaultSettingsButton />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
LoRAModelPicker.displayName = 'LoRAModelPicker';
|
||||
@@ -0,0 +1,84 @@
|
||||
import { Flex, FormLabel, Icon, Select } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { modelChanged, selectModel } from 'features/simpleGeneration/store/slice';
|
||||
import { isModel } from 'features/simpleGeneration/store/types';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { MdMoneyOff } from 'react-icons/md';
|
||||
|
||||
export const SimpleTabModel = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const model = useAppSelector(selectModel);
|
||||
const onChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
|
||||
(e) => {
|
||||
if (!isModel(e.target.value)) {
|
||||
return;
|
||||
}
|
||||
dispatch(modelChanged({ model: e.target.value }));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex alignItems="center" gap={2}>
|
||||
<InformationalPopover feature="paramModel">
|
||||
<FormLabel m={0}>{t('modelManager.model')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
{model === 'flux' && (
|
||||
<InformationalPopover feature="fluxDevLicense" hideDisable={true}>
|
||||
<Flex justifyContent="flex-start">
|
||||
<Icon as={MdMoneyOff} />
|
||||
</Flex>
|
||||
</InformationalPopover>
|
||||
)}
|
||||
<Select value={model} onChange={onChange}>
|
||||
<option value="chatgpt-4o">ChatGPT 4o</option>
|
||||
<option value="flux">FLUX</option>
|
||||
<option value="sdxl">SDXL</option>
|
||||
<option value="sd-1">SD 1.5</option>
|
||||
</Select>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
SimpleTabModel.displayName = 'SimpleTabModel';
|
||||
|
||||
// export const SimpleTabModel = memo(() => {
|
||||
// const { t } = useTranslation();
|
||||
// const dispatch = useAppDispatch();
|
||||
// const [modelConfigs] = useSimpleTabModels();
|
||||
// const selectedModelConfig = useSimpleTabModelConfig();
|
||||
// const onChange = useCallback(
|
||||
// (modelConfig: AnyModelConfig) => {
|
||||
// dispatch(modelChanged({ model: zModelIdentifierField.parse(modelConfig) }));
|
||||
// },
|
||||
// [dispatch]
|
||||
// );
|
||||
|
||||
// const isFluxDevSelected = useMemo(
|
||||
// () =>
|
||||
// selectedModelConfig &&
|
||||
// isCheckpointMainModelConfig(selectedModelConfig) &&
|
||||
// selectedModelConfig.config_path === 'flux-dev',
|
||||
// [selectedModelConfig]
|
||||
// );
|
||||
|
||||
// return (
|
||||
// <Flex alignItems="center" gap={2}>
|
||||
// <InformationalPopover feature="paramModel">
|
||||
// <FormLabel>{t('modelManager.model')}</FormLabel>
|
||||
// </InformationalPopover>
|
||||
// {isFluxDevSelected && (
|
||||
// <InformationalPopover feature="fluxDevLicense" hideDisable={true}>
|
||||
// <Flex justifyContent="flex-start">
|
||||
// <Icon as={MdMoneyOff} />
|
||||
// </Flex>
|
||||
// </InformationalPopover>
|
||||
// )}
|
||||
// <ModelPicker modelConfigs={modelConfigs} selectedModelConfig={selectedModelConfig} onChange={onChange} grouped />
|
||||
// </Flex>
|
||||
// );
|
||||
// });
|
||||
// SimpleTabModel.displayName = 'SimpleTabModel';
|
||||
@@ -0,0 +1,52 @@
|
||||
import { FormControl, FormLabel, Select } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { zChatGPT4oAspectRatioID } from 'features/controlLayers/store/types';
|
||||
import { aspectRatioChanged, selectAspectRatio, selectModel } from 'features/simpleGeneration/store/slice';
|
||||
import { isAspectRatio, zAspectRatio } from 'features/simpleGeneration/store/types';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const SimpleTabAspectRatio = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const id = useAppSelector(selectAspectRatio);
|
||||
const model = useAppSelector(selectModel);
|
||||
|
||||
const options = useMemo(() => {
|
||||
// ChatGPT4o has different aspect ratio options
|
||||
if (model === 'chatgpt-4o') {
|
||||
return zChatGPT4oAspectRatioID.options;
|
||||
}
|
||||
// All other models
|
||||
return zAspectRatio.options;
|
||||
}, [model]);
|
||||
|
||||
const onChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
|
||||
(e) => {
|
||||
if (!isAspectRatio(e.target.value)) {
|
||||
return;
|
||||
}
|
||||
dispatch(aspectRatioChanged({ aspectRatio: e.target.value }));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<InformationalPopover feature="paramAspect">
|
||||
<FormLabel m={0}>{t('parameters.aspect')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Select value={id} onChange={onChange}>
|
||||
{options.map((ratio) => (
|
||||
<option key={ratio} value={ratio}>
|
||||
{ratio}
|
||||
</option>
|
||||
))}
|
||||
</Select>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
SimpleTabAspectRatio.displayName = 'SimpleTabAspectRatio';
|
||||
@@ -0,0 +1,17 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { SimpleTabAspectRatio } from 'features/simpleGeneration/components/SimpleTabAspectRatio';
|
||||
import { SimpleTabModel } from 'features/simpleGeneration/components/SImpleTabModel';
|
||||
import { SimpleTabPositivePrompt } from 'features/simpleGeneration/components/SimpleTabPositivePrompt';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const SimpleTabLeftPanel = memo(() => {
|
||||
return (
|
||||
<Flex w="full" h="full" flexDir="column" gap={2}>
|
||||
<SimpleTabPositivePrompt />
|
||||
<SimpleTabModel />
|
||||
<SimpleTabAspectRatio />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
SimpleTabLeftPanel.displayName = 'SimpleTabLeftPanel';
|
||||
@@ -0,0 +1,50 @@
|
||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { PromptLabel } from 'features/parameters/components/Prompts/PromptLabel';
|
||||
import { positivePromptChanged, selectPositivePrompt } from 'features/simpleGeneration/store/slice';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const SimpleTabPositivePrompt = memo(() => {
|
||||
const dispatch = useAppDispatch();
|
||||
const prompt = useAppSelector(selectPositivePrompt);
|
||||
|
||||
const { t } = useTranslation();
|
||||
const onChange = useCallback<ChangeEventHandler<HTMLTextAreaElement>>(
|
||||
(e) => {
|
||||
dispatch(positivePromptChanged({ positivePrompt: e.target.value }));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'focusPrompt',
|
||||
category: 'app',
|
||||
callback: focus,
|
||||
options: { preventDefault: true, enableOnFormTags: ['INPUT', 'SELECT', 'TEXTAREA'] },
|
||||
dependencies: [focus],
|
||||
});
|
||||
|
||||
return (
|
||||
<Box pos="relative">
|
||||
<Textarea
|
||||
id="prompt"
|
||||
name="prompt"
|
||||
value={prompt}
|
||||
onChange={onChange}
|
||||
minH={40}
|
||||
variant="darkFilled"
|
||||
borderTopWidth={24} // This prevents the prompt from being hidden behind the header
|
||||
paddingInlineEnd={10}
|
||||
paddingInlineStart={3}
|
||||
paddingTop={0}
|
||||
paddingBottom={3}
|
||||
/>
|
||||
<PromptLabel label={t('parameters.positivePromptPlaceholder')} />
|
||||
</Box>
|
||||
);
|
||||
});
|
||||
|
||||
SimpleTabPositivePrompt.displayName = 'SimpleTabPositivePrompt';
|
||||
@@ -0,0 +1,11 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectModelKey } from 'features/simpleGeneration/store/slice';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
|
||||
export const useSimpleTabModelConfig = () => {
|
||||
const key = useAppSelector(selectModelKey);
|
||||
const { data: modelConfig } = useGetModelConfigQuery(key ?? skipToken);
|
||||
|
||||
return modelConfig;
|
||||
};
|
||||
@@ -0,0 +1,105 @@
|
||||
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
|
||||
import { createSelector, createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { SimpleGenerationState } from 'features/simpleGeneration/store/types';
|
||||
import { zSimpleGenerationState } from 'features/simpleGeneration/store/types';
|
||||
|
||||
const getInitialState = (): SimpleGenerationState => zSimpleGenerationState.parse({});
|
||||
|
||||
export const simpleGenerationSlice = createSlice({
|
||||
name: 'simpleGeneration',
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
positivePromptChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
positivePrompt: SimpleGenerationState['positivePrompt'];
|
||||
}>
|
||||
) => {
|
||||
const { positivePrompt } = action.payload;
|
||||
state.positivePrompt = positivePrompt;
|
||||
},
|
||||
modelChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
model: SimpleGenerationState['model'];
|
||||
}>
|
||||
) => {
|
||||
const { model } = action.payload;
|
||||
state.model = model;
|
||||
},
|
||||
aspectRatioChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
aspectRatio: SimpleGenerationState['aspectRatio'];
|
||||
}>
|
||||
) => {
|
||||
const { aspectRatio } = action.payload;
|
||||
state.aspectRatio = aspectRatio;
|
||||
},
|
||||
startingImageChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
startingImage: SimpleGenerationState['startingImage'];
|
||||
}>
|
||||
) => {
|
||||
const { startingImage } = action.payload;
|
||||
state.startingImage = startingImage;
|
||||
},
|
||||
referenceImageChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
index: number;
|
||||
referenceImage: SimpleGenerationState['referenceImages'][number];
|
||||
}>
|
||||
) => {
|
||||
const { index, referenceImage } = action.payload;
|
||||
state.referenceImages[index] = referenceImage;
|
||||
},
|
||||
controlImageChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
controlImage: SimpleGenerationState['controlImage'];
|
||||
}>
|
||||
) => {
|
||||
const { controlImage } = action.payload;
|
||||
state.controlImage = controlImage;
|
||||
},
|
||||
reset: () => getInitialState(),
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
aspectRatioChanged,
|
||||
controlImageChanged,
|
||||
modelChanged,
|
||||
positivePromptChanged,
|
||||
referenceImageChanged,
|
||||
startingImageChanged,
|
||||
reset,
|
||||
} = simpleGenerationSlice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateSimpleGenerationState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
export const simpleGenerationPersistConfig: PersistConfig<SimpleGenerationState> = {
|
||||
name: simpleGenerationSlice.name,
|
||||
initialState: getInitialState(),
|
||||
migrate: migrateSimpleGenerationState,
|
||||
persistDenylist: [],
|
||||
};
|
||||
|
||||
export const selectSimpleGenerationSlice = (state: RootState) => state.simpleGeneration;
|
||||
const createSliceSelector = <T>(selector: Selector<SimpleGenerationState, T>) =>
|
||||
createSelector(selectSimpleGenerationSlice, selector);
|
||||
|
||||
export const selectPositivePrompt = createSliceSelector((slice) => slice.positivePrompt);
|
||||
export const selectModel = createSliceSelector((slice) => slice.model);
|
||||
// export const selectModelBase = createSliceSelector((slice) => slice.model?.base);
|
||||
// export const selectModelKey = createSliceSelector((slice) => slice.model?.key);
|
||||
export const selectAspectRatio = createSliceSelector((slice) => slice.aspectRatio);
|
||||
@@ -0,0 +1,54 @@
|
||||
import { zImageWithDims } from 'features/controlLayers/store/types';
|
||||
import { buildTypeGuard } from 'features/parameters/types/parameterSchemas';
|
||||
import { nanoid } from 'nanoid';
|
||||
import { z } from 'zod';
|
||||
|
||||
const zLowMedHigh = z.enum(['low', 'med', 'high']);
|
||||
const zControlType = z.enum(['line', 'depth']);
|
||||
export const zAspectRatio = z.enum(['16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
|
||||
export type AspectRatio = z.infer<typeof zAspectRatio>;
|
||||
export const isAspectRatio = (val: unknown): val is AspectRatio => zAspectRatio.safeParse(val).success;
|
||||
|
||||
const STARTING_IMAGE_TYPE = 'starting_image';
|
||||
const zStartingImage = z.object({
|
||||
type: z.literal(STARTING_IMAGE_TYPE).default(STARTING_IMAGE_TYPE),
|
||||
id: z.string().default(nanoid),
|
||||
image: zImageWithDims.nullable().default(null),
|
||||
variation: zLowMedHigh.default('med'),
|
||||
});
|
||||
export type StartingImage = z.infer<typeof zStartingImage>;
|
||||
export const getStartingImage = (overrides: Partial<Omit<StartingImage, 'type'>>) => zStartingImage.parse(overrides);
|
||||
|
||||
const REFERENCE_IMAGE_TYPE = 'reference_image';
|
||||
const zReferenceImage = z.object({
|
||||
type: z.literal(REFERENCE_IMAGE_TYPE).default(REFERENCE_IMAGE_TYPE),
|
||||
id: z.string().default(nanoid),
|
||||
image: zImageWithDims.nullable().default(null),
|
||||
});
|
||||
export type ReferenceImage = z.infer<typeof zReferenceImage>;
|
||||
export const getReferenceImage = (overrides: Partial<Omit<ReferenceImage, 'type'>>) => zReferenceImage.parse(overrides);
|
||||
|
||||
const CONTROL_IMAGE_TYPE = 'control_image';
|
||||
const zControlImage = z.object({
|
||||
type: z.literal(CONTROL_IMAGE_TYPE).default(CONTROL_IMAGE_TYPE),
|
||||
id: z.string().default(nanoid),
|
||||
control_type: zControlType.default('line'),
|
||||
image: zImageWithDims.nullable().default(null),
|
||||
});
|
||||
export type ControlImage = z.infer<typeof zControlImage>;
|
||||
export const getControlImage = (overrides: Partial<Omit<ControlImage, 'type'>>) => zControlImage.parse(overrides);
|
||||
|
||||
const zModel = z.enum(['chatgpt-4o', 'flux', 'sdxl', 'sd-1']);
|
||||
export const isModel = buildTypeGuard(zModel);
|
||||
|
||||
export const zSimpleGenerationState = z.object({
|
||||
_version: z.literal(1).default(1),
|
||||
positivePrompt: z.string().default(''),
|
||||
model: zModel.default('flux'),
|
||||
aspectRatio: zAspectRatio.default('1:1'),
|
||||
startingImage: zStartingImage.nullable().default(null),
|
||||
referenceImages: z.array(zReferenceImage).default(() => []),
|
||||
controlImage: zControlImage.nullable().default(null),
|
||||
});
|
||||
|
||||
export type SimpleGenerationState = z.infer<typeof zSimpleGenerationState>;
|
||||
@@ -0,0 +1,48 @@
|
||||
import { roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import type { ChatGPT4oAspectRatioID, Dimensions } from 'features/controlLayers/store/types';
|
||||
import type { AspectRatio } from 'features/simpleGeneration/store/types';
|
||||
|
||||
export const getDimensions = (ratio: number, area: number): Dimensions => {
|
||||
const exactWidth = Math.sqrt(area * ratio);
|
||||
const exactHeight = exactWidth / ratio;
|
||||
|
||||
return {
|
||||
width: roundToMultiple(exactWidth, 64),
|
||||
height: roundToMultiple(exactHeight, 64),
|
||||
};
|
||||
};
|
||||
|
||||
const FLUX_SDXL_AREA = 1024 * 1024;
|
||||
export const FLUX_SDXL_ASPECT_RATIO_MAP: Record<AspectRatio, Dimensions> = {
|
||||
'16:9': getDimensions(16 / 9, FLUX_SDXL_AREA),
|
||||
'3:2': getDimensions(3 / 2, FLUX_SDXL_AREA),
|
||||
'4:3': getDimensions(4 / 3, FLUX_SDXL_AREA),
|
||||
'1:1': getDimensions(1, FLUX_SDXL_AREA),
|
||||
'3:4': getDimensions(3 / 4, FLUX_SDXL_AREA),
|
||||
'2:3': getDimensions(2 / 3, FLUX_SDXL_AREA),
|
||||
'9:16': getDimensions(9 / 16, FLUX_SDXL_AREA),
|
||||
};
|
||||
|
||||
const SD_1_AREA = 768 * 768;
|
||||
export const SD_1_ASPECT_RATIO_MAP: Record<AspectRatio, Dimensions> = {
|
||||
'16:9': getDimensions(16 / 9, SD_1_AREA),
|
||||
'3:2': getDimensions(3 / 2, SD_1_AREA),
|
||||
'4:3': getDimensions(4 / 3, SD_1_AREA),
|
||||
'1:1': getDimensions(1, SD_1_AREA),
|
||||
'3:4': getDimensions(3 / 4, SD_1_AREA),
|
||||
'2:3': getDimensions(2 / 3, SD_1_AREA),
|
||||
'9:16': getDimensions(9 / 16, SD_1_AREA),
|
||||
};
|
||||
|
||||
export const CHATGPT_4O_ASPECT_RATIO_MAP: Record<ChatGPT4oAspectRatioID, Dimensions> = {
|
||||
'1:1': { width: 1024, height: 1024 },
|
||||
'2:3': { width: 1024, height: 1536 },
|
||||
'3:2': { width: 1536, height: 1024 },
|
||||
};
|
||||
|
||||
export const ASPECT_RATIO_MAP = {
|
||||
flux: FLUX_SDXL_ASPECT_RATIO_MAP,
|
||||
sdxl: FLUX_SDXL_ASPECT_RATIO_MAP,
|
||||
'chatgpt-4o': CHATGPT_4O_ASPECT_RATIO_MAP,
|
||||
'sd-1': SD_1_ASPECT_RATIO_MAP,
|
||||
} as const;
|
||||
@@ -7,6 +7,7 @@ import GalleryPanelContent from 'features/gallery/components/GalleryPanelContent
|
||||
import { ImageViewer } from 'features/gallery/components/ImageViewer/ImageViewer';
|
||||
import WorkflowsTabLeftPanel from 'features/nodes/components/sidePanel/WorkflowsTabLeftPanel';
|
||||
import QueueControls from 'features/queue/components/QueueControls';
|
||||
import { SimpleTabLeftPanel } from 'features/simpleGeneration/components/SimpleTabLeftPanel';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton';
|
||||
import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons';
|
||||
@@ -165,7 +166,7 @@ const RightPanelContent = memo(() => {
|
||||
if (tab === 'canvas') {
|
||||
return <CanvasRightPanel />;
|
||||
}
|
||||
if (tab === 'upscaling' || tab === 'workflows') {
|
||||
if (tab === 'simple' || tab === 'upscaling' || tab === 'workflows') {
|
||||
return <GalleryPanelContent />;
|
||||
}
|
||||
|
||||
@@ -176,6 +177,9 @@ RightPanelContent.displayName = 'RightPanelContent';
|
||||
const LeftPanelContent = memo(() => {
|
||||
const tab = useAppSelector(selectActiveTab);
|
||||
|
||||
if (tab === 'simple') {
|
||||
return <SimpleTabLeftPanel />;
|
||||
}
|
||||
if (tab === 'canvas') {
|
||||
return <ParametersPanelTextToImage />;
|
||||
}
|
||||
@@ -199,6 +203,9 @@ const MainPanelContent = memo(() => {
|
||||
if (tab === 'upscaling') {
|
||||
return <ImageViewer />;
|
||||
}
|
||||
if (tab === 'simple') {
|
||||
return <ImageViewer />;
|
||||
}
|
||||
if (tab === 'workflows') {
|
||||
return <WorkflowsMainPanel />;
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import { VideosModalButton } from 'features/system/components/VideosModal/Videos
|
||||
import { TabMountGate } from 'features/ui/components/TabMountGate';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiBoundingBoxBold, PiCubeBold, PiFlowArrowBold, PiFrameCornersBold, PiQueueBold } from 'react-icons/pi';
|
||||
import { PiBoundingBoxBold, PiBowlSteam, PiCubeBold, PiFlowArrowBold, PiFrameCornersBold, PiQueueBold } from 'react-icons/pi';
|
||||
|
||||
import { Notifications } from './Notifications';
|
||||
import { TabButton } from './TabButton';
|
||||
@@ -21,6 +21,9 @@ export const VerticalNavBar = memo(() => {
|
||||
<Flex flexDir="column" alignItems="center" py={2} gap={4} minW={0}>
|
||||
<InvokeAILogoComponent />
|
||||
<Flex gap={4} pt={6} h="full" flexDir="column">
|
||||
<TabMountGate tab="simple">
|
||||
<TabButton tab="simple" icon={<PiBowlSteam />} label={t('ui.tabs.simple')} />
|
||||
</TabMountGate>
|
||||
<TabMountGate tab="canvas">
|
||||
<TabButton tab="canvas" icon={<PiBoundingBoxBold />} label={t('ui.tabs.canvas')} />
|
||||
</TabMountGate>
|
||||
|
||||
@@ -8,7 +8,7 @@ import { atom } from 'nanostores';
|
||||
import type { CanvasRightPanelTabName, TabName, UIState } from './uiTypes';
|
||||
|
||||
const initialUIState: UIState = {
|
||||
_version: 3,
|
||||
_version: 4,
|
||||
activeTab: 'canvas',
|
||||
activeTabCanvasRightPanel: 'gallery',
|
||||
shouldShowImageDetails: false,
|
||||
@@ -81,6 +81,10 @@ const migrateUIState = (state: any): any => {
|
||||
state.activeTab = 'canvas';
|
||||
state._version = 3;
|
||||
}
|
||||
if (state._version === 3) {
|
||||
state.activeTab = 'simple';
|
||||
state._version = 4;
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
@@ -91,12 +95,12 @@ export const uiPersistConfig: PersistConfig<UIState> = {
|
||||
persistDenylist: ['shouldShowImageDetails'],
|
||||
};
|
||||
|
||||
const TABS_WITH_LEFT_PANEL: TabName[] = ['canvas', 'upscaling', 'workflows'] as const;
|
||||
const TABS_WITH_LEFT_PANEL: TabName[] = ['simple', 'canvas', 'upscaling', 'workflows'] as const;
|
||||
export const LEFT_PANEL_MIN_SIZE_PX = 400;
|
||||
export const $isLeftPanelOpen = atom(true);
|
||||
export const selectWithLeftPanel = createSelector(selectUiSlice, (ui) => TABS_WITH_LEFT_PANEL.includes(ui.activeTab));
|
||||
|
||||
const TABS_WITH_RIGHT_PANEL: TabName[] = ['canvas', 'upscaling', 'workflows'] as const;
|
||||
const TABS_WITH_RIGHT_PANEL: TabName[] = ['simple', 'canvas', 'upscaling', 'workflows'] as const;
|
||||
export const RIGHT_PANEL_MIN_SIZE_PX = 390;
|
||||
export const $isRightPanelOpen = atom(true);
|
||||
export const selectWithRightPanel = createSelector(selectUiSlice, (ui) => TABS_WITH_RIGHT_PANEL.includes(ui.activeTab));
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
export type TabName = 'canvas' | 'upscaling' | 'workflows' | 'models' | 'queue';
|
||||
export type TabName = 'canvas' | 'upscaling' | 'workflows' | 'models' | 'queue' | 'simple';
|
||||
export type CanvasRightPanelTabName = 'layers' | 'gallery';
|
||||
|
||||
export interface UIState {
|
||||
/**
|
||||
* Slice schema version.
|
||||
*/
|
||||
_version: 3;
|
||||
_version: 4;
|
||||
/**
|
||||
* The currently active tab.
|
||||
*/
|
||||
|
||||
@@ -29,6 +29,7 @@ import {
|
||||
isSD3MainModelModelConfig,
|
||||
isSDXLMainModelModelConfig,
|
||||
isSigLipModelConfig,
|
||||
isSimpleTabModelConfig,
|
||||
isSpandrelImageToImageModelConfig,
|
||||
isT2IAdapterModelConfig,
|
||||
isT5EncoderModelConfig,
|
||||
@@ -59,6 +60,7 @@ const buildModelsHook =
|
||||
return [modelConfigs, result] as const;
|
||||
};
|
||||
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
|
||||
export const useSimpleTabModels = buildModelsHook(isSimpleTabModelConfig);
|
||||
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
|
||||
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
|
||||
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
|
||||
|
||||
@@ -276,6 +276,15 @@ export const isTIModelConfig = (config: AnyModelConfig): config is MainModelConf
|
||||
return config.type === 'embedding';
|
||||
};
|
||||
|
||||
export const isSimpleTabModelConfig = (
|
||||
config: AnyModelConfig
|
||||
): config is Extract<MainModelConfig, { base: 'chatgpt-4o' | 'flux' | 'sdxl' | 'sd-1' }> => {
|
||||
return (
|
||||
config.type === 'main' &&
|
||||
(config.base === 'chatgpt-4o' || config.base === 'flux' || config.base === 'sdxl' || config.base === 'sd-1')
|
||||
);
|
||||
};
|
||||
|
||||
export type ModelInstallJob = S['ModelInstallJob'];
|
||||
export type ModelInstallStatus = S['InstallStatus'];
|
||||
|
||||
|
||||
Reference in New Issue
Block a user