mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-14 09:54:56 -05:00
Implement automatic reference image model switching on base model change
Co-authored-by: kent <kent@invoke.ai>
This commit is contained in:
committed by
psychedelicious
parent
50079ea349
commit
1caab2b9c4
@@ -4,14 +4,56 @@ import { bboxSyncedToOptimalDimension } from 'features/controlLayers/store/canva
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
|
||||
import { modelChanged, syncedToOptimalDimension, vaeSelected } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectBboxModelBase } from 'features/controlLayers/store/selectors';
|
||||
import {
|
||||
refImageModelChanged,
|
||||
selectReferenceImageEntities
|
||||
} from 'features/controlLayers/store/refImagesSlice';
|
||||
import { selectBboxModelBase, selectAllEntities } from 'features/controlLayers/store/selectors';
|
||||
import {
|
||||
rgRefImageModelChanged
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import {
|
||||
getEntityIdentifier,
|
||||
isRegionalGuidanceEntityIdentifier
|
||||
} from 'features/controlLayers/store/types';
|
||||
import type {
|
||||
CanvasEntityState,
|
||||
RefImageState
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import {
|
||||
selectIPAdapterModels
|
||||
} from 'services/api/hooks/modelsByType';
|
||||
import type {
|
||||
AnyModelConfig
|
||||
} from 'services/api/types';
|
||||
import {
|
||||
isIPAdapterModelConfig,
|
||||
isFluxReduxModelConfig,
|
||||
isChatGPT4oModelConfig,
|
||||
isFluxKontextApiModelConfig,
|
||||
isFluxKontextModelConfig
|
||||
} from 'services/api/types';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { t } from 'i18next';
|
||||
|
||||
const log = logger('models');
|
||||
|
||||
// Selector for global reference image models
|
||||
const selectGlobalReferenceImageModels = (state: RootState): AnyModelConfig[] => {
|
||||
const allModels = selectIPAdapterModels(state);
|
||||
// Add other model types that can be used as reference images
|
||||
return allModels.filter((model: AnyModelConfig) =>
|
||||
isIPAdapterModelConfig(model) ||
|
||||
isFluxReduxModelConfig(model) ||
|
||||
isChatGPT4oModelConfig(model) ||
|
||||
isFluxKontextApiModelConfig(model) ||
|
||||
isFluxKontextModelConfig(model)
|
||||
);
|
||||
};
|
||||
|
||||
export const addModelSelectedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: modelSelected,
|
||||
@@ -34,7 +76,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
let modelsCleared = 0;
|
||||
|
||||
// handle incompatible loras
|
||||
state.loras.loras.forEach((lora) => {
|
||||
state.loras.loras.forEach((lora: any) => {
|
||||
if (lora.model.base !== newBaseModel) {
|
||||
dispatch(loraDeleted({ id: lora.id }));
|
||||
modelsCleared += 1;
|
||||
@@ -58,6 +100,61 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
// }
|
||||
// });
|
||||
|
||||
// Handle incompatible reference image models - switch to first compatible model
|
||||
const availableRefImageModels = selectGlobalReferenceImageModels(state).filter((model: AnyModelConfig) => model.base === newBaseModel);
|
||||
const firstCompatibleModel = availableRefImageModels[0] || null;
|
||||
|
||||
// Handle global reference images
|
||||
const refImageEntities = selectReferenceImageEntities(state);
|
||||
refImageEntities.forEach((entity: RefImageState) => {
|
||||
if (entity.config.model && entity.config.model.base !== newBaseModel) {
|
||||
dispatch(refImageModelChanged({
|
||||
id: entity.id,
|
||||
modelConfig: firstCompatibleModel
|
||||
}));
|
||||
if (firstCompatibleModel) {
|
||||
log.debug(
|
||||
{ oldModel: entity.config.model, newModel: firstCompatibleModel },
|
||||
'Switched global reference image model to compatible model'
|
||||
);
|
||||
} else {
|
||||
log.debug(
|
||||
{ oldModel: entity.config.model },
|
||||
'Cleared global reference image model - no compatible models available'
|
||||
);
|
||||
modelsCleared += 1;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Handle regional guidance reference images
|
||||
const canvasEntities = selectAllEntities(state.canvas.present);
|
||||
canvasEntities.forEach((entity: CanvasEntityState) => {
|
||||
if (isRegionalGuidanceEntityIdentifier(getEntityIdentifier(entity))) {
|
||||
entity.referenceImages.forEach((refImage: any) => {
|
||||
if (refImage.config.model && refImage.config.model.base !== newBaseModel) {
|
||||
dispatch(rgRefImageModelChanged({
|
||||
entityIdentifier: getEntityIdentifier(entity),
|
||||
referenceImageId: refImage.id,
|
||||
modelConfig: firstCompatibleModel
|
||||
}));
|
||||
if (firstCompatibleModel) {
|
||||
log.debug(
|
||||
{ oldModel: refImage.config.model, newModel: firstCompatibleModel },
|
||||
'Switched regional guidance reference image model to compatible model'
|
||||
);
|
||||
} else {
|
||||
log.debug(
|
||||
{ oldModel: refImage.config.model },
|
||||
'Cleared regional guidance reference image model - no compatible models available'
|
||||
);
|
||||
modelsCleared += 1;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
if (modelsCleared > 0) {
|
||||
toast({
|
||||
id: 'BASE_MODEL_CHANGED',
|
||||
|
||||
@@ -400,3 +400,12 @@ export type UploadImageArg = {
|
||||
|
||||
export type ImageUploadEntryResponse = S['ImageUploadEntry'];
|
||||
export type ImageUploadEntryRequest = paths['/api/v1/images/']['post']['requestBody']['content']['application/json'];
|
||||
|
||||
export const isApiModelConfig = (config: AnyModelConfig): config is ApiModelConfig => {
|
||||
return (
|
||||
isChatGPT4oModelConfig(config) ||
|
||||
isImagen3ModelConfig(config) ||
|
||||
isImagen4ModelConfig(config) ||
|
||||
isFluxKontextApiModelConfig(config)
|
||||
);
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user