fix a mistake I made, and more formatting changes, and add ++/Karras (#1619)

* fixed missing line break in `stablelm_ui.py` `start_message`
- also more formatting changes

* fix variable spelling mistake

* revert some formatting cause black wants it different

* one less line, still less than 79

* add ++, karras, and karras++ types of dpmsolver.

* black line length 79

---------

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
This commit is contained in:
xzuyn
2023-07-05 12:00:16 -04:00
committed by GitHub
parent a1b1ce935c
commit 043e5a5c7a
16 changed files with 350 additions and 182 deletions

View File

@@ -13,6 +13,10 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -34,6 +38,10 @@ class Text2ImagePipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,

View File

@@ -17,6 +17,9 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -67,6 +70,11 @@ class UpscalerPipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
low_res_scheduler: Union[
DDIMScheduler,
@@ -78,6 +86,10 @@ class UpscalerPipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,

View File

@@ -15,6 +15,9 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
@@ -48,6 +51,10 @@ class StableDiffusionPipeline:
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -67,7 +74,8 @@ class StableDiffusionPipeline:
self.import_mlir = import_mlir
self.use_lora = use_lora
self.ondemand = ondemand
# TODO: Find a better workaround for fetching base_model_id early enough for CLIPTokenizer.
# TODO: Find a better workaround for fetching base_model_id early
# enough for CLIPTokenizer.
try:
self.tokenizer = get_tokenizer()
except:
@@ -82,7 +90,8 @@ class StableDiffusionPipeline:
if self.import_mlir or self.use_lora:
if not self.import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. Importing MLIR anyways."
"Warning: LoRA provided but import_mlir not specified. "
"Importing MLIR anyways."
)
self.text_encoder = self.sd_model.clip()
else:
@@ -310,6 +319,10 @@ class StableDiffusionPipeline:
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
import_mlir: bool,
model_id: str,
@@ -394,16 +407,21 @@ class StableDiffusionPipeline:
prompt (`str` or `list(int)`):
prompt to be encoded
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
The prompt or prompts not to guide the image generation.
Ignored when not using guidance
(i.e., ignored if `guidance_scale` is less than `1`).
model_max_length (int):
SHARK: pass the max length instead of relying on pipe.tokenizer.model_max_length
SHARK: pass the max length instead of relying on
pipe.tokenizer.model_max_length
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not,
SHARK: must be set to True as we always expect neg embeddings (defaulted to True)
SHARK: must be set to True as we always expect neg embeddings
(defaulted to True)
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
SHARK: max_embeddings_multiples>1 produce a tensor shape error (defaulted to 1)
The max multiple length of prompt embeddings compared to the
max output length of text encoder.
SHARK: max_embeddings_multiples>1 produce a tensor shape error
(defaulted to 1)
num_images_per_prompt (`int`):
number of images that should be generated per prompt
SHARK: num_images_per_prompt is not used (defaulted to 1)
@@ -422,9 +440,11 @@ class StableDiffusionPipeline:
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
f"`negative_prompt`: "
f"{negative_prompt} has batch size {len(negative_prompt)}, "
f"but `prompt`: {prompt} has batch size {batch_size}. "
f"Please make sure that passed `negative_prompt` matches "
"the batch size of `prompt`."
)
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
@@ -437,14 +457,36 @@ class StableDiffusionPipeline:
)
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = text_embeddings.shape
# text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
# text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# text_embeddings = text_embeddings.repeat(
# 1,
# num_images_per_prompt,
# 1
# )
# text_embeddings = (
# text_embeddings.view(
# bs_embed * num_images_per_prompt,
# seq_len,
# -1
# )
# )
if do_classifier_free_guidance:
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = uncond_embeddings.shape
# uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
# uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# uncond_embeddings = (
# uncond_embeddings.repeat(
# 1,
# num_images_per_prompt,
# 1
# )
# )
# uncond_embeddings = (
# uncond_embeddings.view(
# bs_embed * num_images_per_prompt,
# seq_len,
# -1
# )
# )
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
if text_embeddings.shape[1] > model_max_length:
@@ -486,7 +528,8 @@ re_attention = re.compile(
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Parses a string with attention tokens and returns a list of pairs:
text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12

View File

@@ -41,9 +41,28 @@ def get_schedulers(model_id):
)
schedulers[
"DPMSolverMultistep"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", algorithm_type="dpmsolver"
)
schedulers[
"DPMSolverMultistep++"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
)
schedulers[
"DPMSolverMultistepKarras"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
use_karras_sigmas=True,
)
schedulers[
"DPMSolverMultistepKarras++"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
algorithm_type="dpmsolver++",
use_karras_sigmas=True,
)
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_id,

View File

@@ -33,9 +33,10 @@ p.add_argument(
"--prompts",
nargs="+",
default=[
"a photo taken of the front of a super-car drifting on a road near mountains at "
"high speeds with smokes coming off the tires, front angle, front point of view, "
"trees in the mountains of the background, ((sharp focus))"
"a photo taken of the front of a super-car drifting on a road near "
"mountains at high speeds with smokes coming off the tires, front "
"angle, front point of view, trees in the mountains of the "
"background, ((sharp focus))"
],
help="Text of which images to be generated.",
)
@@ -44,8 +45,8 @@ p.add_argument(
"--negative_prompts",
nargs="+",
default=[
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, "
"ugly, blur, oversaturated, cropped"
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), "
"blurry, ugly, blur, oversaturated, cropped"
],
help="Text you don't want to see in the generated image.",
)
@@ -119,15 +120,16 @@ p.add_argument(
"--max_embeddings_multiples",
type=int,
default=5,
help="The max multiple length of prompt embeddings compared to the max output "
"length of text encoder.",
help="The max multiple length of prompt embeddings compared to the max "
"output length of text encoder.",
)
p.add_argument(
"--strength",
type=float,
default=0.8,
help="The strength of change applied on the given input image for img2img.",
help="The strength of change applied on the given input image for "
"img2img.",
)
##############################################################################
@@ -228,7 +230,8 @@ p.add_argument(
"--noise_q",
type=float,
default=1.0,
help="Fall-off exponent for outpainting (lower=higher detail) (min=0.0, max=4.0).",
help="Fall-off exponent for outpainting (lower=higher detail) "
"(min=0.0, max=4.0).",
)
p.add_argument(
@@ -254,16 +257,16 @@ p.add_argument(
"--import_mlir",
default=False,
action=argparse.BooleanOptionalAction,
help="Imports the model from torch module to shark_module otherwise downloads "
"the model from shark_tank.",
help="Imports the model from torch module to shark_module otherwise "
"downloads the model from shark_tank.",
)
p.add_argument(
"--load_vmfb",
default=True,
action=argparse.BooleanOptionalAction,
help="Attempts to load the model from a precompiled flat-buffer and compiles "
"+ saves it if not found.",
help="Attempts to load the model from a precompiled flat-buffer "
"and compiles + saves it if not found.",
)
p.add_argument(
@@ -291,16 +294,19 @@ p.add_argument(
"--scheduler",
type=str,
default="SharkEulerDiscrete",
help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, DPMSolverMultistep, "
"EulerDiscrete, EulerAncestralDiscrete, DEISMultistep, KDPM2AncestralDiscrete, "
"DPMSolverSinglestep, DDPM, HeunDiscrete]",
help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, "
"DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, "
"DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, "
"DEISMultistep, KDPM2AncestralDiscrete, DPMSolverSinglestep, DDPM, "
"HeunDiscrete].",
)
p.add_argument(
"--output_img_format",
type=str,
default="png",
help="Specify the format in which output image is save. Supported options: jpg / png.",
help="Specify the format in which output image is save. "
"Supported options: jpg / png.",
)
p.add_argument(
@@ -314,7 +320,8 @@ p.add_argument(
"--batch_count",
type=int,
default=1,
help="Number of batch to be generated with random seeds in single execution.",
help="Number of batch to be generated with random seeds in "
"single execution.",
)
p.add_argument(
@@ -328,7 +335,8 @@ p.add_argument(
"--custom_vae",
type=str,
default="",
help="HuggingFace repo-id or path to SD model's checkpoint whose Vae needs to be plugged in.",
help="HuggingFace repo-id or path to SD model's checkpoint whose VAE "
"needs to be plugged in.",
)
p.add_argument(
@@ -349,7 +357,8 @@ p.add_argument(
"--attention_slicing",
type=str,
default="none",
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', or an integer).",
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', "
"or an integer).",
)
p.add_argument(
@@ -362,16 +371,18 @@ p.add_argument(
"--use_lora",
type=str,
default="",
help="Use standalone LoRA weight using a HF ID or a checkpoint file (~3 MB).",
help="Use standalone LoRA weight using a HF ID or a checkpoint "
"file (~3 MB).",
)
p.add_argument(
"--use_quantize",
type=str,
default="none",
help="Runs the quantized version of stable diffusion model. This is currently "
"in experimental phase. Currently, only runs the stable-diffusion-2-1-base "
"model in int8 quantization.",
help="Runs the quantized version of stable diffusion model. "
"This is currently in experimental phase. "
"Currently, only runs the stable-diffusion-2-1-base model in "
"int8 quantization.",
)
p.add_argument(
@@ -409,7 +420,8 @@ p.add_argument(
p.add_argument(
"--vulkan_large_heap_block_size",
default="2073741824",
help="Flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G.",
help="Flag for setting VMA preferredLargeHeapBlockSize for "
"vulkan device, default is 4G.",
)
p.add_argument(
@@ -433,34 +445,38 @@ p.add_argument(
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the "
"default is ~/.local/shark_tank/.",
help="Specify where to save downloaded shark_tank artifacts. "
"If this is not set, the default is ~/.local/shark_tank/.",
)
p.add_argument(
"--dump_isa",
default=False,
action="store_true",
help="When enabled call amdllpc to get ISA dumps. use with dispatch benchmarks.",
help="When enabled call amdllpc to get ISA dumps. "
"Use with dispatch benchmarks.",
)
p.add_argument(
"--dispatch_benchmarks",
default=None,
help='Dispatches to return benchmark data on. Use "All" for all, and None for none.',
help="Dispatches to return benchmark data on. "
'Use "All" for all, and None for none.',
)
p.add_argument(
"--dispatch_benchmarks_dir",
default="temp_dispatch_benchmarks",
help='Directory where you want to store dispatch data generated with "--dispatch_benchmarks".',
help="Directory where you want to store dispatch data "
'generated with "--dispatch_benchmarks".',
)
p.add_argument(
"--enable_rgp",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for inserting debug frames between iterations for use with rgp.",
help="Flag for inserting debug frames between iterations "
"for use with rgp.",
)
p.add_argument(
@@ -474,38 +490,39 @@ p.add_argument(
"--warmup_count",
type=int,
default=0,
help="Flag setting warmup count for clip and vae [>= 0].",
help="Flag setting warmup count for CLIP and VAE [>= 0].",
)
p.add_argument(
"--clear_all",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag to clear all mlir and vmfb from common locations. Recompiling will take "
"several minutes.",
help="Flag to clear all mlir and vmfb from common locations. "
"Recompiling will take several minutes.",
)
p.add_argument(
"--save_metadata_to_json",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for whether or not to save a generation information json file with the image.",
help="Flag for whether or not to save a generation information "
"json file with the image.",
)
p.add_argument(
"--write_metadata_to_png",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for whether or not to save generation information in PNG chunk text to "
"generated images.",
help="Flag for whether or not to save generation information in "
"PNG chunk text to generated images.",
)
p.add_argument(
"--import_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="If import_mlir is True, saves mlir via the debug option in shark importer. Does "
"nothing if import_mlir is false (the default).",
help="If import_mlir is True, saves mlir via the debug option "
"in shark importer. Does nothing if import_mlir is false (the default).",
)
##############################################################################
# Web UI flags
@@ -515,14 +532,16 @@ p.add_argument(
"--progress_bar",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for removing the progress bar animation during image generation.",
help="Flag for removing the progress bar animation during "
"image generation.",
)
p.add_argument(
"--ckpt_dir",
type=str,
default="",
help="Path to directory where all .ckpts are stored in order to populate them in the web UI.",
help="Path to directory where all .ckpts are stored in order to populate "
"them in the web UI.",
)
# TODO: replace API flag when these can be run together
p.add_argument(
@@ -557,16 +576,16 @@ p.add_argument(
"--output_gallery",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for removing the output gallery tab, and avoid exposing images under "
"--output_dir in the UI.",
help="Flag for removing the output gallery tab, and avoid exposing "
"images under --output_dir in the UI.",
)
p.add_argument(
"--output_gallery_followlinks",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for whether the output gallery tab in the UI should follow symlinks when "
"listing subdirectories under --output_dir.",
help="Flag for whether the output gallery tab in the UI should "
"follow symlinks when listing subdirectories under --output_dir.",
)

View File

@@ -208,14 +208,15 @@ def get_device_mapping(driver, key_combination=3):
specific devices for execution
Args:
driver (str): execution driver (vulkan, cuda, rocm, etc)
key_combination (int, optional): choice for mapping value for device name.
key_combination (int, optional): choice for mapping value for
device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Returns:
dict: map to possible device names user can input mapped to desired combination
of name/path.
dict: map to possible device names user can input mapped to desired
combination of name/path.
"""
from shark.iree_utils._common import iree_device_map
@@ -242,10 +243,12 @@ def get_device_mapping(driver, key_combination=3):
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user selected execution device
"""Gives the appropriate device data (supported name/path) for user
selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for device name.
key_combination (int, optional): choice for mapping value for
device name.
1 : path
2 : name
3 : (name, path)
@@ -253,8 +256,8 @@ def map_device_to_name_path(device, key_combination=3):
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for the device
depending on key_combination value
str / tuple: returns the mapping str or tuple of mapping str for
the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
@@ -277,7 +280,8 @@ def set_init_device_flags():
if triple is not None:
args.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
f"Found device {device_name}. Using target triple "
f"{args.iree_vulkan_target_triple}."
)
elif "cuda" in args.device:
args.device = "cuda"
@@ -288,7 +292,8 @@ def set_init_device_flags():
if triple is not None:
args.iree_metal_target_platform = triple
print(
f"Found device {device_name}. Using target triple {args.iree_metal_target_platform}."
f"Found device {device_name}. Using target triple "
f"{args.iree_metal_target_platform}."
)
elif "cpu" in args.device:
args.device = "cpu"
@@ -386,7 +391,8 @@ def set_init_device_flags():
if args.use_tuned:
print(
f"Using tuned models for {base_model_id}(fp16) on device {args.device}."
f"Using tuned models for {base_model_id}(fp16) on "
f"device {args.device}."
)
else:
print("Tuned models are currently not supported for this setting.")
@@ -537,10 +543,10 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
from_safetensors = (
True if custom_weights.lower().endswith(".safetensors") else False
)
# EMA weights usually yield higher quality images for inference but non-EMA weights have
# been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if they want to go for EMA
# weight extraction or not.
# EMA weights usually yield higher quality images for inference but
# non-EMA weights have been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if
# they want to go for EMA weight extraction or not.
extract_ema = False
print(
"Loading diffusers' pipeline from original stable diffusion checkpoint"
@@ -562,8 +568,8 @@ def convert_original_vae(vae_checkpoint):
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
config_url = (
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/"
"stable-diffusion/v1-inference.yaml"
"https://raw.githubusercontent.com/CompVis/stable-diffusion/"
"main/configs/stable-diffusion/v1-inference.yaml"
)
original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file)
@@ -702,13 +708,15 @@ def fetch_and_update_base_model_id(model_to_run, base_model=""):
return base_model
elif base_model == "":
return base_model
# Update JSON data to contain an entry mapping model_to_run with base_model.
# Update JSON data to contain an entry mapping model_to_run with
# base_model.
json_data.update(data)
with open(variants_path, "w", encoding="utf-8") as jsonFile:
json.dump(json_data, jsonFile)
# Generate and return a new seed if the provided one is not in the supported range (including -1)
# Generate and return a new seed if the provided one is not in the
# supported range (including -1)
def sanitize_seed(seed):
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
@@ -727,7 +735,8 @@ def clear_all():
for vmfb in vmfbs:
if os.path.exists(vmfb):
os.remove(vmfb)
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
# Temporary workaround of deleting yaml files to incorporate
# diffusers' pipeline.
# TODO: Remove this once we have better weight updation logic.
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
for yaml in inference_yaml:
@@ -806,8 +815,9 @@ def save_output_img(output_img, img_seed, extra_info=None):
if args.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {args.output_img_format} is not supported yet."
"Image saved as png instead. Supported formats: png / jpg"
f"[ERROR] Format {args.output_img_format} is not "
f"supported yet. Image saved as png instead."
f"Supported formats: png / jpg"
)
# To be as low-impact as possible to the existing CSV format, we append

View File

@@ -133,7 +133,8 @@ def img2img_inf(
image, width, height = resize_stencil(image)
elif "Shark" in args.scheduler:
print(
f"Shark schedulers are not supported. Switching to EulerDiscrete scheduler"
f"Shark schedulers are not supported. Switching to EulerDiscrete "
f"scheduler"
)
args.scheduler = "EulerDiscrete"
cpu_scheduling = not args.scheduler.startswith("Shark")
@@ -393,7 +394,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
"e.g: SG161222/Realistic_Vision_V1.3, "
"https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID or Civitai model download URL",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
)
# janky fix for overflowing text
@@ -507,9 +509,9 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA weights dropdown "
"on the left if you want to use a standalone "
"HuggingFace model ID for LoRA here "
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
@@ -633,7 +635,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
elem_id="gallery",
).style(columns=[2], object_fit="contain")
std_output = gr.Textbox(
value=f"Images will be saved at {get_generated_imgs_path()}",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,

View File

@@ -346,7 +346,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
"https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID or Civitai model download URL",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
)
# janky fix for overflowing text
@@ -401,9 +402,9 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA weights dropdown "
"on the left if you want to use a standalone "
"HuggingFace model ID for LoRA here "
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
@@ -534,7 +535,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
elem_id="gallery",
).style(columns=[2], object_fit="contain")
std_output = gr.Textbox(
value=f"Images will be saved at {get_generated_imgs_path()}",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,

View File

@@ -51,8 +51,8 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
placeholder="Select 'None' in the Models "
"dropdown on the left and enter model ID here "
"e.g: SG161222/Realistic_Vision_V1.3",
value="",
label="HuggingFace Model ID",
@@ -74,8 +74,8 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA weights "
"dropdown on the left if you want to use a "
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use a "
"standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",

View File

@@ -19,7 +19,10 @@ def get_hf_list(num_of_models=20):
def get_civit_list(num_of_models=50):
path = f"https://civitai.com/api/v1/models?limit={num_of_models}&types=Checkpoint"
path = (
f"https://civitai.com/api/v1/models?limit="
f"{num_of_models}&types=Checkpoint"
)
headers = {"Content-Type": "application/json"}
raw_json = requests.get(path, headers=headers).json()
models = list(raw_json.items())[0][1]
@@ -79,7 +82,7 @@ with gr.Blocks() as model_web:
type="value",
label="Model Source",
)
model_numebr = gr.Slider(
model_number = gr.Slider(
1,
100,
value=10,
@@ -111,9 +114,9 @@ with gr.Blocks() as model_web:
modelmanager_sendto_outpaint = gr.Button(value="SendTo Outpaint")
modelmanager_sendto_upscaler = gr.Button(value="SendTo Upscaler")
def get_model_list(model_source, model_numebr):
def get_model_list(model_source, model_number):
if model_source == "Hugging Face":
hf_model_list = get_hf_list(model_numebr)
hf_model_list = get_hf_list(model_number)
models = []
for model in hf_model_list:
# TODO: add model info
@@ -124,7 +127,7 @@ with gr.Blocks() as model_web:
gr.Row.update(visible=True),
)
elif model_source == "Civitai":
civit_model_list = get_civit_list(model_numebr)
civit_model_list = get_civit_list(model_number)
models = []
for model in civit_model_list:
image = get_image_from_model(model)
@@ -148,7 +151,7 @@ with gr.Blocks() as model_web:
get_model_btn.click(
fn=get_model_list,
inputs=[model_source, model_numebr],
inputs=[model_source, model_number],
outputs=[
hf_models,
civit_models,

View File

@@ -265,7 +265,9 @@ def outpaint_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = outpaint_inf(
@@ -352,7 +354,8 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
"https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID or Civitai model download URL",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
)
# janky fix for overflowing text
@@ -404,9 +407,9 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA weights dropdown "
"on the left if you want to use a "
"standalone HuggingFace model ID for LoRA here "
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
@@ -559,7 +562,8 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
elem_id="gallery",
).style(columns=[2], object_fit="contain")
std_output = gr.Textbox(
value=f"Images will be saved at {get_generated_imgs_path()}",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,

View File

@@ -38,14 +38,14 @@ def output_subdirs() -> list[str]:
)
]
# It is less confusing to always including the subdir that will take any images generated
# today even if it doesn't exist yet
# It is less confusing to always including the subdir that will take any
# images generated today even if it doesn't exist yet
if get_generated_imgs_todays_subdir() not in relative_paths:
relative_paths.append(get_generated_imgs_todays_subdir())
# sort subdirectories so that that the date named ones we probably created in this or
# previous sessions come first, sorted with the most recent first. Other subdirs are listed
# after.
# sort subdirectories so that the date named ones we probably
# created in this or previous sessions come first, sorted with the most
# recent first. Other subdirs are listed after.
generated_paths = sorted(
[path for path in relative_paths if path.isnumeric()], reverse=True
)
@@ -66,7 +66,8 @@ with gr.Blocks() as outputgallery_web:
nod_logo = Image.open(nodlogo_loc)
with gr.Row(elem_id="outputgallery_gallery"):
# needed to workaround gradio issue: https://github.com/gradio-app/gradio/issues/2907
# needed to workaround gradio issue:
# https://github.com/gradio-app/gradio/issues/2907
dev_null = gr.Textbox("", visible=False)
gallery_files = gr.State(value=[])
@@ -194,14 +195,18 @@ with gr.Blocks() as outputgallery_web:
def on_refresh(current_subdir: str) -> list:
# get an up-to-date subdirectory list
refreshed_subdirs = output_subdirs()
# get the images using either the current subdirectory or the most recent valid one
# get the images using either the current subdirectory or the most
# recent valid one
new_subdir = (
current_subdir
if current_subdir in refreshed_subdirs
else refreshed_subdirs[0]
)
new_images = outputgallery_filenames(new_subdir)
new_label = f"{len(new_images)} images in {os.path.join(output_dir, new_subdir)}"
new_label = (
f"{len(new_images)} images in "
f"{os.path.join(output_dir, new_subdir)}"
)
return [
gr.Dropdown.update(
@@ -220,17 +225,22 @@ with gr.Blocks() as outputgallery_web:
]
def on_new_image(subdir, subdir_paths, status) -> list:
# prevent error triggered when an image generates before the tab has even been selected
# prevent error triggered when an image generates before the tab
# has even been selected
subdir_paths = (
subdir_paths
if len(subdir_paths) > 0
else [get_generated_imgs_todays_subdir()]
)
# only update if the current subdir is the most recent one as new images only go there
# only update if the current subdir is the most recent one as
# new images only go there
if subdir_paths[0] == subdir:
new_images = outputgallery_filenames(subdir)
new_label = f"{len(new_images)} images in {os.path.join(output_dir, subdir)} - {status}"
new_label = (
f"{len(new_images)} images in "
f"{os.path.join(output_dir, subdir)} - {status}"
)
return [
new_images,
@@ -245,11 +255,13 @@ with gr.Blocks() as outputgallery_web:
),
]
else:
# otherwise change nothing, (only untyped gradio gr.update() does this)
# otherwise change nothing,
# (only untyped gradio gr.update() does this)
return [gr.update(), gr.update(), gr.update()]
def on_select_image(images: list[str], evt: gr.SelectData) -> list:
# evt.index is an index into the full list of filenames for the current subdirectory
# evt.index is an index into the full list of filenames for
# the current subdirectory
filename = images[evt.index]
params = displayable_metadata(filename)
@@ -267,7 +279,8 @@ with gr.Blocks() as outputgallery_web:
def on_outputgallery_filename_change(filename: str) -> list:
exists = filename != "None" and os.path.exists(filename)
return [
# disable or enable each of the sendto button based on whether an image is selected
# disable or enable each of the sendto button based on whether
# an image is selected
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
@@ -276,14 +289,16 @@ with gr.Blocks() as outputgallery_web:
gr.Button.update(interactive=exists),
]
# The time first our tab is selected we need to do an initial refresh to populate
# the subdirectory select box and the images from the most recent subdirectory.
# The time first our tab is selected we need to do an initial refresh
# to populate the subdirectory select box and the images from the most
# recent subdirectory.
#
# We do it at this point rather than setting this up in the controls' definitions
# as when you refresh the browser you always get what was *initially* set, which
# won't include any new subdirectories or images that might have created since
# the application was started. Doing it this way means a browser refresh/reload
# always gets the most up to date data.
# We do it at this point rather than setting this up in the controls'
# definitions as when you refresh the browser you always get what was
# *initially* set, which won't include any new subdirectories or images
# that might have created since the application was started. Doing it
# this way means a browser refresh/reload always gets the most
# up-to-date data.
def on_select_tab(subdir_paths):
if len(subdir_paths) == 0:
return on_refresh("")
@@ -297,11 +312,11 @@ with gr.Blocks() as outputgallery_web:
gr.update(),
)
# Unfortunately as of gradio 3.22.0 gr.update against Galleries doesn't support
# things set with .style, nor the elem_classes kwarg so we have to directly set
# things up via JavaScript if we want the client to take notice of any of our
# changes to the number of columns after it decides to put them back to the
# original number when we change something
# Unfortunately as of gradio 3.22.0 gr.update against Galleries
# doesn't support things set with .style, nor the elem_classes kwarg, so
# we have to directly set things up via JavaScript if we want the client
# to take notice of our changes to the number of columns after it
# decides to put them back to the original number when we change something
def js_set_columns_in_browser(timeout_length):
return f"""
(new_cols) => {{
@@ -318,32 +333,36 @@ with gr.Blocks() as outputgallery_web:
# --- Wire handlers up to the actions
# - Many actions reset the number of columns shown in the gallery on the browser end,
# so we have to set them back to what we think they should be after the initial
# action.
# - None of the actions on this tab trigger inference, and we want the user to be able
# to do them whilst other tabs have ongoing inference running. Waiting in the queue
# behind inference jobs would mean the UI can't fully respond until the inference tasks
# complete, hence queue=False on all of these.
# Many actions reset the number of columns shown in the gallery on the
# browser end, so we have to set them back to what we think they should
# be after the initial action.
#
# None of the actions on this tab trigger inference, and we want the
# user to be able to do them whilst other tabs have ongoing inference
# running. Waiting in the queue behind inference jobs would mean the UI
# can't fully respond until the inference tasks complete,
# hence queue=False on all of these.
set_gallery_columns_immediate = dict(
fn=None,
inputs=[image_columns],
# gradio blanks the UI on Chrome on Linux on gallery select if I don't put an output here
# gradio blanks the UI on Chrome on Linux on gallery select if
# I don't put an output here
outputs=[dev_null],
_js=js_set_columns_in_browser(0),
queue=False,
)
# setting columns after selecting a gallery item needs a real timeout length for the
# number of columns to actually be applied. Not really sure why, maybe something has
# to finish animating?
# setting columns after selecting a gallery item needs a real
# timeout length for the number of columns to actually be applied.
# Not really sure why, maybe something has to finish animating?
set_gallery_columns_delayed = dict(
set_gallery_columns_immediate, _js=js_set_columns_in_browser(250)
)
# clearing images when we need to completely change what's in the gallery avoids current
# images being shown replacing piecemeal and prevents weirdness and errors if the user
# selects an image during the replacement phase.
# clearing images when we need to completely change what's in the
# gallery avoids current images being shown replacing piecemeal and
# prevents weirdness and errors if the user selects an image during the
# replacement phase.
clear_gallery = dict(
fn=on_clear_gallery,
inputs=None,

View File

@@ -8,13 +8,15 @@ from transformers import (
from apps.stable_diffusion.web.ui.utils import available_devices
start_message = (
"<|SYSTEM|># StableLM Tuned (Alpha version)- StableLM is a helpful and "
"harmless open-source AI language model developed by StabilityAI."
"\n- StableLM is excited to be able to help the user, but will refuse to do "
"anything that could be considered harmful to the user."
"\n- StableLM is more than just an information source, StableLM is also able to "
"write poetry, short stories, and make jokes."
"\n- StableLM will refuse to participate in anything that could harm a human."
"<|SYSTEM|># StableLM Tuned (Alpha version)"
"\n- StableLM is a helpful and harmless open-source AI language model "
"developed by StabilityAI."
"\n- StableLM is excited to be able to help the user, but will refuse "
"to do anything that could be considered harmful to the user."
"\n- StableLM is more than just an information source, StableLM is also "
"able to write poetry, short stories, and make jokes."
"\n- StableLM will refuse to participate in anything that "
"could harm a human."
)
@@ -91,7 +93,8 @@ def chat(curr_system_message, history, model, device, precision):
"StableLM"
) # pass elements from UI as required
# Construct the input message string for the model by concatenating the current system message and conversation history
# Construct the input message string for the model by concatenating the
# current system message and conversation history
if len(curr_system_message.split()) > 160:
print("clearing context")
curr_system_message = start_message
@@ -111,7 +114,8 @@ def chat(curr_system_message, history, model, device, precision):
# print(new_text)
partial_text += new_text
history[-1][1] = partial_text
# Yield an empty string to clean up the message textbox and the updated conversation history
# Yield an empty string to clean up the message textbox and the updated
# conversation history
yield history
return words_list

View File

@@ -87,8 +87,8 @@ def txt2img_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both "
"must not be empty",
"Please provide either custom model or huggingface model ID, "
"both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
@@ -316,10 +316,11 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
)
txt2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the dropdown on the left and "
"enter model ID here.",
placeholder="Select 'None' in the dropdown "
"on the left and enter model ID here.",
value="",
label="HuggingFace Model ID or Civitai model download URL.",
label="HuggingFace Model ID or Civitai model "
"download URL.",
lines=3,
)
# janky fix for overflowing text
@@ -375,10 +376,10 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA weights dropdown "
"on the left if you want to use a standalone "
"HuggingFace model ID for LoRA here e.g: "
"sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
@@ -509,7 +510,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
elem_id="gallery",
).style(columns=[2], object_fit="contain")
std_output = gr.Textbox(
value=f"Images will be saved at {get_generated_imgs_path()}",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,

View File

@@ -88,8 +88,8 @@ def upscaler_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be "
"empty.",
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
@@ -227,10 +227,22 @@ def upscaler_inf(
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={steps}, noise_level={noise_level}, guidance_scale={guidance_scale}, seed={seeds}"
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
text_output += (
f"\nmodel_id={args.hf_model_id}, " f"ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, " f"device={device}"
text_output += (
f"\nsteps={steps}, "
f"noise_level={noise_level}, "
f"guidance_scale={guidance_scale}, "
f"seed={seeds}"
)
text_output += (
f"\nsize={height}x{width}, "
f"batch_count={batch_count}, "
f"batch_size={batch_size}, "
f"max_length={args.max_length}"
)
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
@@ -271,7 +283,9 @@ def upscaler_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = upscaler_inf(
@@ -353,7 +367,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
"e.g: SG161222/Realistic_Vision_V1.3, "
"https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID or Civitai model download URL",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
)
# janky fix for overflowing text
@@ -397,7 +412,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
).replace("\\", "\n\\")
upscaler_lora_info = f"LoRA Path: {upscaler_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standalone LoRA weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA Weights",
info=upscaler_lora_info,
elem_id="lora_weights",
value="None",
@@ -405,9 +420,9 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA weights dropdown "
"on the left if you want to use a standalone "
"HuggingFace model ID for LoRA here "
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
@@ -539,7 +554,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
elem_id="gallery",
).style(columns=[2], object_fit="contain")
std_output = gr.Textbox(
value=f"Images will be saved at {get_generated_imgs_path()}",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,

View File

@@ -39,6 +39,9 @@ scheduler_list_cpu_only = [
"LMSDiscrete",
"KDPM2Discrete",
"DPMSolverMultistep",
"DPMSolverMultistep++",
"DPMSolverMultistepKarras",
"DPMSolverMultistepKarras++",
"EulerDiscrete",
"EulerAncestralDiscrete",
"DEISMultistep",
@@ -86,7 +89,8 @@ def create_custom_models_folders():
else:
if not os.path.isdir(args.ckpt_dir):
sys.exit(
f"Invalid --ckpt_dir argument, {args.ckpt_dir} folder does not exists."
f"Invalid --ckpt_dir argument, "
f"{args.ckpt_dir} folder does not exists."
)
for root in dir:
get_custom_model_path(root).mkdir(parents=True, exist_ok=True)