feat(ui,api): support for HF tokens in UI, handle Unauthorized and Forbidden errors

This commit is contained in:
Mary Hipp
2024-10-25 15:52:47 -04:00
committed by psychedelicious
parent 6f0f53849b
commit bcb41399ca
11 changed files with 423 additions and 5 deletions

View File

@@ -1,6 +1,7 @@
# Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for model configuration records."""
import contextlib
import io
import pathlib
import shutil
@@ -10,6 +11,7 @@ from enum import Enum
from tempfile import TemporaryDirectory
from typing import List, Optional, Type
import huggingface_hub
from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.routing import APIRouter
@@ -27,6 +29,7 @@ from invokeai.app.services.model_records import (
ModelRecordChanges,
UnknownModelException,
)
from invokeai.app.util.suppress_output import SuppressOutput
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
@@ -923,3 +926,51 @@ async def get_stats() -> Optional[CacheStats]:
"""Return performance statistics on the model manager's RAM cache. Will return null if no models have been loaded."""
return ApiDependencies.invoker.services.model_manager.load.ram_cache.stats
class HFTokenStatus(str, Enum):
VALID = "valid"
INVALID = "invalid"
UNKNOWN = "unknown"
class HFTokenHelper:
@classmethod
def get_status(cls) -> HFTokenStatus:
try:
if huggingface_hub.get_token_permission(huggingface_hub.get_token()):
# Valid token!
return HFTokenStatus.VALID
# No token set
return HFTokenStatus.INVALID
except Exception:
return HFTokenStatus.UNKNOWN
@classmethod
def set_token(cls, token: str) -> HFTokenStatus:
with SuppressOutput(), contextlib.suppress(Exception):
huggingface_hub.login(token=token, add_to_git_credential=False)
return cls.get_status()
@model_manager_router.get("/hf_login", operation_id="get_hf_login_status", response_model=HFTokenStatus)
async def get_hf_login_status() -> HFTokenStatus:
token_status = HFTokenHelper.get_status()
if token_status is HFTokenStatus.UNKNOWN:
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
return token_status
@model_manager_router.post("/hf_login", operation_id="do_hf_login", response_model=HFTokenStatus)
async def do_hf_login(
token: str = Body(description="Hugging Face token to use for login", embed=True),
) -> HFTokenStatus:
HFTokenHelper.set_token(token)
token_status = HFTokenHelper.get_status()
if token_status is HFTokenStatus.UNKNOWN:
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
return token_status

View File

@@ -733,7 +733,17 @@
"huggingFacePlaceholder": "owner/model-name",
"huggingFaceRepoID": "HuggingFace Repo ID",
"huggingFaceHelper": "If multiple models are found in this repo, you will be prompted to select one to install.",
"hfToken": "HuggingFace Token",
"hfTokenLabel": "HuggingFace Token (Required for some models)",
"hfTokenHelperText": "A HF token is required to use some models. Click here to create or get your token.",
"hfTokenInvalid": "Invalid or Missing HF Token",
"hfForbidden": "You do not have access to this HF model",
"hfForbiddenErrorMessage": "We recommend visiting the repo page on HuggingFace.com. The owner may require acceptance of terms in order to download.",
"hfTokenInvalidErrorMessage": "Invalid or missing HuggingFace token.",
"hfTokenRequired": "You are trying to download a model that requires a valid HuggingFace Token.",
"hfTokenInvalidErrorMessage2": "Update it in the ",
"hfTokenUnableToVerify": "Unable to Verify HF Token",
"hfTokenUnableToVerifyErrorMessage": "Unable to verify HuggingFace token. This is likely due to a network error. Please try again later.",
"hfTokenSaved": "HF Token Saved",
"imageEncoderModelId": "Image Encoder Model ID",
"includesNModels": "Includes {{n}} models and their dependencies",
"installQueue": "Install Queue",

View File

@@ -0,0 +1,49 @@
import { Button, ExternalLink, Text, useToast } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { atom } from 'nanostores';
import { useCallback, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
const FEATURE_ID = 'hfToken';
const TOAST_ID = 'hfForbidden';
/**
* Tracks whether or not the HF Login toast is showing
*/
export const $isHFForbiddenToastOpen = atom<{ isEnabled: boolean; source?: string }>({ isEnabled: false });
export const useHFForbiddenToast = () => {
const { t } = useTranslation();
const toast = useToast();
const isHFForbiddenToastOpen = useStore($isHFForbiddenToastOpen);
useEffect(() => {
if (!isHFForbiddenToastOpen.isEnabled) {
toast.close(TOAST_ID);
return
}
if (isHFForbiddenToastOpen.isEnabled) {
toast({
id: TOAST_ID,
title: t('modelManager.hfForbidden'),
description: (
<Text fontSize="md">
{t('modelManager.hfForbiddenErrorMessage')}
<ExternalLink
label={isHFForbiddenToastOpen.source || ''}
href={`https://huggingface.co/${isHFForbiddenToastOpen.source}`}
/>
</Text>
),
status: 'error',
isClosable: true,
duration: null,
onCloseComplete: () => $isHFForbiddenToastOpen.set({ isEnabled: false }),
});
}
}, [isHFForbiddenToastOpen, t]);
};

View File

@@ -0,0 +1,93 @@
import { Button, Text, useToast } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch } from 'app/store/storeHooks';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { t } from 'i18next';
import { atom } from 'nanostores';
import { useCallback, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetHFTokenStatusQuery } from 'services/api/endpoints/models';
import type { S } from 'services/api/types';
const FEATURE_ID = 'hfToken';
const TOAST_ID = 'hfTokenLogin';
/**
* Tracks whether or not the HF Login toast is showing
*/
export const $isHFLoginToastOpen = atom<boolean>(false);
const getTitle = (token_status: S['HFTokenStatus']) => {
switch (token_status) {
case 'invalid':
return t('modelManager.hfTokenInvalid');
case 'unknown':
return t('modelManager.hfTokenUnableToVerify');
}
};
export const useHFLoginToast = () => {
const isEnabled = useFeatureStatus(FEATURE_ID);
const { data } = useGetHFTokenStatusQuery(isEnabled ? undefined : skipToken);
const toast = useToast();
const isHFLoginToastOpen = useStore($isHFLoginToastOpen);
useEffect(() => {
if (!isHFLoginToastOpen) {
toast.close(TOAST_ID);
return
}
if (isHFLoginToastOpen && data) {
const title = getTitle(data);
toast({
id: TOAST_ID,
title,
description: <ToastDescription token_status={data} />,
status: 'error',
isClosable: true,
duration: null,
onCloseComplete: () => $isHFLoginToastOpen.set(false),
});
}
}, [isHFLoginToastOpen, data]);
};
type Props = {
token_status: S['HFTokenStatus'];
};
const ToastDescription = ({ token_status }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const toast = useToast();
const onClick = useCallback(() => {
dispatch(setActiveTab('models'));
toast.close(FEATURE_ID);
}, [dispatch, toast]);
if (token_status === 'invalid') {
return (
<Text fontSize="md">
{t('modelManager.hfTokenInvalidErrorMessage')} {t('modelManager.hfTokenRequired')}{' '}
{t('modelManager.hfTokenInvalidErrorMessage2')}
<Button onClick={onClick} variant="link" color="base.50" flexGrow={0}>
{t('modelManager.modelManager')}.
</Button>
</Text>
);
}
if (token_status === 'unknown') {
return (
<Text fontSize="md">
{t('modelManager.hfTokenUnableToErrorMessage')}{' '}
<Button onClick={onClick} variant="link" color="base.50" flexGrow={0}>
{t('modelManager.modelManager')}.
</Button>
</Text>
);
}
};

View File

@@ -0,0 +1,81 @@
import {
Button,
ExternalLink,
Flex,
FormControl,
FormErrorMessage,
FormHelperText,
FormLabel,
Input,
useToast,
} from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import type { ChangeEvent } from 'react';
import { useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetHFTokenStatusQuery, useSetHFTokenMutation } from 'services/api/endpoints/models';
import { $isHFLoginToastOpen } from '../../../hooks/useHFLoginToast';
export const HFToken = () => {
const { t } = useTranslation();
const isHFTokenEnabled = useFeatureStatus('hfToken');
const [token, setToken] = useState('');
const [tokenStatus, setTokenStatus] = useState()
const { currentData } = useGetHFTokenStatusQuery(isHFTokenEnabled ? undefined : skipToken);
const [trigger, { isLoading }] = useSetHFTokenMutation();
const toast = useToast();
const onChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setToken(e.target.value);
}, []);
const onClick = useCallback(() => {
trigger({ token })
.unwrap()
.then((res) => {
if (res === 'valid') {
setToken('');
toast({
title: t('modelManager.hfTokenSaved'),
status: 'success',
duration: 3000,
});
$isHFLoginToastOpen.set(false)
}
});
}, [t, toast, token, trigger]);
const error = useMemo(() => {
if (!currentData || isLoading) {
return null;
}
if (token.length && currentData === 'invalid') {
return t('modelManager.hfTokenInvalidErrorMessage');
}
if (token.length && currentData === 'unknown') {
return t('modelManager.hfTokenUnableToVerifyErrorMessage');
}
return null;
}, [currentData, isLoading, t, token.length]);
if (!currentData || currentData === 'valid') {
return null;
}
return (
<Flex borderRadius="base" w="full">
<FormControl isInvalid={Boolean(error)} orientation="vertical">
<FormLabel>{t('modelManager.hfTokenLabel')}</FormLabel>
<Flex gap={3} alignItems="center" w="full">
<Input type="password" value={token} onChange={onChange} />
<Button onClick={onClick} size="sm" isDisabled={token.trim().length === 0} isLoading={isLoading}>
{t('common.save')}
</Button>
</Flex>
<FormHelperText>
<ExternalLink label={t('modelManager.hfTokenHelperText')} href="https://huggingface.co/settings/tokens" />
</FormHelperText>
<FormErrorMessage>{error}</FormErrorMessage>
</FormControl>
</Flex>
);
};

View File

@@ -1,17 +1,22 @@
import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
import { Button, Divider, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import type { ChangeEventHandler } from 'react';
import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
import { useGetHFTokenStatusQuery, useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
import { HuggingFaceResults } from './HuggingFaceResults';
import { HFToken } from './HFToken';
import { skipToken } from '@reduxjs/toolkit/query';
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
export const HuggingFaceForm = memo(() => {
const [huggingFaceRepo, setHuggingFaceRepo] = useState('');
const [displayResults, setDisplayResults] = useState(false);
const [errorMessage, setErrorMessage] = useState('');
const { t } = useTranslation();
const isHFTokenEnabled = useFeatureStatus('hfToken');
const { currentData } = useGetHFTokenStatusQuery(isHFTokenEnabled ? undefined : skipToken);
const [_getHuggingFaceModels, { isLoading, data }] = useLazyGetHuggingFaceModelsQuery();
const [installModel] = useInstallModel();
@@ -41,7 +46,7 @@ export const HuggingFaceForm = memo(() => {
}, []);
return (
<Flex flexDir="column" height="100%" gap={3}>
<Flex flexDir="column" height="100%" gap={4}>
<FormControl isInvalid={!!errorMessage.length} w="full" orientation="vertical" flexShrink={0}>
<FormLabel>{t('modelManager.huggingFaceRepoID')}</FormLabel>
<Flex gap={3} alignItems="center" w="full">
@@ -63,6 +68,7 @@ export const HuggingFaceForm = memo(() => {
<FormHelperText>{t('modelManager.huggingFaceHelper')}</FormHelperText>
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
</FormControl>
{currentData !== 'valid' && <HFToken />}
{data && data.urls && displayResults && <HuggingFaceResults results={data.urls} />}
</Flex>
);

View File

@@ -7,6 +7,8 @@ import { PiPlusBold } from 'react-icons/pi';
import ModelList from './ModelManagerPanel/ModelList';
import { ModelListNavigation } from './ModelManagerPanel/ModelListNavigation';
import { useHFLoginToast } from '../hooks/useHFLoginToast';
import { useHFForbiddenToast } from '../hooks/useHFForbiddenToast';
export const ModelManager = memo(() => {
const { t } = useTranslation();
@@ -16,6 +18,9 @@ export const ModelManager = memo(() => {
}, [dispatch]);
const selectedModelKey = useAppSelector(selectSelectedModelKey);
useHFLoginToast();
useHFForbiddenToast()
return (
<Flex flexDir="column" layerStyle="first" p={4} gap={4} borderRadius="base" w="50%" h="full">
<Flex w="full" gap={4} justifyContent="space-between" alignItems="center">

View File

@@ -3,7 +3,7 @@ import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import queryString from 'query-string';
import type { operations, paths } from 'services/api/schema';
import type { AnyModelConfig } from 'services/api/types';
import type { AnyModelConfig, GetHFTokenStatusResponse, SetHFTokenArg, SetHFTokenResponse } from 'services/api/types';
import type { ApiTagDescription } from '..';
import { api, buildV2Url, LIST_TAG } from '..';
@@ -259,6 +259,22 @@ export const modelsApi = api.injectEndpoints({
query: () => buildModelsUrl('starter_models'),
providesTags: [{ type: 'ModelConfig', id: LIST_TAG }],
}),
getHFTokenStatus: build.query<GetHFTokenStatusResponse, void>({
query: () => buildModelsUrl('hf_login'),
providesTags: ['HFTokenStatus'],
}),
setHFToken: build.mutation<SetHFTokenResponse, SetHFTokenArg>({
query: (body) => ({ url: buildModelsUrl('hf_login'), method: 'POST', body }),
invalidatesTags: ['HFTokenStatus'],
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
try {
const { data } = await queryFulfilled;
dispatch(modelsApi.util.updateQueryData('getHFTokenStatus', undefined, () => data));
} catch {
// no-op
}
},
}),
}),
});
@@ -277,6 +293,8 @@ export const {
useCancelModelInstallMutation,
usePruneCompletedModelInstallsMutation,
useGetStarterModelsQuery,
useGetHFTokenStatusQuery,
useSetHFTokenMutation
} = modelsApi;
export const selectModelConfigsQuery = modelsApi.endpoints.getModelConfigs.select();

View File

@@ -344,6 +344,24 @@ export type paths = {
patch?: never;
trace?: never;
};
"/api/v2/models/hf_login": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/** Get Hf Login Status */
get: operations["get_hf_login_status"];
put?: never;
/** Do Hf Login */
post: operations["do_hf_login"];
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/v1/download_queue/": {
parameters: {
query?: never;
@@ -2157,6 +2175,14 @@ export type components = {
*/
image_names: string[];
};
/** Body_do_hf_login */
Body_do_hf_login: {
/**
* Token
* @description Hugging Face token to use for login
*/
token: string;
};
/** Body_download */
Body_download: {
/**
@@ -7322,6 +7348,11 @@ export type components = {
*/
type: "hf";
};
/**
* HFTokenStatus
* @enum {string}
*/
HFTokenStatus: "valid" | "invalid" | "unknown";
/** HTTPValidationError */
HTTPValidationError: {
/** Detail */
@@ -18274,6 +18305,59 @@ export interface operations {
};
};
};
get_hf_login_status: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HFTokenStatus"];
};
};
};
};
do_hf_login: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody: {
content: {
"application/json": components["schemas"]["Body_do_hf_login"];
};
};
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HFTokenStatus"];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
list_downloads: {
parameters: {
query?: never;

View File

@@ -244,3 +244,13 @@ export type PostUploadAction =
export type BoardRecordOrderBy = S['BoardRecordOrderBy'];
export type StarterModel = S['StarterModel'];
export type GetHFTokenStatusResponse =
paths['/api/v2/models/hf_login']['get']['responses']['200']['content']['application/json'];
export type SetHFTokenResponse = NonNullable<
paths['/api/v2/models/hf_login']['post']['responses']['200']['content']['application/json']
>;
export type SetHFTokenArg = NonNullable<
paths['/api/v2/models/hf_login']['post']['requestBody']['content']['application/json']
>;

View File

@@ -22,6 +22,8 @@ import type { ClientToServerEvents, ServerToClientEvents } from 'services/events
import type { Socket } from 'socket.io-client';
import { $lastProgressEvent } from './stores';
import { $isHFLoginToastOpen } from '../../features/modelManagerV2/hooks/useHFLoginToast';
import { $isHFForbiddenToastOpen } from '../../features/modelManagerV2/hooks/useHFForbiddenToast';
const log = logger('events');
@@ -294,6 +296,15 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
const { id, error, error_type } = data;
const installs = selectModelInstalls(getState()).data;
if (error === "Unauthorized") {
$isHFLoginToastOpen.set(true)
}
if (error === "Forbidden") {
$isHFForbiddenToastOpen.set({isEnabled: true, source: data.source})
}
if (!installs?.find((install) => install.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));