feat(ui): add control loras to control adapter model options, add default settings for preprocessor in probe

This commit is contained in:
Mary Hipp
2024-12-12 21:07:16 -05:00
committed by Kent Keirsey
parent 246b59f148
commit da213e4638
11 changed files with 40 additions and 27 deletions

View File

@@ -26,7 +26,7 @@ import { replaceCanvasEntityObjectsWithImage } from 'features/imageActions/actio
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold, PiShootingStarFill, PiUploadBold } from 'react-icons/pi';
import type { ControlNetModelConfig, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
import type { ControlLoRAModelConfig, ControlNetModelConfig, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
const buildSelectControlAdapter = (entityIdentifier: CanvasEntityIdentifier<'control_layer'>) =>
createMemoizedAppSelector(selectCanvasSlice, (canvas) => {
@@ -66,7 +66,7 @@ export const ControlLayerControlAdapter = memo(() => {
);
const onChangeModel = useCallback(
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => {
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig) => {
dispatch(controlLayerModelChanged({ entityIdentifier, modelConfig }));
// When we change the model, we need may need to start filtering w/ the simplified filter mode, and/or change the
// filter config.

View File

@@ -4,22 +4,22 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useControlNetAndT2IAdapterModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { useControlLayerModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, ControlLoRAModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
type Props = {
modelKey: string | null;
onChange: (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => void;
onChange: (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig) => void;
};
export const ControlLayerControlAdapterModel = memo(({ modelKey, onChange: onChangeModel }: Props) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector(selectBase);
const [modelConfigs, { isLoading }] = useControlNetAndT2IAdapterModels();
const [modelConfigs, { isLoading }] = useControlLayerModels();
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
const _onChange = useCallback(
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null) => {
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null) => {
if (!modelConfig) {
return;
}

View File

@@ -34,7 +34,7 @@ import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/par
import type { IRect } from 'konva/lib/types';
import { merge } from 'lodash-es';
import type { UndoableOptions } from 'redux-undo';
import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import type { ControlLoRAModelConfig, ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import type {
@@ -436,7 +436,7 @@ export const canvasSlice = createSlice({
action: PayloadAction<
EntityIdentifierPayload<
{
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null;
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null;
},
'control_layer'
>

View File

@@ -2,7 +2,7 @@ import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import type { ControlLoRAModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { z } from 'zod';
@@ -454,7 +454,7 @@ const PROCESSOR_TO_FILTER_MAP: Record<string, FilterType> = {
* Gets the default filter for a control model. If the model has a default, it will be used, otherwise the default
* filter for the model type will be used.
*/
export const getFilterForModel = (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null) => {
export const getFilterForModel = (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null) => {
if (!modelConfig) {
// No model
return null;

View File

@@ -296,6 +296,12 @@ const zT2IAdapterConfig = z.object({
});
export type T2IAdapterConfig = z.infer<typeof zT2IAdapterConfig>;
const zControlLoRAConfig = z.object({
type: z.literal('control_lora'),
model: zServerValidatedModelIdentifierField.nullable(),
});
export type ControlLoRAConfig = z.infer<typeof zControlLoRAConfig>;
export const zCanvasRasterLayerState = zCanvasEntityBase.extend({
type: z.literal('raster_layer'),
position: zCoordinate,
@@ -307,7 +313,7 @@ export type CanvasRasterLayerState = z.infer<typeof zCanvasRasterLayerState>;
const zCanvasControlLayerState = zCanvasRasterLayerState.extend({
type: z.literal('control_layer'),
withTransparencyEffect: z.boolean(),
controlAdapter: z.discriminatedUnion('type', [zControlNetConfig, zT2IAdapterConfig]),
controlAdapter: z.discriminatedUnion('type', [zControlNetConfig, zT2IAdapterConfig, zControlLoRAConfig]),
});
export type CanvasControlLayerState = z.infer<typeof zCanvasControlLayerState>;