Compare commits
23 Commits
v4.2.8
...
ryan/flux-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fb5db32bb0 | ||
|
|
823c663e1b | ||
|
|
d40c9ff60a | ||
|
|
373b46867a | ||
|
|
dc66952491 | ||
|
|
1b80832b22 | ||
|
|
96b0450b20 | ||
|
|
45792cc152 | ||
|
|
f0baf880b5 | ||
|
|
a8a2fc106d | ||
|
|
d23ad1818d | ||
|
|
4181ab654b | ||
|
|
1c97360f9f | ||
|
|
74d6fceeb6 | ||
|
|
766ddc18dc | ||
|
|
e6ff7488a1 | ||
|
|
89a652cfcd | ||
|
|
b227b9059d | ||
|
|
3599a4a3e4 | ||
|
|
5dd619e137 | ||
|
|
7d447cbb88 | ||
|
|
3bbba7e4b1 | ||
|
|
b1845019fe |
2
.github/workflows/python-checks.yml
vendored
@@ -62,7 +62,7 @@ jobs:
|
||||
|
||||
- name: install ruff
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: pip install ruff==0.6.0
|
||||
run: pip install ruff
|
||||
shell: bash
|
||||
|
||||
- name: ruff check
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
set -eu
|
||||
|
||||
# Ensure we're in the correct folder in case user's CWD is somewhere else
|
||||
scriptdir=$(dirname $(readlink -f "$0"))
|
||||
scriptdir=$(dirname "$0")
|
||||
cd "$scriptdir"
|
||||
|
||||
. .venv/bin/activate
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
from logging import Logger
|
||||
|
||||
import torch
|
||||
@@ -32,8 +31,6 @@ from invokeai.app.services.session_processor.session_processor_default import (
|
||||
)
|
||||
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
|
||||
from invokeai.app.services.urls.urls_default import LocalUrlService
|
||||
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
@@ -66,12 +63,7 @@ class ApiDependencies:
|
||||
invoker: Invoker
|
||||
|
||||
@staticmethod
|
||||
def initialize(
|
||||
config: InvokeAIAppConfig,
|
||||
event_handler_id: int,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
logger: Logger = logger,
|
||||
) -> None:
|
||||
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger) -> None:
|
||||
logger.info(f"InvokeAI version {__version__}")
|
||||
logger.info(f"Root directory = {str(config.root_path)}")
|
||||
|
||||
@@ -82,7 +74,6 @@ class ApiDependencies:
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
model_images_folder = config.models_path
|
||||
style_presets_folder = config.style_presets_path
|
||||
|
||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||
|
||||
@@ -93,7 +84,7 @@ class ApiDependencies:
|
||||
board_images = BoardImagesService()
|
||||
board_records = SqliteBoardRecordStorage(db=db)
|
||||
boards = BoardService()
|
||||
events = FastAPIEventService(event_handler_id, loop=loop)
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
bulk_download = BulkDownloadService()
|
||||
image_records = SqliteImageRecordStorage(db=db)
|
||||
images = ImageService()
|
||||
@@ -118,8 +109,6 @@ class ApiDependencies:
|
||||
session_queue = SqliteSessionQueue(db=db)
|
||||
urls = LocalUrlService()
|
||||
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
||||
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
|
||||
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
|
||||
|
||||
services = InvocationServices(
|
||||
board_image_records=board_image_records,
|
||||
@@ -145,8 +134,6 @@ class ApiDependencies:
|
||||
workflow_records=workflow_records,
|
||||
tensors=tensors,
|
||||
conditioning=conditioning,
|
||||
style_preset_records=style_preset_records,
|
||||
style_preset_image_files=style_preset_image_files,
|
||||
)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
@@ -218,8 +218,9 @@ async def get_image_workflow(
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
@images_router.api_route(
|
||||
"/i/{image_name}/full",
|
||||
methods=["GET", "HEAD"],
|
||||
operation_id="get_image_full",
|
||||
response_class=Response,
|
||||
responses={
|
||||
@@ -230,18 +231,6 @@ async def get_image_workflow(
|
||||
404: {"description": "Image not found"},
|
||||
},
|
||||
)
|
||||
@images_router.head(
|
||||
"/i/{image_name}/full",
|
||||
operation_id="get_image_full_head",
|
||||
response_class=Response,
|
||||
responses={
|
||||
200: {
|
||||
"description": "Return the full-resolution image",
|
||||
"content": {"image/png": {}},
|
||||
},
|
||||
404: {"description": "Image not found"},
|
||||
},
|
||||
)
|
||||
async def get_image_full(
|
||||
image_name: str = Path(description="The name of full-resolution image file to get"),
|
||||
) -> Response:
|
||||
@@ -253,7 +242,6 @@ async def get_image_full(
|
||||
content = f.read()
|
||||
response = Response(content, media_type="image/png")
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
response.headers["Content-Disposition"] = f'inline; filename="{image_name}"'
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
@@ -1,274 +0,0 @@
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
import pydantic
|
||||
from fastapi import APIRouter, File, Form, HTTPException, Path, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_common import StylePresetImageFileNotFoundException
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
||||
InvalidPresetImportDataError,
|
||||
PresetData,
|
||||
PresetType,
|
||||
StylePresetChanges,
|
||||
StylePresetNotFoundError,
|
||||
StylePresetRecordWithImage,
|
||||
StylePresetWithoutId,
|
||||
UnsupportedFileTypeError,
|
||||
parse_presets_from_file,
|
||||
)
|
||||
|
||||
|
||||
class StylePresetFormData(BaseModel):
|
||||
name: str = Field(description="Preset name")
|
||||
positive_prompt: str = Field(description="Positive prompt")
|
||||
negative_prompt: str = Field(description="Negative prompt")
|
||||
type: PresetType = Field(description="Preset type")
|
||||
|
||||
|
||||
style_presets_router = APIRouter(prefix="/v1/style_presets", tags=["style_presets"])
|
||||
|
||||
|
||||
@style_presets_router.get(
|
||||
"/i/{style_preset_id}",
|
||||
operation_id="get_style_preset",
|
||||
responses={
|
||||
200: {"model": StylePresetRecordWithImage},
|
||||
},
|
||||
)
|
||||
async def get_style_preset(
|
||||
style_preset_id: str = Path(description="The style preset to get"),
|
||||
) -> StylePresetRecordWithImage:
|
||||
"""Gets a style preset"""
|
||||
try:
|
||||
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
|
||||
style_preset = ApiDependencies.invoker.services.style_preset_records.get(style_preset_id)
|
||||
return StylePresetRecordWithImage(image=image, **style_preset.model_dump())
|
||||
except StylePresetNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Style preset not found")
|
||||
|
||||
|
||||
@style_presets_router.patch(
|
||||
"/i/{style_preset_id}",
|
||||
operation_id="update_style_preset",
|
||||
responses={
|
||||
200: {"model": StylePresetRecordWithImage},
|
||||
},
|
||||
)
|
||||
async def update_style_preset(
|
||||
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
|
||||
style_preset_id: str = Path(description="The id of the style preset to update"),
|
||||
data: str = Form(description="The data of the style preset to update"),
|
||||
) -> StylePresetRecordWithImage:
|
||||
"""Updates a style preset"""
|
||||
if image is not None:
|
||||
if not image.content_type or not image.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await image.read()
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
|
||||
except Exception:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.style_preset_image_files.save(style_preset_id, pil_image)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
else:
|
||||
try:
|
||||
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
|
||||
except StylePresetImageFileNotFoundException:
|
||||
pass
|
||||
|
||||
try:
|
||||
parsed_data = json.loads(data)
|
||||
validated_data = StylePresetFormData(**parsed_data)
|
||||
|
||||
name = validated_data.name
|
||||
type = validated_data.type
|
||||
positive_prompt = validated_data.positive_prompt
|
||||
negative_prompt = validated_data.negative_prompt
|
||||
|
||||
except pydantic.ValidationError:
|
||||
raise HTTPException(status_code=400, detail="Invalid preset data")
|
||||
|
||||
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
|
||||
changes = StylePresetChanges(name=name, preset_data=preset_data, type=type)
|
||||
|
||||
style_preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
|
||||
style_preset = ApiDependencies.invoker.services.style_preset_records.update(
|
||||
style_preset_id=style_preset_id, changes=changes
|
||||
)
|
||||
return StylePresetRecordWithImage(image=style_preset_image, **style_preset.model_dump())
|
||||
|
||||
|
||||
@style_presets_router.delete(
|
||||
"/i/{style_preset_id}",
|
||||
operation_id="delete_style_preset",
|
||||
)
|
||||
async def delete_style_preset(
|
||||
style_preset_id: str = Path(description="The style preset to delete"),
|
||||
) -> None:
|
||||
"""Deletes a style preset"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
|
||||
except StylePresetImageFileNotFoundException:
|
||||
pass
|
||||
|
||||
ApiDependencies.invoker.services.style_preset_records.delete(style_preset_id)
|
||||
|
||||
|
||||
@style_presets_router.post(
|
||||
"/",
|
||||
operation_id="create_style_preset",
|
||||
responses={
|
||||
200: {"model": StylePresetRecordWithImage},
|
||||
},
|
||||
)
|
||||
async def create_style_preset(
|
||||
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
|
||||
data: str = Form(description="The data of the style preset to create"),
|
||||
) -> StylePresetRecordWithImage:
|
||||
"""Creates a style preset"""
|
||||
|
||||
try:
|
||||
parsed_data = json.loads(data)
|
||||
validated_data = StylePresetFormData(**parsed_data)
|
||||
|
||||
name = validated_data.name
|
||||
type = validated_data.type
|
||||
positive_prompt = validated_data.positive_prompt
|
||||
negative_prompt = validated_data.negative_prompt
|
||||
|
||||
except pydantic.ValidationError:
|
||||
raise HTTPException(status_code=400, detail="Invalid preset data")
|
||||
|
||||
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
|
||||
style_preset = StylePresetWithoutId(name=name, preset_data=preset_data, type=type)
|
||||
new_style_preset = ApiDependencies.invoker.services.style_preset_records.create(style_preset=style_preset)
|
||||
|
||||
if image is not None:
|
||||
if not image.content_type or not image.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await image.read()
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
|
||||
except Exception:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.style_preset_image_files.save(new_style_preset.id, pil_image)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(new_style_preset.id)
|
||||
return StylePresetRecordWithImage(image=preset_image, **new_style_preset.model_dump())
|
||||
|
||||
|
||||
@style_presets_router.get(
|
||||
"/",
|
||||
operation_id="list_style_presets",
|
||||
responses={
|
||||
200: {"model": list[StylePresetRecordWithImage]},
|
||||
},
|
||||
)
|
||||
async def list_style_presets() -> list[StylePresetRecordWithImage]:
|
||||
"""Gets a page of style presets"""
|
||||
style_presets_with_image: list[StylePresetRecordWithImage] = []
|
||||
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many()
|
||||
for preset in style_presets:
|
||||
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(preset.id)
|
||||
style_preset_with_image = StylePresetRecordWithImage(image=image, **preset.model_dump())
|
||||
style_presets_with_image.append(style_preset_with_image)
|
||||
|
||||
return style_presets_with_image
|
||||
|
||||
|
||||
@style_presets_router.get(
|
||||
"/i/{style_preset_id}/image",
|
||||
operation_id="get_style_preset_image",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The style preset image was fetched successfully",
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The style preset image could not be found"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def get_style_preset_image(
|
||||
style_preset_id: str = Path(description="The id of the style preset image to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets an image file that previews the model"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.style_preset_image_files.get_path(style_preset_id)
|
||||
|
||||
response = FileResponse(
|
||||
path,
|
||||
media_type="image/png",
|
||||
filename=style_preset_id + ".png",
|
||||
content_disposition_type="inline",
|
||||
)
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@style_presets_router.get(
|
||||
"/export",
|
||||
operation_id="export_style_presets",
|
||||
responses={200: {"content": {"text/csv": {}}, "description": "A CSV file with the requested data."}},
|
||||
status_code=200,
|
||||
)
|
||||
async def export_style_presets():
|
||||
# Create an in-memory stream to store the CSV data
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
|
||||
# Write the header
|
||||
writer.writerow(["name", "prompt", "negative_prompt"])
|
||||
|
||||
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many(type=PresetType.User)
|
||||
|
||||
for preset in style_presets:
|
||||
writer.writerow([preset.name, preset.preset_data.positive_prompt, preset.preset_data.negative_prompt])
|
||||
|
||||
csv_data = output.getvalue()
|
||||
output.close()
|
||||
|
||||
return Response(
|
||||
content=csv_data,
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": "attachment; filename=prompt_templates.csv"},
|
||||
)
|
||||
|
||||
|
||||
@style_presets_router.post(
|
||||
"/import",
|
||||
operation_id="import_style_presets",
|
||||
)
|
||||
async def import_style_presets(file: UploadFile = File(description="The file to import")):
|
||||
try:
|
||||
style_presets = await parse_presets_from_file(file)
|
||||
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)
|
||||
except InvalidPresetImportDataError as e:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except UnsupportedFileTypeError as e:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail=str(e))
|
||||
@@ -30,7 +30,6 @@ from invokeai.app.api.routers import (
|
||||
images,
|
||||
model_manager,
|
||||
session_queue,
|
||||
style_presets,
|
||||
utilities,
|
||||
workflows,
|
||||
)
|
||||
@@ -56,13 +55,11 @@ mimetypes.add_type("text/css", ".css")
|
||||
torch_device_name = TorchDevice.get_torch_device_name()
|
||||
logger.info(f"Using torch device: {torch_device_name}")
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Add startup event to load dependencies
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, loop=loop, logger=logger)
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||
yield
|
||||
# Shut down threads
|
||||
ApiDependencies.shutdown()
|
||||
@@ -109,7 +106,6 @@ app.include_router(board_images.board_images_router, prefix="/api")
|
||||
app.include_router(app_info.app_router, prefix="/api")
|
||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||
app.include_router(workflows.workflows_router, prefix="/api")
|
||||
app.include_router(style_presets.style_presets_router, prefix="/api")
|
||||
|
||||
app.openapi = get_openapi_func(app)
|
||||
|
||||
@@ -188,6 +184,8 @@ def invoke_api() -> None:
|
||||
|
||||
check_cudnn(logger)
|
||||
|
||||
# Start our own event loop for eventing usage
|
||||
loop = asyncio.new_event_loop()
|
||||
config = uvicorn.Config(
|
||||
app=app,
|
||||
host=app_config.host,
|
||||
|
||||
@@ -21,8 +21,6 @@ from controlnet_aux import (
|
||||
from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import DepthEstimationPipeline
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
@@ -46,12 +44,13 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||
from invokeai.backend.image_util.canny import get_canny_edges
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
|
||||
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
||||
from invokeai.backend.image_util.hed import HEDProcessor
|
||||
from invokeai.backend.image_util.lineart import LineartProcessor
|
||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
@@ -593,14 +592,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
return color_map
|
||||
|
||||
|
||||
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small", "small_v2"]
|
||||
# DepthAnything V2 Small model is licensed under Apache 2.0 but not the base and large models.
|
||||
DEPTH_ANYTHING_MODELS = {
|
||||
"large": "LiheYoung/depth-anything-large-hf",
|
||||
"base": "LiheYoung/depth-anything-base-hf",
|
||||
"small": "LiheYoung/depth-anything-small-hf",
|
||||
"small_v2": "depth-anything/Depth-Anything-V2-Small-hf",
|
||||
}
|
||||
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -608,33 +600,28 @@ DEPTH_ANYTHING_MODELS = {
|
||||
title="Depth Anything Processor",
|
||||
tags=["controlnet", "depth", "depth anything"],
|
||||
category="controlnet",
|
||||
version="1.1.3",
|
||||
version="1.1.2",
|
||||
)
|
||||
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a depth map based on the Depth Anything algorithm"""
|
||||
|
||||
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
|
||||
default="small_v2", description="The size of the depth model to use"
|
||||
default="small", description="The size of the depth model to use"
|
||||
)
|
||||
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
def load_depth_anything(model_path: Path):
|
||||
depth_anything_pipeline = pipeline(model=str(model_path), task="depth-estimation", local_files_only=True)
|
||||
assert isinstance(depth_anything_pipeline, DepthEstimationPipeline)
|
||||
return DepthAnythingPipeline(depth_anything_pipeline)
|
||||
def loader(model_path: Path):
|
||||
return DepthAnythingDetector.load_model(
|
||||
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
|
||||
with self._context.models.load_remote_model(
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
|
||||
) as depth_anything_detector:
|
||||
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
|
||||
depth_map = depth_anything_detector.generate_depth(image)
|
||||
|
||||
# Resizing to user target specified size
|
||||
new_height = int(image.size[1] * (self.resolution / image.size[0]))
|
||||
depth_map = depth_map.resize((self.resolution, new_height))
|
||||
|
||||
return depth_map
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
|
||||
) as model:
|
||||
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
|
||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
|
||||
135
invokeai/app/invocations/flux_text_encoder.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||
from optimum.quanto import qfloat8
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import InputField
|
||||
from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding, TFluxModelKeys
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_text_encoder",
|
||||
title="FLUX Text Encoding",
|
||||
tags=["image"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxTextEncoderInvocation(BaseInvocation):
|
||||
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
|
||||
use_8bit: bool = InputField(
|
||||
default=False, description="Whether to quantize the transformer model to 8-bit precision."
|
||||
)
|
||||
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
|
||||
|
||||
# TODO(ryand): Should we create a new return type for this invocation? This ConditioningOutput is clearly not
|
||||
# compatible with other ConditioningOutputs.
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
|
||||
|
||||
t5_embeddings, clip_embeddings = self._encode_prompt(context, model_path)
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return ConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Determine the T5 max sequence length based on the model.
|
||||
if self.model == "flux-schnell":
|
||||
max_seq_len = 256
|
||||
# elif self.model == "flux-dev":
|
||||
# max_seq_len = 512
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {self.model}")
|
||||
|
||||
# Load the CLIP tokenizer.
|
||||
clip_tokenizer_path = flux_model_dir / "tokenizer"
|
||||
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path, local_files_only=True)
|
||||
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||
|
||||
# Load the T5 tokenizer.
|
||||
t5_tokenizer_path = flux_model_dir / "tokenizer_2"
|
||||
t5_tokenizer = T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
|
||||
assert isinstance(t5_tokenizer, T5TokenizerFast)
|
||||
|
||||
clip_text_encoder_path = flux_model_dir / "text_encoder"
|
||||
t5_text_encoder_path = flux_model_dir / "text_encoder_2"
|
||||
with (
|
||||
context.models.load_local_model(
|
||||
model_path=clip_text_encoder_path, loader=self._load_flux_text_encoder
|
||||
) as clip_text_encoder,
|
||||
context.models.load_local_model(
|
||||
model_path=t5_text_encoder_path, loader=self._load_flux_text_encoder_2
|
||||
) as t5_text_encoder,
|
||||
):
|
||||
assert isinstance(clip_text_encoder, CLIPTextModel)
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
pipeline = FluxPipeline(
|
||||
scheduler=None,
|
||||
vae=None,
|
||||
text_encoder=clip_text_encoder,
|
||||
tokenizer=clip_tokenizer,
|
||||
text_encoder_2=t5_text_encoder,
|
||||
tokenizer_2=t5_tokenizer,
|
||||
transformer=None,
|
||||
)
|
||||
|
||||
# prompt_embeds: T5 embeddings
|
||||
# pooled_prompt_embeds: CLIP embeddings
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
|
||||
prompt=self.positive_prompt,
|
||||
prompt_2=self.positive_prompt,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
max_sequence_length=max_seq_len,
|
||||
)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@staticmethod
|
||||
def _load_flux_text_encoder(path: Path) -> CLIPTextModel:
|
||||
model = CLIPTextModel.from_pretrained(path, local_files_only=True)
|
||||
assert isinstance(model, CLIPTextModel)
|
||||
return model
|
||||
|
||||
def _load_flux_text_encoder_2(self, path: Path) -> T5EncoderModel:
|
||||
if self.use_8bit:
|
||||
model_8bit_path = path / "quantized"
|
||||
if model_8bit_path.exists():
|
||||
# The quantized model exists, load it.
|
||||
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
|
||||
# something that we should be able to make much faster.
|
||||
q_model = QuantizedModelForTextEncoding.from_pretrained(model_8bit_path)
|
||||
|
||||
# Access the underlying wrapped model.
|
||||
# We access the wrapped model, even though it is private, because it simplifies the type checking by
|
||||
# always returning a T5EncoderModel from this function.
|
||||
model = q_model._wrapped
|
||||
else:
|
||||
# The quantized model does not exist yet, quantize and save it.
|
||||
# TODO(ryand): dtype?
|
||||
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
|
||||
assert isinstance(model, T5EncoderModel)
|
||||
|
||||
q_model = QuantizedModelForTextEncoding.quantize(model, weights=qfloat8)
|
||||
|
||||
model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||
q_model.save_pretrained(model_8bit_path)
|
||||
|
||||
# (See earlier comment about accessing the wrapped model.)
|
||||
model = q_model._wrapped
|
||||
else:
|
||||
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
|
||||
|
||||
assert isinstance(model, T5EncoderModel)
|
||||
return model
|
||||
255
invokeai/app/invocations/flux_text_to_image.py
Normal file
@@ -0,0 +1,255 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from einops import rearrange, repeat
|
||||
from flux.model import Flux
|
||||
from flux.modules.autoencoder import AutoEncoder
|
||||
from flux.sampling import denoise, get_noise, get_schedule, unpack
|
||||
from flux.util import configs as flux_configs
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from transformers.models.auto import AutoModelForTextEncoding
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
|
||||
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
TFluxModelKeys = Literal["flux-schnell"]
|
||||
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
|
||||
|
||||
|
||||
class QuantizedFluxTransformer2DModel(FastQuantizedDiffusersModel):
|
||||
base_class = FluxTransformer2DModel
|
||||
|
||||
|
||||
class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
|
||||
auto_class = AutoModelForTextEncoding
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_text_to_image",
|
||||
title="FLUX Text to Image",
|
||||
tags=["image"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Text-to-image generation using a FLUX model."""
|
||||
|
||||
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
|
||||
quantization_type: Literal["raw", "NF4", "llm_int8"] = InputField(
|
||||
default="raw", description="The type of quantization to use for the transformer model."
|
||||
)
|
||||
use_8bit: bool = InputField(
|
||||
default=False, description="Whether to quantize the transformer model to 8-bit precision."
|
||||
)
|
||||
positive_text_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
num_steps: int = InputField(default=4, description="Number of diffusion steps.")
|
||||
guidance: float = InputField(
|
||||
default=4.0,
|
||||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images.",
|
||||
)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
|
||||
flux_transformer_path = context.models.download_and_cache_model(
|
||||
"https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors"
|
||||
)
|
||||
flux_ae_path = context.models.download_and_cache_model(
|
||||
"https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors"
|
||||
)
|
||||
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
|
||||
latents = self._run_diffusion(
|
||||
context, flux_transformer_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds
|
||||
)
|
||||
image = self._run_vae_decoding(context, flux_ae_path, latents)
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
flux_transformer_path: Path,
|
||||
clip_embeddings: torch.Tensor,
|
||||
t5_embeddings: torch.Tensor,
|
||||
):
|
||||
inference_dtype = TorchDevice.choose_torch_dtype()
|
||||
|
||||
# Prepare input noise.
|
||||
# TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a
|
||||
# CPU RNG?
|
||||
x = get_noise(
|
||||
num_samples=1,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
img, img_ids = self._prepare_latent_img_patches(x)
|
||||
|
||||
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
|
||||
is_schnell = "schnell" in str(flux_transformer_path)
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=img.shape[1],
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
|
||||
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
|
||||
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
|
||||
# if the cache is not empty.
|
||||
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
||||
|
||||
with context.models.load_local_model(
|
||||
model_path=flux_transformer_path, loader=self._load_flux_transformer
|
||||
) as transformer:
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=t5_embeddings,
|
||||
txt_ids=txt_ids,
|
||||
vec=clip_embeddings,
|
||||
timesteps=timesteps,
|
||||
guidance=self.guidance,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
|
||||
return x
|
||||
|
||||
def _prepare_latent_img_patches(self, latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert an input image in latent space to patches for diffusion.
|
||||
|
||||
This implementation was extracted from:
|
||||
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
|
||||
"""
|
||||
bs, c, h, w = latent_img.shape
|
||||
|
||||
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
||||
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
# Generate patch position ids.
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
img_ids = img_ids.to(latent_img.device)
|
||||
|
||||
return img, img_ids
|
||||
|
||||
def _run_vae_decoding(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
flux_ae_path: Path,
|
||||
latents: torch.Tensor,
|
||||
) -> Image.Image:
|
||||
with context.models.load_local_model(model_path=flux_ae_path, loader=self._load_flux_vae) as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
# TODO(ryand): Test that this works with both float16 and bfloat16.
|
||||
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
|
||||
img = vae.decode(latents)
|
||||
|
||||
img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
|
||||
return img_pil
|
||||
|
||||
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
|
||||
inference_dtype = TorchDevice.choose_torch_dtype()
|
||||
if self.quantization_type == "raw":
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
params = flux_configs["flux-schnell"].params
|
||||
|
||||
# Initialize the model on the "meta" device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params).to(inference_dtype)
|
||||
|
||||
state_dict = load_file(path)
|
||||
# TODO(ryand): Cast the state_dict to the appropriate dtype?
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
elif self.quantization_type == "NF4":
|
||||
model_path = path.parent / "bnb_nf4.safetensors"
|
||||
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
params = flux_configs["flux-schnell"].params
|
||||
# Initialize the model on the "meta" device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params)
|
||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||
|
||||
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
|
||||
# this on GPUs without bfloat16 support.
|
||||
state_dict = load_file(model_path)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
elif self.quantization_type == "llm_int8":
|
||||
raise NotImplementedError("LLM int8 quantization is not yet supported.")
|
||||
# model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
|
||||
# with accelerate.init_empty_weights():
|
||||
# empty_model = FluxTransformer2DModel.from_config(model_config)
|
||||
# assert isinstance(empty_model, FluxTransformer2DModel)
|
||||
# model_int8_path = path / "bnb_llm_int8"
|
||||
# assert model_int8_path.exists()
|
||||
# with accelerate.init_empty_weights():
|
||||
# model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set())
|
||||
|
||||
# sd = load_file(model_int8_path / "model.safetensors")
|
||||
# model.load_state_dict(sd, strict=True, assign=True)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization type: {self.quantization_type}")
|
||||
|
||||
assert isinstance(model, Flux)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _load_flux_vae(path: Path) -> AutoEncoder:
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
ae_params = flux_configs["flux-schnell"].ae_params
|
||||
with accelerate.init_empty_weights():
|
||||
ae = AutoEncoder(ae_params)
|
||||
|
||||
state_dict = load_file(path)
|
||||
ae.load_state_dict(state_dict, strict=True, assign=True)
|
||||
return ae
|
||||
@@ -91,7 +91,6 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
db_dir: Path to InvokeAI databases directory.
|
||||
outputs_dir: Path to directory for outputs.
|
||||
custom_nodes_dir: Path to directory for custom nodes.
|
||||
style_presets_dir: Path to directory for style presets.
|
||||
log_handlers: Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".
|
||||
log_format: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.<br>Valid values: `plain`, `color`, `syslog`, `legacy`
|
||||
log_level: Emit logging messages at this level or higher.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
|
||||
@@ -154,7 +153,6 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
|
||||
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
|
||||
custom_nodes_dir: Path = Field(default=Path("nodes"), description="Path to directory for custom nodes.")
|
||||
style_presets_dir: Path = Field(default=Path("style_presets"), description="Path to directory for style presets.")
|
||||
|
||||
# LOGGING
|
||||
log_handlers: list[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".')
|
||||
@@ -302,11 +300,6 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
"""Path to the models directory, resolved to an absolute path.."""
|
||||
return self._resolve(self.models_dir)
|
||||
|
||||
@property
|
||||
def style_presets_path(self) -> Path:
|
||||
"""Path to the style presets directory, resolved to an absolute path.."""
|
||||
return self._resolve(self.style_presets_dir)
|
||||
|
||||
@property
|
||||
def convert_cache_path(self) -> Path:
|
||||
"""Path to the converted cache models directory, resolved to an absolute path.."""
|
||||
|
||||
@@ -1,44 +1,46 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from queue import Empty, Queue
|
||||
|
||||
from fastapi_events.dispatcher import dispatch
|
||||
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
from invokeai.app.services.events.events_common import (
|
||||
EventBase,
|
||||
)
|
||||
|
||||
|
||||
class FastAPIEventService(EventServiceBase):
|
||||
def __init__(self, event_handler_id: int, loop: asyncio.AbstractEventLoop) -> None:
|
||||
def __init__(self, event_handler_id: int) -> None:
|
||||
self.event_handler_id = event_handler_id
|
||||
self._queue = asyncio.Queue[EventBase | None]()
|
||||
self._queue = Queue[EventBase | None]()
|
||||
self._stop_event = threading.Event()
|
||||
self._loop = loop
|
||||
|
||||
# We need to store a reference to the task so it doesn't get GC'd
|
||||
# See: https://docs.python.org/3/library/asyncio-task.html#creating-tasks
|
||||
self._background_tasks: set[asyncio.Task[None]] = set()
|
||||
task = self._loop.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.remove)
|
||||
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
|
||||
|
||||
super().__init__()
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self._stop_event.set()
|
||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, None)
|
||||
self._queue.put(None)
|
||||
|
||||
def dispatch(self, event: EventBase) -> None:
|
||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
|
||||
self._queue.put(event)
|
||||
|
||||
async def _dispatch_from_queue(self, stop_event: threading.Event):
|
||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
event = await self._queue.get()
|
||||
event = self._queue.get(block=False)
|
||||
if not event: # Probably stopping
|
||||
continue
|
||||
# Leave the payloads as live pydantic models
|
||||
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0.1)
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
raise e # Raise a proper error
|
||||
|
||||
@@ -4,8 +4,6 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
@@ -63,8 +61,6 @@ class InvocationServices:
|
||||
workflow_records: "WorkflowRecordsStorageBase",
|
||||
tensors: "ObjectSerializerBase[torch.Tensor]",
|
||||
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
|
||||
style_preset_records: "StylePresetRecordsStorageBase",
|
||||
style_preset_image_files: "StylePresetImageFileStorageBase",
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.board_image_records = board_image_records
|
||||
@@ -89,5 +85,3 @@ class InvocationServices:
|
||||
self.workflow_records = workflow_records
|
||||
self.tensors = tensors
|
||||
self.conditioning = conditioning
|
||||
self.style_preset_records = style_preset_records
|
||||
self.style_preset_image_files = style_preset_image_files
|
||||
|
||||
@@ -16,7 +16,6 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@@ -50,7 +49,6 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
|
||||
migrator.register_migration(build_migration_12(app_config=config))
|
||||
migrator.register_migration(build_migration_13())
|
||||
migrator.register_migration(build_migration_14())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration14Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._create_style_presets(cursor)
|
||||
|
||||
def _create_style_presets(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Create the table used to store style presets."""
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS style_presets (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
preset_data TEXT NOT NULL,
|
||||
type TEXT NOT NULL DEFAULT "user",
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS style_presets
|
||||
AFTER UPDATE
|
||||
ON style_presets FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE style_presets SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
# Add indexes for searchable fields
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS idx_style_presets_name ON style_presets(name);",
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
def build_migration_14() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 13 to 14..
|
||||
|
||||
This migration does the following:
|
||||
- Create the table used to store style presets.
|
||||
"""
|
||||
migration_14 = Migration(
|
||||
from_version=13,
|
||||
to_version=14,
|
||||
callback=Migration14Callback(),
|
||||
)
|
||||
|
||||
return migration_14
|
||||
|
Before Width: | Height: | Size: 98 KiB |
|
Before Width: | Height: | Size: 138 KiB |
|
Before Width: | Height: | Size: 122 KiB |
|
Before Width: | Height: | Size: 123 KiB |
|
Before Width: | Height: | Size: 160 KiB |
|
Before Width: | Height: | Size: 146 KiB |
|
Before Width: | Height: | Size: 119 KiB |
|
Before Width: | Height: | Size: 117 KiB |
|
Before Width: | Height: | Size: 110 KiB |
|
Before Width: | Height: | Size: 46 KiB |
|
Before Width: | Height: | Size: 79 KiB |
|
Before Width: | Height: | Size: 156 KiB |
|
Before Width: | Height: | Size: 141 KiB |
|
Before Width: | Height: | Size: 96 KiB |
|
Before Width: | Height: | Size: 91 KiB |
|
Before Width: | Height: | Size: 88 KiB |
|
Before Width: | Height: | Size: 107 KiB |
|
Before Width: | Height: | Size: 132 KiB |
@@ -1,33 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
|
||||
class StylePresetImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, style_preset_id: str) -> PILImageType:
|
||||
"""Retrieves a style preset image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, style_preset_id: str) -> Path:
|
||||
"""Gets the internal path to a style preset image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, style_preset_id: str) -> str | None:
|
||||
"""Gets the URL to fetch a style preset image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, style_preset_id: str, image: PILImageType) -> None:
|
||||
"""Saves a style preset image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
"""Deletes a style preset image."""
|
||||
pass
|
||||
@@ -1,19 +0,0 @@
|
||||
class StylePresetImageFileNotFoundException(Exception):
|
||||
"""Raised when an image file is not found in storage."""
|
||||
|
||||
def __init__(self, message: str = "Style preset image file not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class StylePresetImageFileSaveException(Exception):
|
||||
"""Raised when an image cannot be saved."""
|
||||
|
||||
def __init__(self, message: str = "Style preset image file not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class StylePresetImageFileDeleteException(Exception):
|
||||
"""Raised when an image cannot be deleted."""
|
||||
|
||||
def __init__(self, message: str = "Style preset image file not deleted"):
|
||||
super().__init__(message)
|
||||
@@ -1,88 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_common import (
|
||||
StylePresetImageFileDeleteException,
|
||||
StylePresetImageFileNotFoundException,
|
||||
StylePresetImageFileSaveException,
|
||||
)
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_common import PresetType
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.app.util.thumbnails import make_thumbnail
|
||||
|
||||
|
||||
class StylePresetImageFileStorageDisk(StylePresetImageFileStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
def __init__(self, style_preset_images_folder: Path):
|
||||
self._style_preset_images_folder = style_preset_images_folder
|
||||
self._validate_storage_folders()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
def get(self, style_preset_id: str) -> PILImageType:
|
||||
try:
|
||||
path = self.get_path(style_preset_id)
|
||||
|
||||
return Image.open(path)
|
||||
except FileNotFoundError as e:
|
||||
raise StylePresetImageFileNotFoundException from e
|
||||
|
||||
def save(self, style_preset_id: str, image: PILImageType) -> None:
|
||||
try:
|
||||
self._validate_storage_folders()
|
||||
image_path = self._style_preset_images_folder / (style_preset_id + ".webp")
|
||||
thumbnail = make_thumbnail(image, 256)
|
||||
thumbnail.save(image_path, format="webp")
|
||||
|
||||
except Exception as e:
|
||||
raise StylePresetImageFileSaveException from e
|
||||
|
||||
def get_path(self, style_preset_id: str) -> Path:
|
||||
style_preset = self._invoker.services.style_preset_records.get(style_preset_id)
|
||||
if style_preset.type is PresetType.Default:
|
||||
default_images_dir = Path(__file__).parent / Path("default_style_preset_images")
|
||||
path = default_images_dir / (style_preset.name + ".png")
|
||||
else:
|
||||
path = self._style_preset_images_folder / (style_preset_id + ".webp")
|
||||
|
||||
return path
|
||||
|
||||
def get_url(self, style_preset_id: str) -> str | None:
|
||||
path = self.get_path(style_preset_id)
|
||||
if not self._validate_path(path):
|
||||
return
|
||||
|
||||
url = self._invoker.services.urls.get_style_preset_image_url(style_preset_id)
|
||||
|
||||
# The image URL never changes, so we must add random query string to it to prevent caching
|
||||
url += f"?{uuid_string()}"
|
||||
|
||||
return url
|
||||
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
try:
|
||||
path = self.get_path(style_preset_id)
|
||||
|
||||
if not self._validate_path(path):
|
||||
raise StylePresetImageFileNotFoundException
|
||||
|
||||
path.unlink()
|
||||
|
||||
except StylePresetImageFileNotFoundException as e:
|
||||
raise StylePresetImageFileNotFoundException from e
|
||||
except Exception as e:
|
||||
raise StylePresetImageFileDeleteException from e
|
||||
|
||||
def _validate_path(self, path: Path) -> bool:
|
||||
"""Validates the path given for an image."""
|
||||
return path.exists()
|
||||
|
||||
def _validate_storage_folders(self) -> None:
|
||||
"""Checks if the required folders exist and create them if they don't"""
|
||||
self._style_preset_images_folder.mkdir(parents=True, exist_ok=True)
|
||||
@@ -1,146 +0,0 @@
|
||||
[
|
||||
{
|
||||
"name": "Photography (General)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt}. photography. f/2.8 macro photo, bokeh, photorealism",
|
||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Photography (Studio Lighting)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt}, photography. f/8 photo. centered subject, studio lighting.",
|
||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Photography (Landscape)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt}, landscape photograph, f/12, lifelike, highly detailed.",
|
||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Photography (Portrait)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt}. photography. portraiture. catch light in eyes. one flash. rembrandt lighting. Soft box. dark shadows. High contrast. 80mm lens. F2.8.",
|
||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Photography (Black and White)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} photography. natural light. 80mm lens. F1.4. strong contrast, hard light. dark contrast. blurred background. black and white",
|
||||
"negative_prompt": "painting, digital art. sketch, colour+"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Architectural Visualization",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt}. architectural photography, f/12, luxury, aesthetically pleasing form and function.",
|
||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Concept Art (Fantasy)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "concept artwork of a {prompt}. (digital painterly art style)++, mythological, (textured 2d dry media brushpack)++, glazed brushstrokes, otherworldly. painting+, illustration+",
|
||||
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Concept Art (Sci-Fi)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "(concept art)++, {prompt}, (sleek futurism)++, (textured 2d dry media)++, metallic highlights, digital painting style",
|
||||
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Concept Art (Character)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "(character concept art)++, stylized painterly digital painting of {prompt}, (painterly, impasto. Dry brush.)++",
|
||||
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Concept Art (Painterly)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} oil painting. high contrast. impasto. sfumato. chiaroscuro. Palette knife.",
|
||||
"negative_prompt": "photo. smooth. border. frame"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Environment Art",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} environment artwork, hyper-realistic digital painting style with cinematic composition, atmospheric, depth and detail, voluminous. textured dry brush 2d media",
|
||||
"negative_prompt": "photo, distorted, blurry, out of focus. sketch."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Interior Design (Visualization)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} interior design photo, gentle shadows, light mid-tones, dimension, mix of smooth and textured surfaces, focus on negative space and clean lines, focus",
|
||||
"negative_prompt": "photo, distorted. sketch."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Product Rendering",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} high quality product photography, 3d rendering with key lighting, shallow depth of field, simple plain background, studio lighting.",
|
||||
"negative_prompt": "blurry, sketch, messy, dirty. unfinished."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Sketch",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} black and white pencil drawing, off-center composition, cross-hatching for shadows, bold strokes, textured paper. sketch+++",
|
||||
"negative_prompt": "blurry, photo, painting, color. messy, dirty. unfinished. frame, borders."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Line Art",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} Line art. bold outline. simplistic. white background. 2d",
|
||||
"negative_prompt": "photo. digital art. greyscale. solid black. painting"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Anime",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} anime++, bold outline, cel-shaded coloring, shounen, seinen",
|
||||
"negative_prompt": "(photo)+++. greyscale. solid black. painting"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Illustration",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} illustration, bold linework, illustrative details, vector art style, flat coloring",
|
||||
"negative_prompt": "(photo)+++. greyscale. painting, black and white."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Vehicles",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "A weird futuristic normal auto, {prompt} elegant design, nice color, nice wheels",
|
||||
"negative_prompt": "sketch. digital art. greyscale. painting"
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -1,42 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
||||
PresetType,
|
||||
StylePresetChanges,
|
||||
StylePresetRecordDTO,
|
||||
StylePresetWithoutId,
|
||||
)
|
||||
|
||||
|
||||
class StylePresetRecordsStorageBase(ABC):
|
||||
"""Base class for style preset storage services."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
|
||||
"""Get style preset by id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
|
||||
"""Creates a style preset."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
|
||||
"""Creates many style presets."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
||||
"""Updates a style preset."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
"""Deletes a style preset."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
|
||||
"""Gets many workflows."""
|
||||
pass
|
||||
@@ -1,139 +0,0 @@
|
||||
import codecs
|
||||
import csv
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
import pydantic
|
||||
from fastapi import UploadFile
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
|
||||
|
||||
class StylePresetNotFoundError(Exception):
|
||||
"""Raised when a style preset is not found"""
|
||||
|
||||
|
||||
class PresetData(BaseModel, extra="forbid"):
|
||||
positive_prompt: str = Field(description="Positive prompt")
|
||||
negative_prompt: str = Field(description="Negative prompt")
|
||||
|
||||
|
||||
PresetDataValidator = TypeAdapter(PresetData)
|
||||
|
||||
|
||||
class PresetType(str, Enum, metaclass=MetaEnum):
|
||||
User = "user"
|
||||
Default = "default"
|
||||
Project = "project"
|
||||
|
||||
|
||||
class StylePresetChanges(BaseModel, extra="forbid"):
|
||||
name: Optional[str] = Field(default=None, description="The style preset's new name.")
|
||||
preset_data: Optional[PresetData] = Field(default=None, description="The updated data for style preset.")
|
||||
type: Optional[PresetType] = Field(description="The updated type of the style preset")
|
||||
|
||||
|
||||
class StylePresetWithoutId(BaseModel):
|
||||
name: str = Field(description="The name of the style preset.")
|
||||
preset_data: PresetData = Field(description="The preset data")
|
||||
type: PresetType = Field(description="The type of style preset")
|
||||
|
||||
|
||||
class StylePresetRecordDTO(StylePresetWithoutId):
|
||||
id: str = Field(description="The style preset ID.")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "StylePresetRecordDTO":
|
||||
data["preset_data"] = PresetDataValidator.validate_json(data.get("preset_data", ""))
|
||||
return StylePresetRecordDTOValidator.validate_python(data)
|
||||
|
||||
|
||||
StylePresetRecordDTOValidator = TypeAdapter(StylePresetRecordDTO)
|
||||
|
||||
|
||||
class StylePresetRecordWithImage(StylePresetRecordDTO):
|
||||
image: Optional[str] = Field(description="The path for image")
|
||||
|
||||
|
||||
class StylePresetImportRow(BaseModel):
|
||||
name: str = Field(min_length=1, description="The name of the preset.")
|
||||
positive_prompt: str = Field(
|
||||
default="",
|
||||
description="The positive prompt for the preset.",
|
||||
validation_alias=AliasChoices("positive_prompt", "prompt"),
|
||||
)
|
||||
negative_prompt: str = Field(default="", description="The negative prompt for the preset.")
|
||||
|
||||
model_config = ConfigDict(str_strip_whitespace=True, extra="forbid")
|
||||
|
||||
|
||||
StylePresetImportList = list[StylePresetImportRow]
|
||||
StylePresetImportListTypeAdapter = TypeAdapter(StylePresetImportList)
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(ValueError):
|
||||
"""Raised when an unsupported file type is encountered"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidPresetImportDataError(ValueError):
|
||||
"""Raised when invalid preset import data is encountered"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
async def parse_presets_from_file(file: UploadFile) -> list[StylePresetWithoutId]:
|
||||
"""Parses style presets from a file. The file must be a CSV or JSON file.
|
||||
|
||||
If CSV, the file must have the following columns:
|
||||
- name
|
||||
- prompt (or positive_prompt)
|
||||
- negative_prompt
|
||||
|
||||
If JSON, the file must be a list of objects with the following keys:
|
||||
- name
|
||||
- prompt (or positive_prompt)
|
||||
- negative_prompt
|
||||
|
||||
Args:
|
||||
file (UploadFile): The file to parse.
|
||||
|
||||
Returns:
|
||||
list[StylePresetWithoutId]: The parsed style presets.
|
||||
|
||||
Raises:
|
||||
UnsupportedFileTypeError: If the file type is not supported.
|
||||
InvalidPresetImportDataError: If the data in the file is invalid.
|
||||
"""
|
||||
if file.content_type not in ["text/csv", "application/json"]:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
if file.content_type == "text/csv":
|
||||
csv_reader = csv.DictReader(codecs.iterdecode(file.file, "utf-8"))
|
||||
data = list(csv_reader)
|
||||
else: # file.content_type == "application/json":
|
||||
json_data = await file.read()
|
||||
data = json.loads(json_data)
|
||||
|
||||
try:
|
||||
imported_presets = StylePresetImportListTypeAdapter.validate_python(data)
|
||||
|
||||
style_presets: list[StylePresetWithoutId] = []
|
||||
|
||||
for imported in imported_presets:
|
||||
preset_data = PresetData(positive_prompt=imported.positive_prompt, negative_prompt=imported.negative_prompt)
|
||||
style_preset = StylePresetWithoutId(name=imported.name, preset_data=preset_data, type=PresetType.User)
|
||||
style_presets.append(style_preset)
|
||||
except pydantic.ValidationError as e:
|
||||
if file.content_type == "text/csv":
|
||||
msg = "Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt' and name cannot be blank"
|
||||
else: # file.content_type == "application/json":
|
||||
msg = "Invalid JSON format: must be a list of objects with keys 'name', 'prompt', and 'negative_prompt' and name cannot be blank"
|
||||
raise InvalidPresetImportDataError(msg) from e
|
||||
finally:
|
||||
file.file.close()
|
||||
|
||||
return style_presets
|
||||
@@ -1,215 +0,0 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
||||
PresetType,
|
||||
StylePresetChanges,
|
||||
StylePresetNotFoundError,
|
||||
StylePresetRecordDTO,
|
||||
StylePresetWithoutId,
|
||||
)
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
|
||||
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
self._sync_default_style_presets()
|
||||
|
||||
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
|
||||
"""Gets a style preset by ID."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM style_presets
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
row = self._cursor.fetchone()
|
||||
if row is None:
|
||||
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
|
||||
return StylePresetRecordDTO.from_dict(dict(row))
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
|
||||
style_preset_id = uuid_string()
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO style_presets (
|
||||
id,
|
||||
name,
|
||||
preset_data,
|
||||
type
|
||||
)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
style_preset_id,
|
||||
style_preset.name,
|
||||
style_preset.preset_data.model_dump_json(),
|
||||
style_preset.type,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
|
||||
style_preset_ids = []
|
||||
try:
|
||||
self._lock.acquire()
|
||||
for style_preset in style_presets:
|
||||
style_preset_id = uuid_string()
|
||||
style_preset_ids.append(style_preset_id)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO style_presets (
|
||||
id,
|
||||
name,
|
||||
preset_data,
|
||||
type
|
||||
)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
style_preset_id,
|
||||
style_preset.name,
|
||||
style_preset.preset_data.model_dump_json(),
|
||||
style_preset.type,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
return None
|
||||
|
||||
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Change the name of a style preset
|
||||
if changes.name is not None:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE style_presets
|
||||
SET name = ?
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(changes.name, style_preset_id),
|
||||
)
|
||||
|
||||
# Change the preset data for a style preset
|
||||
if changes.preset_data is not None:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE style_presets
|
||||
SET preset_data = ?
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(changes.preset_data.model_dump_json(), style_preset_id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE from style_presets
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return None
|
||||
|
||||
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
main_query = """
|
||||
SELECT
|
||||
*
|
||||
FROM style_presets
|
||||
"""
|
||||
|
||||
if type is not None:
|
||||
main_query += "WHERE type = ? "
|
||||
|
||||
main_query += "ORDER BY LOWER(name) ASC"
|
||||
|
||||
if type is not None:
|
||||
self._cursor.execute(main_query, (type,))
|
||||
else:
|
||||
self._cursor.execute(main_query)
|
||||
|
||||
rows = self._cursor.fetchall()
|
||||
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
|
||||
|
||||
return style_presets
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _sync_default_style_presets(self) -> None:
|
||||
"""Syncs default style presets to the database. Internal use only."""
|
||||
|
||||
# First delete all existing default style presets
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM style_presets
|
||||
WHERE type = "default";
|
||||
"""
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
# Next, parse and create the default style presets
|
||||
with self._lock, open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
|
||||
presets = json.load(file)
|
||||
for preset in presets:
|
||||
style_preset = StylePresetWithoutId.model_validate(preset)
|
||||
self.create(style_preset)
|
||||
@@ -13,8 +13,3 @@ class UrlServiceBase(ABC):
|
||||
def get_model_image_url(self, model_key: str) -> str:
|
||||
"""Gets the URL for a model image"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_style_preset_image_url(self, style_preset_id: str) -> str:
|
||||
"""Gets the URL for a style preset image"""
|
||||
pass
|
||||
|
||||
@@ -19,6 +19,3 @@ class LocalUrlService(UrlServiceBase):
|
||||
|
||||
def get_model_image_url(self, model_key: str) -> str:
|
||||
return f"{self._base_url_v2}/models/i/{model_key}/image"
|
||||
|
||||
def get_style_preset_image_url(self, style_preset_id: str) -> str:
|
||||
return f"{self._base_url}/style_presets/i/{style_preset_id}/image"
|
||||
|
||||
@@ -81,7 +81,7 @@ def get_openapi_func(
|
||||
# Add the output map to the schema
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
|
||||
"type": "object",
|
||||
"properties": dict(sorted(invocation_output_map_properties.items())),
|
||||
"properties": invocation_output_map_properties,
|
||||
"required": invocation_output_map_required,
|
||||
}
|
||||
|
||||
|
||||
517
invokeai/backend/bnb.py
Normal file
@@ -0,0 +1,517 @@
|
||||
from typing import Any, Optional, Set, Type
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
# The utils in this file take ideas from
|
||||
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
|
||||
|
||||
|
||||
# Patterns:
|
||||
# - Quantize:
|
||||
# - Initialize model on meta device
|
||||
# - Replace layers
|
||||
# - Load state_dict to cpu
|
||||
# - Load state_dict into model
|
||||
# - Quantize on GPU
|
||||
# - Extract state_dict
|
||||
# - Save
|
||||
|
||||
# - Load:
|
||||
# - Initialize model on meta device
|
||||
# - Replace layers
|
||||
# - Load state_dict to cpu
|
||||
# - Load state_dict into model on cpu
|
||||
# - Move to GPU
|
||||
|
||||
|
||||
# class InvokeInt8Params(bnb.nn.Int8Params):
|
||||
# """Overrides `bnb.nn.Int8Params` to add the following functionality:
|
||||
# - Make it possible to load a quantized state dict without putting the weight on a "cuda" device.
|
||||
# """
|
||||
|
||||
# def quantize(self, device: Optional[torch.device] = None):
|
||||
# device = device or torch.device("cuda")
|
||||
# if device.type != "cuda":
|
||||
# raise RuntimeError(f"Int8Params quantization is only supported on CUDA devices ({device=}).")
|
||||
|
||||
# # https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L291-L302
|
||||
# B = self.data.contiguous().half().cuda(device)
|
||||
# if self.has_fp16_weights:
|
||||
# self.data = B
|
||||
# else:
|
||||
# # we store the 8-bit rows-major weight
|
||||
# # we convert this weight to the turning/ampere weight during the first inference pass
|
||||
# CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
|
||||
# del CBt
|
||||
# del SCBt
|
||||
# self.data = CB
|
||||
# self.CB = CB
|
||||
# self.SCB = SCB
|
||||
|
||||
|
||||
class Invoke2Linear8bitLt(torch.nn.Linear):
|
||||
"""This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_features: int,
|
||||
output_features: int,
|
||||
bias=True,
|
||||
has_fp16_weights=True,
|
||||
memory_efficient_backward=False,
|
||||
threshold=0.0,
|
||||
index=None,
|
||||
device=None,
|
||||
):
|
||||
"""
|
||||
Initialize Linear8bitLt class.
|
||||
|
||||
Args:
|
||||
input_features (`int`):
|
||||
Number of input features of the linear layer.
|
||||
output_features (`int`):
|
||||
Number of output features of the linear layer.
|
||||
bias (`bool`, defaults to `True`):
|
||||
Whether the linear class uses the bias term as well.
|
||||
"""
|
||||
super().__init__(input_features, output_features, bias, device)
|
||||
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
|
||||
self.state = bnb.MatmulLtState()
|
||||
self.index = index
|
||||
|
||||
self.state.threshold = threshold
|
||||
self.state.has_fp16_weights = has_fp16_weights
|
||||
self.state.memory_efficient_backward = memory_efficient_backward
|
||||
if threshold > 0.0 and not has_fp16_weights:
|
||||
self.state.use_pool = True
|
||||
|
||||
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
|
||||
self._register_load_state_dict_pre_hook(maybe_rearrange_weight)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
|
||||
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
|
||||
scb_name = "SCB"
|
||||
|
||||
# case 1: .cuda was called, SCB is in self.weight
|
||||
param_from_weight = getattr(self.weight, scb_name)
|
||||
# case 2: self.init_8bit_state was called, SCB is in self.state
|
||||
param_from_state = getattr(self.state, scb_name)
|
||||
# case 3: SCB is in self.state, weight layout reordered after first forward()
|
||||
layout_reordered = self.state.CxB is not None
|
||||
|
||||
key_name = prefix + f"{scb_name}"
|
||||
format_name = prefix + "weight_format"
|
||||
|
||||
if not self.state.has_fp16_weights:
|
||||
if param_from_weight is not None:
|
||||
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
|
||||
destination[format_name] = torch.tensor(0, dtype=torch.uint8)
|
||||
elif param_from_state is not None and not layout_reordered:
|
||||
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
|
||||
destination[format_name] = torch.tensor(0, dtype=torch.uint8)
|
||||
elif param_from_state is not None:
|
||||
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
|
||||
weights_format = self.state.formatB
|
||||
# At this point `weights_format` is an str
|
||||
if weights_format not in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
|
||||
raise ValueError(f"Unrecognized weights format {weights_format}")
|
||||
|
||||
weights_format = LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weights_format]
|
||||
|
||||
destination[format_name] = torch.tensor(weights_format, dtype=torch.uint8)
|
||||
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
super()._load_from_state_dict(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
unexpected_copy = list(unexpected_keys)
|
||||
|
||||
for key in unexpected_copy:
|
||||
input_name = key[len(prefix) :]
|
||||
if input_name == "SCB":
|
||||
if self.weight.SCB is None:
|
||||
# buffers not yet initialized, can't access them directly without quantizing first
|
||||
raise RuntimeError(
|
||||
"Loading a quantized checkpoint into non-quantized Linear8bitLt is "
|
||||
"not supported. Please call module.cuda() before module.load_state_dict()",
|
||||
)
|
||||
|
||||
input_param = state_dict[key]
|
||||
self.weight.SCB.copy_(input_param)
|
||||
|
||||
if self.state.SCB is not None:
|
||||
self.state.SCB = self.weight.SCB
|
||||
|
||||
unexpected_keys.remove(key)
|
||||
|
||||
def init_8bit_state(self):
|
||||
self.state.CB = self.weight.CB
|
||||
self.state.SCB = self.weight.SCB
|
||||
self.weight.CB = None
|
||||
self.weight.SCB = None
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
self.state.is_training = self.training
|
||||
if self.weight.CB is not None:
|
||||
self.init_8bit_state()
|
||||
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
||||
|
||||
if not self.state.has_fp16_weights:
|
||||
if self.state.CB is not None and self.state.CxB is not None:
|
||||
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
||||
# we no longer need the row-major weight
|
||||
del self.state.CB
|
||||
self.weight.data = self.state.CxB
|
||||
return out
|
||||
|
||||
|
||||
class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
||||
"""Wraps `bnb.nn.Linear8bitLt` and adds the following functionality:
|
||||
- enables instantiation directly on the device
|
||||
- re-quantizaton when loading the state dict
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *args: Any, device: Optional[torch.device] = None, threshold: float = 6.0, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(*args, device=device, threshold=threshold, **kwargs)
|
||||
# If the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
|
||||
# filling the device memory with float32 weights which could lead to OOM
|
||||
# if torch.tensor(0, device=device).device.type == "cuda":
|
||||
# self.quantize_()
|
||||
# self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_))
|
||||
# self.register_load_state_dict_post_hook(_ignore_missing_weights_hook)
|
||||
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
super()._load_from_state_dict(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
unexpected_copy = list(unexpected_keys)
|
||||
|
||||
for key in unexpected_copy:
|
||||
input_name = key[len(prefix) :]
|
||||
if input_name == "SCB":
|
||||
if self.weight.SCB is None:
|
||||
# buffers not yet initialized, can't access them directly without quantizing first
|
||||
raise RuntimeError(
|
||||
"Loading a quantized checkpoint into non-quantized Linear8bitLt is "
|
||||
"not supported. Please call module.cuda() before module.load_state_dict()",
|
||||
)
|
||||
|
||||
input_param = state_dict[key]
|
||||
self.weight.SCB.copy_(input_param)
|
||||
|
||||
if self.state.SCB is not None:
|
||||
self.state.SCB = self.weight.SCB
|
||||
|
||||
unexpected_keys.remove(key)
|
||||
|
||||
def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torch.device] = None) -> None:
|
||||
"""Inplace quantize."""
|
||||
if weight is None:
|
||||
weight = self.weight.data
|
||||
if weight.data.dtype == torch.int8:
|
||||
# already quantized
|
||||
return
|
||||
assert isinstance(self.weight, bnb.nn.Int8Params)
|
||||
self.weight = self.quantize(self.weight, weight, device)
|
||||
|
||||
@staticmethod
|
||||
def quantize(
|
||||
int8params: bnb.nn.Int8Params, weight: torch.Tensor, device: Optional[torch.device]
|
||||
) -> bnb.nn.Int8Params:
|
||||
device = device or torch.device("cuda")
|
||||
if device.type != "cuda":
|
||||
raise RuntimeError(f"Unexpected device type: {device.type}")
|
||||
# https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L291-L302
|
||||
B = weight.contiguous().to(device=device, dtype=torch.float16)
|
||||
if int8params.has_fp16_weights:
|
||||
int8params.data = B
|
||||
else:
|
||||
CB, CBt, SCB, SCBt, _ = bnb.functional.double_quant(B)
|
||||
del CBt
|
||||
del SCBt
|
||||
int8params.data = CB
|
||||
int8params.CB = CB
|
||||
int8params.SCB = SCB
|
||||
return int8params
|
||||
|
||||
|
||||
# class _Linear4bit(bnb.nn.Linear4bit):
|
||||
# """Wraps `bnb.nn.Linear4bit` to enable: instantiation directly on the device, re-quantizaton when loading the
|
||||
# state dict, meta-device initialization, and materialization."""
|
||||
|
||||
# def __init__(self, *args: Any, device: Optional[torch.device] = None, **kwargs: Any) -> None:
|
||||
# super().__init__(*args, device=device, **kwargs)
|
||||
# self.weight = cast(bnb.nn.Params4bit, self.weight) # type: ignore[has-type]
|
||||
# self.bias = cast(Optional[torch.nn.Parameter], self.bias) # type: ignore[has-type]
|
||||
# # if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
|
||||
# # filling the device memory with float32 weights which could lead to OOM
|
||||
# if torch.tensor(0, device=device).device.type == "cuda":
|
||||
# self.quantize_()
|
||||
# self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_))
|
||||
# self.register_load_state_dict_post_hook(_ignore_missing_weights_hook)
|
||||
|
||||
# def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torch.device] = None) -> None:
|
||||
# """Inplace quantize."""
|
||||
# if weight is None:
|
||||
# weight = self.weight.data
|
||||
# if weight.data.dtype == torch.uint8:
|
||||
# # already quantized
|
||||
# return
|
||||
# assert isinstance(self.weight, bnb.nn.Params4bit)
|
||||
# self.weight = self.quantize(self.weight, weight, device)
|
||||
|
||||
# @staticmethod
|
||||
# def quantize(
|
||||
# params4bit: bnb.nn.Params4bit, weight: torch.Tensor, device: Optional[torch.device]
|
||||
# ) -> bnb.nn.Params4bit:
|
||||
# device = device or torch.device("cuda")
|
||||
# if device.type != "cuda":
|
||||
# raise RuntimeError(f"Unexpected device type: {device.type}")
|
||||
# # https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L156-L159
|
||||
# w = weight.contiguous().to(device=device, dtype=torch.half)
|
||||
# w_4bit, quant_state = bnb.functional.quantize_4bit(
|
||||
# w,
|
||||
# blocksize=params4bit.blocksize,
|
||||
# compress_statistics=params4bit.compress_statistics,
|
||||
# quant_type=params4bit.quant_type,
|
||||
# )
|
||||
# return _replace_param(params4bit, w_4bit, quant_state)
|
||||
|
||||
# def to_empty(self, *, device: _DEVICE, recurse: bool = True) -> Self:
|
||||
# if self.weight.dtype == torch.uint8: # was quantized
|
||||
# # cannot init the quantized params directly
|
||||
# weight = torch.empty(self.weight.quant_state.shape, device=device, dtype=torch.half)
|
||||
# else:
|
||||
# weight = torch.empty_like(self.weight.data, device=device)
|
||||
# device = torch.device(device)
|
||||
# if device.type == "cuda": # re-quantize
|
||||
# self.quantize_(weight, device)
|
||||
# else:
|
||||
# self.weight = _replace_param(self.weight, weight)
|
||||
# if self.bias is not None:
|
||||
# self.bias = _replace_param(self.bias, torch.empty_like(self.bias, device=device))
|
||||
# return self
|
||||
|
||||
|
||||
def convert_model_to_bnb_llm_int8(model: torch.nn.Module, ignore_modules: set[str]):
|
||||
linear_cls = InvokeLinear8bitLt
|
||||
_convert_linear_layers(model, linear_cls, ignore_modules)
|
||||
|
||||
# TODO(ryand): Is this necessary?
|
||||
# set the compute dtype if necessary
|
||||
# for m in model.modules():
|
||||
# if isinstance(m, bnb.nn.Linear4bit):
|
||||
# m.compute_dtype = self.dtype
|
||||
# m.compute_type_is_set = False
|
||||
|
||||
|
||||
# class BitsandbytesPrecision(Precision):
|
||||
# """Plugin for quantizing weights with `bitsandbytes <https://github.com/TimDettmers/bitsandbytes>`__.
|
||||
|
||||
# .. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
|
||||
|
||||
# .. note::
|
||||
# The optimizer is not automatically replaced with ``bitsandbytes.optim.Adam8bit`` or equivalent 8-bit optimizers.
|
||||
|
||||
# Args:
|
||||
# mode: The quantization mode to use.
|
||||
# dtype: The compute dtype to use.
|
||||
# ignore_modules: The submodules whose Linear layers should not be replaced, for example. ``{"lm_head"}``.
|
||||
# This might be desirable for numerical stability. The string will be checked in as a prefix, so a value like
|
||||
# "transformer.blocks" will ignore all linear layers in all of the transformer blocks.
|
||||
# """
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# mode: Literal["nf4", "nf4-dq", "fp4", "fp4-dq", "int8", "int8-training"],
|
||||
# dtype: Optional[torch.dtype] = None,
|
||||
# ignore_modules: Optional[Set[str]] = None,
|
||||
# ) -> None:
|
||||
# if dtype is None:
|
||||
# # try to be smart about the default selection
|
||||
# if mode.startswith("int8"):
|
||||
# dtype = torch.float16
|
||||
# else:
|
||||
# dtype = (
|
||||
# torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
||||
# )
|
||||
# if mode.startswith("int8") and dtype is not torch.float16:
|
||||
# # this limitation is mentioned in https://huggingface.co/blog/hf-bitsandbytes-integration#usage
|
||||
# raise ValueError(f"{mode!r} only works with `dtype=torch.float16`, but you chose `{dtype}`")
|
||||
|
||||
# globals_ = globals()
|
||||
# mode_to_cls = {
|
||||
# "nf4": globals_["_NF4Linear"],
|
||||
# "nf4-dq": globals_["_NF4DQLinear"],
|
||||
# "fp4": globals_["_FP4Linear"],
|
||||
# "fp4-dq": globals_["_FP4DQLinear"],
|
||||
# "int8-training": globals_["_Linear8bitLt"],
|
||||
# "int8": globals_["_Int8LinearInference"],
|
||||
# }
|
||||
# self._linear_cls = mode_to_cls[mode]
|
||||
# self.dtype = dtype
|
||||
# self.ignore_modules = ignore_modules or set()
|
||||
|
||||
# @override
|
||||
# def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
# # avoid naive users thinking they quantized their model
|
||||
# if not any(isinstance(m, torch.nn.Linear) for m in module.modules()):
|
||||
# raise TypeError(
|
||||
# "You are using the bitsandbytes precision plugin, but your model has no Linear layers. This plugin"
|
||||
# " won't work for your model."
|
||||
# )
|
||||
|
||||
# # convert modules if they haven't been converted already
|
||||
# if not any(isinstance(m, (bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)) for m in module.modules()):
|
||||
# # this will not quantize the model but only replace the layer classes
|
||||
# _convert_layers(module, self._linear_cls, self.ignore_modules)
|
||||
|
||||
# # set the compute dtype if necessary
|
||||
# for m in module.modules():
|
||||
# if isinstance(m, bnb.nn.Linear4bit):
|
||||
# m.compute_dtype = self.dtype
|
||||
# m.compute_type_is_set = False
|
||||
# return module
|
||||
|
||||
|
||||
# def _quantize_on_load_hook(quantize_fn: Callable[[torch.Tensor], None], state_dict: OrderedDict, *_: Any) -> None:
|
||||
# # There is only one key that ends with `*.weight`, the other one is the bias
|
||||
# weight_key = next((name for name in state_dict if name.endswith("weight")), None)
|
||||
# if weight_key is None:
|
||||
# return
|
||||
# # Load the weight from the state dict and re-quantize it
|
||||
# weight = state_dict.pop(weight_key)
|
||||
# quantize_fn(weight)
|
||||
|
||||
|
||||
# def _ignore_missing_weights_hook(module: torch.nn.Module, incompatible_keys: _IncompatibleKeys) -> None:
|
||||
# # since we manually loaded the weight in the `_quantize_on_load_hook` hook, we need to avoid this missing key false
|
||||
# # positive
|
||||
# for key in reversed(incompatible_keys.missing_keys):
|
||||
# if key.endswith("weight"):
|
||||
# incompatible_keys.missing_keys.remove(key)
|
||||
|
||||
|
||||
def _convert_linear_layers(
|
||||
module: torch.nn.Module, linear_cls: Type, ignore_modules: Set[str], prefix: str = ""
|
||||
) -> None:
|
||||
for name, child in module.named_children():
|
||||
fullname = f"{prefix}.{name}" if prefix else name
|
||||
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
||||
has_bias = child.bias is not None
|
||||
# since we are going to copy over the child's data, the device doesn't matter. I chose CPU
|
||||
# to avoid spiking CUDA memory even though initialization is slower
|
||||
# 4bit layers support quantizing from meta-device params so this is only relevant for 8-bit
|
||||
_Linear4bit = globals()["_Linear4bit"]
|
||||
device = torch.device("meta" if issubclass(linear_cls, _Linear4bit) else "cpu")
|
||||
replacement = linear_cls(
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
bias=has_bias,
|
||||
device=device,
|
||||
)
|
||||
if has_bias:
|
||||
replacement.bias = _replace_param(replacement.bias, child.bias.data.clone())
|
||||
state = {"quant_state": replacement.weight.quant_state if issubclass(linear_cls, _Linear4bit) else None}
|
||||
replacement.weight = _replace_param(replacement.weight, child.weight.data.clone(), **state)
|
||||
module.__setattr__(name, replacement)
|
||||
else:
|
||||
_convert_linear_layers(child, linear_cls, ignore_modules, prefix=fullname)
|
||||
|
||||
|
||||
# def _replace_linear_layers(
|
||||
# model: torch.nn.Module,
|
||||
# linear_layer_type: Literal["Linear8bitLt", "Linear4bit"],
|
||||
# modules_to_not_convert: set[str],
|
||||
# current_key_name: str | None = None,
|
||||
# ):
|
||||
# has_been_replaced = False
|
||||
# for name, module in model.named_children():
|
||||
# if current_key_name is None:
|
||||
# current_key_name = []
|
||||
# current_key_name.append(name)
|
||||
# if isinstance(module, torch.nn.Linear) and name not in modules_to_not_convert:
|
||||
# # Check if the current key is not in the `modules_to_not_convert`
|
||||
# current_key_name_str = ".".join(current_key_name)
|
||||
# proceed = True
|
||||
# for key in modules_to_not_convert:
|
||||
# if (
|
||||
# (key in current_key_name_str) and (key + "." in current_key_name_str)
|
||||
# ) or key == current_key_name_str:
|
||||
# proceed = False
|
||||
# break
|
||||
# if proceed:
|
||||
# # Load bnb module with empty weight and replace ``nn.Linear` module
|
||||
# if bnb_quantization_config.load_in_8bit:
|
||||
# bnb_module = bnb.nn.Linear8bitLt(
|
||||
# module.in_features,
|
||||
# module.out_features,
|
||||
# module.bias is not None,
|
||||
# has_fp16_weights=False,
|
||||
# threshold=bnb_quantization_config.llm_int8_threshold,
|
||||
# )
|
||||
# elif bnb_quantization_config.load_in_4bit:
|
||||
# bnb_module = bnb.nn.Linear4bit(
|
||||
# module.in_features,
|
||||
# module.out_features,
|
||||
# module.bias is not None,
|
||||
# bnb_quantization_config.bnb_4bit_compute_dtype,
|
||||
# compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant,
|
||||
# quant_type=bnb_quantization_config.bnb_4bit_quant_type,
|
||||
# )
|
||||
# else:
|
||||
# raise ValueError("load_in_8bit and load_in_4bit can't be both False")
|
||||
# bnb_module.weight.data = module.weight.data
|
||||
# if module.bias is not None:
|
||||
# bnb_module.bias.data = module.bias.data
|
||||
# bnb_module.requires_grad_(False)
|
||||
# setattr(model, name, bnb_module)
|
||||
# has_been_replaced = True
|
||||
# if len(list(module.children())) > 0:
|
||||
# _, _has_been_replaced = _replace_with_bnb_layers(
|
||||
# module, bnb_quantization_config, modules_to_not_convert, current_key_name
|
||||
# )
|
||||
# has_been_replaced = has_been_replaced | _has_been_replaced
|
||||
# # Remove the last key for recursion
|
||||
# current_key_name.pop(-1)
|
||||
# return model, has_been_replaced
|
||||
90
invokeai/backend/image_util/depth_anything/__init__.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import repeat
|
||||
from PIL import Image
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = get_config()
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
|
||||
DEPTH_ANYTHING_MODELS = {
|
||||
"large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
|
||||
"base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
|
||||
"small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
|
||||
}
|
||||
|
||||
|
||||
transform = Compose(
|
||||
[
|
||||
Resize(
|
||||
width=518,
|
||||
height=518,
|
||||
resize_target=False,
|
||||
keep_aspect_ratio=True,
|
||||
ensure_multiple_of=14,
|
||||
resize_method="lower_bound",
|
||||
image_interpolation_method=cv2.INTER_CUBIC,
|
||||
),
|
||||
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
PrepareForNet(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class DepthAnythingDetector:
|
||||
def __init__(self, model: DPT_DINOv2, device: torch.device) -> None:
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
@staticmethod
|
||||
def load_model(
|
||||
model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small"
|
||||
) -> DPT_DINOv2:
|
||||
match model_size:
|
||||
case "small":
|
||||
model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
||||
case "base":
|
||||
model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||
case "large":
|
||||
model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||
|
||||
model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu"))
|
||||
model.eval()
|
||||
|
||||
model.to(device)
|
||||
return model
|
||||
|
||||
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
||||
if not self.model:
|
||||
logger.warn("DepthAnything model was not loaded. Returning original image")
|
||||
return image
|
||||
|
||||
np_image = np.array(image, dtype=np.uint8)
|
||||
np_image = np_image[:, :, ::-1] / 255.0
|
||||
|
||||
image_height, image_width = np_image.shape[:2]
|
||||
np_image = transform({"image": np_image})["image"]
|
||||
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
depth = self.model(tensor_image)
|
||||
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
|
||||
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
||||
|
||||
depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8)
|
||||
depth_map = Image.fromarray(depth_map)
|
||||
|
||||
new_height = int(image_height * (resolution / image_width))
|
||||
depth_map = depth_map.resize((resolution, new_height))
|
||||
|
||||
return depth_map
|
||||
@@ -1,31 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers.pipelines import DepthEstimationPipeline
|
||||
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class DepthAnythingPipeline(RawModel):
|
||||
"""Custom wrapper for the Depth Estimation pipeline from transformers adding compatibility
|
||||
for Invoke's Model Management System"""
|
||||
|
||||
def __init__(self, pipeline: DepthEstimationPipeline) -> None:
|
||||
self._pipeline = pipeline
|
||||
|
||||
def generate_depth(self, image: Image.Image) -> Image.Image:
|
||||
depth_map = self._pipeline(image)["depth"]
|
||||
assert isinstance(depth_map, Image.Image)
|
||||
return depth_map
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
if device is not None and device.type not in {"cpu", "cuda"}:
|
||||
device = None
|
||||
self._pipeline.model.to(device=device, dtype=dtype)
|
||||
self._pipeline.device = self._pipeline.model.device
|
||||
|
||||
def calc_size(self) -> int:
|
||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||
|
||||
return calc_module_size(self._pipeline.model)
|
||||
145
invokeai/backend/image_util/depth_anything/model/blocks.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
||||
scratch = nn.Module()
|
||||
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape
|
||||
out_shape3 = out_shape
|
||||
if len(in_shape) >= 4:
|
||||
out_shape4 = out_shape
|
||||
|
||||
if expand:
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape * 2
|
||||
out_shape3 = out_shape * 4
|
||||
if len(in_shape) >= 4:
|
||||
out_shape4 = out_shape * 8
|
||||
|
||||
scratch.layer1_rn = nn.Conv2d(
|
||||
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer2_rn = nn.Conv2d(
|
||||
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer3_rn = nn.Conv2d(
|
||||
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
if len(in_shape) >= 4:
|
||||
scratch.layer4_rn = nn.Conv2d(
|
||||
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
|
||||
return scratch
|
||||
|
||||
|
||||
class ResidualConvUnit(nn.Module):
|
||||
"""Residual convolution module."""
|
||||
|
||||
def __init__(self, features, activation, bn):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.bn = bn
|
||||
|
||||
self.groups = 1
|
||||
|
||||
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
||||
|
||||
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
||||
|
||||
if self.bn:
|
||||
self.bn1 = nn.BatchNorm2d(features)
|
||||
self.bn2 = nn.BatchNorm2d(features)
|
||||
|
||||
self.activation = activation
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
|
||||
out = self.activation(x)
|
||||
out = self.conv1(out)
|
||||
if self.bn:
|
||||
out = self.bn1(out)
|
||||
|
||||
out = self.activation(out)
|
||||
out = self.conv2(out)
|
||||
if self.bn:
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.groups > 1:
|
||||
out = self.conv_merge(out)
|
||||
|
||||
return self.skip_add.add(out, x)
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
"""Feature fusion block."""
|
||||
|
||||
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super(FeatureFusionBlock, self).__init__()
|
||||
|
||||
self.deconv = deconv
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.groups = 1
|
||||
|
||||
self.expand = expand
|
||||
out_features = features
|
||||
if self.expand:
|
||||
out_features = features // 2
|
||||
|
||||
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
||||
|
||||
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
|
||||
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
self.size = size
|
||||
|
||||
def forward(self, *xs, size=None):
|
||||
"""Forward pass.
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
output = xs[0]
|
||||
|
||||
if len(xs) == 2:
|
||||
res = self.resConfUnit1(xs[1])
|
||||
output = self.skip_add.add(output, res)
|
||||
|
||||
output = self.resConfUnit2(output)
|
||||
|
||||
if (size is None) and (self.size is None):
|
||||
modifier = {"scale_factor": 2}
|
||||
elif size is None:
|
||||
modifier = {"size": self.size}
|
||||
else:
|
||||
modifier = {"size": size}
|
||||
|
||||
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
||||
|
||||
output = self.out_conv(output)
|
||||
|
||||
return output
|
||||
183
invokeai/backend/image_util/depth_anything/model/dpt.py
Normal file
@@ -0,0 +1,183 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from invokeai.backend.image_util.depth_anything.model.blocks import FeatureFusionBlock, _make_scratch
|
||||
|
||||
torchhub_path = Path(__file__).parent.parent / "torchhub"
|
||||
|
||||
|
||||
def _make_fusion_block(features, use_bn, size=None):
|
||||
return FeatureFusionBlock(
|
||||
features,
|
||||
nn.ReLU(False),
|
||||
deconv=False,
|
||||
bn=use_bn,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
class DPTHead(nn.Module):
|
||||
def __init__(self, nclass, in_channels, features, out_channels, use_bn=False, use_clstoken=False):
|
||||
super(DPTHead, self).__init__()
|
||||
|
||||
self.nclass = nclass
|
||||
self.use_clstoken = use_clstoken
|
||||
|
||||
self.projects = nn.ModuleList(
|
||||
[
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channel,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
)
|
||||
for out_channel in out_channels
|
||||
]
|
||||
)
|
||||
|
||||
self.resize_layers = nn.ModuleList(
|
||||
[
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
||||
),
|
||||
nn.Identity(),
|
||||
nn.Conv2d(
|
||||
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if use_clstoken:
|
||||
self.readout_projects = nn.ModuleList()
|
||||
for _ in range(len(self.projects)):
|
||||
self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
|
||||
|
||||
self.scratch = _make_scratch(
|
||||
out_channels,
|
||||
features,
|
||||
groups=1,
|
||||
expand=False,
|
||||
)
|
||||
|
||||
self.scratch.stem_transpose = None
|
||||
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
||||
|
||||
head_features_1 = features
|
||||
head_features_2 = 32
|
||||
|
||||
if nclass > 1:
|
||||
self.scratch.output_conv = nn.Sequential(
|
||||
nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
|
||||
)
|
||||
else:
|
||||
self.scratch.output_conv1 = nn.Conv2d(
|
||||
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
self.scratch.output_conv2 = nn.Sequential(
|
||||
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
||||
nn.ReLU(True),
|
||||
nn.Identity(),
|
||||
)
|
||||
|
||||
def forward(self, out_features, patch_h, patch_w):
|
||||
out = []
|
||||
for i, x in enumerate(out_features):
|
||||
if self.use_clstoken:
|
||||
x, cls_token = x[0], x[1]
|
||||
readout = cls_token.unsqueeze(1).expand_as(x)
|
||||
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
||||
else:
|
||||
x = x[0]
|
||||
|
||||
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
||||
|
||||
x = self.projects[i](x)
|
||||
x = self.resize_layers[i](x)
|
||||
|
||||
out.append(x)
|
||||
|
||||
layer_1, layer_2, layer_3, layer_4 = out
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||
|
||||
out = self.scratch.output_conv1(path_1)
|
||||
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
|
||||
out = self.scratch.output_conv2(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class DPT_DINOv2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
features,
|
||||
out_channels,
|
||||
encoder="vitl",
|
||||
use_bn=False,
|
||||
use_clstoken=False,
|
||||
):
|
||||
super(DPT_DINOv2, self).__init__()
|
||||
|
||||
assert encoder in ["vits", "vitb", "vitl"]
|
||||
|
||||
# # in case the Internet connection is not stable, please load the DINOv2 locally
|
||||
# if use_local:
|
||||
# self.pretrained = torch.hub.load(
|
||||
# torchhub_path / "facebookresearch_dinov2_main",
|
||||
# "dinov2_{:}14".format(encoder),
|
||||
# source="local",
|
||||
# pretrained=False,
|
||||
# )
|
||||
# else:
|
||||
# self.pretrained = torch.hub.load(
|
||||
# "facebookresearch/dinov2",
|
||||
# "dinov2_{:}14".format(encoder),
|
||||
# )
|
||||
|
||||
self.pretrained = torch.hub.load(
|
||||
"facebookresearch/dinov2",
|
||||
"dinov2_{:}14".format(encoder),
|
||||
)
|
||||
|
||||
dim = self.pretrained.blocks[0].attn.qkv.in_features
|
||||
|
||||
self.depth_head = DPTHead(1, dim, features, out_channels=out_channels, use_bn=use_bn, use_clstoken=use_clstoken)
|
||||
|
||||
def forward(self, x):
|
||||
h, w = x.shape[-2:]
|
||||
|
||||
features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
|
||||
|
||||
patch_h, patch_w = h // 14, w // 14
|
||||
|
||||
depth = self.depth_head(features, patch_h, patch_w)
|
||||
depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
|
||||
depth = F.relu(depth)
|
||||
|
||||
return depth.squeeze(1)
|
||||
227
invokeai/backend/image_util/depth_anything/utilities/util.py
Normal file
@@ -0,0 +1,227 @@
|
||||
import math
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
||||
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
||||
|
||||
Args:
|
||||
sample (dict): sample
|
||||
size (tuple): image size
|
||||
|
||||
Returns:
|
||||
tuple: new size
|
||||
"""
|
||||
shape = list(sample["disparity"].shape)
|
||||
|
||||
if shape[0] >= size[0] and shape[1] >= size[1]:
|
||||
return sample
|
||||
|
||||
scale = [0, 0]
|
||||
scale[0] = size[0] / shape[0]
|
||||
scale[1] = size[1] / shape[1]
|
||||
|
||||
scale = max(scale)
|
||||
|
||||
shape[0] = math.ceil(scale * shape[0])
|
||||
shape[1] = math.ceil(scale * shape[1])
|
||||
|
||||
# resize
|
||||
sample["image"] = cv2.resize(sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method)
|
||||
|
||||
sample["disparity"] = cv2.resize(sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST)
|
||||
sample["mask"] = cv2.resize(
|
||||
sample["mask"].astype(np.float32),
|
||||
tuple(shape[::-1]),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
sample["mask"] = sample["mask"].astype(bool)
|
||||
|
||||
return tuple(shape)
|
||||
|
||||
|
||||
class Resize(object):
|
||||
"""Resize sample to given size (width, height)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width,
|
||||
height,
|
||||
resize_target=True,
|
||||
keep_aspect_ratio=False,
|
||||
ensure_multiple_of=1,
|
||||
resize_method="lower_bound",
|
||||
image_interpolation_method=cv2.INTER_AREA,
|
||||
):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
width (int): desired output width
|
||||
height (int): desired output height
|
||||
resize_target (bool, optional):
|
||||
True: Resize the full sample (image, mask, target).
|
||||
False: Resize image only.
|
||||
Defaults to True.
|
||||
keep_aspect_ratio (bool, optional):
|
||||
True: Keep the aspect ratio of the input sample.
|
||||
Output sample might not have the given width and height, and
|
||||
resize behaviour depends on the parameter 'resize_method'.
|
||||
Defaults to False.
|
||||
ensure_multiple_of (int, optional):
|
||||
Output width and height is constrained to be multiple of this parameter.
|
||||
Defaults to 1.
|
||||
resize_method (str, optional):
|
||||
"lower_bound": Output will be at least as large as the given size.
|
||||
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller
|
||||
than given size.)
|
||||
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
||||
Defaults to "lower_bound".
|
||||
"""
|
||||
self.__width = width
|
||||
self.__height = height
|
||||
|
||||
self.__resize_target = resize_target
|
||||
self.__keep_aspect_ratio = keep_aspect_ratio
|
||||
self.__multiple_of = ensure_multiple_of
|
||||
self.__resize_method = resize_method
|
||||
self.__image_interpolation_method = image_interpolation_method
|
||||
|
||||
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
||||
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
if max_val is not None and y > max_val:
|
||||
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
if y < min_val:
|
||||
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
return y
|
||||
|
||||
def get_size(self, width, height):
|
||||
# determine new height and width
|
||||
scale_height = self.__height / height
|
||||
scale_width = self.__width / width
|
||||
|
||||
if self.__keep_aspect_ratio:
|
||||
if self.__resize_method == "lower_bound":
|
||||
# scale such that output size is lower bound
|
||||
if scale_width > scale_height:
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
elif self.__resize_method == "upper_bound":
|
||||
# scale such that output size is upper bound
|
||||
if scale_width < scale_height:
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
elif self.__resize_method == "minimal":
|
||||
# scale as least as possbile
|
||||
if abs(1 - scale_width) < abs(1 - scale_height):
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
else:
|
||||
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
||||
|
||||
if self.__resize_method == "lower_bound":
|
||||
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
|
||||
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
|
||||
elif self.__resize_method == "upper_bound":
|
||||
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
|
||||
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
|
||||
elif self.__resize_method == "minimal":
|
||||
new_height = self.constrain_to_multiple_of(scale_height * height)
|
||||
new_width = self.constrain_to_multiple_of(scale_width * width)
|
||||
else:
|
||||
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
||||
|
||||
return (new_width, new_height)
|
||||
|
||||
def __call__(self, sample):
|
||||
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
|
||||
|
||||
# resize sample
|
||||
sample["image"] = cv2.resize(
|
||||
sample["image"],
|
||||
(width, height),
|
||||
interpolation=self.__image_interpolation_method,
|
||||
)
|
||||
|
||||
if self.__resize_target:
|
||||
if "disparity" in sample:
|
||||
sample["disparity"] = cv2.resize(
|
||||
sample["disparity"],
|
||||
(width, height),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
|
||||
if "depth" in sample:
|
||||
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
if "semseg_mask" in sample:
|
||||
# sample["semseg_mask"] = cv2.resize(
|
||||
# sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
|
||||
# )
|
||||
sample["semseg_mask"] = F.interpolate(
|
||||
torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode="nearest"
|
||||
).numpy()[0, 0]
|
||||
|
||||
if "mask" in sample:
|
||||
sample["mask"] = cv2.resize(
|
||||
sample["mask"].astype(np.float32),
|
||||
(width, height),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
# sample["mask"] = sample["mask"].astype(bool)
|
||||
|
||||
# print(sample['image'].shape, sample['depth'].shape)
|
||||
return sample
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
"""Normlize image by given mean and std."""
|
||||
|
||||
def __init__(self, mean, std):
|
||||
self.__mean = mean
|
||||
self.__std = std
|
||||
|
||||
def __call__(self, sample):
|
||||
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class PrepareForNet(object):
|
||||
"""Prepare sample for usage as network input."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, sample):
|
||||
image = np.transpose(sample["image"], (2, 0, 1))
|
||||
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
||||
|
||||
if "mask" in sample:
|
||||
sample["mask"] = sample["mask"].astype(np.float32)
|
||||
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
||||
|
||||
if "depth" in sample:
|
||||
depth = sample["depth"].astype(np.float32)
|
||||
sample["depth"] = np.ascontiguousarray(depth)
|
||||
|
||||
if "semseg_mask" in sample:
|
||||
sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
|
||||
sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
|
||||
|
||||
return sample
|
||||
129
invokeai/backend/load_flux_model.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from diffusers.models.model_loading_utils import load_state_dict
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from diffusers.utils import (
|
||||
CONFIG_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
_get_checkpoint_shard_files,
|
||||
is_accelerate_available,
|
||||
)
|
||||
from optimum.quanto import qfloat8
|
||||
from optimum.quanto.models import QuantizedDiffusersModel
|
||||
from optimum.quanto.models.shared_dict import ShardedStateDict
|
||||
|
||||
from invokeai.backend.requantize import requantize
|
||||
|
||||
|
||||
class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
|
||||
base_class = FluxTransformer2DModel
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
|
||||
if cls.base_class is None:
|
||||
raise ValueError("The `base_class` attribute needs to be configured.")
|
||||
|
||||
if not is_accelerate_available():
|
||||
raise ValueError("Reloading a quantized diffusers model requires the accelerate library.")
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
if os.path.isdir(model_name_or_path):
|
||||
# Look for a quantization map
|
||||
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
|
||||
if not os.path.exists(qmap_path):
|
||||
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
|
||||
|
||||
# Look for original model config file.
|
||||
model_config_path = os.path.join(model_name_or_path, CONFIG_NAME)
|
||||
if not os.path.exists(model_config_path):
|
||||
raise ValueError(f"{CONFIG_NAME} not found in {model_name_or_path}.")
|
||||
|
||||
with open(qmap_path, "r", encoding="utf-8") as f:
|
||||
qmap = json.load(f)
|
||||
|
||||
with open(model_config_path, "r", encoding="utf-8") as f:
|
||||
original_model_cls_name = json.load(f)["_class_name"]
|
||||
configured_cls_name = cls.base_class.__name__
|
||||
if configured_cls_name != original_model_cls_name:
|
||||
raise ValueError(
|
||||
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})."
|
||||
)
|
||||
|
||||
# Create an empty model
|
||||
config = cls.base_class.load_config(model_name_or_path)
|
||||
with init_empty_weights():
|
||||
model = cls.base_class.from_config(config)
|
||||
|
||||
# Look for the index of a sharded checkpoint
|
||||
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||
if os.path.exists(checkpoint_file):
|
||||
# Convert the checkpoint path to a list of shards
|
||||
_, sharded_metadata = _get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
|
||||
# Create a mapping for the sharded safetensor files
|
||||
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
|
||||
else:
|
||||
# Look for a single checkpoint file
|
||||
checkpoint_file = os.path.join(model_name_or_path, SAFETENSORS_WEIGHTS_NAME)
|
||||
if not os.path.exists(checkpoint_file):
|
||||
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
|
||||
# Get state_dict from model checkpoint
|
||||
state_dict = load_state_dict(checkpoint_file)
|
||||
|
||||
# Requantize and load quantized weights from state_dict
|
||||
requantize(model, state_dict=state_dict, quantization_map=qmap)
|
||||
model.eval()
|
||||
return cls(model)
|
||||
else:
|
||||
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
|
||||
|
||||
|
||||
def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
|
||||
# model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||
model_8bit_path = path / "quantized"
|
||||
if model_8bit_path.exists():
|
||||
# The quantized model exists, load it.
|
||||
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
|
||||
# something that we should be able to make much faster.
|
||||
q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)
|
||||
|
||||
# Access the underlying wrapped model.
|
||||
# We access the wrapped model, even though it is private, because it simplifies the type checking by
|
||||
# always returning a FluxTransformer2DModel from this function.
|
||||
model = q_model._wrapped
|
||||
else:
|
||||
# The quantized model does not exist yet, quantize and save it.
|
||||
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
|
||||
# GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
|
||||
# here.
|
||||
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||
assert isinstance(model, FluxTransformer2DModel)
|
||||
|
||||
q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8)
|
||||
|
||||
model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||
q_model.save_pretrained(model_8bit_path)
|
||||
|
||||
# (See earlier comment about accessing the wrapped model.)
|
||||
model = q_model._wrapped
|
||||
|
||||
assert isinstance(model, FluxTransformer2DModel)
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
start = time.time()
|
||||
model = load_flux_transformer(
|
||||
Path("/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/")
|
||||
)
|
||||
print(f"Time to load: {time.time() - start}s")
|
||||
print("hi")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
124
invokeai/backend/load_flux_model_bnb_llm_int8_old.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from accelerate.utils import BnbQuantizationConfig, load_and_quantize_model
|
||||
from accelerate.utils.bnb import get_keys_to_not_convert
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from invokeai.backend.bnb import quantize_model_llm_int8
|
||||
|
||||
# Docs:
|
||||
# https://huggingface.co/docs/accelerate/usage_guides/quantization
|
||||
# https://huggingface.co/docs/bitsandbytes/v0.43.3/en/integrations#accelerate
|
||||
|
||||
|
||||
def get_parameter_device(parameter: torch.nn.Module):
|
||||
return next(parameter.parameters()).device
|
||||
|
||||
|
||||
# def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], llm_int8_threshold: int = 6):
|
||||
# """Apply bitsandbytes LLM.8bit() quantization to the model."""
|
||||
# model_device = get_parameter_device(model)
|
||||
# if model_device.type != "meta":
|
||||
# # Note: This is not strictly required, but I can't think of a good reason to quantize a model that's not on the
|
||||
# # meta device, so we enforce it for now.
|
||||
# raise RuntimeError("The model should be on the meta device to apply LLM.8bit() quantization.")
|
||||
|
||||
# bnb_quantization_config = BnbQuantizationConfig(
|
||||
# load_in_8bit=True,
|
||||
# llm_int8_threshold=llm_int8_threshold,
|
||||
# )
|
||||
|
||||
# with accelerate.init_empty_weights():
|
||||
# model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert)
|
||||
|
||||
# return model
|
||||
|
||||
|
||||
def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
|
||||
model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
|
||||
with accelerate.init_empty_weights():
|
||||
empty_model = FluxTransformer2DModel.from_config(model_config)
|
||||
assert isinstance(empty_model, FluxTransformer2DModel)
|
||||
|
||||
bnb_quantization_config = BnbQuantizationConfig(
|
||||
load_in_8bit=True,
|
||||
llm_int8_threshold=6,
|
||||
)
|
||||
|
||||
model_8bit_path = path / "bnb_llm_int8"
|
||||
if model_8bit_path.exists():
|
||||
# The quantized model already exists, load it and return it.
|
||||
# Note that the model loading code is the same when loading from quantized vs original weights. The only
|
||||
# difference is the weights_location.
|
||||
# model = load_and_quantize_model(
|
||||
# empty_model,
|
||||
# weights_location=model_8bit_path,
|
||||
# bnb_quantization_config=bnb_quantization_config,
|
||||
# # device_map="auto",
|
||||
# device_map={"": "cpu"},
|
||||
# )
|
||||
|
||||
# TODO: Handle the keys that were not quantized (get_keys_to_not_convert).
|
||||
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set())
|
||||
|
||||
# model = quantize_model_llm_int8(empty_model, set())
|
||||
|
||||
# Load sharded state dict.
|
||||
files = list(path.glob("*.safetensors"))
|
||||
state_dict = dict()
|
||||
for file in files:
|
||||
sd = load_file(file)
|
||||
state_dict.update(sd)
|
||||
|
||||
else:
|
||||
# The quantized model does not exist yet, quantize and save it.
|
||||
model = load_and_quantize_model(
|
||||
empty_model,
|
||||
weights_location=path,
|
||||
bnb_quantization_config=bnb_quantization_config,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
keys_to_not_convert = get_keys_to_not_convert(empty_model) # TODO
|
||||
|
||||
model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||
accl = accelerate.Accelerator()
|
||||
accl.save_model(model, model_8bit_path)
|
||||
|
||||
# ---------------------
|
||||
|
||||
# model = quantize_model_llm_int8(empty_model, set())
|
||||
|
||||
# # Load sharded state dict.
|
||||
# files = list(path.glob("*.safetensors"))
|
||||
# state_dict = dict()
|
||||
# for file in files:
|
||||
# sd = load_file(file)
|
||||
# state_dict.update(sd)
|
||||
|
||||
# # Load the state dict into the model. The bitsandbytes layers know how to load from both quantized and
|
||||
# # non-quantized state dicts.
|
||||
# result = model.load_state_dict(state_dict, strict=True)
|
||||
# model = model.to("cuda")
|
||||
|
||||
# ---------------------
|
||||
|
||||
assert isinstance(model, FluxTransformer2DModel)
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
start = time.time()
|
||||
model = load_flux_transformer(
|
||||
Path("/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/")
|
||||
)
|
||||
print(f"Time to load: {time.time() - start}s")
|
||||
print("hi")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -220,17 +220,11 @@ class LoKRLayer(LoRALayerBase):
|
||||
if self.w1 is None:
|
||||
self.w1_a = values["lokr_w1_a"]
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
else:
|
||||
self.w1_b = None
|
||||
self.w1_a = None
|
||||
|
||||
self.w2 = values.get("lokr_w2", None)
|
||||
if self.w2 is None:
|
||||
self.w2_a = values["lokr_w2_a"]
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
else:
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
|
||||
self.t2 = values.get("lokr_t2", None)
|
||||
|
||||
@@ -378,39 +372,7 @@ class IA3Layer(LoRALayerBase):
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class NormLayer(LoRALayerBase):
|
||||
# bias handled in LoRALayerBase(calc_size, to)
|
||||
# weight: torch.Tensor
|
||||
# bias: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["w_norm"]
|
||||
self.bias = values.get("b_norm", None)
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"w_norm", "b_norm"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||
|
||||
|
||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
@@ -551,10 +513,6 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
elif "on_input" in values:
|
||||
layer = IA3Layer(layer_key, values)
|
||||
|
||||
# norms
|
||||
elif "w_norm" in values:
|
||||
layer = NormLayer(layer_key, values)
|
||||
|
||||
else:
|
||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||
raise Exception("Unknown lora format!")
|
||||
|
||||
@@ -11,7 +11,6 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
@@ -46,7 +45,6 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
SpandrelImageToImageModel,
|
||||
GroundingDinoPipeline,
|
||||
SegmentAnythingPipeline,
|
||||
DepthAnythingPipeline,
|
||||
),
|
||||
):
|
||||
return model.calc_size()
|
||||
|
||||
@@ -54,6 +54,7 @@ def filter_files(
|
||||
"lora_weights.safetensors",
|
||||
"weights.pb",
|
||||
"onnx_data",
|
||||
"spiece.model", # Added for `black-forest-labs/FLUX.1-schnell`.
|
||||
)
|
||||
):
|
||||
paths.append(file)
|
||||
@@ -62,7 +63,7 @@ def filter_files(
|
||||
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
||||
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
||||
# will adhere to this naming convention, so this is an area to be careful of.
|
||||
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
||||
elif re.search(r"model.*\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
||||
paths.append(file)
|
||||
|
||||
# limit search to subfolder if requested
|
||||
@@ -97,7 +98,9 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
if variant == ModelRepoVariant.Flax:
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".json", ".txt"]:
|
||||
# Note: '.model' was added to support:
|
||||
# https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/768d12a373ed5cc9ef9a9dea7504dc09fcc14842/tokenizer_2/spiece.model
|
||||
elif path.suffix in [".json", ".txt", ".model"]:
|
||||
result.add(path)
|
||||
|
||||
elif variant in [
|
||||
@@ -140,6 +143,23 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
continue
|
||||
|
||||
for candidate_list in subfolder_weights.values():
|
||||
# Check if at least one of the files has the explicit fp16 variant.
|
||||
at_least_one_fp16 = False
|
||||
for candidate in candidate_list:
|
||||
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
|
||||
at_least_one_fp16 = True
|
||||
break
|
||||
|
||||
if not at_least_one_fp16:
|
||||
# If none of the candidates in this candidate_list have the explicit fp16 variant label, then this
|
||||
# candidate_list probably doesn't adhere to the variant naming convention that we expected. In this case,
|
||||
# we'll simply keep all the candidates. An example of a model that hits this case is
|
||||
# `black-forest-labs/FLUX.1-schnell` (as of commit 012d2fd).
|
||||
for candidate in candidate_list:
|
||||
result.add(candidate.path)
|
||||
|
||||
# The candidate_list seems to have the expected variant naming convention. We'll select the highest scoring
|
||||
# candidate.
|
||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
||||
if highest_score_candidate:
|
||||
result.add(highest_score_candidate.path)
|
||||
|
||||
125
invokeai/backend/quantization/bnb_llm_int8.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
# This file contains utils for working with models that use bitsandbytes LLM.int8() quantization.
|
||||
# The utils in this file are partially inspired by:
|
||||
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
|
||||
|
||||
|
||||
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
|
||||
# cleaner by re-implementing bnb.nn.Linear8bitLt with proper use of buffers and less magic. But, for now, we try to
|
||||
# stick close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
|
||||
|
||||
|
||||
class InvokeInt8Params(bnb.nn.Int8Params):
|
||||
"""We override cuda() to avoid re-quantizing the weights in the following cases:
|
||||
- We loaded quantized weights from a state_dict on the cpu, and then moved the model to the gpu.
|
||||
- We are moving the model back-and-forth between the cpu and gpu.
|
||||
"""
|
||||
|
||||
def cuda(self, device):
|
||||
if self.has_fp16_weights:
|
||||
return super().cuda(device)
|
||||
elif self.CB is not None and self.SCB is not None:
|
||||
self.data = self.data.cuda()
|
||||
self.CB = self.CB.cuda()
|
||||
self.SCB = self.SCB.cuda()
|
||||
else:
|
||||
# we store the 8-bit rows-major weight
|
||||
# we convert this weight to the turning/ampere weight during the first inference pass
|
||||
B = self.data.contiguous().half().cuda(device)
|
||||
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
|
||||
del CBt
|
||||
del SCBt
|
||||
self.data = CB
|
||||
self.CB = CB
|
||||
self.SCB = SCB
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
prefix: str,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
weight = state_dict.pop(prefix + "weight")
|
||||
bias = state_dict.pop(prefix + "bias", None)
|
||||
|
||||
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
|
||||
scb = state_dict.pop(prefix + "SCB", None)
|
||||
# weight_format is unused, but we pop it so we can validate that there are no unexpected keys.
|
||||
_weight_format = state_dict.pop(prefix + "weight_format", None)
|
||||
|
||||
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
||||
# rather than raising an exception to correctly implement this API.
|
||||
assert len(state_dict) == 0
|
||||
|
||||
if scb is not None:
|
||||
# We are loading a pre-quantized state dict.
|
||||
self.weight = InvokeInt8Params(
|
||||
data=weight,
|
||||
requires_grad=self.weight.requires_grad,
|
||||
has_fp16_weights=False,
|
||||
# Note: After quantization, CB is the same as weight.
|
||||
CB=weight,
|
||||
SCB=scb,
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||
else:
|
||||
# We are loading a non-quantized state dict.
|
||||
|
||||
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
|
||||
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
|
||||
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
|
||||
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
|
||||
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
|
||||
self.weight = InvokeInt8Params(
|
||||
data=weight,
|
||||
requires_grad=self.weight.requires_grad,
|
||||
has_fp16_weights=False,
|
||||
CB=None,
|
||||
SCB=None,
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||
|
||||
|
||||
def _convert_linear_layers_to_llm_8bit(
|
||||
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
|
||||
) -> None:
|
||||
"""Convert all linear layers in the module to bnb.nn.Linear8bitLt layers."""
|
||||
for name, child in module.named_children():
|
||||
fullname = f"{prefix}.{name}" if prefix else name
|
||||
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
||||
has_bias = child.bias is not None
|
||||
replacement = InvokeLinear8bitLt(
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
bias=has_bias,
|
||||
has_fp16_weights=False,
|
||||
threshold=outlier_threshold,
|
||||
)
|
||||
replacement.weight.data = child.weight.data
|
||||
if has_bias:
|
||||
replacement.bias.data = child.bias.data
|
||||
replacement.requires_grad_(False)
|
||||
module.__setattr__(name, replacement)
|
||||
else:
|
||||
_convert_linear_layers_to_llm_8bit(
|
||||
child, ignore_modules, outlier_threshold=outlier_threshold, prefix=fullname
|
||||
)
|
||||
|
||||
|
||||
def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], outlier_threshold: float = 6.0):
|
||||
"""Apply bitsandbytes LLM.8bit() quantization to the model."""
|
||||
_convert_linear_layers_to_llm_8bit(
|
||||
module=model, ignore_modules=modules_to_not_convert, outlier_threshold=outlier_threshold
|
||||
)
|
||||
|
||||
return model
|
||||
156
invokeai/backend/quantization/bnb_nf4.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
# This file contains utils for working with models that use bitsandbytes NF4 quantization.
|
||||
# The utils in this file are partially inspired by:
|
||||
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
|
||||
|
||||
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
|
||||
# cleaner by re-implementing bnb.nn.LinearNF4 with proper use of buffers and less magic. But, for now, we try to stick
|
||||
# close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
|
||||
|
||||
|
||||
class InvokeLinearNF4(bnb.nn.LinearNF4):
|
||||
"""A class that extends `bnb.nn.LinearNF4` to add the following functionality:
|
||||
- Ability to load Linear NF4 layers from a pre-quantized state_dict.
|
||||
- Ability to load Linear NF4 layers from a state_dict when the model is on the "meta" device.
|
||||
"""
|
||||
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
prefix: str,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
"""This method is based on the logic in the bitsandbytes serialization unit tests for `Linear4bit`:
|
||||
https://github.com/bitsandbytes-foundation/bitsandbytes/blob/6d714a5cce3db5bd7f577bc447becc7a92d5ccc7/tests/test_linear4bit.py#L52-L71
|
||||
"""
|
||||
weight = state_dict.pop(prefix + "weight")
|
||||
bias = state_dict.pop(prefix + "bias", None)
|
||||
# We expect the remaining keys to be quant_state keys.
|
||||
quant_state_sd = state_dict
|
||||
|
||||
# During serialization, the quant_state is stored as subkeys of "weight." (See
|
||||
# `bnb.nn.LinearNF4._save_to_state_dict()`). We validate that they at least have the correct prefix.
|
||||
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
||||
# rather than raising an exception to correctly implement this API.
|
||||
assert all(k.startswith(prefix + "weight.") for k in quant_state_sd.keys())
|
||||
|
||||
if len(quant_state_sd) > 0:
|
||||
# We are loading a pre-quantized state dict.
|
||||
self.weight = bnb.nn.Params4bit.from_prequantized(
|
||||
data=weight, quantized_stats=quant_state_sd, device=weight.device
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias, requires_grad=False)
|
||||
else:
|
||||
# We are loading a non-quantized state dict.
|
||||
|
||||
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
|
||||
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
|
||||
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
|
||||
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
|
||||
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
|
||||
self.weight = bnb.nn.Params4bit(
|
||||
data=weight,
|
||||
requires_grad=self.weight.requires_grad,
|
||||
compress_statistics=self.weight.compress_statistics,
|
||||
quant_type=self.weight.quant_type,
|
||||
quant_storage=self.weight.quant_storage,
|
||||
module=self,
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||
|
||||
|
||||
def _replace_param(
|
||||
param: torch.nn.Parameter | bnb.nn.Params4bit,
|
||||
data: torch.Tensor,
|
||||
) -> torch.nn.Parameter:
|
||||
"""A helper function to replace the data of a model parameter with new data in a way that allows replacing params on
|
||||
the "meta" device.
|
||||
|
||||
Supports both `torch.nn.Parameter` and `bnb.nn.Params4bit` parameters.
|
||||
"""
|
||||
if param.device.type == "meta":
|
||||
# Doing `param.data = data` raises a RuntimeError if param.data was on the "meta" device, so we need to
|
||||
# re-create the param instead of overwriting the data.
|
||||
if isinstance(param, bnb.nn.Params4bit):
|
||||
return bnb.nn.Params4bit(
|
||||
data,
|
||||
requires_grad=data.requires_grad,
|
||||
quant_state=param.quant_state,
|
||||
compress_statistics=param.compress_statistics,
|
||||
quant_type=param.quant_type,
|
||||
)
|
||||
return torch.nn.Parameter(data, requires_grad=data.requires_grad)
|
||||
|
||||
param.data = data
|
||||
return param
|
||||
|
||||
|
||||
def _convert_linear_layers_to_nf4(
|
||||
module: torch.nn.Module,
|
||||
ignore_modules: set[str],
|
||||
compute_dtype: torch.dtype,
|
||||
compress_statistics: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
"""Convert all linear layers in the model to NF4 quantized linear layers.
|
||||
|
||||
Args:
|
||||
module: All linear layers in this module will be converted.
|
||||
ignore_modules: A set of module prefixes to ignore when converting linear layers.
|
||||
compute_dtype: The dtype to use for computation in the quantized linear layers.
|
||||
compress_statistics: Whether to enable nested quantization (aka double quantization) where the quantization
|
||||
constants from the first quantization are quantized again.
|
||||
prefix: The prefix of the current module in the model. Used to call this function recursively.
|
||||
"""
|
||||
for name, child in module.named_children():
|
||||
fullname = f"{prefix}.{name}" if prefix else name
|
||||
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
||||
has_bias = child.bias is not None
|
||||
replacement = InvokeLinearNF4(
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
bias=has_bias,
|
||||
compute_dtype=torch.float16,
|
||||
compress_statistics=compress_statistics,
|
||||
)
|
||||
if has_bias:
|
||||
replacement.bias = _replace_param(replacement.bias, child.bias.data)
|
||||
replacement.weight = _replace_param(replacement.weight, child.weight.data)
|
||||
replacement.requires_grad_(False)
|
||||
module.__setattr__(name, replacement)
|
||||
else:
|
||||
_convert_linear_layers_to_nf4(child, ignore_modules, compute_dtype=compute_dtype, prefix=fullname)
|
||||
|
||||
|
||||
def quantize_model_nf4(model: torch.nn.Module, modules_to_not_convert: set[str], compute_dtype: torch.dtype):
|
||||
"""Apply bitsandbytes nf4 quantization to the model.
|
||||
|
||||
You likely want to call this function inside a `accelerate.init_empty_weights()` context.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
# Initialize the model from a config on the meta device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = ModelClass.from_config(...)
|
||||
|
||||
# Add NF4 quantization linear layers to the model - still on the meta device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.float16)
|
||||
|
||||
# Load a state_dict into the model. (Could be either a prequantized or non-quantized state_dict.)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
# Move the model to the "cuda" device. If the model was non-quantized, this is where the weight quantization takes
|
||||
# place.
|
||||
model.to("cuda")
|
||||
```
|
||||
"""
|
||||
_convert_linear_layers_to_nf4(module=model, ignore_modules=modules_to_not_convert, compute_dtype=compute_dtype)
|
||||
|
||||
return model
|
||||
@@ -0,0 +1,77 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from diffusers.models.model_loading_utils import load_state_dict
|
||||
from diffusers.utils import (
|
||||
CONFIG_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
_get_checkpoint_shard_files,
|
||||
is_accelerate_available,
|
||||
)
|
||||
from optimum.quanto.models import QuantizedDiffusersModel
|
||||
from optimum.quanto.models.shared_dict import ShardedStateDict
|
||||
|
||||
from invokeai.backend.requantize import requantize
|
||||
|
||||
|
||||
class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
|
||||
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
|
||||
if cls.base_class is None:
|
||||
raise ValueError("The `base_class` attribute needs to be configured.")
|
||||
|
||||
if not is_accelerate_available():
|
||||
raise ValueError("Reloading a quantized diffusers model requires the accelerate library.")
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
if os.path.isdir(model_name_or_path):
|
||||
# Look for a quantization map
|
||||
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
|
||||
if not os.path.exists(qmap_path):
|
||||
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
|
||||
|
||||
# Look for original model config file.
|
||||
model_config_path = os.path.join(model_name_or_path, CONFIG_NAME)
|
||||
if not os.path.exists(model_config_path):
|
||||
raise ValueError(f"{CONFIG_NAME} not found in {model_name_or_path}.")
|
||||
|
||||
with open(qmap_path, "r", encoding="utf-8") as f:
|
||||
qmap = json.load(f)
|
||||
|
||||
with open(model_config_path, "r", encoding="utf-8") as f:
|
||||
original_model_cls_name = json.load(f)["_class_name"]
|
||||
configured_cls_name = cls.base_class.__name__
|
||||
if configured_cls_name != original_model_cls_name:
|
||||
raise ValueError(
|
||||
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})."
|
||||
)
|
||||
|
||||
# Create an empty model
|
||||
config = cls.base_class.load_config(model_name_or_path)
|
||||
with init_empty_weights():
|
||||
model = cls.base_class.from_config(config)
|
||||
|
||||
# Look for the index of a sharded checkpoint
|
||||
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||
if os.path.exists(checkpoint_file):
|
||||
# Convert the checkpoint path to a list of shards
|
||||
_, sharded_metadata = _get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
|
||||
# Create a mapping for the sharded safetensor files
|
||||
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
|
||||
else:
|
||||
# Look for a single checkpoint file
|
||||
checkpoint_file = os.path.join(model_name_or_path, SAFETENSORS_WEIGHTS_NAME)
|
||||
if not os.path.exists(checkpoint_file):
|
||||
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
|
||||
# Get state_dict from model checkpoint
|
||||
state_dict = load_state_dict(checkpoint_file)
|
||||
|
||||
# Requantize and load quantized weights from state_dict
|
||||
requantize(model, state_dict=state_dict, quantization_map=qmap)
|
||||
model.eval()
|
||||
return cls(model)
|
||||
else:
|
||||
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
|
||||
@@ -0,0 +1,61 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from optimum.quanto.models import QuantizedTransformersModel
|
||||
from optimum.quanto.models.shared_dict import ShardedStateDict
|
||||
from transformers import AutoConfig
|
||||
from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
|
||||
|
||||
from invokeai.backend.requantize import requantize
|
||||
|
||||
|
||||
class FastQuantizedTransformersModel(QuantizedTransformersModel):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
|
||||
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
|
||||
if cls.auto_class is None:
|
||||
raise ValueError(
|
||||
"Quantized models cannot be reloaded using {cls}: use a specialized quantized class such as QuantizedModelForCausalLM instead."
|
||||
)
|
||||
if not is_accelerate_available():
|
||||
raise ValueError("Reloading a quantized transformers model requires the accelerate library.")
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
if os.path.isdir(model_name_or_path):
|
||||
# Look for a quantization map
|
||||
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
|
||||
if not os.path.exists(qmap_path):
|
||||
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
|
||||
with open(qmap_path, "r", encoding="utf-8") as f:
|
||||
qmap = json.load(f)
|
||||
# Create an empty model
|
||||
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||
with init_empty_weights():
|
||||
model = cls.auto_class.from_config(config)
|
||||
# Look for the index of a sharded checkpoint
|
||||
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||
if os.path.exists(checkpoint_file):
|
||||
# Convert the checkpoint path to a list of shards
|
||||
checkpoint_file, sharded_metadata = get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
|
||||
# Create a mapping for the sharded safetensor files
|
||||
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
|
||||
else:
|
||||
# Look for a single checkpoint file
|
||||
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_NAME)
|
||||
if not os.path.exists(checkpoint_file):
|
||||
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
|
||||
# Get state_dict from model checkpoint
|
||||
state_dict = load_state_dict(checkpoint_file)
|
||||
# Requantize and load quantized weights from state_dict
|
||||
requantize(model, state_dict=state_dict, quantization_map=qmap)
|
||||
if getattr(model.config, "tie_word_embeddings", True):
|
||||
# Tie output weight embeddings to input weight embeddings
|
||||
# Note that if they were quantized they would NOT be tied
|
||||
model.tie_weights()
|
||||
# Set model in evaluation mode as it is done in transformers
|
||||
model.eval()
|
||||
return cls(model)
|
||||
else:
|
||||
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
|
||||
@@ -0,0 +1,89 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||
|
||||
|
||||
@contextmanager
|
||||
def log_time(name: str):
|
||||
"""Helper context manager to log the time taken by a block of code."""
|
||||
start = time.time()
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
end = time.time()
|
||||
print(f"'{name}' took {end - start:.4f} secs")
|
||||
|
||||
|
||||
def main():
|
||||
# Load the FLUX transformer model onto the meta device.
|
||||
model_path = Path(
|
||||
"/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/"
|
||||
)
|
||||
|
||||
with log_time("Initialize FLUX transformer on meta device"):
|
||||
model_config = FluxTransformer2DModel.load_config(model_path, local_files_only=True)
|
||||
with accelerate.init_empty_weights():
|
||||
empty_model = FluxTransformer2DModel.from_config(model_config)
|
||||
assert isinstance(empty_model, FluxTransformer2DModel)
|
||||
|
||||
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||
modules_to_not_convert: set[str] = set()
|
||||
|
||||
model_int8_path = model_path / "bnb_llm_int8"
|
||||
if model_int8_path.exists():
|
||||
# The quantized model already exists, load it and return it.
|
||||
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
|
||||
|
||||
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
|
||||
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=modules_to_not_convert)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
sd = load_file(model_int8_path / "model.safetensors")
|
||||
model.load_state_dict(sd, strict=True, assign=True)
|
||||
|
||||
with log_time("Move model to cuda"):
|
||||
model = model.to("cuda")
|
||||
|
||||
print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.")
|
||||
|
||||
else:
|
||||
# The quantized model does not exist, quantize the model and save it.
|
||||
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
|
||||
|
||||
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=modules_to_not_convert)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
# Load sharded state dict.
|
||||
files = list(model_path.glob("*.safetensors"))
|
||||
state_dict = dict()
|
||||
for file in files:
|
||||
sd = load_file(file)
|
||||
state_dict.update(sd)
|
||||
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
with log_time("Move model to cuda and quantize"):
|
||||
model = model.to("cuda")
|
||||
|
||||
with log_time("Save quantized model"):
|
||||
model_int8_path.mkdir(parents=True, exist_ok=True)
|
||||
output_path = model_int8_path / "model.safetensors"
|
||||
save_file(model.state_dict(), output_path)
|
||||
|
||||
print(f"Successfully quantized and saved model to '{output_path}'.")
|
||||
|
||||
assert isinstance(model, FluxTransformer2DModel)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
91
invokeai/backend/quantization/load_flux_model_bnb_nf4.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from flux.model import Flux
|
||||
from flux.util import configs as flux_configs
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
|
||||
|
||||
@contextmanager
|
||||
def log_time(name: str):
|
||||
"""Helper context manager to log the time taken by a block of code."""
|
||||
start = time.time()
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
end = time.time()
|
||||
print(f"'{name}' took {end - start:.4f} secs")
|
||||
|
||||
|
||||
def main():
|
||||
model_path = Path(
|
||||
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
|
||||
)
|
||||
|
||||
# inference_dtype = torch.bfloat16
|
||||
with log_time("Intialize FLUX transformer on meta device"):
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
params = flux_configs["flux-schnell"].params
|
||||
|
||||
# Initialize the model on the "meta" device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params)
|
||||
|
||||
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||
modules_to_not_convert: set[str] = set()
|
||||
|
||||
model_nf4_path = model_path.parent / "bnb_nf4.safetensors"
|
||||
if model_nf4_path.exists():
|
||||
# The quantized model already exists, load it and return it.
|
||||
print(f"A pre-quantized model already exists at '{model_nf4_path}'. Attempting to load it...")
|
||||
|
||||
# Replace the linear layers with NF4 quantized linear layers (still on the meta device).
|
||||
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(
|
||||
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
state_dict = load_file(model_nf4_path)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
with log_time("Move model to cuda"):
|
||||
model = model.to("cuda")
|
||||
|
||||
print(f"Successfully loaded pre-quantized model from '{model_nf4_path}'.")
|
||||
|
||||
else:
|
||||
# The quantized model does not exist, quantize the model and save it.
|
||||
print(f"No pre-quantized model found at '{model_nf4_path}'. Quantizing the model...")
|
||||
|
||||
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(
|
||||
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
state_dict = load_file(model_path)
|
||||
# TODO(ryand): Cast the state_dict to the appropriate dtype?
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
with log_time("Move model to cuda and quantize"):
|
||||
model = model.to("cuda")
|
||||
|
||||
with log_time("Save quantized model"):
|
||||
model_nf4_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
save_file(model.state_dict(), model_nf4_path)
|
||||
|
||||
print(f"Successfully quantized and saved model to '{model_nf4_path}'.")
|
||||
|
||||
assert isinstance(model, Flux)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
53
invokeai/backend/requantize.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from optimum.quanto.quantize import _quantize_submodule
|
||||
|
||||
# def custom_freeze(model: torch.nn.Module):
|
||||
# for name, m in model.named_modules():
|
||||
# if isinstance(m, QModuleMixin):
|
||||
# m.weight =
|
||||
# m.freeze()
|
||||
|
||||
|
||||
def requantize(
|
||||
model: torch.nn.Module,
|
||||
state_dict: Dict[str, Any],
|
||||
quantization_map: Dict[str, Dict[str, str]],
|
||||
device: torch.device = None,
|
||||
):
|
||||
if device is None:
|
||||
device = next(model.parameters()).device
|
||||
if device.type == "meta":
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Quantize the model with parameters from the quantization map
|
||||
for name, m in model.named_modules():
|
||||
qconfig = quantization_map.get(name, None)
|
||||
if qconfig is not None:
|
||||
weights = qconfig["weights"]
|
||||
if weights == "none":
|
||||
weights = None
|
||||
activations = qconfig["activations"]
|
||||
if activations == "none":
|
||||
activations = None
|
||||
_quantize_submodule(model, name, m, weights=weights, activations=activations)
|
||||
|
||||
# Move model parameters and buffers to CPU before materializing quantized weights
|
||||
for name, m in model.named_modules():
|
||||
|
||||
def move_tensor(t, device):
|
||||
if t.device.type == "meta":
|
||||
return torch.empty_like(t, device=device)
|
||||
return t.to(device)
|
||||
|
||||
for name, param in m.named_parameters(recurse=False):
|
||||
setattr(m, name, torch.nn.Parameter(move_tensor(param, "cpu")))
|
||||
for name, param in m.named_buffers(recurse=False):
|
||||
setattr(m, name, move_tensor(param, "cpu"))
|
||||
# Freeze model and move to target device
|
||||
# freeze(model)
|
||||
# model.to(device)
|
||||
|
||||
# Load the quantized model weights
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
@@ -25,11 +25,6 @@ class BasicConditioningInfo:
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[BasicConditioningInfo]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
"""SDXL text conditioning information produced by Compel."""
|
||||
@@ -43,6 +38,17 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
return super().to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FLUXConditioningInfo:
|
||||
clip_embeds: torch.Tensor
|
||||
t5_embeds: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPAdapterConditioningInfo:
|
||||
cond_image_prompt_embeds: torch.Tensor
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import diffusers
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import FromOriginalControlNetMixin
|
||||
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
||||
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
||||
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
||||
from diffusers.models.embeddings import (
|
||||
@@ -32,7 +32,7 @@ from invokeai.backend.util.logging import InvokeAILogger
|
||||
logger = InvokeAILogger.get_logger(__name__)
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
|
||||
@@ -53,63 +53,64 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"@chakra-ui/react-use-size": "^2.1.0",
|
||||
"@dagrejs/dagre": "^1.1.3",
|
||||
"@dagrejs/graphlib": "^2.2.3",
|
||||
"@dagrejs/dagre": "^1.1.2",
|
||||
"@dagrejs/graphlib": "^2.2.2",
|
||||
"@dnd-kit/core": "^6.1.0",
|
||||
"@dnd-kit/sortable": "^8.0.0",
|
||||
"@dnd-kit/utilities": "^3.2.2",
|
||||
"@fontsource-variable/inter": "^5.0.20",
|
||||
"@invoke-ai/ui-library": "^0.0.29",
|
||||
"@nanostores/react": "^0.7.3",
|
||||
"@fontsource-variable/inter": "^5.0.18",
|
||||
"@invoke-ai/ui-library": "^0.0.25",
|
||||
"@nanostores/react": "^0.7.2",
|
||||
"@reduxjs/toolkit": "2.2.3",
|
||||
"@roarr/browser-log-writer": "^1.3.0",
|
||||
"chakra-react-select": "^4.9.1",
|
||||
"compare-versions": "^6.1.1",
|
||||
"chakra-react-select": "^4.7.6",
|
||||
"compare-versions": "^6.1.0",
|
||||
"dateformat": "^5.0.3",
|
||||
"fracturedjsonjs": "^4.0.2",
|
||||
"framer-motion": "^11.3.24",
|
||||
"i18next": "^23.12.2",
|
||||
"i18next-http-backend": "^2.5.2",
|
||||
"fracturedjsonjs": "^4.0.1",
|
||||
"framer-motion": "^11.1.8",
|
||||
"i18next": "^23.11.3",
|
||||
"i18next-http-backend": "^2.5.1",
|
||||
"idb-keyval": "^6.2.1",
|
||||
"jsondiffpatch": "^0.6.0",
|
||||
"konva": "^9.3.14",
|
||||
"konva": "^9.3.6",
|
||||
"lodash-es": "^4.17.21",
|
||||
"nanostores": "^0.11.2",
|
||||
"nanostores": "^0.10.3",
|
||||
"new-github-issue-url": "^1.0.0",
|
||||
"overlayscrollbars": "^2.10.0",
|
||||
"overlayscrollbars": "^2.7.3",
|
||||
"overlayscrollbars-react": "^0.5.6",
|
||||
"query-string": "^9.1.0",
|
||||
"query-string": "^9.0.0",
|
||||
"react": "^18.3.1",
|
||||
"react-colorful": "^5.6.1",
|
||||
"react-dom": "^18.3.1",
|
||||
"react-dropzone": "^14.2.3",
|
||||
"react-error-boundary": "^4.0.13",
|
||||
"react-hook-form": "^7.52.2",
|
||||
"react-hook-form": "^7.51.4",
|
||||
"react-hotkeys-hook": "4.5.0",
|
||||
"react-i18next": "^14.1.3",
|
||||
"react-icons": "^5.2.1",
|
||||
"react-i18next": "^14.1.1",
|
||||
"react-icons": "^5.2.0",
|
||||
"react-konva": "^18.2.10",
|
||||
"react-redux": "9.1.2",
|
||||
"react-resizable-panels": "^2.0.23",
|
||||
"react-resizable-panels": "^2.0.19",
|
||||
"react-select": "5.8.0",
|
||||
"react-use": "^17.5.1",
|
||||
"react-virtuoso": "^4.9.0",
|
||||
"reactflow": "^11.11.4",
|
||||
"react-use": "^17.5.0",
|
||||
"react-virtuoso": "^4.7.10",
|
||||
"reactflow": "^11.11.3",
|
||||
"redux-dynamic-middlewares": "^2.2.0",
|
||||
"redux-remember": "^5.1.0",
|
||||
"redux-undo": "^1.1.0",
|
||||
"rfdc": "^1.4.1",
|
||||
"rfdc": "^1.3.1",
|
||||
"roarr": "^7.21.1",
|
||||
"serialize-error": "^11.0.3",
|
||||
"socket.io-client": "^4.7.5",
|
||||
"use-debounce": "^10.0.2",
|
||||
"use-debounce": "^10.0.0",
|
||||
"use-device-pixel-ratio": "^1.1.2",
|
||||
"use-image": "^1.1.1",
|
||||
"uuid": "^10.0.0",
|
||||
"zod": "^3.23.8",
|
||||
"zod-validation-error": "^3.3.1"
|
||||
"uuid": "^9.0.1",
|
||||
"zod": "^3.23.6",
|
||||
"zod-validation-error": "^3.2.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@chakra-ui/react": "^2.8.2",
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0",
|
||||
"ts-toolbelt": "^9.6.0"
|
||||
@@ -117,38 +118,38 @@
|
||||
"devDependencies": {
|
||||
"@invoke-ai/eslint-config-react": "^0.0.14",
|
||||
"@invoke-ai/prettier-config-react": "^0.0.7",
|
||||
"@storybook/addon-essentials": "^8.2.8",
|
||||
"@storybook/addon-interactions": "^8.2.8",
|
||||
"@storybook/addon-links": "^8.2.8",
|
||||
"@storybook/addon-storysource": "^8.2.8",
|
||||
"@storybook/manager-api": "^8.2.8",
|
||||
"@storybook/react": "^8.2.8",
|
||||
"@storybook/react-vite": "^8.2.8",
|
||||
"@storybook/theming": "^8.2.8",
|
||||
"@storybook/addon-essentials": "^8.0.10",
|
||||
"@storybook/addon-interactions": "^8.0.10",
|
||||
"@storybook/addon-links": "^8.0.10",
|
||||
"@storybook/addon-storysource": "^8.0.10",
|
||||
"@storybook/manager-api": "^8.0.10",
|
||||
"@storybook/react": "^8.0.10",
|
||||
"@storybook/react-vite": "^8.0.10",
|
||||
"@storybook/theming": "^8.0.10",
|
||||
"@types/dateformat": "^5.0.2",
|
||||
"@types/lodash-es": "^4.17.12",
|
||||
"@types/node": "^20.14.15",
|
||||
"@types/react": "^18.3.3",
|
||||
"@types/node": "^20.12.10",
|
||||
"@types/react": "^18.3.1",
|
||||
"@types/react-dom": "^18.3.0",
|
||||
"@types/uuid": "^10.0.0",
|
||||
"@vitejs/plugin-react-swc": "^3.7.0",
|
||||
"@types/uuid": "^9.0.8",
|
||||
"@vitejs/plugin-react-swc": "^3.6.0",
|
||||
"@vitest/coverage-v8": "^1.5.0",
|
||||
"@vitest/ui": "^1.5.0",
|
||||
"concurrently": "^8.2.2",
|
||||
"dpdm": "^3.14.0",
|
||||
"eslint": "^8.57.0",
|
||||
"eslint-plugin-i18next": "^6.0.9",
|
||||
"eslint-plugin-i18next": "^6.0.3",
|
||||
"eslint-plugin-path": "^1.3.0",
|
||||
"knip": "^5.27.2",
|
||||
"knip": "^5.12.3",
|
||||
"openapi-types": "^12.1.3",
|
||||
"openapi-typescript": "^7.3.0",
|
||||
"prettier": "^3.3.3",
|
||||
"openapi-typescript": "^6.7.5",
|
||||
"prettier": "^3.2.5",
|
||||
"rollup-plugin-visualizer": "^5.12.0",
|
||||
"storybook": "^8.2.8",
|
||||
"storybook": "^8.0.10",
|
||||
"ts-toolbelt": "^9.6.0",
|
||||
"tsafe": "^1.7.2",
|
||||
"typescript": "^5.5.4",
|
||||
"vite": "^5.4.0",
|
||||
"tsafe": "^1.6.6",
|
||||
"typescript": "^5.4.5",
|
||||
"vite": "^5.2.11",
|
||||
"vite-plugin-css-injected-by-js": "^3.5.1",
|
||||
"vite-plugin-dts": "^3.9.1",
|
||||
"vite-plugin-eslint": "^1.8.1",
|
||||
|
||||
5816
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -200,7 +200,6 @@
|
||||
"delete": "Delete",
|
||||
"depthAnything": "Depth Anything",
|
||||
"depthAnythingDescription": "Depth map generation using the Depth Anything technique",
|
||||
"depthAnythingSmallV2": "Small V2",
|
||||
"depthMidas": "Depth (Midas)",
|
||||
"depthMidasDescription": "Depth map generation using Midas",
|
||||
"depthZoe": "Depth (Zoe)",
|
||||
@@ -1141,8 +1140,6 @@
|
||||
"imageSavingFailed": "Image Saving Failed",
|
||||
"imageUploaded": "Image Uploaded",
|
||||
"imageUploadFailed": "Image Upload Failed",
|
||||
"importFailed": "Import Failed",
|
||||
"importSuccessful": "Import Successful",
|
||||
"invalidUpload": "Invalid Upload",
|
||||
"loadedWithWarnings": "Workflow Loaded with Warnings",
|
||||
"maskSavedAssets": "Mask Saved to Assets",
|
||||
@@ -1675,7 +1672,6 @@
|
||||
"layers_other": "Layers"
|
||||
},
|
||||
"upscaling": {
|
||||
"upscale": "Upscale",
|
||||
"creativity": "Creativity",
|
||||
"exceedsMaxSize": "Upscale settings exceed max size limit",
|
||||
"exceedsMaxSizeDetails": "Max upscale limit is {{maxUpscaleDimension}}x{{maxUpscaleDimension}} pixels. Please try a smaller image or decrease your scale selection.",
|
||||
@@ -1692,53 +1688,6 @@
|
||||
"missingUpscaleModel": "Missing upscale model",
|
||||
"missingTileControlNetModel": "No valid tile ControlNet models installed"
|
||||
},
|
||||
"stylePresets": {
|
||||
"active": "Active",
|
||||
"choosePromptTemplate": "Choose Prompt Template",
|
||||
"clearTemplateSelection": "Clear Template Selection",
|
||||
"copyTemplate": "Copy Template",
|
||||
"createPromptTemplate": "Create Prompt Template",
|
||||
"defaultTemplates": "Default Templates",
|
||||
"deleteImage": "Delete Image",
|
||||
"deleteTemplate": "Delete Template",
|
||||
"deleteTemplate2": "Are you sure you want to delete this template? This cannot be undone.",
|
||||
"exportPromptTemplates": "Export My Prompt Templates (CSV)",
|
||||
"editTemplate": "Edit Template",
|
||||
"exportDownloaded": "Export Downloaded",
|
||||
"exportFailed": "Unable to generate and download CSV",
|
||||
"flatten": "Flatten selected template into current prompt",
|
||||
"importTemplates": "Import Prompt Templates (CSV/JSON)",
|
||||
"acceptedColumnsKeys": "Accepted columns/keys:",
|
||||
"nameColumn": "'name'",
|
||||
"positivePromptColumn": "'prompt' or 'positive_prompt'",
|
||||
"negativePromptColumn": "'negative_prompt'",
|
||||
"insertPlaceholder": "Insert placeholder",
|
||||
"myTemplates": "My Templates",
|
||||
"name": "Name",
|
||||
"negativePrompt": "Negative Prompt",
|
||||
"noTemplates": "No templates",
|
||||
"noMatchingTemplates": "No matching templates",
|
||||
"promptTemplatesDesc1": "Prompt templates add text to the prompts you write in the prompt box.",
|
||||
"promptTemplatesDesc2": "Use the placeholder string <Pre>{{placeholder}}</Pre> to specify where your prompt should be included in the template.",
|
||||
"promptTemplatesDesc3": "If you omit the placeholder, the template will be appended to the end of your prompt.",
|
||||
"positivePrompt": "Positive Prompt",
|
||||
"preview": "Preview",
|
||||
"private": "Private",
|
||||
"promptTemplateCleared": "Prompt Template Cleared",
|
||||
"searchByName": "Search by name",
|
||||
"shared": "Shared",
|
||||
"sharedTemplates": "Shared Templates",
|
||||
"templateActions": "Template Actions",
|
||||
"templateDeleted": "Prompt template deleted",
|
||||
"toggleViewMode": "Toggle View Mode",
|
||||
"type": "Type",
|
||||
"unableToDeleteTemplate": "Unable to delete prompt template",
|
||||
"updatePromptTemplate": "Update Prompt Template",
|
||||
"uploadImage": "Upload Image",
|
||||
"useForTemplate": "Use For Prompt Template",
|
||||
"viewList": "View Template List",
|
||||
"viewModeTooltip": "This is how your prompt will look with your currently selected template. To edit your prompt, click anywhere in the text box."
|
||||
},
|
||||
"upsell": {
|
||||
"inviteTeammates": "Invite Teammates",
|
||||
"professional": "Professional",
|
||||
|
||||
@@ -90,7 +90,7 @@
|
||||
"disabled": "Disabilitato",
|
||||
"comparingDesc": "Confronta due immagini",
|
||||
"comparing": "Confronta",
|
||||
"dontShowMeThese": "Non mostrare più"
|
||||
"dontShowMeThese": "Non mostrarmi questi"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Dimensione dell'immagine",
|
||||
@@ -701,9 +701,7 @@
|
||||
"baseModelChanged": "Modello base modificato",
|
||||
"sessionRef": "Sessione: {{sessionId}}",
|
||||
"somethingWentWrong": "Qualcosa è andato storto",
|
||||
"outOfMemoryErrorDesc": "Le impostazioni della generazione attuale superano la capacità del sistema. Modifica le impostazioni e riprova.",
|
||||
"importFailed": "Importazione non riuscita",
|
||||
"importSuccessful": "Importazione riuscita"
|
||||
"outOfMemoryErrorDesc": "Le impostazioni della generazione attuale superano la capacità del sistema. Modifica le impostazioni e riprova."
|
||||
},
|
||||
"tooltip": {
|
||||
"feature": {
|
||||
@@ -929,7 +927,7 @@
|
||||
"missingInvocationTemplate": "Modello di invocazione mancante",
|
||||
"missingFieldTemplate": "Modello di campo mancante",
|
||||
"singleFieldType": "{{name}} (Singola)",
|
||||
"imageAccessError": "Impossibile trovare l'immagine {{image_name}}, ripristino ai valori predefiniti",
|
||||
"imageAccessError": "Impossibile trovare l'immagine {{image_name}}, ripristino delle impostazioni predefinite",
|
||||
"boardAccessError": "Impossibile trovare la bacheca {{board_id}}, ripristino ai valori predefiniti",
|
||||
"modelAccessError": "Impossibile trovare il modello {{key}}, ripristino ai valori predefiniti"
|
||||
},
|
||||
@@ -1528,7 +1526,7 @@
|
||||
},
|
||||
"upscaleModel": {
|
||||
"paragraphs": [
|
||||
"Il modello di ampliamento (Upscale), scala l'immagine alle dimensioni di uscita prima di aggiungere i dettagli. È possibile utilizzare qualsiasi modello di ampliamento supportato, ma alcuni sono specializzati per diversi tipi di immagini, come foto o disegni al tratto."
|
||||
"Il modello di ampliamento ridimensiona l'immagine alle dimensioni di uscita prima che vengano aggiunti i dettagli. È possibile utilizzare qualsiasi modello di ampliamento supportato, ma alcuni sono specializzati per diversi tipi di immagini, come foto o disegni al tratto."
|
||||
],
|
||||
"heading": "Modello di ampliamento"
|
||||
},
|
||||
@@ -1737,58 +1735,12 @@
|
||||
"missingUpscaleModel": "Modello per l’ampliamento mancante",
|
||||
"missingTileControlNetModel": "Nessun modello ControlNet Tile valido installato",
|
||||
"postProcessingModel": "Modello di post-elaborazione",
|
||||
"postProcessingMissingModelWarning": "Visita <LinkComponent>Gestione modelli</LinkComponent> per installare un modello di post-elaborazione (da immagine a immagine).",
|
||||
"exceedsMaxSize": "Le impostazioni di ampliamento superano il limite massimo delle dimensioni",
|
||||
"exceedsMaxSizeDetails": "Il limite massimo di ampliamento è {{maxUpscaleDimension}}x{{maxUpscaleDimension}} pixel. Prova un'immagine più piccola o diminuisci la scala selezionata."
|
||||
"postProcessingMissingModelWarning": "Visita <LinkComponent>Gestione modelli</LinkComponent> per installare un modello di post-elaborazione (da immagine a immagine)."
|
||||
},
|
||||
"upsell": {
|
||||
"inviteTeammates": "Invita collaboratori",
|
||||
"shareAccess": "Condividi l'accesso",
|
||||
"professional": "Professionale",
|
||||
"professionalUpsell": "Disponibile nell'edizione Professional di Invoke. Fai clic qui o visita invoke.com/pricing per ulteriori dettagli."
|
||||
},
|
||||
"stylePresets": {
|
||||
"active": "Attivo",
|
||||
"choosePromptTemplate": "Scegli un modello di prompt",
|
||||
"clearTemplateSelection": "Cancella selezione modello",
|
||||
"copyTemplate": "Copia modello",
|
||||
"createPromptTemplate": "Crea modello di prompt",
|
||||
"defaultTemplates": "Modelli predefiniti",
|
||||
"deleteImage": "Elimina immagine",
|
||||
"deleteTemplate": "Elimina modello",
|
||||
"editTemplate": "Modifica modello",
|
||||
"flatten": "Unisci il modello selezionato al prompt corrente",
|
||||
"insertPlaceholder": "Inserisci segnaposto",
|
||||
"myTemplates": "I miei modelli",
|
||||
"name": "Nome",
|
||||
"negativePrompt": "Prompt Negativo",
|
||||
"noMatchingTemplates": "Nessun modello corrispondente",
|
||||
"promptTemplatesDesc1": "I modelli di prompt aggiungono testo ai prompt che scrivi nelle caselle dei prompt.",
|
||||
"promptTemplatesDesc3": "Se si omette il segnaposto, il modello verrà aggiunto alla fine del prompt.",
|
||||
"positivePrompt": "Prompt Positivo",
|
||||
"preview": "Anteprima",
|
||||
"private": "Privato",
|
||||
"searchByName": "Cerca per nome",
|
||||
"shared": "Condiviso",
|
||||
"sharedTemplates": "Modelli condivisi",
|
||||
"templateDeleted": "Modello di prompt eliminato",
|
||||
"toggleViewMode": "Attiva/disattiva visualizzazione",
|
||||
"uploadImage": "Carica immagine",
|
||||
"useForTemplate": "Usa per modello di prompt",
|
||||
"viewList": "Visualizza l'elenco dei modelli",
|
||||
"viewModeTooltip": "Ecco come apparirà il tuo prompt con il modello attualmente selezionato. Per modificare il tuo prompt, clicca in un punto qualsiasi della casella di testo.",
|
||||
"deleteTemplate2": "Vuoi davvero eliminare questo modello? Questa operazione non può essere annullata.",
|
||||
"unableToDeleteTemplate": "Impossibile eliminare il modello di prompt",
|
||||
"updatePromptTemplate": "Aggiorna il modello di prompt",
|
||||
"type": "Tipo",
|
||||
"promptTemplatesDesc2": "Utilizza la stringa segnaposto <Pre>{{placeholder}}</Pre> per specificare dove inserire il tuo prompt nel modello.",
|
||||
"importTemplates": "Importa modelli di prompt (CSV/JSON)",
|
||||
"exportDownloaded": "Esportazione completata",
|
||||
"exportFailed": "Impossibile generare e scaricare il file CSV",
|
||||
"exportPromptTemplates": "Esporta i miei modelli di prompt (CSV)",
|
||||
"positivePromptColumn": "'prompt' o 'positive_prompt'",
|
||||
"noTemplates": "Nessun modello",
|
||||
"acceptedColumnsKeys": "Colonne/chiavi accettate:",
|
||||
"templateActions": "Azioni modello"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,8 +91,7 @@
|
||||
"enabled": "Включено",
|
||||
"disabled": "Отключено",
|
||||
"comparingDesc": "Сравнение двух изображений",
|
||||
"comparing": "Сравнение",
|
||||
"dontShowMeThese": "Не показывай мне это"
|
||||
"comparing": "Сравнение"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Размер изображений",
|
||||
@@ -154,11 +153,7 @@
|
||||
"showArchivedBoards": "Показать архивированные доски",
|
||||
"searchImages": "Поиск по метаданным",
|
||||
"displayBoardSearch": "Отобразить поиск досок",
|
||||
"displaySearch": "Отобразить поиск",
|
||||
"exitBoardSearch": "Выйти из поиска досок",
|
||||
"go": "Перейти",
|
||||
"exitSearch": "Выйти из поиска",
|
||||
"jump": "Пыгнуть"
|
||||
"displaySearch": "Отобразить поиск"
|
||||
},
|
||||
"hotkeys": {
|
||||
"keyboardShortcuts": "Горячие клавиши",
|
||||
@@ -381,10 +376,6 @@
|
||||
"toggleViewer": {
|
||||
"title": "Переключить просмотр изображений",
|
||||
"desc": "Переключение между средством просмотра изображений и рабочей областью для текущей вкладки."
|
||||
},
|
||||
"postProcess": {
|
||||
"desc": "Обработайте текущее изображение с помощью выбранной модели постобработки",
|
||||
"title": "Обработать изображение"
|
||||
}
|
||||
},
|
||||
"modelManager": {
|
||||
@@ -598,10 +589,7 @@
|
||||
"infillColorValue": "Цвет заливки",
|
||||
"globalSettings": "Глобальные настройки",
|
||||
"globalNegativePromptPlaceholder": "Глобальный негативный запрос",
|
||||
"globalPositivePromptPlaceholder": "Глобальный запрос",
|
||||
"postProcessing": "Постобработка (Shift + U)",
|
||||
"processImage": "Обработка изображения",
|
||||
"sendToUpscale": "Отправить на увеличение"
|
||||
"globalPositivePromptPlaceholder": "Глобальный запрос"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Модели",
|
||||
@@ -635,9 +623,7 @@
|
||||
"intermediatesCleared_many": "Очищено {{count}} промежуточных",
|
||||
"clearIntermediatesDesc1": "Очистка промежуточных элементов приведет к сбросу состояния Canvas и ControlNet.",
|
||||
"intermediatesClearedFailed": "Проблема очистки промежуточных",
|
||||
"reloadingIn": "Перезагрузка через",
|
||||
"informationalPopoversDisabled": "Информационные всплывающие окна отключены",
|
||||
"informationalPopoversDisabledDesc": "Информационные всплывающие окна были отключены. Включите их в Настройках."
|
||||
"reloadingIn": "Перезагрузка через"
|
||||
},
|
||||
"toast": {
|
||||
"uploadFailed": "Загрузка не удалась",
|
||||
@@ -708,9 +694,7 @@
|
||||
"sessionRef": "Сессия: {{sessionId}}",
|
||||
"outOfMemoryError": "Ошибка нехватки памяти",
|
||||
"outOfMemoryErrorDesc": "Ваши текущие настройки генерации превышают возможности системы. Пожалуйста, измените настройки и повторите попытку.",
|
||||
"somethingWentWrong": "Что-то пошло не так",
|
||||
"importFailed": "Импорт неудачен",
|
||||
"importSuccessful": "Импорт успешен"
|
||||
"somethingWentWrong": "Что-то пошло не так"
|
||||
},
|
||||
"tooltip": {
|
||||
"feature": {
|
||||
@@ -1033,8 +1017,7 @@
|
||||
"composition": "Только композиция",
|
||||
"hed": "HED",
|
||||
"beginEndStepPercentShort": "Начало/конец %",
|
||||
"setControlImageDimensionsForce": "Скопируйте размер в Ш/В (игнорируйте модель)",
|
||||
"depthAnythingSmallV2": "Small V2"
|
||||
"setControlImageDimensionsForce": "Скопируйте размер в Ш/В (игнорируйте модель)"
|
||||
},
|
||||
"boards": {
|
||||
"autoAddBoard": "Авто добавление Доски",
|
||||
@@ -1059,7 +1042,7 @@
|
||||
"downloadBoard": "Скачать доску",
|
||||
"deleteBoard": "Удалить доску",
|
||||
"deleteBoardAndImages": "Удалить доску и изображения",
|
||||
"deletedBoardsCannotbeRestored": "Удаленные доски не могут быть восстановлены. Выбор «Удалить только доску» переведет изображения в состояние без категории.",
|
||||
"deletedBoardsCannotbeRestored": "Удаленные доски не подлежат восстановлению",
|
||||
"assetsWithCount_one": "{{count}} ассет",
|
||||
"assetsWithCount_few": "{{count}} ассета",
|
||||
"assetsWithCount_many": "{{count}} ассетов",
|
||||
@@ -1074,11 +1057,7 @@
|
||||
"boards": "Доски",
|
||||
"addPrivateBoard": "Добавить личную доску",
|
||||
"private": "Личные доски",
|
||||
"shared": "Общие доски",
|
||||
"hideBoards": "Скрыть доски",
|
||||
"viewBoards": "Просмотреть доски",
|
||||
"noBoards": "Нет досок {{boardType}}",
|
||||
"deletedPrivateBoardsCannotbeRestored": "Удаленные доски не могут быть восстановлены. Выбор «Удалить только доску» переведет изображения в приватное состояние без категории для создателя изображения."
|
||||
"shared": "Общие доски"
|
||||
},
|
||||
"dynamicPrompts": {
|
||||
"seedBehaviour": {
|
||||
@@ -1438,30 +1417,6 @@
|
||||
"paragraphs": [
|
||||
"Метод, с помощью которого применяется текущий IP-адаптер."
|
||||
]
|
||||
},
|
||||
"structure": {
|
||||
"paragraphs": [
|
||||
"Структура контролирует, насколько точно выходное изображение будет соответствовать макету оригинала. Низкая структура допускает значительные изменения, в то время как высокая структура строго сохраняет исходную композицию и макет."
|
||||
],
|
||||
"heading": "Структура"
|
||||
},
|
||||
"scale": {
|
||||
"paragraphs": [
|
||||
"Масштаб управляет размером выходного изображения и основывается на кратном разрешении входного изображения. Например, при увеличении в 2 раза изображения 1024x1024 на выходе получится 2048 x 2048."
|
||||
],
|
||||
"heading": "Масштаб"
|
||||
},
|
||||
"creativity": {
|
||||
"paragraphs": [
|
||||
"Креативность контролирует степень свободы, предоставляемой модели при добавлении деталей. При низкой креативности модель остается близкой к оригинальному изображению, в то время как высокая креативность позволяет вносить больше изменений. При использовании подсказки высокая креативность увеличивает влияние подсказки."
|
||||
],
|
||||
"heading": "Креативность"
|
||||
},
|
||||
"upscaleModel": {
|
||||
"heading": "Модель увеличения",
|
||||
"paragraphs": [
|
||||
"Модель увеличения масштаба масштабирует изображение до выходного размера перед добавлением деталей. Можно использовать любую поддерживаемую модель масштабирования, но некоторые из них специализированы для различных видов изображений, например фотографий или линейных рисунков."
|
||||
]
|
||||
}
|
||||
},
|
||||
"metadata": {
|
||||
@@ -1738,78 +1693,7 @@
|
||||
"canvasTab": "$t(ui.tabs.canvas) $t(common.tab)",
|
||||
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
|
||||
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
|
||||
"queue": "Очередь",
|
||||
"upscaling": "Увеличение",
|
||||
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
|
||||
"queue": "Очередь"
|
||||
}
|
||||
},
|
||||
"upscaling": {
|
||||
"exceedsMaxSize": "Параметры масштабирования превышают максимальный размер",
|
||||
"exceedsMaxSizeDetails": "Максимальный предел масштабирования составляет {{maxUpscaleDimension}}x{{maxUpscaleDimension}} пикселей. Пожалуйста, попробуйте использовать меньшее изображение или уменьшите масштаб.",
|
||||
"structure": "Структура",
|
||||
"missingTileControlNetModel": "Не установлены подходящие модели ControlNet",
|
||||
"missingUpscaleInitialImage": "Отсутствует увеличиваемое изображение",
|
||||
"missingUpscaleModel": "Отсутствует увеличивающая модель",
|
||||
"creativity": "Креативность",
|
||||
"upscaleModel": "Модель увеличения",
|
||||
"scale": "Масштаб",
|
||||
"mainModelDesc": "Основная модель (архитектура SD1.5 или SDXL)",
|
||||
"upscaleModelDesc": "Модель увеличения (img2img)",
|
||||
"postProcessingModel": "Модель постобработки",
|
||||
"tileControlNetModelDesc": "Модель ControlNet для выбранной архитектуры основной модели",
|
||||
"missingModelsWarning": "Зайдите в <LinkComponent>Менеджер моделей</LinkComponent> чтоб установить необходимые модели:",
|
||||
"postProcessingMissingModelWarning": "Посетите <LinkComponent>Менеджер моделей</LinkComponent>, чтобы установить модель постобработки (img2img)."
|
||||
},
|
||||
"stylePresets": {
|
||||
"noMatchingTemplates": "Нет подходящих шаблонов",
|
||||
"promptTemplatesDesc1": "Шаблоны подсказок добавляют текст к подсказкам, которые вы пишете в окне подсказок.",
|
||||
"sharedTemplates": "Общие шаблоны",
|
||||
"templateDeleted": "Шаблон запроса удален",
|
||||
"toggleViewMode": "Переключить режим просмотра",
|
||||
"type": "Тип",
|
||||
"unableToDeleteTemplate": "Не получилось удалить шаблон запроса",
|
||||
"viewModeTooltip": "Вот как будет выглядеть ваш запрос с выбранным шаблоном. Чтобы его отредактировать, щелкните в любом месте текстового поля.",
|
||||
"viewList": "Просмотреть список шаблонов",
|
||||
"active": "Активно",
|
||||
"choosePromptTemplate": "Выберите шаблон запроса",
|
||||
"defaultTemplates": "Стандартные шаблоны",
|
||||
"deleteImage": "Удалить изображение",
|
||||
"deleteTemplate": "Удалить шаблон",
|
||||
"deleteTemplate2": "Вы уверены, что хотите удалить этот шаблон? Это нельзя отменить.",
|
||||
"editTemplate": "Редактировать шаблон",
|
||||
"exportPromptTemplates": "Экспорт моих шаблонов запроса (CSV)",
|
||||
"exportDownloaded": "Экспорт скачан",
|
||||
"exportFailed": "Невозможно сгенерировать и загрузить CSV",
|
||||
"flatten": "Объединить выбранный шаблон с текущим запросом",
|
||||
"acceptedColumnsKeys": "Принимаемые столбцы/ключи:",
|
||||
"positivePromptColumn": "'prompt' или 'positive_prompt'",
|
||||
"insertPlaceholder": "Вставить заполнитель",
|
||||
"name": "Имя",
|
||||
"negativePrompt": "Негативный запрос",
|
||||
"promptTemplatesDesc3": "Если вы не используете заполнитель, шаблон будет добавлен в конец запроса.",
|
||||
"positivePrompt": "Позитивный запрос",
|
||||
"preview": "Предпросмотр",
|
||||
"private": "Приватный",
|
||||
"templateActions": "Действия с шаблоном",
|
||||
"updatePromptTemplate": "Обновить шаблон запроса",
|
||||
"uploadImage": "Загрузить изображение",
|
||||
"useForTemplate": "Использовать для шаблона запроса",
|
||||
"clearTemplateSelection": "Очистить выбор шаблона",
|
||||
"copyTemplate": "Копировать шаблон",
|
||||
"createPromptTemplate": "Создать шаблон запроса",
|
||||
"importTemplates": "Импортировать шаблоны запроса (CSV/JSON)",
|
||||
"nameColumn": "'name'",
|
||||
"negativePromptColumn": "'negative_prompt'",
|
||||
"myTemplates": "Мои шаблоны",
|
||||
"noTemplates": "Нет шаблонов",
|
||||
"promptTemplatesDesc2": "Используйте строку-заполнитель <Pre>{{placeholder}}</Pre>, чтобы указать место, куда должен быть включен ваш запрос в шаблоне.",
|
||||
"searchByName": "Поиск по имени",
|
||||
"shared": "Общий"
|
||||
},
|
||||
"upsell": {
|
||||
"inviteTeammates": "Пригласите членов команды",
|
||||
"professional": "Профессионал",
|
||||
"professionalUpsell": "Доступно в профессиональной версии Invoke. Нажмите здесь или посетите invoke.com/pricing для получения более подробной информации.",
|
||||
"shareAccess": "Поделиться доступом"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -493,8 +493,7 @@
|
||||
"defaultSettingsSaved": "默认设置已保存",
|
||||
"huggingFacePlaceholder": "所有者或模型名称",
|
||||
"huggingFaceRepoID": "HuggingFace仓库ID",
|
||||
"loraTriggerPhrases": "LoRA 触发词",
|
||||
"ipAdapters": "IP适配器"
|
||||
"loraTriggerPhrases": "LoRA 触发词"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "图像",
|
||||
@@ -1703,9 +1702,7 @@
|
||||
"upscaleModelDesc": "图像放大(图像到图像转换)模型",
|
||||
"postProcessingMissingModelWarning": "请访问 <LinkComponent>模型管理器</LinkComponent>来安装一个后处理(图像到图像转换)模型.",
|
||||
"missingModelsWarning": "请访问<LinkComponent>模型管理器</LinkComponent> 安装所需的模型:",
|
||||
"mainModelDesc": "主模型(SD1.5或SDXL架构)",
|
||||
"exceedsMaxSize": "放大设置超出了最大尺寸限制",
|
||||
"exceedsMaxSizeDetails": "最大放大限制是 {{maxUpscaleDimension}}x{{maxUpscaleDimension}} 像素.请尝试一个较小的图像或减少您的缩放选择."
|
||||
"mainModelDesc": "主模型(SD1.5或SDXL架构)"
|
||||
},
|
||||
"upsell": {
|
||||
"inviteTeammates": "邀请团队成员",
|
||||
|
||||
@@ -1,40 +1,26 @@
|
||||
/* eslint-disable no-console */
|
||||
import fs from 'node:fs';
|
||||
|
||||
import openapiTS, { astToString } from 'openapi-typescript';
|
||||
import ts from 'typescript';
|
||||
import openapiTS from 'openapi-typescript';
|
||||
|
||||
const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json';
|
||||
const OUTPUT_FILE = 'src/services/api/schema.ts';
|
||||
|
||||
async function generateTypes(schema) {
|
||||
process.stdout.write(`Generating types ${OUTPUT_FILE}...`);
|
||||
|
||||
// Use https://ts-ast-viewer.com to figure out how to create these AST nodes - define a type and use the bottom-left pane's output
|
||||
// `Blob` type
|
||||
const BLOB = ts.factory.createTypeReferenceNode(ts.factory.createIdentifier('Blob'));
|
||||
// `null` type
|
||||
const NULL = ts.factory.createLiteralTypeNode(ts.factory.createNull());
|
||||
// `Record<string, unknown>` type
|
||||
const RECORD_STRING_UNKNOWN = ts.factory.createTypeReferenceNode(ts.factory.createIdentifier('Record'), [
|
||||
ts.factory.createKeywordTypeNode(ts.SyntaxKind.StringKeyword),
|
||||
ts.factory.createKeywordTypeNode(ts.SyntaxKind.UnknownKeyword),
|
||||
]);
|
||||
|
||||
const types = await openapiTS(schema, {
|
||||
exportType: true,
|
||||
transform: (schemaObject) => {
|
||||
if ('format' in schemaObject && schemaObject.format === 'binary') {
|
||||
return schemaObject.nullable ? ts.factory.createUnionTypeNode([BLOB, NULL]) : BLOB;
|
||||
return schemaObject.nullable ? 'Blob | null' : 'Blob';
|
||||
}
|
||||
if (schemaObject.title === 'MetadataField') {
|
||||
// This is `Record<string, never>` by default, but it actually accepts any a dict of any valid JSON value.
|
||||
return RECORD_STRING_UNKNOWN;
|
||||
return 'Record<string, unknown>';
|
||||
}
|
||||
},
|
||||
defaultNonNullable: false,
|
||||
});
|
||||
fs.writeFileSync(OUTPUT_FILE, astToString(types));
|
||||
fs.writeFileSync(OUTPUT_FILE, types);
|
||||
process.stdout.write(`\nOK!\r\n`);
|
||||
}
|
||||
|
||||
|
||||
@@ -13,13 +13,11 @@ import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardMo
|
||||
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
|
||||
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
|
||||
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
|
||||
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { languageSelector } from 'features/system/store/systemSelectors';
|
||||
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { useGetAndLoadLibraryWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadLibraryWorkflow';
|
||||
import { AnimatePresence } from 'framer-motion';
|
||||
import i18n from 'i18n';
|
||||
import { size } from 'lodash-es';
|
||||
@@ -38,11 +36,10 @@ interface Props {
|
||||
imageName: string;
|
||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||
};
|
||||
selectedWorkflowId?: string;
|
||||
destination?: InvokeTabName | undefined;
|
||||
}
|
||||
|
||||
const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, destination }: Props) => {
|
||||
const App = ({ config = DEFAULT_CONFIG, selectedImage, destination }: Props) => {
|
||||
const language = useAppSelector(languageSelector);
|
||||
const logger = useLogger('system');
|
||||
const dispatch = useAppDispatch();
|
||||
@@ -73,14 +70,6 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
|
||||
}
|
||||
}, [dispatch, config, logger]);
|
||||
|
||||
const { getAndLoadWorkflow } = useGetAndLoadLibraryWorkflow();
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedWorkflowId) {
|
||||
getAndLoadWorkflow(selectedWorkflowId);
|
||||
}
|
||||
}, [selectedWorkflowId, getAndLoadWorkflow]);
|
||||
|
||||
useEffect(() => {
|
||||
if (destination) {
|
||||
dispatch(setActiveTab(destination));
|
||||
@@ -115,7 +104,6 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
|
||||
<DeleteImageModal />
|
||||
<ChangeBoardModal />
|
||||
<DynamicPromptsModal />
|
||||
<StylePresetModal />
|
||||
<PreselectedImage selectedImage={selectedImage} />
|
||||
</ErrorBoundary>
|
||||
);
|
||||
|
||||
@@ -44,7 +44,6 @@ interface Props extends PropsWithChildren {
|
||||
imageName: string;
|
||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||
};
|
||||
selectedWorkflowId?: string;
|
||||
destination?: InvokeTabName;
|
||||
customStarUi?: CustomStarUi;
|
||||
socketOptions?: Partial<ManagerOptions & SocketOptions>;
|
||||
@@ -65,7 +64,6 @@ const InvokeAIUI = ({
|
||||
projectUrl,
|
||||
queueId,
|
||||
selectedImage,
|
||||
selectedWorkflowId,
|
||||
destination,
|
||||
customStarUi,
|
||||
socketOptions,
|
||||
@@ -223,12 +221,7 @@ const InvokeAIUI = ({
|
||||
<React.Suspense fallback={<Loading />}>
|
||||
<ThemeLocaleProvider>
|
||||
<AppDndContext>
|
||||
<App
|
||||
config={config}
|
||||
selectedImage={selectedImage}
|
||||
selectedWorkflowId={selectedWorkflowId}
|
||||
destination={destination}
|
||||
/>
|
||||
<App config={config} selectedImage={selectedImage} destination={destination} />
|
||||
</AppDndContext>
|
||||
</ThemeLocaleProvider>
|
||||
</React.Suspense>
|
||||
|
||||
@@ -11,9 +11,6 @@ import {
|
||||
promptsChanged,
|
||||
} from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
|
||||
import { getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
|
||||
import { stylePresetsApi } from 'services/api/endpoints/stylePresets';
|
||||
import { utilitiesApi } from 'services/api/endpoints/utilities';
|
||||
import { socketConnected } from 'services/events/actions';
|
||||
|
||||
@@ -22,11 +19,7 @@ const matcher = isAnyOf(
|
||||
combinatorialToggled,
|
||||
maxPromptsChanged,
|
||||
maxPromptsReset,
|
||||
socketConnected,
|
||||
activeStylePresetIdChanged,
|
||||
stylePresetsApi.endpoints.deleteStylePreset.matchFulfilled,
|
||||
stylePresetsApi.endpoints.updateStylePreset.matchFulfilled,
|
||||
stylePresetsApi.endpoints.listStylePresets.matchFulfilled
|
||||
socketConnected
|
||||
);
|
||||
|
||||
export const addDynamicPromptsListener = (startAppListening: AppStartListening) => {
|
||||
@@ -35,7 +28,7 @@ export const addDynamicPromptsListener = (startAppListening: AppStartListening)
|
||||
effect: async (action, { dispatch, getState, cancelActiveListeners, delay }) => {
|
||||
cancelActiveListeners();
|
||||
const state = getState();
|
||||
const { positivePrompt } = getPresetModifiedPrompts(state);
|
||||
const { positivePrompt } = state.controlLayers.present;
|
||||
const { maxPrompts } = state.dynamicPrompts;
|
||||
|
||||
if (state.config.disabledFeatures.includes('dynamicPrompting')) {
|
||||
|
||||
@@ -28,7 +28,6 @@ import { generationPersistConfig, generationSlice } from 'features/parameters/st
|
||||
import { upscalePersistConfig, upscaleSlice } from 'features/parameters/store/upscaleSlice';
|
||||
import { queueSlice } from 'features/queue/store/queueSlice';
|
||||
import { sdxlPersistConfig, sdxlSlice } from 'features/sdxl/store/sdxlSlice';
|
||||
import { stylePresetPersistConfig, stylePresetSlice } from 'features/stylePresets/store/stylePresetSlice';
|
||||
import { configSlice } from 'features/system/store/configSlice';
|
||||
import { systemPersistConfig, systemSlice } from 'features/system/store/systemSlice';
|
||||
import { uiPersistConfig, uiSlice } from 'features/ui/store/uiSlice';
|
||||
@@ -70,7 +69,6 @@ const allReducers = {
|
||||
[workflowSettingsSlice.name]: workflowSettingsSlice.reducer,
|
||||
[api.reducerPath]: api.reducer,
|
||||
[upscaleSlice.name]: upscaleSlice.reducer,
|
||||
[stylePresetSlice.name]: stylePresetSlice.reducer,
|
||||
};
|
||||
|
||||
const rootReducer = combineReducers(allReducers);
|
||||
@@ -116,7 +114,6 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
|
||||
[controlLayersPersistConfig.name]: controlLayersPersistConfig,
|
||||
[workflowSettingsPersistConfig.name]: workflowSettingsPersistConfig,
|
||||
[upscalePersistConfig.name]: upscalePersistConfig,
|
||||
[stylePresetPersistConfig.name]: stylePresetPersistConfig,
|
||||
};
|
||||
|
||||
const unserialize: UnserializeFunction = (data, key) => {
|
||||
@@ -167,8 +164,8 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
|
||||
reducer: rememberedRootReducer,
|
||||
middleware: (getDefaultMiddleware) =>
|
||||
getDefaultMiddleware({
|
||||
serializableCheck: import.meta.env.MODE === 'development',
|
||||
immutableCheck: import.meta.env.MODE === 'development',
|
||||
serializableCheck: false,
|
||||
immutableCheck: false,
|
||||
})
|
||||
.concat(api.middleware)
|
||||
.concat(dynamicMiddlewares)
|
||||
|
||||
@@ -71,7 +71,6 @@ export type AppConfig = {
|
||||
*/
|
||||
maxUpscaleDimension?: number;
|
||||
allowPrivateBoards: boolean;
|
||||
allowPrivateStylePresets: boolean;
|
||||
disabledTabs: InvokeTabName[];
|
||||
disabledFeatures: AppFeature[];
|
||||
disabledSDFeatures: SDFeature[];
|
||||
|
||||
@@ -47,7 +47,6 @@ export const IAINoContentFallback = memo((props: IAINoImageFallbackProps) => {
|
||||
userSelect: 'none',
|
||||
opacity: 0.7,
|
||||
color: 'base.500',
|
||||
fontSize: 'md',
|
||||
...sx,
|
||||
}),
|
||||
[sx]
|
||||
@@ -56,7 +55,11 @@ export const IAINoContentFallback = memo((props: IAINoImageFallbackProps) => {
|
||||
return (
|
||||
<Flex sx={styles} {...rest}>
|
||||
{icon && <Icon as={icon} boxSize={boxSize} opacity={0.7} />}
|
||||
{props.label && <Text textAlign="center">{props.label}</Text>}
|
||||
{props.label && (
|
||||
<Text textAlign="center" fontSize="md">
|
||||
{props.label}
|
||||
</Text>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { convertImageUrlToBlob } from 'common/util/convertImageUrlToBlob';
|
||||
import { useImageUrlToBlob } from 'common/hooks/useImageUrlToBlob';
|
||||
import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
@@ -6,6 +6,7 @@ import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const useCopyImageToClipboard = () => {
|
||||
const { t } = useTranslation();
|
||||
const imageUrlToBlob = useImageUrlToBlob();
|
||||
|
||||
const isClipboardAPIAvailable = useMemo(() => {
|
||||
return Boolean(navigator.clipboard) && Boolean(window.ClipboardItem);
|
||||
@@ -22,7 +23,7 @@ export const useCopyImageToClipboard = () => {
|
||||
});
|
||||
}
|
||||
try {
|
||||
const blob = await convertImageUrlToBlob(image_url);
|
||||
const blob = await imageUrlToBlob(image_url);
|
||||
|
||||
if (!blob) {
|
||||
throw new Error('Unable to create Blob');
|
||||
@@ -44,7 +45,7 @@ export const useCopyImageToClipboard = () => {
|
||||
});
|
||||
}
|
||||
},
|
||||
[isClipboardAPIAvailable, t]
|
||||
[imageUrlToBlob, isClipboardAPIAvailable, t]
|
||||
);
|
||||
|
||||
return { isClipboardAPIAvailable, copyImageToClipboard };
|
||||
|
||||
40
invokeai/frontend/web/src/common/hooks/useImageUrlToBlob.ts
Normal file
@@ -0,0 +1,40 @@
|
||||
import { $authToken } from 'app/store/nanostores/authToken';
|
||||
import { useCallback } from 'react';
|
||||
|
||||
/**
|
||||
* Converts an image URL to a Blob by creating an <img /> element, drawing it to canvas
|
||||
* and then converting the canvas to a Blob.
|
||||
*
|
||||
* @returns A function that takes a URL and returns a Promise that resolves with a Blob
|
||||
*/
|
||||
export const useImageUrlToBlob = () => {
|
||||
const imageUrlToBlob = useCallback(
|
||||
async (url: string) =>
|
||||
new Promise<Blob | null>((resolve) => {
|
||||
const img = new Image();
|
||||
img.onload = () => {
|
||||
const canvas = document.createElement('canvas');
|
||||
canvas.width = img.width;
|
||||
canvas.height = img.height;
|
||||
|
||||
const context = canvas.getContext('2d');
|
||||
if (!context) {
|
||||
return;
|
||||
}
|
||||
context.drawImage(img, 0, 0);
|
||||
resolve(
|
||||
new Promise<Blob | null>((resolve) => {
|
||||
canvas.toBlob(function (blob) {
|
||||
resolve(blob);
|
||||
}, 'image/png');
|
||||
})
|
||||
);
|
||||
};
|
||||
img.crossOrigin = $authToken.get() ? 'use-credentials' : 'anonymous';
|
||||
img.src = url;
|
||||
}),
|
||||
[]
|
||||
);
|
||||
|
||||
return imageUrlToBlob;
|
||||
};
|
||||
@@ -1,39 +0,0 @@
|
||||
import { $authToken } from 'app/store/nanostores/authToken';
|
||||
|
||||
/**
|
||||
* Converts an image URL to a Blob by creating an <img /> element, drawing it to canvas
|
||||
* and then converting the canvas to a Blob.
|
||||
*
|
||||
* @returns A function that takes a URL and returns a Promise that resolves with a Blob
|
||||
*/
|
||||
|
||||
export const convertImageUrlToBlob = async (url: string) =>
|
||||
new Promise<Blob | null>((resolve, reject) => {
|
||||
const img = new Image();
|
||||
img.onload = () => {
|
||||
const canvas = document.createElement('canvas');
|
||||
canvas.width = img.width;
|
||||
canvas.height = img.height;
|
||||
|
||||
const context = canvas.getContext('2d');
|
||||
if (!context) {
|
||||
reject(new Error('Failed to get canvas context'));
|
||||
return;
|
||||
}
|
||||
context.drawImage(img, 0, 0);
|
||||
canvas.toBlob((blob) => {
|
||||
if (blob) {
|
||||
resolve(blob);
|
||||
} else {
|
||||
reject(new Error('Failed to convert image to blob'));
|
||||
}
|
||||
}, 'image/png');
|
||||
};
|
||||
|
||||
img.onerror = () => {
|
||||
reject(new Error('Image failed to load. The URL may be invalid or the object may not exist.'));
|
||||
};
|
||||
|
||||
img.crossOrigin = $authToken.get() ? 'use-credentials' : 'anonymous';
|
||||
img.src = url;
|
||||
});
|
||||
@@ -42,7 +42,6 @@ const DepthAnythingProcessor = (props: Props) => {
|
||||
|
||||
const options: { label: string; value: DepthAnythingModelSize }[] = useMemo(
|
||||
() => [
|
||||
{ label: t('controlnet.depthAnythingSmallV2'), value: 'small_v2' },
|
||||
{ label: t('controlnet.small'), value: 'small' },
|
||||
{ label: t('controlnet.base'), value: 'base' },
|
||||
{ label: t('controlnet.large'), value: 'large' },
|
||||
|
||||
@@ -94,7 +94,7 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
|
||||
buildDefaults: (baseModel?: BaseModelType) => ({
|
||||
id: 'depth_anything_image_processor',
|
||||
type: 'depth_anything_image_processor',
|
||||
model_size: 'small_v2',
|
||||
model_size: 'small',
|
||||
resolution: baseModel === 'sdxl' ? 1024 : 512,
|
||||
}),
|
||||
},
|
||||
|
||||
@@ -84,7 +84,7 @@ export type RequiredDepthAnythingImageProcessorInvocation = O.Required<
|
||||
'type' | 'model_size' | 'resolution' | 'offload'
|
||||
>;
|
||||
|
||||
const zDepthAnythingModelSize = z.enum(['large', 'base', 'small', 'small_v2']);
|
||||
const zDepthAnythingModelSize = z.enum(['large', 'base', 'small']);
|
||||
export type DepthAnythingModelSize = z.infer<typeof zDepthAnythingModelSize>;
|
||||
export const isDepthAnythingModelSize = (v: unknown): v is DepthAnythingModelSize =>
|
||||
zDepthAnythingModelSize.safeParse(v).success;
|
||||
|
||||
@@ -24,7 +24,6 @@ export const DepthAnythingProcessor = memo(({ onChange, config }: Props) => {
|
||||
|
||||
const options: { label: string; value: DepthAnythingModelSize }[] = useMemo(
|
||||
() => [
|
||||
{ label: t('controlnet.depthAnythingSmallV2'), value: 'small_v2' },
|
||||
{ label: t('controlnet.small'), value: 'small' },
|
||||
{ label: t('controlnet.base'), value: 'base' },
|
||||
{ label: t('controlnet.large'), value: 'large' },
|
||||
|
||||
@@ -36,7 +36,7 @@ const zContentShuffleProcessorConfig = z.object({
|
||||
});
|
||||
export type ContentShuffleProcessorConfig = z.infer<typeof zContentShuffleProcessorConfig>;
|
||||
|
||||
const zDepthAnythingModelSize = z.enum(['large', 'base', 'small', 'small_v2']);
|
||||
const zDepthAnythingModelSize = z.enum(['large', 'base', 'small']);
|
||||
export type DepthAnythingModelSize = z.infer<typeof zDepthAnythingModelSize>;
|
||||
export const isDepthAnythingModelSize = (v: unknown): v is DepthAnythingModelSize =>
|
||||
zDepthAnythingModelSize.safeParse(v).success;
|
||||
@@ -298,7 +298,7 @@ export const CA_PROCESSOR_DATA: CAProcessorsData = {
|
||||
buildDefaults: () => ({
|
||||
id: 'depth_anything_image_processor',
|
||||
type: 'depth_anything_image_processor',
|
||||
model_size: 'small_v2',
|
||||
model_size: 'small',
|
||||
}),
|
||||
buildNode: (image, config) => ({
|
||||
...config,
|
||||
|
||||
@@ -66,7 +66,7 @@ export const Gallery = () => {
|
||||
<Flex flexDirection="column" alignItems="center" justifyContent="space-between" h="full" w="full" pt={1}>
|
||||
<Tabs index={galleryView === 'images' ? 0 : 1} variant="enclosed" display="flex" flexDir="column" w="full">
|
||||
<TabList gap={2} fontSize="sm" borderColor="base.800" alignItems="center" w="full">
|
||||
<Text fontSize="sm" fontWeight="semibold" noOfLines={1} px="2" wordBreak="break-all">
|
||||
<Text fontSize="sm" fontWeight="semibold" noOfLines={1} px="2">
|
||||
{boardName}
|
||||
</Text>
|
||||
<Spacer />
|
||||
|
||||
@@ -30,7 +30,6 @@ import {
|
||||
PiFlowArrowBold,
|
||||
PiFoldersBold,
|
||||
PiImagesBold,
|
||||
PiPaintBrushBold,
|
||||
PiPlantBold,
|
||||
PiQuotesBold,
|
||||
PiShareFatBold,
|
||||
@@ -56,17 +55,8 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
const { downloadImage } = useDownloadImage();
|
||||
const templates = useStore($templates);
|
||||
|
||||
const {
|
||||
recallAll,
|
||||
remix,
|
||||
recallSeed,
|
||||
recallPrompts,
|
||||
hasMetadata,
|
||||
hasSeed,
|
||||
hasPrompts,
|
||||
isLoadingMetadata,
|
||||
createAsPreset,
|
||||
} = useImageActions(imageDTO?.image_name);
|
||||
const { recallAll, remix, recallSeed, recallPrompts, hasMetadata, hasSeed, hasPrompts, isLoadingMetadata } =
|
||||
useImageActions(imageDTO?.image_name);
|
||||
|
||||
const { getAndLoadEmbeddedWorkflow, getAndLoadEmbeddedWorkflowResult } = useGetAndLoadEmbeddedWorkflow({});
|
||||
|
||||
@@ -192,13 +182,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
>
|
||||
{t('parameters.useAll')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={isLoadingMetadata ? <SpinnerIcon /> : <PiPaintBrushBold />}
|
||||
onClickCapture={createAsPreset}
|
||||
isDisabled={isLoadingMetadata || !hasPrompts}
|
||||
>
|
||||
{t('stylePresets.useForTemplate')}
|
||||
</MenuItem>
|
||||
<MenuDivider />
|
||||
<MenuItem icon={<PiShareFatBold />} onClickCapture={handleSendToImageToImage} id="send-to-img2img">
|
||||
{t('parameters.sendToImg2Img')}
|
||||
|
||||
@@ -60,6 +60,7 @@ const ImageGalleryContent = () => {
|
||||
<GalleryHeader />
|
||||
<Flex alignItems="center" justifyContent="space-between" w="full">
|
||||
<Button
|
||||
w={112}
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
onClick={handleToggleBoardPanel}
|
||||
|
||||
@@ -1,25 +1,15 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { handlers, parseAndRecallAllMetadata, parseAndRecallPrompts } from 'features/metadata/util/handlers';
|
||||
import { $stylePresetModalState } from 'features/stylePresets/store/stylePresetModal';
|
||||
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
|
||||
|
||||
export const useImageActions = (image_name?: string) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
const activeStylePresetId = useAppSelector((s) => s.stylePreset.activeStylePresetId);
|
||||
const { metadata, isLoading: isLoadingMetadata } = useDebouncedMetadata(image_name);
|
||||
const [hasMetadata, setHasMetadata] = useState(false);
|
||||
const [hasSeed, setHasSeed] = useState(false);
|
||||
const [hasPrompts, setHasPrompts] = useState(false);
|
||||
const { data: imageDTO } = useGetImageDTOQuery(image_name ?? skipToken);
|
||||
|
||||
useEffect(() => {
|
||||
const parseMetadata = async () => {
|
||||
@@ -52,26 +42,14 @@ export const useImageActions = (image_name?: string) => {
|
||||
parseMetadata();
|
||||
}, [metadata]);
|
||||
|
||||
const clearStylePreset = useCallback(() => {
|
||||
if (activeStylePresetId) {
|
||||
dispatch(activeStylePresetIdChanged(null));
|
||||
toast({
|
||||
status: 'info',
|
||||
title: t('stylePresets.promptTemplateCleared'),
|
||||
});
|
||||
}
|
||||
}, [dispatch, activeStylePresetId, t]);
|
||||
|
||||
const recallAll = useCallback(() => {
|
||||
parseAndRecallAllMetadata(metadata, activeTabName === 'generation');
|
||||
clearStylePreset();
|
||||
}, [activeTabName, metadata, clearStylePreset]);
|
||||
}, [activeTabName, metadata]);
|
||||
|
||||
const remix = useCallback(() => {
|
||||
// Recalls all metadata parameters except seed
|
||||
parseAndRecallAllMetadata(metadata, activeTabName === 'generation', ['seed']);
|
||||
clearStylePreset();
|
||||
}, [activeTabName, metadata, clearStylePreset]);
|
||||
}, [activeTabName, metadata]);
|
||||
|
||||
const recallSeed = useCallback(() => {
|
||||
handlers.seed.parse(metadata).then((seed) => {
|
||||
@@ -81,37 +59,7 @@ export const useImageActions = (image_name?: string) => {
|
||||
|
||||
const recallPrompts = useCallback(() => {
|
||||
parseAndRecallPrompts(metadata);
|
||||
clearStylePreset();
|
||||
}, [metadata, clearStylePreset]);
|
||||
}, [metadata]);
|
||||
|
||||
const createAsPreset = useCallback(async () => {
|
||||
if (image_name && metadata && imageDTO) {
|
||||
const positivePrompt = await handlers.positivePrompt.parse(metadata);
|
||||
const negativePrompt = await handlers.negativePrompt.parse(metadata);
|
||||
|
||||
$stylePresetModalState.set({
|
||||
prefilledFormData: {
|
||||
name: '',
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
imageUrl: imageDTO.image_url,
|
||||
type: 'user',
|
||||
},
|
||||
updatingStylePresetId: null,
|
||||
isModalOpen: true,
|
||||
});
|
||||
}
|
||||
}, [image_name, metadata, imageDTO]);
|
||||
|
||||
return {
|
||||
recallAll,
|
||||
remix,
|
||||
recallSeed,
|
||||
recallPrompts,
|
||||
hasMetadata,
|
||||
hasSeed,
|
||||
hasPrompts,
|
||||
isLoadingMetadata,
|
||||
createAsPreset,
|
||||
};
|
||||
return { recallAll, remix, recallSeed, recallPrompts, hasMetadata, hasSeed, hasPrompts, isLoadingMetadata };
|
||||
};
|
||||
|
||||
@@ -22,10 +22,11 @@ import {
|
||||
} from './constants';
|
||||
import { addLoRAs } from './generation/addLoRAs';
|
||||
import { addSDXLLoRas } from './generation/addSDXLLoRAs';
|
||||
import { getBoardField, getPresetModifiedPrompts } from './graphBuilderUtils';
|
||||
import { getBoardField, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
|
||||
export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise<GraphType> => {
|
||||
const { model, cfgScale: cfg_scale, scheduler, steps, vaePrecision, seed, vae } = state.generation;
|
||||
const { positivePrompt, negativePrompt } = state.controlLayers.present;
|
||||
const { upscaleModel, upscaleInitialImage, structure, creativity, tileControlnetModel, scale } = state.upscale;
|
||||
|
||||
assert(model, 'No model found in state');
|
||||
@@ -98,8 +99,7 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
|
||||
let modelNode;
|
||||
|
||||
if (model.base === 'sdxl') {
|
||||
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } =
|
||||
getPresetModifiedPrompts(state);
|
||||
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
|
||||
|
||||
posCondNode = g.addNode({
|
||||
type: 'sdxl_compel_prompt',
|
||||
@@ -132,8 +132,6 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
|
||||
negative_style_prompt: negativeStylePrompt,
|
||||
});
|
||||
} else {
|
||||
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
|
||||
|
||||
posCondNode = g.addNode({
|
||||
type: 'compel',
|
||||
id: POSITIVE_CONDITIONING,
|
||||
|
||||
@@ -16,7 +16,7 @@ import {
|
||||
SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import { getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
import { isRefinerMainModelModelConfig } from 'services/api/types';
|
||||
|
||||
@@ -59,7 +59,7 @@ export const addSDXLRefinerToGraph = async (
|
||||
const modelLoaderId = modelLoaderNodeId ? modelLoaderNodeId : SDXL_MODEL_LOADER;
|
||||
|
||||
// Construct Style Prompt
|
||||
const { positiveStylePrompt, negativeStylePrompt } = getPresetModifiedPrompts(state);
|
||||
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
|
||||
|
||||
// Unplug SDXL Latents Generation To Latents To Image
|
||||
graph.edges = graph.edges.filter((e) => !(e.source.node_id === baseNodeId && ['latents'].includes(e.source.field)));
|
||||
|
||||
@@ -16,11 +16,7 @@ import {
|
||||
POSITIVE_CONDITIONING,
|
||||
SEAMLESS,
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import {
|
||||
getBoardField,
|
||||
getIsIntermediate,
|
||||
getPresetModifiedPrompts,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
|
||||
@@ -55,6 +51,7 @@ export const buildCanvasImageToImageGraph = async (
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
} = state.generation;
|
||||
const { positivePrompt, negativePrompt } = state.controlLayers.present;
|
||||
|
||||
// The bounding box determines width and height, not the width and height params
|
||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||
@@ -74,8 +71,6 @@ export const buildCanvasImageToImageGraph = async (
|
||||
|
||||
const use_cpu = shouldUseCpuNoise;
|
||||
|
||||
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
|
||||
@@ -19,11 +19,7 @@ import {
|
||||
POSITIVE_CONDITIONING,
|
||||
SEAMLESS,
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import {
|
||||
getBoardField,
|
||||
getIsIntermediate,
|
||||
getPresetModifiedPrompts,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
@@ -62,6 +58,7 @@ export const buildCanvasInpaintGraph = async (
|
||||
canvasCoherenceEdgeSize,
|
||||
maskBlur,
|
||||
} = state.generation;
|
||||
const { positivePrompt, negativePrompt } = state.controlLayers.present;
|
||||
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
@@ -82,8 +79,6 @@ export const buildCanvasInpaintGraph = async (
|
||||
|
||||
const use_cpu = shouldUseCpuNoise;
|
||||
|
||||
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
|
||||
|
||||
const graph: NonNullableGraph = {
|
||||
id: CANVAS_INPAINT_GRAPH,
|
||||
nodes: {
|
||||
|
||||
@@ -23,11 +23,7 @@ import {
|
||||
POSITIVE_CONDITIONING,
|
||||
SEAMLESS,
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import {
|
||||
getBoardField,
|
||||
getIsIntermediate,
|
||||
getPresetModifiedPrompts,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
@@ -74,6 +70,7 @@ export const buildCanvasOutpaintGraph = async (
|
||||
canvasCoherenceEdgeSize,
|
||||
maskBlur,
|
||||
} = state.generation;
|
||||
const { positivePrompt, negativePrompt } = state.controlLayers.present;
|
||||
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
@@ -94,8 +91,6 @@ export const buildCanvasOutpaintGraph = async (
|
||||
|
||||
const use_cpu = shouldUseCpuNoise;
|
||||
|
||||
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
|
||||
|
||||
const graph: NonNullableGraph = {
|
||||
id: CANVAS_OUTPAINT_GRAPH,
|
||||
nodes: {
|
||||
|
||||
@@ -16,11 +16,7 @@ import {
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
SEAMLESS,
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import {
|
||||
getBoardField,
|
||||
getIsIntermediate,
|
||||
getPresetModifiedPrompts,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
|
||||
@@ -55,6 +51,7 @@ export const buildCanvasSDXLImageToImageGraph = async (
|
||||
seamlessYAxis,
|
||||
img2imgStrength: strength,
|
||||
} = state.generation;
|
||||
const { positivePrompt, negativePrompt } = state.controlLayers.present;
|
||||
|
||||
const { refinerModel, refinerStart } = state.sdxl;
|
||||
|
||||
@@ -78,7 +75,7 @@ export const buildCanvasSDXLImageToImageGraph = async (
|
||||
const use_cpu = shouldUseCpuNoise;
|
||||
|
||||
// Construct Style Prompt
|
||||
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } = getPresetModifiedPrompts(state);
|
||||
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
|
||||
@@ -19,11 +19,7 @@ import {
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
SEAMLESS,
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import {
|
||||
getBoardField,
|
||||
getIsIntermediate,
|
||||
getPresetModifiedPrompts,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
@@ -62,6 +58,7 @@ export const buildCanvasSDXLInpaintGraph = async (
|
||||
canvasCoherenceEdgeSize,
|
||||
maskBlur,
|
||||
} = state.generation;
|
||||
const { positivePrompt, negativePrompt } = state.controlLayers.present;
|
||||
|
||||
const { refinerModel, refinerStart } = state.sdxl;
|
||||
|
||||
@@ -86,7 +83,7 @@ export const buildCanvasSDXLInpaintGraph = async (
|
||||
const use_cpu = shouldUseCpuNoise;
|
||||
|
||||
// Construct Style Prompt
|
||||
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } = getPresetModifiedPrompts(state);
|
||||
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
|
||||
|
||||
const graph: NonNullableGraph = {
|
||||
id: SDXL_CANVAS_INPAINT_GRAPH,
|
||||
|
||||
@@ -23,11 +23,7 @@ import {
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
SEAMLESS,
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import {
|
||||
getBoardField,
|
||||
getIsIntermediate,
|
||||
getPresetModifiedPrompts,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
@@ -74,6 +70,7 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
||||
canvasCoherenceEdgeSize,
|
||||
maskBlur,
|
||||
} = state.generation;
|
||||
const { positivePrompt, negativePrompt } = state.controlLayers.present;
|
||||
|
||||
const { refinerModel, refinerStart } = state.sdxl;
|
||||
|
||||
@@ -97,7 +94,7 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
||||
const use_cpu = shouldUseCpuNoise;
|
||||
|
||||
// Construct Style Prompt
|
||||
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } = getPresetModifiedPrompts(state);
|
||||
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
|
||||
|
||||
const graph: NonNullableGraph = {
|
||||
id: SDXL_CANVAS_OUTPAINT_GRAPH,
|
||||
|
||||
@@ -14,11 +14,7 @@ import {
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
SEAMLESS,
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import {
|
||||
getBoardField,
|
||||
getIsIntermediate,
|
||||
getPresetModifiedPrompts,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
@@ -48,6 +44,7 @@ export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
} = state.generation;
|
||||
const { positivePrompt, negativePrompt } = state.controlLayers.present;
|
||||
|
||||
// The bounding box determines width and height, not the width and height params
|
||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||
@@ -70,7 +67,7 @@ export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise
|
||||
let modelLoaderNodeId = SDXL_MODEL_LOADER;
|
||||
|
||||
// Construct Style Prompt
|
||||
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } = getPresetModifiedPrompts(state);
|
||||
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
|
||||
@@ -14,11 +14,7 @@ import {
|
||||
POSITIVE_CONDITIONING,
|
||||
SEAMLESS,
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import {
|
||||
getBoardField,
|
||||
getIsIntermediate,
|
||||
getPresetModifiedPrompts,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
@@ -48,6 +44,7 @@ export const buildCanvasTextToImageGraph = async (state: RootState): Promise<Non
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
} = state.generation;
|
||||
const { positivePrompt, negativePrompt } = state.controlLayers.present;
|
||||
|
||||
// The bounding box determines width and height, not the width and height params
|
||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||
@@ -67,8 +64,6 @@ export const buildCanvasTextToImageGraph = async (state: RootState): Promise<Non
|
||||
|
||||
let modelLoaderNodeId = MAIN_MODEL_LOADER;
|
||||
|
||||
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
|
||||
@@ -22,7 +22,7 @@ import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
|
||||
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
||||
import type { GraphType } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { getBoardField, getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
@@ -40,12 +40,11 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
|
||||
seed,
|
||||
vae,
|
||||
} = state.generation;
|
||||
const { positivePrompt, negativePrompt } = state.controlLayers.present;
|
||||
const { width, height } = state.controlLayers.present.size;
|
||||
|
||||
assert(model, 'No model found in state');
|
||||
|
||||
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
|
||||
|
||||
const g = new Graph(CONTROL_LAYERS_GRAPH);
|
||||
const modelLoader = g.addNode({
|
||||
type: 'main_model_loader',
|
||||
|
||||