Merge branch 'main' into lstein/more-model-loading-fixes

This commit is contained in:
blessedcoolant
2023-07-07 00:32:22 +12:00
committed by GitHub
14 changed files with 429 additions and 75 deletions

View File

@@ -1,15 +1,16 @@
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react';
import { MultiSelect, MultiSelectProps } from '@mantine/core';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { memo } from 'react';
import { RefObject, memo } from 'react';
import { mode } from 'theme/util/mode';
type IAIMultiSelectProps = MultiSelectProps & {
tooltip?: string;
inputRef?: RefObject<HTMLInputElement>;
};
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const { searchable = true, tooltip, ...rest } = props;
const { searchable = true, tooltip, inputRef, ...rest } = props;
const {
base50,
base100,
@@ -33,6 +34,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
return (
<Tooltip label={tooltip} placement="top" hasArrow>
<MultiSelect
ref={inputRef}
searchable={searchable}
styles={() => ({
label: {

View File

@@ -0,0 +1,33 @@
import IAIIconButton from 'common/components/IAIIconButton';
import { memo } from 'react';
import { BiCode } from 'react-icons/bi';
type Props = {
onClick: () => void;
};
const AddEmbeddingButton = (props: Props) => {
const { onClick } = props;
return (
<IAIIconButton
size="sm"
aria-label="Add Embedding"
tooltip="Add Embedding"
icon={<BiCode />}
sx={{
p: 2,
color: 'base.700',
_hover: {
color: 'base.550',
},
_active: {
color: 'base.500',
},
}}
variant="link"
onClick={onClick}
/>
);
};
export default memo(AddEmbeddingButton);

View File

@@ -0,0 +1,151 @@
import {
Flex,
Popover,
PopoverBody,
PopoverContent,
PopoverTrigger,
Text,
} from '@chakra-ui/react';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import { forEach } from 'lodash-es';
import {
PropsWithChildren,
forwardRef,
useCallback,
useMemo,
useRef,
} from 'react';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
type EmbeddingSelectItem = {
label: string;
value: string;
description?: string;
};
type Props = PropsWithChildren & {
onSelect: (v: string) => void;
isOpen: boolean;
onClose: () => void;
};
const ParamEmbeddingPopover = (props: Props) => {
const { onSelect, isOpen, onClose, children } = props;
const { data: embeddingQueryData } = useGetTextualInversionModelsQuery();
const inputRef = useRef<HTMLInputElement>(null);
const data = useMemo(() => {
if (!embeddingQueryData) {
return [];
}
const data: EmbeddingSelectItem[] = [];
forEach(embeddingQueryData.entities, (embedding, _) => {
if (!embedding) return;
data.push({
value: embedding.name,
label: embedding.name,
description: embedding.description,
});
});
return data;
}, [embeddingQueryData]);
const handleChange = useCallback(
(v: string[]) => {
if (v.length === 0) {
return;
}
onSelect(v[0]);
},
[onSelect]
);
return (
<Popover
initialFocusRef={inputRef}
isOpen={isOpen}
onClose={onClose}
placement="bottom"
openDelay={0}
closeDelay={0}
closeOnBlur={true}
returnFocusOnClose={true}
>
<PopoverTrigger>{children}</PopoverTrigger>
<PopoverContent
sx={{
p: 0,
top: -1,
shadow: 'dark-lg',
borderColor: 'accent.300',
borderWidth: '2px',
borderStyle: 'solid',
_dark: { borderColor: 'accent.400' },
}}
>
<PopoverBody
sx={{ p: 0, w: `calc(${PARAMETERS_PANEL_WIDTH} - 2rem )` }}
>
{data.length === 0 ? (
<Flex sx={{ justifyContent: 'center', p: 2 }}>
<Text
sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}
>
No Embeddings Loaded
</Text>
</Flex>
) : (
<IAIMantineMultiSelect
inputRef={inputRef}
placeholder={'Add Embedding'}
value={[]}
data={data}
maxDropdownHeight={400}
nothingFound="No Matching Embeddings"
itemComponent={SelectItem}
disabled={data.length === 0}
filter={(value, selected, item: EmbeddingSelectItem) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim())
}
onChange={handleChange}
/>
)}
</PopoverBody>
</PopoverContent>
</Popover>
);
};
export default ParamEmbeddingPopover;
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
value: string;
label: string;
description?: string;
}
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
({ label, description, ...others }: ItemProps, ref) => {
return (
<div ref={ref} {...others}>
<div>
<Text>{label}</Text>
{description && (
<Text size="xs" color="base.600">
{description}
</Text>
)}
</div>
</div>
);
}
);
SelectItem.displayName = 'SelectItem';

View File

@@ -4,7 +4,12 @@ import IAIIconButton from 'common/components/IAIIconButton';
import IAISlider from 'common/components/IAISlider';
import { memo, useCallback } from 'react';
import { FaTrash } from 'react-icons/fa';
import { Lora, loraRemoved, loraWeightChanged } from '../store/loraSlice';
import {
Lora,
loraRemoved,
loraWeightChanged,
loraWeightReset,
} from '../store/loraSlice';
type Props = {
lora: Lora;
@@ -22,7 +27,7 @@ const ParamLora = (props: Props) => {
);
const handleReset = useCallback(() => {
dispatch(loraWeightChanged({ id: lora.id, weight: 1 }));
dispatch(loraWeightReset(lora.id));
}, [dispatch, lora.id]);
const handleRemoveLora = useCallback(() => {

View File

@@ -1,4 +1,4 @@
import { Text } from '@chakra-ui/react';
import { Flex, Text } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@@ -61,6 +61,16 @@ const ParamLoraSelect = () => {
[dispatch, lorasQueryData?.entities]
);
if (lorasQueryData?.ids.length === 0) {
return (
<Flex sx={{ justifyContent: 'center', p: 2 }}>
<Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}>
No LoRAs Loaded
</Text>
</Flex>
);
}
return (
<IAIMantineMultiSelect
placeholder={data.length === 0 ? 'All LoRAs added' : 'Add LoRA'}

View File

@@ -8,7 +8,7 @@ export type Lora = {
};
export const defaultLoRAConfig: Omit<Lora, 'id' | 'name'> = {
weight: 1,
weight: 0.75,
};
export type LoraState = {
@@ -38,9 +38,14 @@ export const loraSlice = createSlice({
const { id, weight } = action.payload;
state.loras[id].weight = weight;
},
loraWeightReset: (state, action: PayloadAction<string>) => {
const id = action.payload;
state.loras[id].weight = defaultLoRAConfig.weight;
},
},
});
export const { loraAdded, loraRemoved, loraWeightChanged } = loraSlice.actions;
export const { loraAdded, loraRemoved, loraWeightChanged, loraWeightReset } =
loraSlice.actions;
export default loraSlice.reducer;

View File

@@ -1,29 +1,107 @@
import { FormControl } from '@chakra-ui/react';
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAITextarea from 'common/components/IAITextarea';
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
import { setNegativePrompt } from 'features/parameters/store/generationSlice';
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
import { flushSync } from 'react-dom';
import { useTranslation } from 'react-i18next';
const ParamNegativeConditioning = () => {
const negativePrompt = useAppSelector(
(state: RootState) => state.generation.negativePrompt
);
const promptRef = useRef<HTMLTextAreaElement>(null);
const { isOpen, onClose, onOpen } = useDisclosure();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChangePrompt = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(setNegativePrompt(e.target.value));
},
[dispatch]
);
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === '<') {
onOpen();
}
},
[onOpen]
);
const handleSelectEmbedding = useCallback(
(v: string) => {
if (!promptRef.current) {
return;
}
// this is where we insert the TI trigger
const caret = promptRef.current.selectionStart;
if (caret === undefined) {
return;
}
let newPrompt = negativePrompt.slice(0, caret);
if (newPrompt[newPrompt.length - 1] !== '<') {
newPrompt += '<';
}
newPrompt += `${v}>`;
// we insert the cursor after the `>`
const finalCaretPos = newPrompt.length;
newPrompt += negativePrompt.slice(caret);
// must flush dom updates else selection gets reset
flushSync(() => {
dispatch(setNegativePrompt(newPrompt));
});
// set the caret position to just after the TI trigger promptRef.current.selectionStart = finalCaretPos;
promptRef.current.selectionEnd = finalCaretPos;
onClose();
},
[dispatch, onClose, negativePrompt]
);
return (
<FormControl>
<IAITextarea
id="negativePrompt"
name="negativePrompt"
value={negativePrompt}
onChange={(e) => dispatch(setNegativePrompt(e.target.value))}
placeholder={t('parameters.negativePromptPlaceholder')}
fontSize="sm"
minH={16}
/>
<ParamEmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={handleSelectEmbedding}
>
<IAITextarea
id="negativePrompt"
name="negativePrompt"
ref={promptRef}
value={negativePrompt}
placeholder={t('parameters.negativePromptPlaceholder')}
onChange={handleChangePrompt}
onKeyDown={handleKeyDown}
resize="vertical"
fontSize="sm"
minH={16}
/>
</ParamEmbeddingPopover>
{!isOpen && (
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineEnd: 0,
}}
>
<AddEmbeddingButton onClick={onOpen} />
</Box>
)}
</FormControl>
);
};

View File

@@ -1,4 +1,4 @@
import { Box, FormControl } from '@chakra-ui/react';
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
@@ -11,12 +11,15 @@ import {
} from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { userInvoked } from 'app/store/actions';
import IAITextarea from 'common/components/IAITextarea';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
import { isEqual } from 'lodash-es';
import { flushSync } from 'react-dom';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
const promptInputSelector = createSelector(
[(state: RootState) => state.generation, activeTabNameSelector],
@@ -40,14 +43,15 @@ const ParamPositiveConditioning = () => {
const dispatch = useAppDispatch();
const { prompt, activeTabName } = useAppSelector(promptInputSelector);
const isReady = useIsReadyToInvoke();
const promptRef = useRef<HTMLTextAreaElement>(null);
const { isOpen, onClose, onOpen } = useDisclosure();
const { t } = useTranslation();
const handleChangePrompt = (e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(setPositivePrompt(e.target.value));
};
const handleChangePrompt = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(setPositivePrompt(e.target.value));
},
[dispatch]
);
useHotkeys(
'alt+a',
@@ -57,6 +61,45 @@ const ParamPositiveConditioning = () => {
[]
);
const handleSelectEmbedding = useCallback(
(v: string) => {
if (!promptRef.current) {
return;
}
// this is where we insert the TI trigger
const caret = promptRef.current.selectionStart;
if (caret === undefined) {
return;
}
let newPrompt = prompt.slice(0, caret);
if (newPrompt[newPrompt.length - 1] !== '<') {
newPrompt += '<';
}
newPrompt += `${v}>`;
// we insert the cursor after the `>`
const finalCaretPos = newPrompt.length;
newPrompt += prompt.slice(caret);
// must flush dom updates else selection gets reset
flushSync(() => {
dispatch(setPositivePrompt(newPrompt));
});
// set the caret position to just after the TI trigger
promptRef.current.selectionStart = finalCaretPos;
promptRef.current.selectionEnd = finalCaretPos;
onClose();
},
[dispatch, onClose, prompt]
);
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
@@ -64,25 +107,50 @@ const ParamPositiveConditioning = () => {
dispatch(clampSymmetrySteps());
dispatch(userInvoked(activeTabName));
}
if (e.key === '<') {
onOpen();
}
},
[dispatch, activeTabName, isReady]
[isReady, dispatch, activeTabName, onOpen]
);
// const handleSelect = (e: MouseEvent<HTMLTextAreaElement>) => {
// const target = e.target as HTMLTextAreaElement;
// setCaret({ start: target.selectionStart, end: target.selectionEnd });
// };
return (
<Box>
<FormControl>
<IAITextarea
id="prompt"
name="prompt"
placeholder={t('parameters.positivePromptPlaceholder')}
value={prompt}
onChange={handleChangePrompt}
onKeyDown={handleKeyDown}
resize="vertical"
ref={promptRef}
minH={32}
/>
<ParamEmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={handleSelectEmbedding}
>
<IAITextarea
id="prompt"
name="prompt"
ref={promptRef}
value={prompt}
placeholder={t('parameters.positivePromptPlaceholder')}
onChange={handleChangePrompt}
onKeyDown={handleKeyDown}
resize="vertical"
minH={32}
/>
</ParamEmbeddingPopover>
</FormControl>
{!isOpen && (
<Box
sx={{
position: 'absolute',
top: 6,
insetInlineEnd: 0,
}}
>
<AddEmbeddingButton onClick={onOpen} />
</Box>
)}
</Box>
);
};

View File

@@ -1,4 +1,3 @@
import { Tooltip } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton, {
IAIIconButtonProps,
@@ -25,26 +24,25 @@ const PinParametersPanelButton = (props: PinParametersPanelButtonProps) => {
};
return (
<Tooltip label={t('common.pinOptionsPanel')}>
<IAIIconButton
{...props}
aria-label={t('common.pinOptionsPanel')}
onClick={handleClickPinOptionsPanel}
icon={shouldPinParametersPanel ? <BsPinAngleFill /> : <BsPinAngle />}
variant="ghost"
size="sm"
sx={{
color: 'base.700',
_hover: {
color: 'base.550',
},
_active: {
color: 'base.500',
},
...sx,
}}
/>
</Tooltip>
<IAIIconButton
{...props}
tooltip={t('common.pinOptionsPanel')}
aria-label={t('common.pinOptionsPanel')}
onClick={handleClickPinOptionsPanel}
icon={shouldPinParametersPanel ? <BsPinAngleFill /> : <BsPinAngle />}
variant="ghost"
size="sm"
sx={{
color: 'base.700',
_hover: {
color: 'base.550',
},
_active: {
color: 'base.500',
},
...sx,
}}
/>
);
};

View File

@@ -1,10 +1,10 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
import { setActiveTabReducer } from './extraReducers';
import { InvokeTabName } from './tabMap';
import { AddNewModelType, UIState } from './uiTypes';
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
export const initialUIState: UIState = {
activeTab: 0,
@@ -19,6 +19,7 @@ export const initialUIState: UIState = {
shouldShowGallery: true,
shouldHidePreview: false,
shouldShowProgressInViewer: true,
shouldShowEmbeddingPicker: false,
favoriteSchedulers: [],
};
@@ -96,6 +97,9 @@ export const uiSlice = createSlice({
) => {
state.favoriteSchedulers = action.payload;
},
toggleEmbeddingPicker: (state) => {
state.shouldShowEmbeddingPicker = !state.shouldShowEmbeddingPicker;
},
},
extraReducers(builder) {
builder.addCase(initialImageChanged, (state) => {
@@ -122,6 +126,7 @@ export const {
toggleGalleryPanel,
setShouldShowProgressInViewer,
favoriteSchedulersChanged,
toggleEmbeddingPicker,
} = uiSlice.actions;
export default uiSlice.reducer;

View File

@@ -27,5 +27,6 @@ export interface UIState {
shouldPinGallery: boolean;
shouldShowGallery: boolean;
shouldShowProgressInViewer: boolean;
shouldShowEmbeddingPicker: boolean;
favoriteSchedulers: SchedulerParam[];
}

View File

@@ -1,18 +1,18 @@
import { Middleware, MiddlewareAPI } from '@reduxjs/toolkit';
import { io, Socket } from 'socket.io-client';
import { Socket, io } from 'socket.io-client';
import { AppThunkDispatch, RootState } from 'app/store/store';
import { getTimestamp } from 'common/util/getTimestamp';
import { sessionCreated } from 'services/api/thunks/session';
import {
ClientToServerEvents,
ServerToClientEvents,
} from 'services/events/types';
import { socketSubscribed, socketUnsubscribed } from './actions';
import { AppThunkDispatch, RootState } from 'app/store/store';
import { getTimestamp } from 'common/util/getTimestamp';
import { sessionCreated } from 'services/api/thunks/session';
// import { OpenAPI } from 'services/api/types';
import { setEventListeners } from 'services/events/util/setEventListeners';
import { log } from 'app/logging/useLogger';
import { $authToken, $baseUrl } from 'services/api/client';
import { setEventListeners } from 'services/events/util/setEventListeners';
const socketioLog = log.child({ namespace: 'socketio' });
@@ -88,7 +88,7 @@ export const socketMiddleware = () => {
socketSubscribed({
sessionId: sessionId,
timestamp: getTimestamp(),
boardId: getState().boards.selectedBoardId,
boardId: getState().gallery.selectedBoardId,
})
);
}