mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add CogView4 to frontend.
This commit is contained in:
committed by
psychedelicious
parent
e1133bc53f
commit
f4e00ab261
@@ -15,6 +15,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
|
||||
sdxl: 'invokeBlue',
|
||||
'sdxl-refiner': 'invokeBlue',
|
||||
flux: 'gold',
|
||||
cogview4: 'orange',
|
||||
};
|
||||
|
||||
const ModelBaseBadge = ({ base }: Props) => {
|
||||
|
||||
@@ -29,6 +29,8 @@ import {
|
||||
isCLIPGEmbedModelFieldInputTemplate,
|
||||
isCLIPLEmbedModelFieldInputInstance,
|
||||
isCLIPLEmbedModelFieldInputTemplate,
|
||||
isCogView4MainModelFieldInputInstance,
|
||||
isCogView4MainModelFieldInputTemplate,
|
||||
isColorFieldInputInstance,
|
||||
isColorFieldInputTemplate,
|
||||
isControlLoRAModelFieldInputInstance,
|
||||
@@ -106,6 +108,7 @@ import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
|
||||
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
|
||||
import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInputComponent';
|
||||
import CLIPLEmbedModelFieldInputComponent from './inputs/CLIPLEmbedModelFieldInputComponent';
|
||||
import CogView4MainModelFieldInputComponent from './inputs/CogView4MainModelFieldInputComponent';
|
||||
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
|
||||
import ControlLoRAModelFieldInputComponent from './inputs/ControlLoraModelFieldInputComponent';
|
||||
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
|
||||
@@ -412,6 +415,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
|
||||
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isCogView4MainModelFieldInputTemplate(template)) {
|
||||
if (!isCogView4MainModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <CogView4MainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isSDXLMainModelFieldInputTemplate(template)) {
|
||||
if (!isSDXLMainModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type {
|
||||
CogView4MainModelFieldInputInstance,
|
||||
CogView4MainModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCogView4Models } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<CogView4MainModelFieldInputInstance, CogView4MainModelFieldInputTemplate>;
|
||||
|
||||
const CogView4MainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useCogView4Models();
|
||||
const _onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldMainModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<FormControl
|
||||
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
|
||||
isDisabled={!options.length}
|
||||
isInvalid={!value && props.fieldTemplate.required}
|
||||
>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(CogView4MainModelFieldInputComponent);
|
||||
@@ -61,8 +61,8 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
|
||||
// #endregion
|
||||
|
||||
// #region Model-related schemas
|
||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner', 'flux']);
|
||||
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux']);
|
||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner', 'flux', 'cogview4']);
|
||||
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4']);
|
||||
export type MainModelBase = z.infer<typeof zMainModelBase>;
|
||||
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
|
||||
const zModelType = z.enum([
|
||||
|
||||
@@ -53,6 +53,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
||||
MainModelField: 'teal.500',
|
||||
FluxMainModelField: 'teal.500',
|
||||
SD3MainModelField: 'teal.500',
|
||||
CogView4MainModelField: 'teal.500',
|
||||
SDXLMainModelField: 'teal.500',
|
||||
SDXLRefinerModelField: 'teal.500',
|
||||
SpandrelImageToImageModelField: 'teal.500',
|
||||
|
||||
@@ -176,6 +176,10 @@ const zSD3MainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SD3MainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zCogView4MainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('CogView4MainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zFluxMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FluxMainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -276,6 +280,7 @@ const zStatefulFieldType = z.union([
|
||||
zMainModelFieldType,
|
||||
zSDXLMainModelFieldType,
|
||||
zSD3MainModelFieldType,
|
||||
zCogView4MainModelFieldType,
|
||||
zFluxMainModelFieldType,
|
||||
zSDXLRefinerModelFieldType,
|
||||
zVAEModelFieldType,
|
||||
@@ -313,6 +318,7 @@ const modelFieldTypeNames = [
|
||||
zMainModelFieldType.shape.name.value,
|
||||
zSDXLMainModelFieldType.shape.name.value,
|
||||
zSD3MainModelFieldType.shape.name.value,
|
||||
zCogView4MainModelFieldType.shape.name.value,
|
||||
zFluxMainModelFieldType.shape.name.value,
|
||||
zSDXLRefinerModelFieldType.shape.name.value,
|
||||
zVAEModelFieldType.shape.name.value,
|
||||
@@ -817,6 +823,26 @@ export const isSD3MainModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<SD3MainModelFieldInputTemplate>('SD3MainModelField');
|
||||
// #endregion
|
||||
|
||||
// #region CogView4MainModelField
|
||||
const zCogView4MainModelFieldValue = zMainModelFieldValue;
|
||||
const zCogView4MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zCogView4MainModelFieldValue,
|
||||
});
|
||||
const zCogView4MainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zCogView4MainModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zCogView4MainModelFieldValue,
|
||||
});
|
||||
const zCogView4MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zCogView4MainModelFieldType,
|
||||
});
|
||||
export type CogView4MainModelFieldInputInstance = z.infer<typeof zCogView4MainModelFieldInputInstance>;
|
||||
export type CogView4MainModelFieldInputTemplate = z.infer<typeof zCogView4MainModelFieldInputTemplate>;
|
||||
export const isCogView4MainModelFieldInputInstance = buildInstanceTypeGuard(zCogView4MainModelFieldInputInstance);
|
||||
export const isCogView4MainModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<CogView4MainModelFieldInputTemplate>('CogView4MainModelField');
|
||||
// #endregion
|
||||
|
||||
// #region FluxMainModelField
|
||||
const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
|
||||
const zFluxMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
@@ -1765,6 +1791,7 @@ export const zStatefulFieldValue = z.union([
|
||||
zSDXLMainModelFieldValue,
|
||||
zFluxMainModelFieldValue,
|
||||
zSD3MainModelFieldValue,
|
||||
zCogView4MainModelFieldValue,
|
||||
zSDXLRefinerModelFieldValue,
|
||||
zVAEModelFieldValue,
|
||||
zLoRAModelFieldValue,
|
||||
@@ -1811,6 +1838,7 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zMainModelFieldInputInstance,
|
||||
zFluxMainModelFieldInputInstance,
|
||||
zSD3MainModelFieldInputInstance,
|
||||
zCogView4MainModelFieldInputInstance,
|
||||
zSDXLMainModelFieldInputInstance,
|
||||
zSDXLRefinerModelFieldInputInstance,
|
||||
zVAEModelFieldInputInstance,
|
||||
@@ -1852,6 +1880,7 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zMainModelFieldInputTemplate,
|
||||
zFluxMainModelFieldInputTemplate,
|
||||
zSD3MainModelFieldInputTemplate,
|
||||
zCogView4MainModelFieldInputTemplate,
|
||||
zSDXLMainModelFieldInputTemplate,
|
||||
zSDXLRefinerModelFieldInputTemplate,
|
||||
zVAEModelFieldInputTemplate,
|
||||
@@ -1899,6 +1928,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
||||
zMainModelFieldOutputTemplate,
|
||||
zFluxMainModelFieldOutputTemplate,
|
||||
zSD3MainModelFieldOutputTemplate,
|
||||
zCogView4MainModelFieldOutputTemplate,
|
||||
zSDXLMainModelFieldOutputTemplate,
|
||||
zSDXLRefinerModelFieldOutputTemplate,
|
||||
zVAEModelFieldOutputTemplate,
|
||||
|
||||
@@ -18,6 +18,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
SDXLMainModelField: undefined,
|
||||
FluxMainModelField: undefined,
|
||||
SD3MainModelField: undefined,
|
||||
CogView4MainModelField: undefined,
|
||||
SDXLRefinerModelField: undefined,
|
||||
StringField: '',
|
||||
T2IAdapterModelField: undefined,
|
||||
|
||||
@@ -5,6 +5,7 @@ import type {
|
||||
CLIPEmbedModelFieldInputTemplate,
|
||||
CLIPGEmbedModelFieldInputTemplate,
|
||||
CLIPLEmbedModelFieldInputTemplate,
|
||||
CogView4MainModelFieldInputTemplate,
|
||||
ColorFieldInputTemplate,
|
||||
ControlLoRAModelFieldInputTemplate,
|
||||
ControlNetModelFieldInputTemplate,
|
||||
@@ -351,6 +352,20 @@ const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder<SD3MainMode
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildCogView4MainModelFieldInputTemplate: FieldInputTemplateBuilder<CogView4MainModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: CogView4MainModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefinerModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -761,6 +776,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
||||
SchedulerField: buildSchedulerFieldInputTemplate,
|
||||
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
|
||||
SD3MainModelField: buildSD3MainModelFieldInputTemplate,
|
||||
CogView4MainModelField: buildCogView4MainModelFieldInputTemplate,
|
||||
FluxMainModelField: buildFluxMainModelFieldInputTemplate,
|
||||
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
|
||||
StringField: buildStringFieldInputTemplate,
|
||||
|
||||
@@ -11,6 +11,7 @@ export const MODEL_TYPE_MAP = {
|
||||
sdxl: 'Stable Diffusion XL',
|
||||
'sdxl-refiner': 'Stable Diffusion XL Refiner',
|
||||
flux: 'FLUX',
|
||||
cogview4: 'CogView4',
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -24,6 +25,7 @@ export const MODEL_TYPE_SHORT_MAP = {
|
||||
sdxl: 'SDXL',
|
||||
'sdxl-refiner': 'SDXLR',
|
||||
flux: 'FLUX',
|
||||
cogview4: 'CogView4',
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -58,6 +60,10 @@ export const CLIP_SKIP_MAP = {
|
||||
maxClip: 0,
|
||||
markers: [],
|
||||
},
|
||||
cogview4: {
|
||||
maxClip: 0,
|
||||
markers: [],
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user