WIP - model selection for LLaVA

This commit is contained in:
Billy
2025-03-11 15:57:03 +11:00
committed by psychedelicious
parent 9ed46f60b7
commit fc82775d7a
10 changed files with 136 additions and 6 deletions

View File

@@ -32,6 +32,8 @@ import {
isColorFieldInputTemplate,
isControlLoRAModelFieldInputInstance,
isControlLoRAModelFieldInputTemplate,
isLLaVAModelFieldInputInstance,
isLLaVAModelFieldInputTemplate,
isControlNetModelFieldInputInstance,
isControlNetModelFieldInputTemplate,
isEnumFieldInputInstance,
@@ -105,6 +107,7 @@ import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInp
import CLIPLEmbedModelFieldInputComponent from './inputs/CLIPLEmbedModelFieldInputComponent';
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
import ControlLoRAModelFieldInputComponent from './inputs/ControlLoraModelFieldInputComponent';
import LLaVAModelFieldInputComponent from './inputs/LLaVAModelFieldInputComponent';
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
@@ -322,6 +325,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <ControlLoRAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isLLaVAModelFieldInputTemplate(template)) {
if (!isLLaVAModelFieldInputInstance(field)) {
return null;
}
return <LLaVAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isFluxVAEModelFieldInputTemplate(template)) {
if (!isFluxVAEModelFieldInputInstance(field)) {
return null;

View File

@@ -0,0 +1,55 @@
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldLLaVAModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { LLaVAModelFieldInputInstance, LLaVAModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useLLaVAModels } from 'services/api/hooks/modelsByType';
import type { LlavaOnevisionConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<LLaVAModelFieldInputInstance, LLaVAModelFieldInputTemplate>;
const LLaVAModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useLLaVAModels();
const _onChange = useCallback(
(value: LlavaOnevisionConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldLLaVAModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value} isDisabled={!options.length}>
<Combobox
value={value}
placeholder={placeholder}
noOptionsMessage={noOptionsMessage}
options={options}
onChange={onChange}
/>
</FormControl>
);
};
export default memo(LLaVAModelFieldInputComponent);

View File

@@ -28,6 +28,7 @@ import type {
IntegerGeneratorFieldValue,
IPAdapterModelFieldValue,
LoRAModelFieldValue,
LLaVAModelFieldValue,
MainModelFieldValue,
ModelIdentifierFieldValue,
SchedulerFieldValue,
@@ -65,6 +66,7 @@ import {
zIntegerGeneratorFieldValue,
zIPAdapterModelFieldValue,
zLoRAModelFieldValue,
zLLaVAModelFieldValue,
zMainModelFieldValue,
zModelIdentifierFieldValue,
zSchedulerFieldValue,
@@ -380,6 +382,9 @@ export const nodesSlice = createSlice({
fieldLoRAModelValueChanged: (state, action: FieldValueAction<LoRAModelFieldValue>) => {
fieldValueReducer(state, action, zLoRAModelFieldValue);
},
fieldLLaVAModelValueChanged: (state, action: FieldValueAction<LLaVAModelFieldValue>) => {
fieldValueReducer(state, action, zLLaVAModelFieldValue);
},
fieldControlNetModelValueChanged: (state, action: FieldValueAction<ControlNetModelFieldValue>) => {
fieldValueReducer(state, action, zControlNetModelFieldValue);
},
@@ -509,6 +514,7 @@ export const {
fieldSpandrelImageToImageModelValueChanged,
fieldLabelChanged,
fieldLoRAModelValueChanged,
fieldLLaVAModelValueChanged,
fieldModelIdentifierValueChanged,
fieldMainModelValueChanged,
fieldIntegerValueChanged,
@@ -633,6 +639,7 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
fieldT2IAdapterModelValueChanged,
fieldLabelChanged,
fieldLoRAModelValueChanged,
fieldLLaVAModelValueChanged,
fieldMainModelValueChanged,
fieldIntegerValueChanged,
fieldIntegerCollectionValueChanged,

View File

@@ -69,6 +69,7 @@ const zModelType = z.enum([
'main',
'vae',
'lora',
"llava_onevision",
'control_lora',
'controlnet',
't2i_adapter',

View File

@@ -189,6 +189,10 @@ const zLoRAModelFieldType = zFieldTypeBase.extend({
name: z.literal('LoRAModelField'),
originalType: zStatelessFieldType.optional(),
});
const zLLaVAModelFieldType = zFieldTypeBase.extend({
name: z.literal('LLaVAModelField'),
originalType: zStatelessFieldType.optional(),
});
const zControlNetModelFieldType = zFieldTypeBase.extend({
name: z.literal('ControlNetModelField'),
originalType: zStatelessFieldType.optional(),
@@ -273,6 +277,7 @@ const zStatefulFieldType = z.union([
zSDXLRefinerModelFieldType,
zVAEModelFieldType,
zLoRAModelFieldType,
zLLaVAModelFieldType,
zControlNetModelFieldType,
zIPAdapterModelFieldType,
zT2IAdapterModelFieldType,
@@ -309,6 +314,7 @@ const modelFieldTypeNames = [
zSDXLRefinerModelFieldType.shape.name.value,
zVAEModelFieldType.shape.name.value,
zLoRAModelFieldType.shape.name.value,
zLLaVAModelFieldType.shape.name.value,
zControlNetModelFieldType.shape.name.value,
zIPAdapterModelFieldType.shape.name.value,
zT2IAdapterModelFieldType.shape.name.value,
@@ -891,6 +897,27 @@ export const isLoRAModelFieldInputInstance = buildInstanceTypeGuard(zLoRAModelFi
export const isLoRAModelFieldInputTemplate = buildTemplateTypeGuard<LoRAModelFieldInputTemplate>('LoRAModelField');
// #endregion
// #region LLaVAModelField
export const zLLaVAModelFieldValue = zModelIdentifierField.optional();
const zLLaVAModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zLLaVAModelFieldValue,
});
const zLLaVAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zLLaVAModelFieldType,
originalType: zFieldType.optional(),
default: zLLaVAModelFieldValue,
});
const zLLaVAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zLLaVAModelFieldType,
});
export type LLaVAModelFieldValue = z.infer<typeof zLLaVAModelFieldValue>;
export type LLaVAModelFieldInputInstance = z.infer<typeof zLLaVAModelFieldInputInstance>;
export type LLaVAModelFieldInputTemplate = z.infer<typeof zLLaVAModelFieldInputTemplate>;
export const isLLaVAModelFieldInputInstance = buildInstanceTypeGuard(zLLaVAModelFieldInputInstance);
export const isLLaVAModelFieldInputTemplate = buildTemplateTypeGuard<LLaVAModelFieldInputTemplate>('LLaVAModelField');
// #endregion
// #region ControlNetModelField
export const zControlNetModelFieldValue = zModelIdentifierField.optional();
const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({
@@ -1739,6 +1766,7 @@ export const zStatefulFieldValue = z.union([
zSDXLRefinerModelFieldValue,
zVAEModelFieldValue,
zLoRAModelFieldValue,
zLLaVAModelFieldValue,
zControlNetModelFieldValue,
zIPAdapterModelFieldValue,
zT2IAdapterModelFieldValue,
@@ -1785,6 +1813,7 @@ const zStatefulFieldInputInstance = z.union([
zSDXLRefinerModelFieldInputInstance,
zVAEModelFieldInputInstance,
zLoRAModelFieldInputInstance,
zLLaVAModelFieldInputInstance,
zControlNetModelFieldInputInstance,
zIPAdapterModelFieldInputInstance,
zT2IAdapterModelFieldInputInstance,
@@ -1825,6 +1854,7 @@ const zStatefulFieldInputTemplate = z.union([
zSDXLRefinerModelFieldInputTemplate,
zVAEModelFieldInputTemplate,
zLoRAModelFieldInputTemplate,
zLLaVAModelFieldInputTemplate,
zControlNetModelFieldInputTemplate,
zIPAdapterModelFieldInputTemplate,
zT2IAdapterModelFieldInputTemplate,
@@ -1871,6 +1901,7 @@ const zStatefulFieldOutputTemplate = z.union([
zSDXLRefinerModelFieldOutputTemplate,
zVAEModelFieldOutputTemplate,
zLoRAModelFieldOutputTemplate,
zLLaVAModelFieldOutputTemplate,
zControlNetModelFieldOutputTemplate,
zIPAdapterModelFieldOutputTemplate,
zT2IAdapterModelFieldOutputTemplate,

View File

@@ -11,6 +11,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
IntegerField: 0,
IPAdapterModelField: undefined,
LoRAModelField: undefined,
LLaVAModelField: undefined,
ModelIdentifierField: undefined,
MainModelField: undefined,
SchedulerField: 'dpmpp_3m_k',

View File

@@ -24,6 +24,7 @@ import type {
IntegerFieldInputTemplate,
IntegerGeneratorFieldInputTemplate,
IPAdapterModelFieldInputTemplate,
LLaVAModelFieldInputTemplate,
LoRAModelFieldInputTemplate,
MainModelFieldInputTemplate,
ModelIdentifierFieldInputTemplate,
@@ -448,6 +449,19 @@ const buildControlLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<Control
return template;
};
const buildLLaVAModelFieldInputTemplate: FieldInputTemplateBuilder<LLaVAModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: LLaVAModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildFluxVAEModelFieldInputTemplate: FieldInputTemplateBuilder<FluxVAEModelFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -741,6 +755,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
IntegerField: buildIntegerFieldInputTemplate,
IPAdapterModelField: buildIPAdapterModelFieldInputTemplate,
LoRAModelField: buildLoRAModelFieldInputTemplate,
LLaVAModelField: buildLLaVAModelFieldInputTemplate,
ModelIdentifierField: buildModelIdentifierFieldInputTemplate,
MainModelField: buildMainModelFieldInputTemplate,
SchedulerField: buildSchedulerFieldInputTemplate,