Files
InvokeAI/invokeai/app/api_app.py
Lincoln Stein 01c67c5468 Fix (multiuser): Ask user to log back in when security token has expired (#9017)
* Initial plan

* Warn user when credentials have expired in multiuser mode

Agent-Logs-Url: https://github.com/lstein/InvokeAI/sessions/f0947cda-b15c-475d-b7f4-2d553bdf2cd6

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>

* Address code review: avoid multiple localStorage reads in base query

Agent-Logs-Url: https://github.com/lstein/InvokeAI/sessions/f0947cda-b15c-475d-b7f4-2d553bdf2cd6

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>

* bugfix(multiuser): ask user to log back in when authentication token expires

* feat: sliding window session expiry with token refresh

Backend:
- SlidingWindowTokenMiddleware refreshes JWT on each mutating request
  (POST/PUT/PATCH/DELETE), returning a new token in X-Refreshed-Token
  response header. GET requests don't refresh (they're often background
  fetches that shouldn't reset the inactivity timer).
- CORS expose_headers updated to allow X-Refreshed-Token.

Frontend:
- dynamicBaseQuery picks up X-Refreshed-Token from responses and
  updates localStorage so subsequent requests use the fresh expiry.
- 401 handler only triggers sessionExpiredLogout when a token was
  actually sent (not for unauthenticated background requests).
- ProtectedRoute polls localStorage every 5s and listens for storage
  events to detect token removal (e.g. manual deletion, other tabs).

Result: session expires after TOKEN_EXPIRATION_NORMAL (1 day) of
inactivity, not a fixed time after login. Any user-initiated action
resets the clock.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* chore(backend): ruff

* fix: address review feedback on auth token handling

Bug fixes:
- ProtectedRoute: only treat 401 errors as session expiry, not
  transient 500/network errors that should not force logout
- Token refresh: use explicit remember_me claim in JWT instead of
  inferring from remaining lifetime, preventing silent downgrade of
  7-day tokens to 1-day when <24h remains
- TokenData: add remember_me field, set during login

Tests (6 new):
- Mutating requests (POST/PUT/DELETE) return X-Refreshed-Token
- GET requests do not return X-Refreshed-Token
- Unauthenticated requests do not return X-Refreshed-Token
- Remember-me token refreshes to 7-day duration even near expiry
- Normal token refreshes to 1-day duration
- remember_me claim preserved through refresh cycle

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* chore(backend): ruff

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com>
2026-04-05 23:11:44 -04:00

224 lines
8.5 KiB
Python

import asyncio
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
import invokeai.frontend.web as web_dir
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.api.routers import (
app_info,
auth,
board_images,
boards,
client_state,
download_queue,
images,
model_manager,
model_relationships,
recall_parameters,
session_queue,
style_presets,
utilities,
workflows,
)
from invokeai.app.api.sockets import SocketIO
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.custom_openapi import get_openapi_func
from invokeai.backend.util.logging import InvokeAILogger
app_config = get_config()
logger = InvokeAILogger.get_logger(config=app_config)
loop = asyncio.new_event_loop()
@asynccontextmanager
async def lifespan(app: FastAPI):
# Add startup event to load dependencies
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, loop=loop, logger=logger)
# Log the server address when it starts - in case the network log level is not high enough to see the startup log
proto = "https" if app_config.ssl_certfile else "http"
msg = f"Invoke running on {proto}://{app_config.host}:{app_config.port} (Press CTRL+C to quit)"
# Logging this way ignores the logger's log level and _always_ logs the message
record = logger.makeRecord(
name=logger.name,
level=logging.INFO,
fn="",
lno=0,
msg=msg,
args=(),
exc_info=None,
)
logger.handle(record)
yield
# Shut down threads
ApiDependencies.shutdown()
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
app = FastAPI(
title="Invoke - Community Edition",
docs_url=None,
redoc_url=None,
separate_input_output_schemas=False,
lifespan=lifespan,
)
class SlidingWindowTokenMiddleware(BaseHTTPMiddleware):
"""Refresh the JWT token on each authenticated response.
When a request includes a valid Bearer token, the response includes a
X-Refreshed-Token header with a new token that has a fresh expiry.
This implements sliding-window session expiry: the session only expires
after a period of *inactivity*, not a fixed time after login.
"""
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
response = await call_next(request)
# Only refresh on mutating requests (POST/PUT/PATCH/DELETE) — these indicate
# genuine user activity. GET requests are often background fetches (RTK Query
# cache revalidation, refetch-on-focus, etc.) and should not reset the
# inactivity timer.
if response.status_code < 400 and request.method in ("POST", "PUT", "PATCH", "DELETE"):
auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[7:]
try:
from datetime import timedelta
from invokeai.app.api.routers.auth import TOKEN_EXPIRATION_NORMAL, TOKEN_EXPIRATION_REMEMBER_ME
from invokeai.app.services.auth.token_service import create_access_token, verify_token
token_data = verify_token(token)
if token_data is not None:
# Use the remember_me claim from the token to determine the
# correct refresh duration. This avoids the bug where a 7-day
# token with <24h remaining would be silently downgraded to 1 day.
if token_data.remember_me:
expires_delta = timedelta(days=TOKEN_EXPIRATION_REMEMBER_ME)
else:
expires_delta = timedelta(days=TOKEN_EXPIRATION_NORMAL)
new_token = create_access_token(token_data, expires_delta)
response.headers["X-Refreshed-Token"] = new_token
except Exception:
pass # Don't fail the request if token refresh fails
return response
class RedirectRootWithQueryStringMiddleware(BaseHTTPMiddleware):
"""When a request is made to the root path with a query string, redirect to the root path without the query string.
For example, to force a Gradio app to use dark mode, users may append `?__theme=dark` to the URL. Their browser may
have this query string saved in history or a bookmark, so when the user navigates to `http://127.0.0.1:9090/`, the
browser takes them to `http://127.0.0.1:9090/?__theme=dark`.
This breaks the static file serving in the UI, so we redirect the user to the root path without the query string.
"""
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
if request.url.path == "/" and request.url.query:
return RedirectResponse(url="/")
response = await call_next(request)
return response
# Add the middleware
app.add_middleware(RedirectRootWithQueryStringMiddleware)
app.add_middleware(SlidingWindowTokenMiddleware)
# Add event handler
event_handler_id: int = id(app)
app.add_middleware(
EventHandlerASGIMiddleware,
handlers=[local_handler], # TODO: consider doing this in services to support different configurations
middleware_id=event_handler_id,
)
socket_io = SocketIO(app)
app.add_middleware(
CORSMiddleware,
allow_origins=app_config.allow_origins,
allow_credentials=app_config.allow_credentials,
allow_methods=app_config.allow_methods,
allow_headers=app_config.allow_headers,
expose_headers=["X-Refreshed-Token"],
)
app.add_middleware(GZipMiddleware, minimum_size=1000)
# Include all routers
# Authentication router should be first so it's registered before protected routes
app.include_router(auth.auth_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(model_manager.model_manager_router, prefix="/api")
app.include_router(download_queue.download_queue_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")
app.include_router(model_relationships.model_relationships_router, prefix="/api")
app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")
app.include_router(style_presets.style_presets_router, prefix="/api")
app.include_router(client_state.client_state_router, prefix="/api")
app.include_router(recall_parameters.recall_parameters_router, prefix="/api")
app.openapi = get_openapi_func(app)
@app.get("/docs", include_in_schema=False)
def overridden_swagger() -> HTMLResponse:
return get_swagger_ui_html(
openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
title=f"{app.title} - Swagger UI",
swagger_favicon_url="static/docs/invoke-favicon-docs.svg",
)
@app.get("/redoc", include_in_schema=False)
def overridden_redoc() -> HTMLResponse:
return get_redoc_html(
openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
title=f"{app.title} - Redoc",
redoc_favicon_url="static/docs/invoke-favicon-docs.svg",
)
web_root_path = Path(list(web_dir.__path__)[0])
if app_config.unsafe_disable_picklescan:
logger.warning(
"The unsafe_disable_picklescan option is enabled. This disables malware scanning while installing and"
"loading models, which may allow malicious code to be executed. Use at your own risk."
)
try:
app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui")
except RuntimeError:
logger.warning(f"No UI found at {web_root_path}/dist, skipping UI mount")
app.mount(
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
) # docs favicon is in here