mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
* Merges LoRA selection in the UI into a single selection, rather than one for LoRAs under ./models and another for Hugging Face Id * Add LoRA strength to UI and pipeline parameters. * Add a `--lora_strength` command line argument. * Bake LoRA strength into .vmfb naming when a LoRA is specified. * Use LoRA embedded alpha values and (up tensor dimension * LoRA strength) for final alpha when applying LoRA weights rather than a hardcoded value of 0.75 * Adds additional cases to the LoRA weight application that are present for weight application in the Kohya scripts. * Include lora strength when reading and writing png metadata. * Allow lora_strength to be set above 1.0 in the UI, so similar effects to the prior (overdriven alpha) implementation can be obtained.
201 lines
5.9 KiB
Python
201 lines
5.9 KiB
Python
import base64
|
|
import pickle
|
|
|
|
from argparse import Namespace
|
|
from fastapi.exceptions import HTTPException
|
|
from io import BytesIO
|
|
from PIL import Image
|
|
from pydantic import BaseModel, Field
|
|
|
|
from apps.stable_diffusion.src import args
|
|
from apps.stable_diffusion.web.ui.utils import (
|
|
available_devices,
|
|
get_custom_model_files,
|
|
predefined_models,
|
|
predefined_paint_models,
|
|
predefined_upscaler_models,
|
|
scheduler_list,
|
|
scheduler_list_cpu_only,
|
|
)
|
|
|
|
|
|
# Probably overly cautious, but try to ensure we only use the starting
|
|
# args in each api call, as the code does `args.<whatever> = <changed_value>`
|
|
# in lots of places and in testing, it seemed to me, these changes leaked
|
|
# into subsequent api calls.
|
|
|
|
# Roundtripping through pickle for deepcopy, there is probably a better way
|
|
frozen_args = Namespace(**(pickle.loads(pickle.dumps(vars(args)))))
|
|
|
|
# an attempt to map some of the A1111 sampler names to scheduler names
|
|
# https://github.com/huggingface/diffusers/issues/4167 is where the
|
|
# (not so obvious) ones come from
|
|
sampler_aliases = {
|
|
# a1111/onnx (these point to diffusers classes in A1111)
|
|
"pndm": "PNDM",
|
|
"heun": "HeunDiscrete",
|
|
"ddim": "DDIM",
|
|
"ddpm": "DDPM",
|
|
"euler": "EulerDiscrete",
|
|
"euler-ancestral": "EulerAncestralDiscrete",
|
|
"dpm": "DPMSolverMultistep",
|
|
# a1111/k_diffusion (the obvious ones)
|
|
"Euler a": "EulerAncestralDiscrete",
|
|
"Euler": "EulerDiscrete",
|
|
"LMS": "LMSDiscrete",
|
|
"Heun": "HeunDiscrete",
|
|
# a1111/k_diffusion (not so obvious)
|
|
"DPM++ 2M": "DPMSolverMultistep",
|
|
"DPM++ 2M Karras": "DPMSolverMultistepKarras",
|
|
"DPM++ 2M SDE": "DPMSolverMultistep++",
|
|
"DPM++ 2M SDE Karras": "DPMSolverMultistepKarras++",
|
|
"DPM2": "KDPM2Discrete",
|
|
"DPM2 a": "KDPM2AncestralDiscrete",
|
|
}
|
|
|
|
allowed_schedulers = {
|
|
"txt2img": {
|
|
"schedulers": scheduler_list,
|
|
"fallback": "SharkEulerDiscrete",
|
|
},
|
|
"txt2img_hires": {
|
|
"schedulers": scheduler_list_cpu_only,
|
|
"fallback": "DEISMultistep",
|
|
},
|
|
"img2img": {
|
|
"schedulers": scheduler_list_cpu_only,
|
|
"fallback": "EulerDiscrete",
|
|
},
|
|
"inpaint": {
|
|
"schedulers": scheduler_list_cpu_only,
|
|
"fallback": "DDIM",
|
|
},
|
|
"outpaint": {
|
|
"schedulers": scheduler_list_cpu_only,
|
|
"fallback": "DDIM",
|
|
},
|
|
"upscaler": {
|
|
"schedulers": scheduler_list_cpu_only,
|
|
"fallback": "DDIM",
|
|
},
|
|
}
|
|
|
|
# base pydantic model for sd generation apis
|
|
|
|
|
|
class GenerationInputData(BaseModel):
|
|
prompt: str = ""
|
|
negative_prompt: str = ""
|
|
hf_model_id: str | None = None
|
|
height: int = Field(
|
|
default=frozen_args.height, ge=128, le=768, multiple_of=8
|
|
)
|
|
width: int = Field(
|
|
default=frozen_args.width, ge=128, le=768, multiple_of=8
|
|
)
|
|
sampler_name: str = frozen_args.scheduler
|
|
cfg_scale: float = Field(default=frozen_args.guidance_scale, ge=1)
|
|
steps: int = Field(default=frozen_args.steps, ge=1, le=100)
|
|
seed: int = frozen_args.seed
|
|
n_iter: int = Field(default=frozen_args.batch_count)
|
|
|
|
|
|
class GenerationResponseData(BaseModel):
|
|
images: list[str] = Field(description="Generated images, Base64 encoded")
|
|
properties: dict = {}
|
|
info: str
|
|
|
|
|
|
# image encoding/decoding
|
|
|
|
|
|
def encode_pil_to_base64(images: list[Image.Image]):
|
|
encoded_imgs = []
|
|
for image in images:
|
|
with BytesIO() as output_bytes:
|
|
if frozen_args.output_img_format.lower() == "png":
|
|
image.save(output_bytes, format="PNG")
|
|
|
|
elif frozen_args.output_img_format.lower() in ("jpg", "jpeg"):
|
|
image.save(output_bytes, format="JPEG")
|
|
else:
|
|
raise HTTPException(
|
|
status_code=500, detail="Invalid image format"
|
|
)
|
|
bytes_data = output_bytes.getvalue()
|
|
encoded_imgs.append(base64.b64encode(bytes_data))
|
|
return encoded_imgs
|
|
|
|
|
|
def decode_base64_to_image(encoding: str):
|
|
if encoding.startswith("data:image/"):
|
|
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
|
|
try:
|
|
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
|
return image
|
|
except Exception as err:
|
|
print(err)
|
|
raise HTTPException(status_code=400, detail="Invalid encoded image")
|
|
|
|
|
|
# get valid sd models/vaes/schedulers etc.
|
|
|
|
|
|
def get_predefined_models(custom_checkpoint_type: str):
|
|
match custom_checkpoint_type:
|
|
case "inpainting":
|
|
return predefined_paint_models
|
|
case "upscaler":
|
|
return predefined_upscaler_models
|
|
case _:
|
|
return predefined_models
|
|
|
|
|
|
def get_model_from_request(
|
|
request_data=None,
|
|
checkpoint_type: str = "",
|
|
fallback_model: str = "",
|
|
):
|
|
model = None
|
|
if request_data:
|
|
if request_data.hf_model_id:
|
|
model = request_data.hf_model_id
|
|
elif request_data.override_settings:
|
|
model = request_data.override_settings.sd_model_checkpoint
|
|
|
|
# if the request didn't specify a model try the command line args
|
|
result = model or frozen_args.ckpt_loc or frozen_args.hf_model_id
|
|
|
|
# make sure whatever we have is a valid model for the checkpoint type
|
|
if result in get_custom_model_files(
|
|
custom_checkpoint_type=checkpoint_type
|
|
) + get_predefined_models(checkpoint_type):
|
|
return result
|
|
# if not return what was specified as the fallback
|
|
else:
|
|
return fallback_model
|
|
|
|
|
|
def get_scheduler_from_request(
|
|
request_data: GenerationInputData, operation: str
|
|
):
|
|
allowed = allowed_schedulers[operation]
|
|
|
|
requested = request_data.sampler_name
|
|
requested = sampler_aliases.get(requested, requested)
|
|
|
|
return (
|
|
requested
|
|
if requested in allowed["schedulers"]
|
|
else allowed["fallback"]
|
|
)
|
|
|
|
|
|
def get_device(device_str: str):
|
|
# first substring match in the list available devices, with first
|
|
# device when none are matched
|
|
return next(
|
|
(device for device in available_devices if device_str in device),
|
|
available_devices[0],
|
|
)
|