mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-14 00:08:01 -05:00
* SD/API: Add missing a1111 API features for Koboldcpp * Refactors SD api functions into their own file * Adds the following apis implemented by a1111 as needed by koboldcpp: - adds /sdapi/v1/sd-models (lists available models) - adds /sdapi/v1/options (only the bare minimum needed) * Adds optional CORS support, use the '--api_accept_origin' command line argument to activate and configure. * Extends existing APIs to include optional sampler/scheduler selection * Extends /sdapi/v1/textimg to recognise the method used by koboldcpp to select the model. * Where possible take values not provided to the API in the request from the existing relevant command line parameters rather than hardcoding them. * return a 400 response when a request doesn't have required properties. * changed default schedulers and models for some apis to ones that actually seem to work. * Update api_test.py to include the new APIs. * Update api_test.py to include a '--verbose' command line option. * SD/API: Take more API values from args * Take LoRA from '--use_lora' command line arg if specified * Take device from '--device' command line arg if specified (substring match, so a short name such as 'vulkan://0' should work) * SD/API: add more endpoints and pydantic typing * Mount the whole of /sdapi from index.py as a FastAPI application, rather than each endpoint individually * Add the following additional API endpoints: * /sdapi/v1/samplers * /sdapi/v1/cmd-flags * Make scheduler/sampler selection checking and fallback much more robust. * Support aliasing some A1111 scheduler/sampler names to the diffusers ones we are using. * Expand response /sdapi/v1/options to add a few more things. * Split non-api functions and variables into their own utils.py file. * Support 'n_iter' request property and the return of multiple images from generation endpoints. Equivalent of '--batch_count', batch_size is stil hardcoded at 1 * Include (some) hires_fix request properties in txt2img endpoint * Rework endpoints using pydantic model classes for better request validation and so we get much improved swagger api docs at /sdapi/docs and redoc at /sdapi/redoc * SD/API Delete commented out code from index.py * Delete some code that is no longer needed by the SD API in index.py (and one line sdapi_v1.py) that I'd previously only commented out. * SD/UI: Add shark_sd_koboldcpp.md document * Add documentation on how to set up Koboldcpp with SHARK * Link this and the existing blender set up document from the main README.md * SD/API Improve stencil options in img2img endpoint In /sdapi/v1/img2img: * Add zoedepth to the controlnet use_stencil options * Require and use second image as stencil mask for controlnet scribble
212 lines
6.4 KiB
Python
212 lines
6.4 KiB
Python
import base64
|
|
import pickle
|
|
|
|
from argparse import Namespace
|
|
from fastapi.exceptions import HTTPException
|
|
from io import BytesIO
|
|
from PIL import Image
|
|
from pydantic import BaseModel, Field
|
|
|
|
from apps.stable_diffusion.src import args
|
|
from apps.stable_diffusion.web.ui.utils import (
|
|
available_devices,
|
|
get_custom_model_files,
|
|
predefined_models,
|
|
predefined_paint_models,
|
|
predefined_upscaler_models,
|
|
scheduler_list,
|
|
scheduler_list_cpu_only,
|
|
)
|
|
|
|
|
|
# Probably overly cautious, but try to ensure we only use the starting
|
|
# args in each api call, as the code does `args.<whatever> = <changed_value>`
|
|
# in lots of places and in testing, it seemed to me, these changes leaked
|
|
# into subsequent api calls.
|
|
|
|
# Roundtripping through pickle for deepcopy, there is probably a better way
|
|
frozen_args = Namespace(**(pickle.loads(pickle.dumps(vars(args)))))
|
|
|
|
# an attempt to map some of the A1111 sampler names to scheduler names
|
|
# https://github.com/huggingface/diffusers/issues/4167 is where the
|
|
# (not so obvious) ones come from
|
|
sampler_aliases = {
|
|
# a1111/onnx (these point to diffusers classes in A1111)
|
|
"pndm": "PNDM",
|
|
"heun": "HeunDiscrete",
|
|
"ddim": "DDIM",
|
|
"ddpm": "DDPM",
|
|
"euler": "EulerDiscrete",
|
|
"euler-ancestral": "EulerAncestralDiscrete",
|
|
"dpm": "DPMSolverMultistep",
|
|
# a1111/k_diffusion (the obvious ones)
|
|
"Euler a": "EulerAncestralDiscrete",
|
|
"Euler": "EulerDiscrete",
|
|
"LMS": "LMSDiscrete",
|
|
"Heun": "HeunDiscrete",
|
|
# a1111/k_diffusion (not so obvious)
|
|
"DPM++ 2M": "DPMSolverMultistep",
|
|
"DPM++ 2M Karras": "DPMSolverMultistepKarras",
|
|
"DPM++ 2M SDE": "DPMSolverMultistep++",
|
|
"DPM++ 2M SDE Karras": "DPMSolverMultistepKarras++",
|
|
"DPM2": "KDPM2Discrete",
|
|
"DPM2 a": "KDPM2AncestralDiscrete",
|
|
}
|
|
|
|
allowed_schedulers = {
|
|
"txt2img": {
|
|
"schedulers": scheduler_list,
|
|
"fallback": "SharkEulerDiscrete",
|
|
},
|
|
"txt2img_hires": {
|
|
"schedulers": scheduler_list_cpu_only,
|
|
"fallback": "DEISMultistep",
|
|
},
|
|
"img2img": {
|
|
"schedulers": scheduler_list_cpu_only,
|
|
"fallback": "EulerDiscrete",
|
|
},
|
|
"inpaint": {
|
|
"schedulers": scheduler_list_cpu_only,
|
|
"fallback": "DDIM",
|
|
},
|
|
"outpaint": {
|
|
"schedulers": scheduler_list_cpu_only,
|
|
"fallback": "DDIM",
|
|
},
|
|
"upscaler": {
|
|
"schedulers": scheduler_list_cpu_only,
|
|
"fallback": "DDIM",
|
|
},
|
|
}
|
|
|
|
# base pydantic model for sd generation apis
|
|
|
|
|
|
class GenerationInputData(BaseModel):
|
|
prompt: str = ""
|
|
negative_prompt: str = ""
|
|
hf_model_id: str | None = None
|
|
height: int = Field(
|
|
default=frozen_args.height, ge=128, le=768, multiple_of=8
|
|
)
|
|
width: int = Field(
|
|
default=frozen_args.width, ge=128, le=768, multiple_of=8
|
|
)
|
|
sampler_name: str = frozen_args.scheduler
|
|
cfg_scale: float = Field(default=frozen_args.guidance_scale, ge=1)
|
|
steps: int = Field(default=frozen_args.steps, ge=1, le=100)
|
|
seed: int = frozen_args.seed
|
|
n_iter: int = Field(default=frozen_args.batch_count)
|
|
|
|
|
|
class GenerationResponseData(BaseModel):
|
|
images: list[str] = Field(description="Generated images, Base64 encoded")
|
|
properties: dict = {}
|
|
info: str
|
|
|
|
|
|
# image encoding/decoding
|
|
|
|
|
|
def encode_pil_to_base64(images: list[Image.Image]):
|
|
encoded_imgs = []
|
|
for image in images:
|
|
with BytesIO() as output_bytes:
|
|
if frozen_args.output_img_format.lower() == "png":
|
|
image.save(output_bytes, format="PNG")
|
|
|
|
elif frozen_args.output_img_format.lower() in ("jpg", "jpeg"):
|
|
image.save(output_bytes, format="JPEG")
|
|
else:
|
|
raise HTTPException(
|
|
status_code=500, detail="Invalid image format"
|
|
)
|
|
bytes_data = output_bytes.getvalue()
|
|
encoded_imgs.append(base64.b64encode(bytes_data))
|
|
return encoded_imgs
|
|
|
|
|
|
def decode_base64_to_image(encoding: str):
|
|
if encoding.startswith("data:image/"):
|
|
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
|
|
try:
|
|
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
|
return image
|
|
except Exception as err:
|
|
print(err)
|
|
raise HTTPException(status_code=400, detail="Invalid encoded image")
|
|
|
|
|
|
# get valid sd models/vaes/schedulers etc.
|
|
|
|
|
|
def get_predefined_models(custom_checkpoint_type: str):
|
|
match custom_checkpoint_type:
|
|
case "inpainting":
|
|
return predefined_paint_models
|
|
case "upscaler":
|
|
return predefined_upscaler_models
|
|
case _:
|
|
return predefined_models
|
|
|
|
|
|
def get_model_from_request(
|
|
request_data=None,
|
|
checkpoint_type: str = "",
|
|
fallback_model: str = "",
|
|
):
|
|
model = None
|
|
if request_data:
|
|
if request_data.hf_model_id:
|
|
model = request_data.hf_model_id
|
|
elif request_data.override_settings:
|
|
model = request_data.override_settings.sd_model_checkpoint
|
|
|
|
# if the request didn't specify a model try the command line args
|
|
result = model or frozen_args.ckpt_loc or frozen_args.hf_model_id
|
|
|
|
# make sure whatever we have is a valid model for the checkpoint type
|
|
if result in get_custom_model_files(
|
|
custom_checkpoint_type=checkpoint_type
|
|
) + get_predefined_models(checkpoint_type):
|
|
return result
|
|
# if not return what was specified as the fallback
|
|
else:
|
|
return fallback_model
|
|
|
|
|
|
def get_scheduler_from_request(
|
|
request_data: GenerationInputData, operation: str
|
|
):
|
|
allowed = allowed_schedulers[operation]
|
|
|
|
requested = request_data.sampler_name
|
|
requested = sampler_aliases.get(requested, requested)
|
|
|
|
return (
|
|
requested
|
|
if requested in allowed["schedulers"]
|
|
else allowed["fallback"]
|
|
)
|
|
|
|
|
|
def get_lora_params(use_lora: str):
|
|
# TODO: since the inference functions in the webui, which we are
|
|
# still calling into for the api, jam these back together again before
|
|
# handing them off to the pipeline, we should remove this nonsense
|
|
# and unify their selection in the UI and command line args proper
|
|
if use_lora in get_custom_model_files("lora"):
|
|
return (use_lora, "")
|
|
|
|
return ("None", use_lora)
|
|
|
|
|
|
def get_device(device_str: str):
|
|
# first substring match in the list available devices, with first
|
|
# device when none are matched
|
|
return next(
|
|
(device for device in available_devices if device_str in device),
|
|
available_devices[0],
|
|
)
|