Add CogView4 to frontend.

This commit is contained in:
Ryan Dick
2025-03-06 18:30:51 +00:00
committed by psychedelicious
parent e1133bc53f
commit f4e00ab261
11 changed files with 136 additions and 2 deletions

View File

@@ -15,6 +15,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
sdxl: 'invokeBlue',
'sdxl-refiner': 'invokeBlue',
flux: 'gold',
cogview4: 'orange',
};
const ModelBaseBadge = ({ base }: Props) => {

View File

@@ -29,6 +29,8 @@ import {
isCLIPGEmbedModelFieldInputTemplate,
isCLIPLEmbedModelFieldInputInstance,
isCLIPLEmbedModelFieldInputTemplate,
isCogView4MainModelFieldInputInstance,
isCogView4MainModelFieldInputTemplate,
isColorFieldInputInstance,
isColorFieldInputTemplate,
isControlLoRAModelFieldInputInstance,
@@ -106,6 +108,7 @@ import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInputComponent';
import CLIPLEmbedModelFieldInputComponent from './inputs/CLIPLEmbedModelFieldInputComponent';
import CogView4MainModelFieldInputComponent from './inputs/CogView4MainModelFieldInputComponent';
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
import ControlLoRAModelFieldInputComponent from './inputs/ControlLoraModelFieldInputComponent';
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
@@ -412,6 +415,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isCogView4MainModelFieldInputTemplate(template)) {
if (!isCogView4MainModelFieldInputInstance(field)) {
return null;
}
return <CogView4MainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isSDXLMainModelFieldInputTemplate(template)) {
if (!isSDXLMainModelFieldInputInstance(field)) {
return null;

View File

@@ -0,0 +1,63 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type {
CogView4MainModelFieldInputInstance,
CogView4MainModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useCogView4Models } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<CogView4MainModelFieldInputInstance, CogView4MainModelFieldInputTemplate>;
const CogView4MainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCogView4Models();
const _onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
isDisabled={!options.length}
isInvalid={!value && props.fieldTemplate.required}
>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
);
};
export default memo(CogView4MainModelFieldInputComponent);

View File

@@ -61,8 +61,8 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
// #endregion
// #region Model-related schemas
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner', 'flux']);
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux']);
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner', 'flux', 'cogview4']);
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4']);
export type MainModelBase = z.infer<typeof zMainModelBase>;
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
const zModelType = z.enum([

View File

@@ -53,6 +53,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
MainModelField: 'teal.500',
FluxMainModelField: 'teal.500',
SD3MainModelField: 'teal.500',
CogView4MainModelField: 'teal.500',
SDXLMainModelField: 'teal.500',
SDXLRefinerModelField: 'teal.500',
SpandrelImageToImageModelField: 'teal.500',

View File

@@ -176,6 +176,10 @@ const zSD3MainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SD3MainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zCogView4MainModelFieldType = zFieldTypeBase.extend({
name: z.literal('CogView4MainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zFluxMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxMainModelField'),
originalType: zStatelessFieldType.optional(),
@@ -276,6 +280,7 @@ const zStatefulFieldType = z.union([
zMainModelFieldType,
zSDXLMainModelFieldType,
zSD3MainModelFieldType,
zCogView4MainModelFieldType,
zFluxMainModelFieldType,
zSDXLRefinerModelFieldType,
zVAEModelFieldType,
@@ -313,6 +318,7 @@ const modelFieldTypeNames = [
zMainModelFieldType.shape.name.value,
zSDXLMainModelFieldType.shape.name.value,
zSD3MainModelFieldType.shape.name.value,
zCogView4MainModelFieldType.shape.name.value,
zFluxMainModelFieldType.shape.name.value,
zSDXLRefinerModelFieldType.shape.name.value,
zVAEModelFieldType.shape.name.value,
@@ -817,6 +823,26 @@ export const isSD3MainModelFieldInputTemplate =
buildTemplateTypeGuard<SD3MainModelFieldInputTemplate>('SD3MainModelField');
// #endregion
// #region CogView4MainModelField
const zCogView4MainModelFieldValue = zMainModelFieldValue;
const zCogView4MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zCogView4MainModelFieldValue,
});
const zCogView4MainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zCogView4MainModelFieldType,
originalType: zFieldType.optional(),
default: zCogView4MainModelFieldValue,
});
const zCogView4MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zCogView4MainModelFieldType,
});
export type CogView4MainModelFieldInputInstance = z.infer<typeof zCogView4MainModelFieldInputInstance>;
export type CogView4MainModelFieldInputTemplate = z.infer<typeof zCogView4MainModelFieldInputTemplate>;
export const isCogView4MainModelFieldInputInstance = buildInstanceTypeGuard(zCogView4MainModelFieldInputInstance);
export const isCogView4MainModelFieldInputTemplate =
buildTemplateTypeGuard<CogView4MainModelFieldInputTemplate>('CogView4MainModelField');
// #endregion
// #region FluxMainModelField
const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
const zFluxMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
@@ -1765,6 +1791,7 @@ export const zStatefulFieldValue = z.union([
zSDXLMainModelFieldValue,
zFluxMainModelFieldValue,
zSD3MainModelFieldValue,
zCogView4MainModelFieldValue,
zSDXLRefinerModelFieldValue,
zVAEModelFieldValue,
zLoRAModelFieldValue,
@@ -1811,6 +1838,7 @@ const zStatefulFieldInputInstance = z.union([
zMainModelFieldInputInstance,
zFluxMainModelFieldInputInstance,
zSD3MainModelFieldInputInstance,
zCogView4MainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance,
zVAEModelFieldInputInstance,
@@ -1852,6 +1880,7 @@ const zStatefulFieldInputTemplate = z.union([
zMainModelFieldInputTemplate,
zFluxMainModelFieldInputTemplate,
zSD3MainModelFieldInputTemplate,
zCogView4MainModelFieldInputTemplate,
zSDXLMainModelFieldInputTemplate,
zSDXLRefinerModelFieldInputTemplate,
zVAEModelFieldInputTemplate,
@@ -1899,6 +1928,7 @@ const zStatefulFieldOutputTemplate = z.union([
zMainModelFieldOutputTemplate,
zFluxMainModelFieldOutputTemplate,
zSD3MainModelFieldOutputTemplate,
zCogView4MainModelFieldOutputTemplate,
zSDXLMainModelFieldOutputTemplate,
zSDXLRefinerModelFieldOutputTemplate,
zVAEModelFieldOutputTemplate,

View File

@@ -18,6 +18,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
SDXLMainModelField: undefined,
FluxMainModelField: undefined,
SD3MainModelField: undefined,
CogView4MainModelField: undefined,
SDXLRefinerModelField: undefined,
StringField: '',
T2IAdapterModelField: undefined,

View File

@@ -5,6 +5,7 @@ import type {
CLIPEmbedModelFieldInputTemplate,
CLIPGEmbedModelFieldInputTemplate,
CLIPLEmbedModelFieldInputTemplate,
CogView4MainModelFieldInputTemplate,
ColorFieldInputTemplate,
ControlLoRAModelFieldInputTemplate,
ControlNetModelFieldInputTemplate,
@@ -351,6 +352,20 @@ const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder<SD3MainMode
return template;
};
const buildCogView4MainModelFieldInputTemplate: FieldInputTemplateBuilder<CogView4MainModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: CogView4MainModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefinerModelFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -761,6 +776,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
SchedulerField: buildSchedulerFieldInputTemplate,
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
SD3MainModelField: buildSD3MainModelFieldInputTemplate,
CogView4MainModelField: buildCogView4MainModelFieldInputTemplate,
FluxMainModelField: buildFluxMainModelFieldInputTemplate,
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
StringField: buildStringFieldInputTemplate,

View File

@@ -11,6 +11,7 @@ export const MODEL_TYPE_MAP = {
sdxl: 'Stable Diffusion XL',
'sdxl-refiner': 'Stable Diffusion XL Refiner',
flux: 'FLUX',
cogview4: 'CogView4',
};
/**
@@ -24,6 +25,7 @@ export const MODEL_TYPE_SHORT_MAP = {
sdxl: 'SDXL',
'sdxl-refiner': 'SDXLR',
flux: 'FLUX',
cogview4: 'CogView4',
};
/**
@@ -58,6 +60,10 @@ export const CLIP_SKIP_MAP = {
maxClip: 0,
markers: [],
},
cogview4: {
maxClip: 0,
markers: [],
},
};
/**