Imagen4 working in UI

This commit is contained in:
Mary Hipp
2025-05-21 20:54:50 -04:00
committed by psychedelicious
parent 2f35d74902
commit 27dc843046
29 changed files with 601 additions and 37 deletions

View File

@@ -3,6 +3,7 @@ import {
selectIsChatGTP4o,
selectIsCogView4,
selectIsImagen3,
selectIsImagen4,
selectIsSD3,
} from 'features/controlLayers/store/paramsSlice';
import type { CanvasEntityType } from 'features/controlLayers/store/types';
@@ -14,24 +15,25 @@ export const useIsEntityTypeEnabled = (entityType: CanvasEntityType) => {
const isSD3 = useAppSelector(selectIsSD3);
const isCogView4 = useAppSelector(selectIsCogView4);
const isImagen3 = useAppSelector(selectIsImagen3);
const isImagen4 = useAppSelector(selectIsImagen4);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isEntityTypeEnabled = useMemo<boolean>(() => {
switch (entityType) {
case 'reference_image':
return !isSD3 && !isCogView4 && !isImagen3;
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4;
case 'regional_guidance':
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isChatGPT4o;
case 'control_layer':
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isChatGPT4o;
case 'inpaint_mask':
return !isImagen3 && !isChatGPT4o;
return !isImagen3 && !isImagen4 && !isChatGPT4o;
case 'raster_layer':
return !isImagen3 && !isChatGPT4o;
return !isImagen3 && !isImagen4 && !isChatGPT4o;
default:
assert<Equals<typeof entityType, never>>(false);
}
}, [entityType, isSD3, isCogView4, isImagen3, isChatGPT4o]);
}, [entityType, isSD3, isCogView4, isImagen3, isImagen4, isChatGPT4o]);
return isEntityTypeEnabled;
};

View File

@@ -235,7 +235,7 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
if (tool !== 'bbox') {
return NO_ANCHORS;
}
if (model?.base === 'imagen3' || model?.base === 'chatgpt-4o') {
if (model?.base === 'imagen3' || model?.base === 'imagen4' || model?.base === 'chatgpt-4o') {
// The bbox is not resizable in these modes
return NO_ANCHORS;
}

View File

@@ -68,7 +68,7 @@ import type {
IPMethodV2,
T2IAdapterConfig,
} from './types';
import { getEntityIdentifier, isChatGPT4oAspectRatioID, isImagen3AspectRatioID, isRenderableEntity } from './types';
import { getEntityIdentifier, isChatGPT4oAspectRatioID, isImagenAspectRatioID, isRenderableEntity } from './types';
import {
converters,
getControlLayerState,
@@ -1236,7 +1236,10 @@ export const canvasSlice = createSlice({
state.bbox.aspectRatio.id = id;
if (id === 'Free') {
state.bbox.aspectRatio.isLocked = false;
} else if (state.bbox.modelBase === 'imagen3' && isImagen3AspectRatioID(id)) {
} else if (
(state.bbox.modelBase === 'imagen3' || state.bbox.modelBase === 'imagen4') &&
isImagenAspectRatioID(id)
) {
// Imagen3 has specific output sizes that are not exactly the same as the aspect ratio. Need special handling.
if (id === '16:9') {
state.bbox.rect.width = 1408;
@@ -1742,7 +1745,7 @@ export const canvasSlice = createSlice({
const base = model?.base;
if (isMainModelBase(base) && state.bbox.modelBase !== base) {
state.bbox.modelBase = base;
if (base === 'imagen3' || base === 'chatgpt-4o') {
if (base === 'imagen3' || base === 'chatgpt-4o' || base === 'imagen4') {
state.bbox.aspectRatio.isLocked = true;
state.bbox.aspectRatio.value = 1;
state.bbox.aspectRatio.id = '1:1';
@@ -1881,7 +1884,11 @@ export const canvasPersistConfig: PersistConfig<CanvasState> = {
};
const syncScaledSize = (state: CanvasState) => {
if (state.bbox.modelBase === 'imagen3' || state.bbox.modelBase === 'chatgpt-4o') {
if (
state.bbox.modelBase === 'imagen3' ||
state.bbox.modelBase === 'chatgpt-4o' ||
state.bbox.modelBase === 'imagen4'
) {
// Imagen3 has fixed sizes. Scaled bbox is not supported.
return;
}

View File

@@ -381,6 +381,7 @@ export const selectIsFLUX = createParamsSelector((params) => params.model?.base
export const selectIsSD3 = createParamsSelector((params) => params.model?.base === 'sd-3');
export const selectIsCogView4 = createParamsSelector((params) => params.model?.base === 'cogview4');
export const selectIsImagen3 = createParamsSelector((params) => params.model?.base === 'imagen3');
export const selectIsImagen4 = createParamsSelector((params) => params.model?.base === 'imagen4');
export const selectIsChatGTP4o = createParamsSelector((params) => params.model?.base === 'chatgpt-4o');
export const selectModel = createParamsSelector((params) => params.model);

View File

@@ -406,7 +406,7 @@ export type StagingAreaImage = {
export const zAspectRatioID = z.enum(['Free', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
export const zImagen3AspectRatioID = z.enum(['16:9', '4:3', '1:1', '3:4', '9:16']);
export const isImagen3AspectRatioID = (v: unknown): v is z.infer<typeof zImagen3AspectRatioID> =>
export const isImagenAspectRatioID = (v: unknown): v is z.infer<typeof zImagen3AspectRatioID> =>
zImagen3AspectRatioID.safeParse(v).success;
export const zChatGPT4oAspectRatioID = z.enum(['3:2', '1:1', '2:3']);

View File

@@ -17,6 +17,7 @@ export const BASE_COLOR_MAP: Record<BaseModelType, string> = {
flux: 'gold',
cogview4: 'red',
imagen3: 'pink',
imagen4: 'pink',
'chatgpt-4o': 'pink',
};

View File

@@ -7,6 +7,7 @@ import { FloatGeneratorFieldInputComponent } from 'features/nodes/components/flo
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
import { ImageGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageGeneratorFieldComponent';
import Imagen3ModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/Imagen3ModelFieldInputComponent';
import Imagen4ModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/Imagen4ModelFieldInputComponent';
import { IntegerFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerFieldCollectionInputComponent';
import { IntegerGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorFieldComponent';
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
@@ -63,6 +64,8 @@ import {
isImageGeneratorFieldInputTemplate,
isImagen3ModelFieldInputInstance,
isImagen3ModelFieldInputTemplate,
isImagen4ModelFieldInputInstance,
isImagen4ModelFieldInputTemplate,
isIntegerFieldCollectionInputInstance,
isIntegerFieldCollectionInputTemplate,
isIntegerFieldInputInstance,
@@ -407,6 +410,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <Imagen3ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isImagen4ModelFieldInputTemplate(template)) {
if (!isImagen4ModelFieldInputInstance(field)) {
return null;
}
return <Imagen4ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isChatGPT4oModelFieldInputTemplate(template)) {
if (!isChatGPT4oModelFieldInputInstance(field)) {
return null;

View File

@@ -0,0 +1,46 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldImagen4ModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { Imagen4ModelFieldInputInstance, Imagen4ModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useImagen4Models } from 'services/api/hooks/modelsByType';
import type { ApiModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const Imagen4ModelFieldInputComponent = (
props: FieldComponentProps<Imagen4ModelFieldInputInstance, Imagen4ModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useImagen4Models();
const onChange = useCallback(
(value: ApiModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldImagen4ModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(Imagen4ModelFieldInputComponent);

View File

@@ -123,6 +123,8 @@ const NODE_TYPE_PUBLISH_DENYLIST = [
'metadata_to_t2i_adapters',
'google_imagen3_generate',
'google_imagen3_edit',
'google_imagen4_generate',
'google_imagen4_edit',
'chatgpt_create_image',
'chatgpt_edit_image',
];

View File

@@ -40,6 +40,7 @@ import type {
ImageFieldValue,
ImageGeneratorFieldValue,
Imagen3ModelFieldValue,
Imagen4ModelFieldValue,
IntegerFieldCollectionValue,
IntegerFieldValue,
IntegerGeneratorFieldValue,
@@ -80,6 +81,7 @@ import {
zImageFieldValue,
zImageGeneratorFieldValue,
zImagen3ModelFieldValue,
zImagen4ModelFieldValue,
zIntegerFieldCollectionValue,
zIntegerFieldValue,
zIntegerGeneratorFieldValue,
@@ -519,6 +521,9 @@ export const nodesSlice = createSlice({
fieldImagen3ModelValueChanged: (state, action: FieldValueAction<Imagen3ModelFieldValue>) => {
fieldValueReducer(state, action, zImagen3ModelFieldValue);
},
fieldImagen4ModelValueChanged: (state, action: FieldValueAction<Imagen4ModelFieldValue>) => {
fieldValueReducer(state, action, zImagen4ModelFieldValue);
},
fieldChatGPT4oModelValueChanged: (state, action: FieldValueAction<ChatGPT4oModelFieldValue>) => {
fieldValueReducer(state, action, zChatGPT4oModelFieldValue);
},
@@ -690,6 +695,7 @@ export const {
fieldSigLipModelValueChanged,
fieldFluxReduxModelValueChanged,
fieldImagen3ModelValueChanged,
fieldImagen4ModelValueChanged,
fieldChatGPT4oModelValueChanged,
fieldFloatGeneratorValueChanged,
fieldIntegerGeneratorValueChanged,

View File

@@ -76,10 +76,21 @@ const zBaseModel = z.enum([
'flux',
'cogview4',
'imagen3',
'imagen4',
'chatgpt-4o',
]);
export type BaseModelType = z.infer<typeof zBaseModel>;
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4', 'imagen3', 'chatgpt-4o']);
export const zMainModelBase = z.enum([
'sd-1',
'sd-2',
'sd-3',
'sdxl',
'flux',
'cogview4',
'imagen3',
'imagen4',
'chatgpt-4o',
]);
export type MainModelBase = z.infer<typeof zMainModelBase>;
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
const zModelType = z.enum([

View File

@@ -252,6 +252,10 @@ const zImagen3ModelFieldType = zFieldTypeBase.extend({
name: z.literal('Imagen3ModelField'),
originalType: zStatelessFieldType.optional(),
});
const zImagen4ModelFieldType = zFieldTypeBase.extend({
name: z.literal('Imagen4ModelField'),
originalType: zStatelessFieldType.optional(),
});
const zChatGPT4oModelFieldType = zFieldTypeBase.extend({
name: z.literal('ChatGPT4oModelField'),
originalType: zStatelessFieldType.optional(),
@@ -307,6 +311,7 @@ const zStatefulFieldType = z.union([
zSigLipModelFieldType,
zFluxReduxModelFieldType,
zImagen3ModelFieldType,
zImagen4ModelFieldType,
zChatGPT4oModelFieldType,
zColorFieldType,
zSchedulerFieldType,
@@ -347,6 +352,7 @@ const modelFieldTypeNames = [
zSigLipModelFieldType.shape.name.value,
zFluxReduxModelFieldType.shape.name.value,
zImagen3ModelFieldType.shape.name.value,
zImagen4ModelFieldType.shape.name.value,
zChatGPT4oModelFieldType.shape.name.value,
// Stateless model fields
'UNetField',
@@ -1207,6 +1213,24 @@ export const isImagen3ModelFieldInputTemplate =
buildTemplateTypeGuard<Imagen3ModelFieldInputTemplate>('Imagen3ModelField');
// #endregion
// #region Imagen4ModelField
export const zImagen4ModelFieldValue = zModelIdentifierField.optional();
const zImagen4ModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zImagen4ModelFieldValue,
});
const zImagen4ModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zImagen4ModelFieldType,
originalType: zFieldType.optional(),
default: zImagen4ModelFieldValue,
});
export type Imagen4ModelFieldValue = z.infer<typeof zImagen4ModelFieldValue>;
export type Imagen4ModelFieldInputInstance = z.infer<typeof zImagen4ModelFieldInputInstance>;
export type Imagen4ModelFieldInputTemplate = z.infer<typeof zImagen4ModelFieldInputTemplate>;
export const isImagen4ModelFieldInputInstance = buildInstanceTypeGuard(zImagen4ModelFieldInputInstance);
export const isImagen4ModelFieldInputTemplate =
buildTemplateTypeGuard<Imagen4ModelFieldInputTemplate>('Imagen4ModelField');
// #endregion
// #region ChatGPT4oModelField
export const zChatGPT4oModelFieldValue = zModelIdentifierField.optional();
const zChatGPT4oModelFieldInputInstance = zFieldInputInstanceBase.extend({
@@ -1857,6 +1881,7 @@ export const zStatefulFieldValue = z.union([
zSigLipModelFieldValue,
zFluxReduxModelFieldValue,
zImagen3ModelFieldValue,
zImagen4ModelFieldValue,
zChatGPT4oModelFieldValue,
zColorFieldValue,
zSchedulerFieldValue,
@@ -1949,6 +1974,7 @@ const zStatefulFieldInputTemplate = z.union([
zSigLipModelFieldInputTemplate,
zFluxReduxModelFieldInputTemplate,
zImagen3ModelFieldInputTemplate,
zImagen4ModelFieldInputTemplate,
zChatGPT4oModelFieldInputTemplate,
zColorFieldInputTemplate,
zSchedulerFieldInputTemplate,

View File

@@ -4,7 +4,7 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { isImagen3AspectRatioID } from 'features/controlLayers/store/types';
import { isImagenAspectRatioID } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
@@ -24,7 +24,7 @@ export const buildImagen3Graph = async (state: RootState, manager: CanvasManager
const generationMode = await manager.compositor.getGenerationMode();
if (generationMode !== 'txt2img') {
throw new UnsupportedGenerationModeError(t('toast.imagen3IncompatibleGenerationMode'));
throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'Imagen3' }));
}
log.debug({ generationMode }, 'Building Imagen3 graph');
@@ -38,7 +38,7 @@ export const buildImagen3Graph = async (state: RootState, manager: CanvasManager
assert(model, 'No model found for Imagen3 graph');
assert(model.base === 'imagen3', 'Imagen3 graph requires Imagen3 model');
assert(isImagen3AspectRatioID(bbox.aspectRatio.id), 'Imagen3 does not support this aspect ratio');
assert(isImagenAspectRatioID(bbox.aspectRatio.id), 'Imagen3 does not support this aspect ratio');
assert(positivePrompt.length > 0, 'Imagen3 requires positive prompt to have at least one character');
const is_intermediate = canvasSettings.sendToCanvas;

View File

@@ -0,0 +1,78 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { isImagenAspectRatioID } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
CANVAS_OUTPUT_PREFIX,
getBoardField,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import { type GraphBuilderReturn, UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { t } from 'i18next';
import { selectMainModelConfig } from 'services/api/endpoints/models';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
const log = logger('system');
export const buildImagen4Graph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
const generationMode = await manager.compositor.getGenerationMode();
if (generationMode !== 'txt2img') {
throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'Imagen4' }));
}
log.debug({ generationMode }, 'Building Imagen4 graph');
const canvas = selectCanvasSlice(state);
const canvasSettings = selectCanvasSettingsSlice(state);
const { bbox } = canvas;
const { positivePrompt, negativePrompt } = selectPresetModifiedPrompts(state);
const model = selectMainModelConfig(state);
assert(model, 'No model found for Imagen4 graph');
assert(model.base === 'imagen4', 'Imagen4 graph requires Imagen4 model');
assert(isImagenAspectRatioID(bbox.aspectRatio.id), 'Imagen4 does not support this aspect ratio');
assert(positivePrompt.length > 0, 'Imagen4 requires positive prompt to have at least one character');
const is_intermediate = canvasSettings.sendToCanvas;
const board = canvasSettings.sendToCanvas ? undefined : getBoardField(state);
if (generationMode === 'txt2img') {
const g = new Graph(getPrefixedId('imagen4_txt2img_graph'));
const imagen4 = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'google_imagen4_generate_image',
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
model: zModelIdentifierField.parse(model),
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
aspect_ratio: bbox.aspectRatio.id,
enhance_prompt: true,
// When enhance_prompt is true, Imagen4 will return a new image every time, ignoring the seed.
use_cache: false,
is_intermediate,
board,
});
g.upsertMetadata({
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
width: bbox.rect.width,
height: bbox.rect.height,
model: Graph.getModelMetadataField(model),
});
return {
g,
seedFieldIdentifier: { nodeId: imagen4.id, fieldName: 'seed' },
positivePromptFieldIdentifier: { nodeId: imagen4.id, fieldName: 'positive_prompt' },
};
}
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for Imagen4');
};

View File

@@ -34,6 +34,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
SigLipModelField: undefined,
FluxReduxModelField: undefined,
Imagen3ModelField: undefined,
Imagen4ModelField: undefined,
ChatGPT4oModelField: undefined,
FloatGeneratorField: undefined,
IntegerGeneratorField: undefined,

View File

@@ -23,6 +23,7 @@ import type {
ImageFieldInputTemplate,
ImageGeneratorFieldInputTemplate,
Imagen3ModelFieldInputTemplate,
Imagen4ModelFieldInputTemplate,
IntegerFieldCollectionInputTemplate,
IntegerFieldInputTemplate,
IntegerGeneratorFieldInputTemplate,
@@ -600,6 +601,18 @@ const buildImagen3ModelFieldInputTemplate: FieldInputTemplateBuilder<Imagen3Mode
return template;
};
const buildImagen4ModelFieldInputTemplate: FieldInputTemplateBuilder<Imagen4ModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: Imagen4ModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildChatGPT4oModelFieldInputTemplate: FieldInputTemplateBuilder<ChatGPT4oModelFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -820,6 +833,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
SigLipModelField: buildSigLipModelFieldInputTemplate,
FluxReduxModelField: buildFluxReduxModelFieldInputTemplate,
Imagen3ModelField: buildImagen3ModelFieldInputTemplate,
Imagen4ModelField: buildImagen4ModelFieldInputTemplate,
ChatGPT4oModelField: buildChatGPT4oModelFieldInputTemplate,
FloatGeneratorField: buildFloatGeneratorFieldInputTemplate,
IntegerGeneratorField: buildIntegerGeneratorFieldInputTemplate,

View File

@@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { bboxAspectRatioIdChanged } from 'features/controlLayers/store/canvasSlice';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { selectIsChatGTP4o, selectIsImagen3 } from 'features/controlLayers/store/paramsSlice';
import { selectIsChatGTP4o, selectIsImagen3, selectIsImagen4 } from 'features/controlLayers/store/paramsSlice';
import { selectAspectRatioID } from 'features/controlLayers/store/selectors';
import {
isAspectRatioID,
@@ -23,10 +23,10 @@ export const BboxAspectRatioSelect = memo(() => {
const isStaging = useAppSelector(selectIsStaging);
const isImagen3 = useAppSelector(selectIsImagen3);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isImagen4 = useAppSelector(selectIsImagen4);
const options = useMemo(() => {
// Imagen3 and ChatGPT4o have different aspect ratio options, and do not support freeform sizes
if (isImagen3) {
if (isImagen3 || isImagen4) {
return zImagen3AspectRatioID.options;
}
if (isChatGPT4o) {
@@ -34,7 +34,7 @@ export const BboxAspectRatioSelect = memo(() => {
}
// All other models
return zAspectRatioID.options;
}, [isImagen3, isChatGPT4o]);
}, [isImagen3, isChatGPT4o, isImagen4]);
const onChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
(e) => {

View File

@@ -1,10 +1,12 @@
import { useAppSelector } from 'app/store/storeHooks';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { selectIsChatGTP4o, selectIsImagen3 } from 'features/controlLayers/store/paramsSlice';
import { selectIsChatGTP4o, selectIsImagen3, selectIsImagen4 } from 'features/controlLayers/store/paramsSlice';
export const useIsBboxSizeLocked = () => {
const isStaging = useAppSelector(selectIsStaging);
const isImagen3 = useAppSelector(selectIsImagen3);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
return isImagen3 || isChatGPT4o || isStaging;
const isImagen4 = useAppSelector(selectIsImagen4);
return isImagen3 || isChatGPT4o || isImagen4 || isStaging;
};

View File

@@ -59,28 +59,28 @@ const NoOptionsFallback = memo(() => {
NoOptionsFallback.displayName = 'NoOptionsFallback';
const getGroupIDFromModelConfig = (modelConfig: AnyModelConfig): string => {
if (modelConfig.base === 'chatgpt-4o' || modelConfig.base === 'imagen3') {
if (modelConfig.base === 'chatgpt-4o' || modelConfig.base === 'imagen3' || modelConfig.base === 'imagen4') {
return 'api';
}
return modelConfig.base;
};
const getGroupNameFromModelConfig = (modelConfig: AnyModelConfig): string => {
if (modelConfig.base === 'chatgpt-4o' || modelConfig.base === 'imagen3') {
if (modelConfig.base === 'chatgpt-4o' || modelConfig.base === 'imagen3' || modelConfig.base === 'imagen4') {
return 'External API';
}
return MODEL_TYPE_MAP[modelConfig.base];
};
const getGroupShortNameFromModelConfig = (modelConfig: AnyModelConfig): string => {
if (modelConfig.base === 'chatgpt-4o' || modelConfig.base === 'imagen3') {
if (modelConfig.base === 'chatgpt-4o' || modelConfig.base === 'imagen3' || modelConfig.base === 'imagen4') {
return 'api';
}
return MODEL_TYPE_SHORT_MAP[modelConfig.base];
};
const getGroupColorSchemeFromModelConfig = (modelConfig: AnyModelConfig): string => {
if (modelConfig.base === 'chatgpt-4o' || modelConfig.base === 'imagen3') {
if (modelConfig.base === 'chatgpt-4o' || modelConfig.base === 'imagen3' || modelConfig.base === 'imagen4') {
return 'pink';
}
return BASE_COLOR_MAP[modelConfig.base];

View File

@@ -14,6 +14,7 @@ export const MODEL_TYPE_MAP: Record<BaseModelType, string> = {
flux: 'FLUX',
cogview4: 'CogView4',
imagen3: 'Imagen3',
imagen4: 'Imagen4',
'chatgpt-4o': 'ChatGPT 4o',
};
@@ -30,6 +31,7 @@ export const MODEL_TYPE_SHORT_MAP: Record<BaseModelType, string> = {
flux: 'FLUX',
cogview4: 'CogView4',
imagen3: 'Imagen3',
imagen4: 'Imagen4',
'chatgpt-4o': 'ChatGPT 4o',
};
@@ -73,6 +75,10 @@ export const CLIP_SKIP_MAP: Record<BaseModelType, { maxClip: number; markers: nu
maxClip: 0,
markers: [],
},
imagen4: {
maxClip: 0,
markers: [],
},
'chatgpt-4o': {
maxClip: 0,
markers: [],

View File

@@ -19,6 +19,7 @@ export const getOptimalDimension = (base?: BaseModelType | null): number => {
case 'sd-3':
case 'cogview4':
case 'imagen3':
case 'imagen4':
case 'chatgpt-4o':
default:
return 1024;

View File

@@ -9,6 +9,7 @@ import {
selectIsCogView4,
selectIsFLUX,
selectIsImagen3,
selectIsImagen4,
selectIsSD3,
} from 'features/controlLayers/store/paramsSlice';
import { LoRAList } from 'features/lora/components/LoRAList';
@@ -41,11 +42,12 @@ export const GenerationSettingsAccordion = memo(() => {
const isSD3 = useAppSelector(selectIsSD3);
const isCogView4 = useAppSelector(selectIsCogView4);
const isImagen3 = useAppSelector(selectIsImagen3);
const isImagen4 = useAppSelector(selectIsImagen4);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isApiModel = useMemo(() => {
return isImagen3 || isChatGPT4o;
}, [isImagen3, isChatGPT4o]);
return isImagen3 || isImagen4 || isChatGPT4o;
}, [isImagen3, isImagen4, isChatGPT4o]);
const isUpscaling = useMemo(() => {
return activeTabName === 'upscaling';
@@ -56,7 +58,7 @@ export const GenerationSettingsAccordion = memo(() => {
const enabledLoRAsCount = loras.loras.filter((l) => l.isEnabled).length;
const loraTabBadges = enabledLoRAsCount ? [`${enabledLoRAsCount} ${t('models.concepts')}`] : EMPTY_ARRAY;
const accordionBadges =
modelConfig?.base === 'imagen3' || modelConfig?.base === 'chatgpt-4o'
modelConfig?.base === 'imagen3' || modelConfig?.base === 'chatgpt-4o' || modelConfig?.base === 'imagen4'
? [modelConfig.name]
: modelConfig
? [modelConfig.name, modelConfig.base]

View File

@@ -7,6 +7,7 @@ import {
selectIsChatGTP4o,
selectIsFLUX,
selectIsImagen3,
selectIsImagen4,
selectIsSD3,
selectParamsSlice,
} from 'features/controlLayers/store/paramsSlice';
@@ -67,10 +68,10 @@ export const ImageSettingsAccordion = memo(() => {
const isSD3 = useAppSelector(selectIsSD3);
const isImagen3 = useAppSelector(selectIsImagen3);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isImagen4 = useAppSelector(selectIsImagen4);
const isApiModel = useMemo(() => {
return isImagen3 || isChatGPT4o;
}, [isImagen3, isChatGPT4o]);
return isImagen3 || isChatGPT4o || isImagen4;
}, [isImagen3, isChatGPT4o, isImagen4]);
return (
<StandaloneAccordion

View File

@@ -6,6 +6,7 @@ import {
selectIsChatGTP4o,
selectIsCogView4,
selectIsImagen3,
selectIsImagen4,
selectIsSDXL,
} from 'features/controlLayers/store/paramsSlice';
import { Prompts } from 'features/parameters/components/Prompts/Prompts';
@@ -30,12 +31,13 @@ const ParametersPanelTextToImage = () => {
const isSDXL = useAppSelector(selectIsSDXL);
const isCogview4 = useAppSelector(selectIsCogView4);
const isImagen3 = useAppSelector(selectIsImagen3);
const isImagen4 = useAppSelector(selectIsImagen4);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isStylePresetsMenuOpen = useStore($isStylePresetsMenuOpen);
const isApiModel = useMemo(() => {
return isImagen3 || isChatGPT4o;
}, [isImagen3, isChatGPT4o]);
return isImagen3 || isChatGPT4o || isImagen4;
}, [isImagen3, isChatGPT4o, isImagen4]);
return (
<Flex w="full" h="full" flexDir="column" gap={2}>