Files
SHARK-Studio/apps/stable_diffusion/web/api/utils.py
Stefan Kapusniak ad55cb696f SD/API: Add missing A1111 APIs to Shark to support koboldcpp image generation (#1924)
* SD/API: Add missing a1111 API features for Koboldcpp

* Refactors SD api functions into their own file
* Adds the following apis implemented by a1111 as needed by koboldcpp:
   - adds /sdapi/v1/sd-models (lists available models)
   - adds /sdapi/v1/options (only the bare minimum needed)
* Adds optional CORS support, use the '--api_accept_origin' command line
argument to activate and configure.
* Extends existing APIs to include optional sampler/scheduler selection
* Extends /sdapi/v1/textimg to recognise the method used by koboldcpp
to select the model.
* Where possible take values not provided to the API in the request from
the existing relevant command line parameters rather than hardcoding
them.
* return a 400 response when a request doesn't have required properties.
* changed default schedulers and models for some apis to ones that
actually seem to work.
* Update api_test.py to include the new APIs.
* Update api_test.py to include a '--verbose' command line option.

* SD/API: Take more API values from args

* Take LoRA from '--use_lora' command line arg if specified
* Take device from '--device' command line arg if specified (substring
match, so a short name such as 'vulkan://0' should work)

* SD/API: add more endpoints and pydantic typing

* Mount the whole of /sdapi from index.py as a FastAPI application,
rather than each endpoint individually
* Add the following additional API endpoints:
  * /sdapi/v1/samplers
  * /sdapi/v1/cmd-flags
* Make scheduler/sampler selection checking and fallback much more
robust.
* Support aliasing some A1111 scheduler/sampler names to the diffusers
ones we are using.
* Expand response /sdapi/v1/options to add a few more things.
* Split non-api functions and variables into their own utils.py file.
* Support 'n_iter' request property and the return of multiple images
from generation endpoints. Equivalent of '--batch_count', batch_size
is stil hardcoded at 1
* Include (some) hires_fix request properties in txt2img endpoint
* Rework endpoints using pydantic model classes for better request
validation and so we get much improved swagger api docs at
/sdapi/docs and redoc at /sdapi/redoc

* SD/API Delete commented out code from index.py

* Delete some code that is no longer needed by the SD API in index.py
(and one line sdapi_v1.py) that I'd previously only commented out.

* SD/UI: Add shark_sd_koboldcpp.md document

* Add documentation on how to set up Koboldcpp with SHARK
* Link this and the existing blender set up document from the main
README.md

* SD/API Improve stencil options in img2img endpoint

In /sdapi/v1/img2img:
  * Add zoedepth to the controlnet use_stencil options
  * Require and use second image as stencil mask for controlnet scribble
2023-11-06 15:20:19 -06:00

212 lines
6.4 KiB
Python

import base64
import pickle
from argparse import Namespace
from fastapi.exceptions import HTTPException
from io import BytesIO
from PIL import Image
from pydantic import BaseModel, Field
from apps.stable_diffusion.src import args
from apps.stable_diffusion.web.ui.utils import (
available_devices,
get_custom_model_files,
predefined_models,
predefined_paint_models,
predefined_upscaler_models,
scheduler_list,
scheduler_list_cpu_only,
)
# Probably overly cautious, but try to ensure we only use the starting
# args in each api call, as the code does `args.<whatever> = <changed_value>`
# in lots of places and in testing, it seemed to me, these changes leaked
# into subsequent api calls.
# Roundtripping through pickle for deepcopy, there is probably a better way
frozen_args = Namespace(**(pickle.loads(pickle.dumps(vars(args)))))
# an attempt to map some of the A1111 sampler names to scheduler names
# https://github.com/huggingface/diffusers/issues/4167 is where the
# (not so obvious) ones come from
sampler_aliases = {
# a1111/onnx (these point to diffusers classes in A1111)
"pndm": "PNDM",
"heun": "HeunDiscrete",
"ddim": "DDIM",
"ddpm": "DDPM",
"euler": "EulerDiscrete",
"euler-ancestral": "EulerAncestralDiscrete",
"dpm": "DPMSolverMultistep",
# a1111/k_diffusion (the obvious ones)
"Euler a": "EulerAncestralDiscrete",
"Euler": "EulerDiscrete",
"LMS": "LMSDiscrete",
"Heun": "HeunDiscrete",
# a1111/k_diffusion (not so obvious)
"DPM++ 2M": "DPMSolverMultistep",
"DPM++ 2M Karras": "DPMSolverMultistepKarras",
"DPM++ 2M SDE": "DPMSolverMultistep++",
"DPM++ 2M SDE Karras": "DPMSolverMultistepKarras++",
"DPM2": "KDPM2Discrete",
"DPM2 a": "KDPM2AncestralDiscrete",
}
allowed_schedulers = {
"txt2img": {
"schedulers": scheduler_list,
"fallback": "SharkEulerDiscrete",
},
"txt2img_hires": {
"schedulers": scheduler_list_cpu_only,
"fallback": "DEISMultistep",
},
"img2img": {
"schedulers": scheduler_list_cpu_only,
"fallback": "EulerDiscrete",
},
"inpaint": {
"schedulers": scheduler_list_cpu_only,
"fallback": "DDIM",
},
"outpaint": {
"schedulers": scheduler_list_cpu_only,
"fallback": "DDIM",
},
"upscaler": {
"schedulers": scheduler_list_cpu_only,
"fallback": "DDIM",
},
}
# base pydantic model for sd generation apis
class GenerationInputData(BaseModel):
prompt: str = ""
negative_prompt: str = ""
hf_model_id: str | None = None
height: int = Field(
default=frozen_args.height, ge=128, le=768, multiple_of=8
)
width: int = Field(
default=frozen_args.width, ge=128, le=768, multiple_of=8
)
sampler_name: str = frozen_args.scheduler
cfg_scale: float = Field(default=frozen_args.guidance_scale, ge=1)
steps: int = Field(default=frozen_args.steps, ge=1, le=100)
seed: int = frozen_args.seed
n_iter: int = Field(default=frozen_args.batch_count)
class GenerationResponseData(BaseModel):
images: list[str] = Field(description="Generated images, Base64 encoded")
properties: dict = {}
info: str
# image encoding/decoding
def encode_pil_to_base64(images: list[Image.Image]):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if frozen_args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif frozen_args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
def decode_base64_to_image(encoding: str):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=400, detail="Invalid encoded image")
# get valid sd models/vaes/schedulers etc.
def get_predefined_models(custom_checkpoint_type: str):
match custom_checkpoint_type:
case "inpainting":
return predefined_paint_models
case "upscaler":
return predefined_upscaler_models
case _:
return predefined_models
def get_model_from_request(
request_data=None,
checkpoint_type: str = "",
fallback_model: str = "",
):
model = None
if request_data:
if request_data.hf_model_id:
model = request_data.hf_model_id
elif request_data.override_settings:
model = request_data.override_settings.sd_model_checkpoint
# if the request didn't specify a model try the command line args
result = model or frozen_args.ckpt_loc or frozen_args.hf_model_id
# make sure whatever we have is a valid model for the checkpoint type
if result in get_custom_model_files(
custom_checkpoint_type=checkpoint_type
) + get_predefined_models(checkpoint_type):
return result
# if not return what was specified as the fallback
else:
return fallback_model
def get_scheduler_from_request(
request_data: GenerationInputData, operation: str
):
allowed = allowed_schedulers[operation]
requested = request_data.sampler_name
requested = sampler_aliases.get(requested, requested)
return (
requested
if requested in allowed["schedulers"]
else allowed["fallback"]
)
def get_lora_params(use_lora: str):
# TODO: since the inference functions in the webui, which we are
# still calling into for the api, jam these back together again before
# handing them off to the pipeline, we should remove this nonsense
# and unify their selection in the UI and command line args proper
if use_lora in get_custom_model_files("lora"):
return (use_lora, "")
return ("None", use_lora)
def get_device(device_str: str):
# first substring match in the list available devices, with first
# device when none are matched
return next(
(device for device in available_devices if device_str in device),
available_devices[0],
)