mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Imagen4 working in UI
This commit is contained in:
committed by
psychedelicious
parent
2f35d74902
commit
27dc843046
@@ -3,6 +3,7 @@ import {
|
||||
selectIsChatGTP4o,
|
||||
selectIsCogView4,
|
||||
selectIsImagen3,
|
||||
selectIsImagen4,
|
||||
selectIsSD3,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import type { CanvasEntityType } from 'features/controlLayers/store/types';
|
||||
@@ -14,24 +15,25 @@ export const useIsEntityTypeEnabled = (entityType: CanvasEntityType) => {
|
||||
const isSD3 = useAppSelector(selectIsSD3);
|
||||
const isCogView4 = useAppSelector(selectIsCogView4);
|
||||
const isImagen3 = useAppSelector(selectIsImagen3);
|
||||
const isImagen4 = useAppSelector(selectIsImagen4);
|
||||
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
|
||||
|
||||
const isEntityTypeEnabled = useMemo<boolean>(() => {
|
||||
switch (entityType) {
|
||||
case 'reference_image':
|
||||
return !isSD3 && !isCogView4 && !isImagen3;
|
||||
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4;
|
||||
case 'regional_guidance':
|
||||
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
|
||||
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isChatGPT4o;
|
||||
case 'control_layer':
|
||||
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
|
||||
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isChatGPT4o;
|
||||
case 'inpaint_mask':
|
||||
return !isImagen3 && !isChatGPT4o;
|
||||
return !isImagen3 && !isImagen4 && !isChatGPT4o;
|
||||
case 'raster_layer':
|
||||
return !isImagen3 && !isChatGPT4o;
|
||||
return !isImagen3 && !isImagen4 && !isChatGPT4o;
|
||||
default:
|
||||
assert<Equals<typeof entityType, never>>(false);
|
||||
}
|
||||
}, [entityType, isSD3, isCogView4, isImagen3, isChatGPT4o]);
|
||||
}, [entityType, isSD3, isCogView4, isImagen3, isImagen4, isChatGPT4o]);
|
||||
|
||||
return isEntityTypeEnabled;
|
||||
};
|
||||
|
||||
@@ -235,7 +235,7 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
|
||||
if (tool !== 'bbox') {
|
||||
return NO_ANCHORS;
|
||||
}
|
||||
if (model?.base === 'imagen3' || model?.base === 'chatgpt-4o') {
|
||||
if (model?.base === 'imagen3' || model?.base === 'imagen4' || model?.base === 'chatgpt-4o') {
|
||||
// The bbox is not resizable in these modes
|
||||
return NO_ANCHORS;
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ import type {
|
||||
IPMethodV2,
|
||||
T2IAdapterConfig,
|
||||
} from './types';
|
||||
import { getEntityIdentifier, isChatGPT4oAspectRatioID, isImagen3AspectRatioID, isRenderableEntity } from './types';
|
||||
import { getEntityIdentifier, isChatGPT4oAspectRatioID, isImagenAspectRatioID, isRenderableEntity } from './types';
|
||||
import {
|
||||
converters,
|
||||
getControlLayerState,
|
||||
@@ -1236,7 +1236,10 @@ export const canvasSlice = createSlice({
|
||||
state.bbox.aspectRatio.id = id;
|
||||
if (id === 'Free') {
|
||||
state.bbox.aspectRatio.isLocked = false;
|
||||
} else if (state.bbox.modelBase === 'imagen3' && isImagen3AspectRatioID(id)) {
|
||||
} else if (
|
||||
(state.bbox.modelBase === 'imagen3' || state.bbox.modelBase === 'imagen4') &&
|
||||
isImagenAspectRatioID(id)
|
||||
) {
|
||||
// Imagen3 has specific output sizes that are not exactly the same as the aspect ratio. Need special handling.
|
||||
if (id === '16:9') {
|
||||
state.bbox.rect.width = 1408;
|
||||
@@ -1742,7 +1745,7 @@ export const canvasSlice = createSlice({
|
||||
const base = model?.base;
|
||||
if (isMainModelBase(base) && state.bbox.modelBase !== base) {
|
||||
state.bbox.modelBase = base;
|
||||
if (base === 'imagen3' || base === 'chatgpt-4o') {
|
||||
if (base === 'imagen3' || base === 'chatgpt-4o' || base === 'imagen4') {
|
||||
state.bbox.aspectRatio.isLocked = true;
|
||||
state.bbox.aspectRatio.value = 1;
|
||||
state.bbox.aspectRatio.id = '1:1';
|
||||
@@ -1881,7 +1884,11 @@ export const canvasPersistConfig: PersistConfig<CanvasState> = {
|
||||
};
|
||||
|
||||
const syncScaledSize = (state: CanvasState) => {
|
||||
if (state.bbox.modelBase === 'imagen3' || state.bbox.modelBase === 'chatgpt-4o') {
|
||||
if (
|
||||
state.bbox.modelBase === 'imagen3' ||
|
||||
state.bbox.modelBase === 'chatgpt-4o' ||
|
||||
state.bbox.modelBase === 'imagen4'
|
||||
) {
|
||||
// Imagen3 has fixed sizes. Scaled bbox is not supported.
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -381,6 +381,7 @@ export const selectIsFLUX = createParamsSelector((params) => params.model?.base
|
||||
export const selectIsSD3 = createParamsSelector((params) => params.model?.base === 'sd-3');
|
||||
export const selectIsCogView4 = createParamsSelector((params) => params.model?.base === 'cogview4');
|
||||
export const selectIsImagen3 = createParamsSelector((params) => params.model?.base === 'imagen3');
|
||||
export const selectIsImagen4 = createParamsSelector((params) => params.model?.base === 'imagen4');
|
||||
export const selectIsChatGTP4o = createParamsSelector((params) => params.model?.base === 'chatgpt-4o');
|
||||
|
||||
export const selectModel = createParamsSelector((params) => params.model);
|
||||
|
||||
@@ -406,7 +406,7 @@ export type StagingAreaImage = {
|
||||
export const zAspectRatioID = z.enum(['Free', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
|
||||
|
||||
export const zImagen3AspectRatioID = z.enum(['16:9', '4:3', '1:1', '3:4', '9:16']);
|
||||
export const isImagen3AspectRatioID = (v: unknown): v is z.infer<typeof zImagen3AspectRatioID> =>
|
||||
export const isImagenAspectRatioID = (v: unknown): v is z.infer<typeof zImagen3AspectRatioID> =>
|
||||
zImagen3AspectRatioID.safeParse(v).success;
|
||||
|
||||
export const zChatGPT4oAspectRatioID = z.enum(['3:2', '1:1', '2:3']);
|
||||
|
||||
@@ -17,6 +17,7 @@ export const BASE_COLOR_MAP: Record<BaseModelType, string> = {
|
||||
flux: 'gold',
|
||||
cogview4: 'red',
|
||||
imagen3: 'pink',
|
||||
imagen4: 'pink',
|
||||
'chatgpt-4o': 'pink',
|
||||
};
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import { FloatGeneratorFieldInputComponent } from 'features/nodes/components/flo
|
||||
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
|
||||
import { ImageGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageGeneratorFieldComponent';
|
||||
import Imagen3ModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/Imagen3ModelFieldInputComponent';
|
||||
import Imagen4ModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/Imagen4ModelFieldInputComponent';
|
||||
import { IntegerFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerFieldCollectionInputComponent';
|
||||
import { IntegerGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorFieldComponent';
|
||||
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
|
||||
@@ -63,6 +64,8 @@ import {
|
||||
isImageGeneratorFieldInputTemplate,
|
||||
isImagen3ModelFieldInputInstance,
|
||||
isImagen3ModelFieldInputTemplate,
|
||||
isImagen4ModelFieldInputInstance,
|
||||
isImagen4ModelFieldInputTemplate,
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
isIntegerFieldCollectionInputTemplate,
|
||||
isIntegerFieldInputInstance,
|
||||
@@ -407,6 +410,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
|
||||
return <Imagen3ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isImagen4ModelFieldInputTemplate(template)) {
|
||||
if (!isImagen4ModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <Imagen4ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isChatGPT4oModelFieldInputTemplate(template)) {
|
||||
if (!isChatGPT4oModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldImagen4ModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { Imagen4ModelFieldInputInstance, Imagen4ModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useImagen4Models } from 'services/api/hooks/modelsByType';
|
||||
import type { ApiModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const Imagen4ModelFieldInputComponent = (
|
||||
props: FieldComponentProps<Imagen4ModelFieldInputInstance, Imagen4ModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useImagen4Models();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: ApiModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldImagen4ModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(Imagen4ModelFieldInputComponent);
|
||||
@@ -123,6 +123,8 @@ const NODE_TYPE_PUBLISH_DENYLIST = [
|
||||
'metadata_to_t2i_adapters',
|
||||
'google_imagen3_generate',
|
||||
'google_imagen3_edit',
|
||||
'google_imagen4_generate',
|
||||
'google_imagen4_edit',
|
||||
'chatgpt_create_image',
|
||||
'chatgpt_edit_image',
|
||||
];
|
||||
|
||||
@@ -40,6 +40,7 @@ import type {
|
||||
ImageFieldValue,
|
||||
ImageGeneratorFieldValue,
|
||||
Imagen3ModelFieldValue,
|
||||
Imagen4ModelFieldValue,
|
||||
IntegerFieldCollectionValue,
|
||||
IntegerFieldValue,
|
||||
IntegerGeneratorFieldValue,
|
||||
@@ -80,6 +81,7 @@ import {
|
||||
zImageFieldValue,
|
||||
zImageGeneratorFieldValue,
|
||||
zImagen3ModelFieldValue,
|
||||
zImagen4ModelFieldValue,
|
||||
zIntegerFieldCollectionValue,
|
||||
zIntegerFieldValue,
|
||||
zIntegerGeneratorFieldValue,
|
||||
@@ -519,6 +521,9 @@ export const nodesSlice = createSlice({
|
||||
fieldImagen3ModelValueChanged: (state, action: FieldValueAction<Imagen3ModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zImagen3ModelFieldValue);
|
||||
},
|
||||
fieldImagen4ModelValueChanged: (state, action: FieldValueAction<Imagen4ModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zImagen4ModelFieldValue);
|
||||
},
|
||||
fieldChatGPT4oModelValueChanged: (state, action: FieldValueAction<ChatGPT4oModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zChatGPT4oModelFieldValue);
|
||||
},
|
||||
@@ -690,6 +695,7 @@ export const {
|
||||
fieldSigLipModelValueChanged,
|
||||
fieldFluxReduxModelValueChanged,
|
||||
fieldImagen3ModelValueChanged,
|
||||
fieldImagen4ModelValueChanged,
|
||||
fieldChatGPT4oModelValueChanged,
|
||||
fieldFloatGeneratorValueChanged,
|
||||
fieldIntegerGeneratorValueChanged,
|
||||
|
||||
@@ -76,10 +76,21 @@ const zBaseModel = z.enum([
|
||||
'flux',
|
||||
'cogview4',
|
||||
'imagen3',
|
||||
'imagen4',
|
||||
'chatgpt-4o',
|
||||
]);
|
||||
export type BaseModelType = z.infer<typeof zBaseModel>;
|
||||
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4', 'imagen3', 'chatgpt-4o']);
|
||||
export const zMainModelBase = z.enum([
|
||||
'sd-1',
|
||||
'sd-2',
|
||||
'sd-3',
|
||||
'sdxl',
|
||||
'flux',
|
||||
'cogview4',
|
||||
'imagen3',
|
||||
'imagen4',
|
||||
'chatgpt-4o',
|
||||
]);
|
||||
export type MainModelBase = z.infer<typeof zMainModelBase>;
|
||||
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
|
||||
const zModelType = z.enum([
|
||||
|
||||
@@ -252,6 +252,10 @@ const zImagen3ModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('Imagen3ModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zImagen4ModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('Imagen4ModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zChatGPT4oModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ChatGPT4oModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -307,6 +311,7 @@ const zStatefulFieldType = z.union([
|
||||
zSigLipModelFieldType,
|
||||
zFluxReduxModelFieldType,
|
||||
zImagen3ModelFieldType,
|
||||
zImagen4ModelFieldType,
|
||||
zChatGPT4oModelFieldType,
|
||||
zColorFieldType,
|
||||
zSchedulerFieldType,
|
||||
@@ -347,6 +352,7 @@ const modelFieldTypeNames = [
|
||||
zSigLipModelFieldType.shape.name.value,
|
||||
zFluxReduxModelFieldType.shape.name.value,
|
||||
zImagen3ModelFieldType.shape.name.value,
|
||||
zImagen4ModelFieldType.shape.name.value,
|
||||
zChatGPT4oModelFieldType.shape.name.value,
|
||||
// Stateless model fields
|
||||
'UNetField',
|
||||
@@ -1207,6 +1213,24 @@ export const isImagen3ModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<Imagen3ModelFieldInputTemplate>('Imagen3ModelField');
|
||||
// #endregion
|
||||
|
||||
// #region Imagen4ModelField
|
||||
export const zImagen4ModelFieldValue = zModelIdentifierField.optional();
|
||||
const zImagen4ModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zImagen4ModelFieldValue,
|
||||
});
|
||||
const zImagen4ModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zImagen4ModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zImagen4ModelFieldValue,
|
||||
});
|
||||
export type Imagen4ModelFieldValue = z.infer<typeof zImagen4ModelFieldValue>;
|
||||
export type Imagen4ModelFieldInputInstance = z.infer<typeof zImagen4ModelFieldInputInstance>;
|
||||
export type Imagen4ModelFieldInputTemplate = z.infer<typeof zImagen4ModelFieldInputTemplate>;
|
||||
export const isImagen4ModelFieldInputInstance = buildInstanceTypeGuard(zImagen4ModelFieldInputInstance);
|
||||
export const isImagen4ModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<Imagen4ModelFieldInputTemplate>('Imagen4ModelField');
|
||||
// #endregion
|
||||
|
||||
// #region ChatGPT4oModelField
|
||||
export const zChatGPT4oModelFieldValue = zModelIdentifierField.optional();
|
||||
const zChatGPT4oModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
@@ -1857,6 +1881,7 @@ export const zStatefulFieldValue = z.union([
|
||||
zSigLipModelFieldValue,
|
||||
zFluxReduxModelFieldValue,
|
||||
zImagen3ModelFieldValue,
|
||||
zImagen4ModelFieldValue,
|
||||
zChatGPT4oModelFieldValue,
|
||||
zColorFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
@@ -1949,6 +1974,7 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zSigLipModelFieldInputTemplate,
|
||||
zFluxReduxModelFieldInputTemplate,
|
||||
zImagen3ModelFieldInputTemplate,
|
||||
zImagen4ModelFieldInputTemplate,
|
||||
zChatGPT4oModelFieldInputTemplate,
|
||||
zColorFieldInputTemplate,
|
||||
zSchedulerFieldInputTemplate,
|
||||
|
||||
@@ -4,7 +4,7 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { isImagen3AspectRatioID } from 'features/controlLayers/store/types';
|
||||
import { isImagenAspectRatioID } from 'features/controlLayers/store/types';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import {
|
||||
@@ -24,7 +24,7 @@ export const buildImagen3Graph = async (state: RootState, manager: CanvasManager
|
||||
const generationMode = await manager.compositor.getGenerationMode();
|
||||
|
||||
if (generationMode !== 'txt2img') {
|
||||
throw new UnsupportedGenerationModeError(t('toast.imagen3IncompatibleGenerationMode'));
|
||||
throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'Imagen3' }));
|
||||
}
|
||||
|
||||
log.debug({ generationMode }, 'Building Imagen3 graph');
|
||||
@@ -38,7 +38,7 @@ export const buildImagen3Graph = async (state: RootState, manager: CanvasManager
|
||||
|
||||
assert(model, 'No model found for Imagen3 graph');
|
||||
assert(model.base === 'imagen3', 'Imagen3 graph requires Imagen3 model');
|
||||
assert(isImagen3AspectRatioID(bbox.aspectRatio.id), 'Imagen3 does not support this aspect ratio');
|
||||
assert(isImagenAspectRatioID(bbox.aspectRatio.id), 'Imagen3 does not support this aspect ratio');
|
||||
assert(positivePrompt.length > 0, 'Imagen3 requires positive prompt to have at least one character');
|
||||
|
||||
const is_intermediate = canvasSettings.sendToCanvas;
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { isImagenAspectRatioID } from 'features/controlLayers/store/types';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import {
|
||||
CANVAS_OUTPUT_PREFIX,
|
||||
getBoardField,
|
||||
selectPresetModifiedPrompts,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { type GraphBuilderReturn, UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { t } from 'i18next';
|
||||
import { selectMainModelConfig } from 'services/api/endpoints/models';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildImagen4Graph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
|
||||
const generationMode = await manager.compositor.getGenerationMode();
|
||||
|
||||
if (generationMode !== 'txt2img') {
|
||||
throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'Imagen4' }));
|
||||
}
|
||||
|
||||
log.debug({ generationMode }, 'Building Imagen4 graph');
|
||||
|
||||
const canvas = selectCanvasSlice(state);
|
||||
const canvasSettings = selectCanvasSettingsSlice(state);
|
||||
|
||||
const { bbox } = canvas;
|
||||
const { positivePrompt, negativePrompt } = selectPresetModifiedPrompts(state);
|
||||
const model = selectMainModelConfig(state);
|
||||
|
||||
assert(model, 'No model found for Imagen4 graph');
|
||||
assert(model.base === 'imagen4', 'Imagen4 graph requires Imagen4 model');
|
||||
assert(isImagenAspectRatioID(bbox.aspectRatio.id), 'Imagen4 does not support this aspect ratio');
|
||||
assert(positivePrompt.length > 0, 'Imagen4 requires positive prompt to have at least one character');
|
||||
|
||||
const is_intermediate = canvasSettings.sendToCanvas;
|
||||
const board = canvasSettings.sendToCanvas ? undefined : getBoardField(state);
|
||||
|
||||
if (generationMode === 'txt2img') {
|
||||
const g = new Graph(getPrefixedId('imagen4_txt2img_graph'));
|
||||
const imagen4 = g.addNode({
|
||||
// @ts-expect-error: These nodes are not available in the OSS application
|
||||
type: 'google_imagen4_generate_image',
|
||||
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
|
||||
model: zModelIdentifierField.parse(model),
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
aspect_ratio: bbox.aspectRatio.id,
|
||||
enhance_prompt: true,
|
||||
// When enhance_prompt is true, Imagen4 will return a new image every time, ignoring the seed.
|
||||
use_cache: false,
|
||||
is_intermediate,
|
||||
board,
|
||||
});
|
||||
g.upsertMetadata({
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
width: bbox.rect.width,
|
||||
height: bbox.rect.height,
|
||||
model: Graph.getModelMetadataField(model),
|
||||
});
|
||||
return {
|
||||
g,
|
||||
seedFieldIdentifier: { nodeId: imagen4.id, fieldName: 'seed' },
|
||||
positivePromptFieldIdentifier: { nodeId: imagen4.id, fieldName: 'positive_prompt' },
|
||||
};
|
||||
}
|
||||
|
||||
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for Imagen4');
|
||||
};
|
||||
@@ -34,6 +34,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
SigLipModelField: undefined,
|
||||
FluxReduxModelField: undefined,
|
||||
Imagen3ModelField: undefined,
|
||||
Imagen4ModelField: undefined,
|
||||
ChatGPT4oModelField: undefined,
|
||||
FloatGeneratorField: undefined,
|
||||
IntegerGeneratorField: undefined,
|
||||
|
||||
@@ -23,6 +23,7 @@ import type {
|
||||
ImageFieldInputTemplate,
|
||||
ImageGeneratorFieldInputTemplate,
|
||||
Imagen3ModelFieldInputTemplate,
|
||||
Imagen4ModelFieldInputTemplate,
|
||||
IntegerFieldCollectionInputTemplate,
|
||||
IntegerFieldInputTemplate,
|
||||
IntegerGeneratorFieldInputTemplate,
|
||||
@@ -600,6 +601,18 @@ const buildImagen3ModelFieldInputTemplate: FieldInputTemplateBuilder<Imagen3Mode
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildImagen4ModelFieldInputTemplate: FieldInputTemplateBuilder<Imagen4ModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: Imagen4ModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
return template;
|
||||
};
|
||||
const buildChatGPT4oModelFieldInputTemplate: FieldInputTemplateBuilder<ChatGPT4oModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -820,6 +833,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
||||
SigLipModelField: buildSigLipModelFieldInputTemplate,
|
||||
FluxReduxModelField: buildFluxReduxModelFieldInputTemplate,
|
||||
Imagen3ModelField: buildImagen3ModelFieldInputTemplate,
|
||||
Imagen4ModelField: buildImagen4ModelFieldInputTemplate,
|
||||
ChatGPT4oModelField: buildChatGPT4oModelFieldInputTemplate,
|
||||
FloatGeneratorField: buildFloatGeneratorFieldInputTemplate,
|
||||
IntegerGeneratorField: buildIntegerGeneratorFieldInputTemplate,
|
||||
|
||||
@@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { bboxAspectRatioIdChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { selectIsChatGTP4o, selectIsImagen3 } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectIsChatGTP4o, selectIsImagen3, selectIsImagen4 } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectAspectRatioID } from 'features/controlLayers/store/selectors';
|
||||
import {
|
||||
isAspectRatioID,
|
||||
@@ -23,10 +23,10 @@ export const BboxAspectRatioSelect = memo(() => {
|
||||
const isStaging = useAppSelector(selectIsStaging);
|
||||
const isImagen3 = useAppSelector(selectIsImagen3);
|
||||
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
|
||||
|
||||
const isImagen4 = useAppSelector(selectIsImagen4);
|
||||
const options = useMemo(() => {
|
||||
// Imagen3 and ChatGPT4o have different aspect ratio options, and do not support freeform sizes
|
||||
if (isImagen3) {
|
||||
if (isImagen3 || isImagen4) {
|
||||
return zImagen3AspectRatioID.options;
|
||||
}
|
||||
if (isChatGPT4o) {
|
||||
@@ -34,7 +34,7 @@ export const BboxAspectRatioSelect = memo(() => {
|
||||
}
|
||||
// All other models
|
||||
return zAspectRatioID.options;
|
||||
}, [isImagen3, isChatGPT4o]);
|
||||
}, [isImagen3, isChatGPT4o, isImagen4]);
|
||||
|
||||
const onChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
|
||||
(e) => {
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { selectIsChatGTP4o, selectIsImagen3 } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectIsChatGTP4o, selectIsImagen3, selectIsImagen4 } from 'features/controlLayers/store/paramsSlice';
|
||||
|
||||
export const useIsBboxSizeLocked = () => {
|
||||
const isStaging = useAppSelector(selectIsStaging);
|
||||
const isImagen3 = useAppSelector(selectIsImagen3);
|
||||
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
|
||||
return isImagen3 || isChatGPT4o || isStaging;
|
||||
const isImagen4 = useAppSelector(selectIsImagen4);
|
||||
|
||||
return isImagen3 || isChatGPT4o || isImagen4 || isStaging;
|
||||
};
|
||||
|
||||
@@ -59,28 +59,28 @@ const NoOptionsFallback = memo(() => {
|
||||
NoOptionsFallback.displayName = 'NoOptionsFallback';
|
||||
|
||||
const getGroupIDFromModelConfig = (modelConfig: AnyModelConfig): string => {
|
||||
if (modelConfig.base === 'chatgpt-4o' || modelConfig.base === 'imagen3') {
|
||||
if (modelConfig.base === 'chatgpt-4o' || modelConfig.base === 'imagen3' || modelConfig.base === 'imagen4') {
|
||||
return 'api';
|
||||
}
|
||||
return modelConfig.base;
|
||||
};
|
||||
|
||||
const getGroupNameFromModelConfig = (modelConfig: AnyModelConfig): string => {
|
||||
if (modelConfig.base === 'chatgpt-4o' || modelConfig.base === 'imagen3') {
|
||||
if (modelConfig.base === 'chatgpt-4o' || modelConfig.base === 'imagen3' || modelConfig.base === 'imagen4') {
|
||||
return 'External API';
|
||||
}
|
||||
return MODEL_TYPE_MAP[modelConfig.base];
|
||||
};
|
||||
|
||||
const getGroupShortNameFromModelConfig = (modelConfig: AnyModelConfig): string => {
|
||||
if (modelConfig.base === 'chatgpt-4o' || modelConfig.base === 'imagen3') {
|
||||
if (modelConfig.base === 'chatgpt-4o' || modelConfig.base === 'imagen3' || modelConfig.base === 'imagen4') {
|
||||
return 'api';
|
||||
}
|
||||
return MODEL_TYPE_SHORT_MAP[modelConfig.base];
|
||||
};
|
||||
|
||||
const getGroupColorSchemeFromModelConfig = (modelConfig: AnyModelConfig): string => {
|
||||
if (modelConfig.base === 'chatgpt-4o' || modelConfig.base === 'imagen3') {
|
||||
if (modelConfig.base === 'chatgpt-4o' || modelConfig.base === 'imagen3' || modelConfig.base === 'imagen4') {
|
||||
return 'pink';
|
||||
}
|
||||
return BASE_COLOR_MAP[modelConfig.base];
|
||||
|
||||
@@ -14,6 +14,7 @@ export const MODEL_TYPE_MAP: Record<BaseModelType, string> = {
|
||||
flux: 'FLUX',
|
||||
cogview4: 'CogView4',
|
||||
imagen3: 'Imagen3',
|
||||
imagen4: 'Imagen4',
|
||||
'chatgpt-4o': 'ChatGPT 4o',
|
||||
};
|
||||
|
||||
@@ -30,6 +31,7 @@ export const MODEL_TYPE_SHORT_MAP: Record<BaseModelType, string> = {
|
||||
flux: 'FLUX',
|
||||
cogview4: 'CogView4',
|
||||
imagen3: 'Imagen3',
|
||||
imagen4: 'Imagen4',
|
||||
'chatgpt-4o': 'ChatGPT 4o',
|
||||
};
|
||||
|
||||
@@ -73,6 +75,10 @@ export const CLIP_SKIP_MAP: Record<BaseModelType, { maxClip: number; markers: nu
|
||||
maxClip: 0,
|
||||
markers: [],
|
||||
},
|
||||
imagen4: {
|
||||
maxClip: 0,
|
||||
markers: [],
|
||||
},
|
||||
'chatgpt-4o': {
|
||||
maxClip: 0,
|
||||
markers: [],
|
||||
|
||||
@@ -19,6 +19,7 @@ export const getOptimalDimension = (base?: BaseModelType | null): number => {
|
||||
case 'sd-3':
|
||||
case 'cogview4':
|
||||
case 'imagen3':
|
||||
case 'imagen4':
|
||||
case 'chatgpt-4o':
|
||||
default:
|
||||
return 1024;
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
selectIsCogView4,
|
||||
selectIsFLUX,
|
||||
selectIsImagen3,
|
||||
selectIsImagen4,
|
||||
selectIsSD3,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import { LoRAList } from 'features/lora/components/LoRAList';
|
||||
@@ -41,11 +42,12 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
const isSD3 = useAppSelector(selectIsSD3);
|
||||
const isCogView4 = useAppSelector(selectIsCogView4);
|
||||
const isImagen3 = useAppSelector(selectIsImagen3);
|
||||
const isImagen4 = useAppSelector(selectIsImagen4);
|
||||
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
|
||||
|
||||
const isApiModel = useMemo(() => {
|
||||
return isImagen3 || isChatGPT4o;
|
||||
}, [isImagen3, isChatGPT4o]);
|
||||
return isImagen3 || isImagen4 || isChatGPT4o;
|
||||
}, [isImagen3, isImagen4, isChatGPT4o]);
|
||||
|
||||
const isUpscaling = useMemo(() => {
|
||||
return activeTabName === 'upscaling';
|
||||
@@ -56,7 +58,7 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
const enabledLoRAsCount = loras.loras.filter((l) => l.isEnabled).length;
|
||||
const loraTabBadges = enabledLoRAsCount ? [`${enabledLoRAsCount} ${t('models.concepts')}`] : EMPTY_ARRAY;
|
||||
const accordionBadges =
|
||||
modelConfig?.base === 'imagen3' || modelConfig?.base === 'chatgpt-4o'
|
||||
modelConfig?.base === 'imagen3' || modelConfig?.base === 'chatgpt-4o' || modelConfig?.base === 'imagen4'
|
||||
? [modelConfig.name]
|
||||
: modelConfig
|
||||
? [modelConfig.name, modelConfig.base]
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
selectIsChatGTP4o,
|
||||
selectIsFLUX,
|
||||
selectIsImagen3,
|
||||
selectIsImagen4,
|
||||
selectIsSD3,
|
||||
selectParamsSlice,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
@@ -67,10 +68,10 @@ export const ImageSettingsAccordion = memo(() => {
|
||||
const isSD3 = useAppSelector(selectIsSD3);
|
||||
const isImagen3 = useAppSelector(selectIsImagen3);
|
||||
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
|
||||
|
||||
const isImagen4 = useAppSelector(selectIsImagen4);
|
||||
const isApiModel = useMemo(() => {
|
||||
return isImagen3 || isChatGPT4o;
|
||||
}, [isImagen3, isChatGPT4o]);
|
||||
return isImagen3 || isChatGPT4o || isImagen4;
|
||||
}, [isImagen3, isChatGPT4o, isImagen4]);
|
||||
|
||||
return (
|
||||
<StandaloneAccordion
|
||||
|
||||
@@ -6,6 +6,7 @@ import {
|
||||
selectIsChatGTP4o,
|
||||
selectIsCogView4,
|
||||
selectIsImagen3,
|
||||
selectIsImagen4,
|
||||
selectIsSDXL,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import { Prompts } from 'features/parameters/components/Prompts/Prompts';
|
||||
@@ -30,12 +31,13 @@ const ParametersPanelTextToImage = () => {
|
||||
const isSDXL = useAppSelector(selectIsSDXL);
|
||||
const isCogview4 = useAppSelector(selectIsCogView4);
|
||||
const isImagen3 = useAppSelector(selectIsImagen3);
|
||||
const isImagen4 = useAppSelector(selectIsImagen4);
|
||||
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
|
||||
const isStylePresetsMenuOpen = useStore($isStylePresetsMenuOpen);
|
||||
|
||||
const isApiModel = useMemo(() => {
|
||||
return isImagen3 || isChatGPT4o;
|
||||
}, [isImagen3, isChatGPT4o]);
|
||||
return isImagen3 || isChatGPT4o || isImagen4;
|
||||
}, [isImagen3, isChatGPT4o, isImagen4]);
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" flexDir="column" gap={2}>
|
||||
|
||||
Reference in New Issue
Block a user