mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
fix a mistake I made, and more formatting changes, and add ++/Karras (#1619)
* fixed missing line break in `stablelm_ui.py` `start_message` - also more formatting changes * fix variable spelling mistake * revert some formatting cause black wants it different * one less line, still less than 79 * add ++, karras, and karras++ types of dpmsolver. * black line length 79 --------- Co-authored-by: powderluv <powderluv@users.noreply.github.com>
This commit is contained in:
@@ -13,6 +13,10 @@ from diffusers import (
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
@@ -34,6 +38,10 @@ class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
|
||||
@@ -17,6 +17,9 @@ from diffusers import (
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
@@ -67,6 +70,11 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
],
|
||||
low_res_scheduler: Union[
|
||||
DDIMScheduler,
|
||||
@@ -78,6 +86,10 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
|
||||
@@ -15,6 +15,9 @@ from diffusers import (
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from shark.shark_inference import SharkInference
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
@@ -48,6 +51,10 @@ class StableDiffusionPipeline:
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
@@ -67,7 +74,8 @@ class StableDiffusionPipeline:
|
||||
self.import_mlir = import_mlir
|
||||
self.use_lora = use_lora
|
||||
self.ondemand = ondemand
|
||||
# TODO: Find a better workaround for fetching base_model_id early enough for CLIPTokenizer.
|
||||
# TODO: Find a better workaround for fetching base_model_id early
|
||||
# enough for CLIPTokenizer.
|
||||
try:
|
||||
self.tokenizer = get_tokenizer()
|
||||
except:
|
||||
@@ -82,7 +90,8 @@ class StableDiffusionPipeline:
|
||||
if self.import_mlir or self.use_lora:
|
||||
if not self.import_mlir:
|
||||
print(
|
||||
"Warning: LoRA provided but import_mlir not specified. Importing MLIR anyways."
|
||||
"Warning: LoRA provided but import_mlir not specified. "
|
||||
"Importing MLIR anyways."
|
||||
)
|
||||
self.text_encoder = self.sd_model.clip()
|
||||
else:
|
||||
@@ -310,6 +319,10 @@ class StableDiffusionPipeline:
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
],
|
||||
import_mlir: bool,
|
||||
model_id: str,
|
||||
@@ -394,16 +407,21 @@ class StableDiffusionPipeline:
|
||||
prompt (`str` or `list(int)`):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
The prompt or prompts not to guide the image generation.
|
||||
Ignored when not using guidance
|
||||
(i.e., ignored if `guidance_scale` is less than `1`).
|
||||
model_max_length (int):
|
||||
SHARK: pass the max length instead of relying on pipe.tokenizer.model_max_length
|
||||
SHARK: pass the max length instead of relying on
|
||||
pipe.tokenizer.model_max_length
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not,
|
||||
SHARK: must be set to True as we always expect neg embeddings (defaulted to True)
|
||||
SHARK: must be set to True as we always expect neg embeddings
|
||||
(defaulted to True)
|
||||
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
||||
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
||||
SHARK: max_embeddings_multiples>1 produce a tensor shape error (defaulted to 1)
|
||||
The max multiple length of prompt embeddings compared to the
|
||||
max output length of text encoder.
|
||||
SHARK: max_embeddings_multiples>1 produce a tensor shape error
|
||||
(defaulted to 1)
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
SHARK: num_images_per_prompt is not used (defaulted to 1)
|
||||
@@ -422,9 +440,11 @@ class StableDiffusionPipeline:
|
||||
negative_prompt = [negative_prompt] * batch_size
|
||||
if batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
f"`negative_prompt`: "
|
||||
f"{negative_prompt} has batch size {len(negative_prompt)}, "
|
||||
f"but `prompt`: {prompt} has batch size {batch_size}. "
|
||||
f"Please make sure that passed `negative_prompt` matches "
|
||||
"the batch size of `prompt`."
|
||||
)
|
||||
|
||||
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
|
||||
@@ -437,14 +457,36 @@ class StableDiffusionPipeline:
|
||||
)
|
||||
# SHARK: we are not using num_images_per_prompt
|
||||
# bs_embed, seq_len, _ = text_embeddings.shape
|
||||
# text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
# text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
# text_embeddings = text_embeddings.repeat(
|
||||
# 1,
|
||||
# num_images_per_prompt,
|
||||
# 1
|
||||
# )
|
||||
# text_embeddings = (
|
||||
# text_embeddings.view(
|
||||
# bs_embed * num_images_per_prompt,
|
||||
# seq_len,
|
||||
# -1
|
||||
# )
|
||||
# )
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# SHARK: we are not using num_images_per_prompt
|
||||
# bs_embed, seq_len, _ = uncond_embeddings.shape
|
||||
# uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
# uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
# uncond_embeddings = (
|
||||
# uncond_embeddings.repeat(
|
||||
# 1,
|
||||
# num_images_per_prompt,
|
||||
# 1
|
||||
# )
|
||||
# )
|
||||
# uncond_embeddings = (
|
||||
# uncond_embeddings.view(
|
||||
# bs_embed * num_images_per_prompt,
|
||||
# seq_len,
|
||||
# -1
|
||||
# )
|
||||
# )
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
if text_embeddings.shape[1] > model_max_length:
|
||||
@@ -486,7 +528,8 @@ re_attention = re.compile(
|
||||
|
||||
def parse_prompt_attention(text):
|
||||
"""
|
||||
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
||||
Parses a string with attention tokens and returns a list of pairs:
|
||||
text and its associated weight.
|
||||
Accepted tokens are:
|
||||
(abc) - increases attention to abc by a multiplier of 1.1
|
||||
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
||||
|
||||
@@ -41,9 +41,28 @@ def get_schedulers(model_id):
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverMultistep"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", algorithm_type="dpmsolver"
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverMultistep++"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverMultistepKarras"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
use_karras_sigmas=True,
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverMultistepKarras++"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
algorithm_type="dpmsolver++",
|
||||
use_karras_sigmas=True,
|
||||
)
|
||||
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
|
||||
@@ -33,9 +33,10 @@ p.add_argument(
|
||||
"--prompts",
|
||||
nargs="+",
|
||||
default=[
|
||||
"a photo taken of the front of a super-car drifting on a road near mountains at "
|
||||
"high speeds with smokes coming off the tires, front angle, front point of view, "
|
||||
"trees in the mountains of the background, ((sharp focus))"
|
||||
"a photo taken of the front of a super-car drifting on a road near "
|
||||
"mountains at high speeds with smokes coming off the tires, front "
|
||||
"angle, front point of view, trees in the mountains of the "
|
||||
"background, ((sharp focus))"
|
||||
],
|
||||
help="Text of which images to be generated.",
|
||||
)
|
||||
@@ -44,8 +45,8 @@ p.add_argument(
|
||||
"--negative_prompts",
|
||||
nargs="+",
|
||||
default=[
|
||||
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, "
|
||||
"ugly, blur, oversaturated, cropped"
|
||||
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), "
|
||||
"blurry, ugly, blur, oversaturated, cropped"
|
||||
],
|
||||
help="Text you don't want to see in the generated image.",
|
||||
)
|
||||
@@ -119,15 +120,16 @@ p.add_argument(
|
||||
"--max_embeddings_multiples",
|
||||
type=int,
|
||||
default=5,
|
||||
help="The max multiple length of prompt embeddings compared to the max output "
|
||||
"length of text encoder.",
|
||||
help="The max multiple length of prompt embeddings compared to the max "
|
||||
"output length of text encoder.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--strength",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="The strength of change applied on the given input image for img2img.",
|
||||
help="The strength of change applied on the given input image for "
|
||||
"img2img.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
@@ -228,7 +230,8 @@ p.add_argument(
|
||||
"--noise_q",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Fall-off exponent for outpainting (lower=higher detail) (min=0.0, max=4.0).",
|
||||
help="Fall-off exponent for outpainting (lower=higher detail) "
|
||||
"(min=0.0, max=4.0).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -254,16 +257,16 @@ p.add_argument(
|
||||
"--import_mlir",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Imports the model from torch module to shark_module otherwise downloads "
|
||||
"the model from shark_tank.",
|
||||
help="Imports the model from torch module to shark_module otherwise "
|
||||
"downloads the model from shark_tank.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--load_vmfb",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Attempts to load the model from a precompiled flat-buffer and compiles "
|
||||
"+ saves it if not found.",
|
||||
help="Attempts to load the model from a precompiled flat-buffer "
|
||||
"and compiles + saves it if not found.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -291,16 +294,19 @@ p.add_argument(
|
||||
"--scheduler",
|
||||
type=str,
|
||||
default="SharkEulerDiscrete",
|
||||
help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, DPMSolverMultistep, "
|
||||
"EulerDiscrete, EulerAncestralDiscrete, DEISMultistep, KDPM2AncestralDiscrete, "
|
||||
"DPMSolverSinglestep, DDPM, HeunDiscrete]",
|
||||
help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, "
|
||||
"DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, "
|
||||
"DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, "
|
||||
"DEISMultistep, KDPM2AncestralDiscrete, DPMSolverSinglestep, DDPM, "
|
||||
"HeunDiscrete].",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_img_format",
|
||||
type=str,
|
||||
default="png",
|
||||
help="Specify the format in which output image is save. Supported options: jpg / png.",
|
||||
help="Specify the format in which output image is save. "
|
||||
"Supported options: jpg / png.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -314,7 +320,8 @@ p.add_argument(
|
||||
"--batch_count",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of batch to be generated with random seeds in single execution.",
|
||||
help="Number of batch to be generated with random seeds in "
|
||||
"single execution.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -328,7 +335,8 @@ p.add_argument(
|
||||
"--custom_vae",
|
||||
type=str,
|
||||
default="",
|
||||
help="HuggingFace repo-id or path to SD model's checkpoint whose Vae needs to be plugged in.",
|
||||
help="HuggingFace repo-id or path to SD model's checkpoint whose VAE "
|
||||
"needs to be plugged in.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -349,7 +357,8 @@ p.add_argument(
|
||||
"--attention_slicing",
|
||||
type=str,
|
||||
default="none",
|
||||
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', or an integer).",
|
||||
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', "
|
||||
"or an integer).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -362,16 +371,18 @@ p.add_argument(
|
||||
"--use_lora",
|
||||
type=str,
|
||||
default="",
|
||||
help="Use standalone LoRA weight using a HF ID or a checkpoint file (~3 MB).",
|
||||
help="Use standalone LoRA weight using a HF ID or a checkpoint "
|
||||
"file (~3 MB).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_quantize",
|
||||
type=str,
|
||||
default="none",
|
||||
help="Runs the quantized version of stable diffusion model. This is currently "
|
||||
"in experimental phase. Currently, only runs the stable-diffusion-2-1-base "
|
||||
"model in int8 quantization.",
|
||||
help="Runs the quantized version of stable diffusion model. "
|
||||
"This is currently in experimental phase. "
|
||||
"Currently, only runs the stable-diffusion-2-1-base model in "
|
||||
"int8 quantization.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -409,7 +420,8 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--vulkan_large_heap_block_size",
|
||||
default="2073741824",
|
||||
help="Flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G.",
|
||||
help="Flag for setting VMA preferredLargeHeapBlockSize for "
|
||||
"vulkan device, default is 4G.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -433,34 +445,38 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--local_tank_cache",
|
||||
default="",
|
||||
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the "
|
||||
"default is ~/.local/shark_tank/.",
|
||||
help="Specify where to save downloaded shark_tank artifacts. "
|
||||
"If this is not set, the default is ~/.local/shark_tank/.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dump_isa",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="When enabled call amdllpc to get ISA dumps. use with dispatch benchmarks.",
|
||||
help="When enabled call amdllpc to get ISA dumps. "
|
||||
"Use with dispatch benchmarks.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dispatch_benchmarks",
|
||||
default=None,
|
||||
help='Dispatches to return benchmark data on. Use "All" for all, and None for none.',
|
||||
help="Dispatches to return benchmark data on. "
|
||||
'Use "All" for all, and None for none.',
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dispatch_benchmarks_dir",
|
||||
default="temp_dispatch_benchmarks",
|
||||
help='Directory where you want to store dispatch data generated with "--dispatch_benchmarks".',
|
||||
help="Directory where you want to store dispatch data "
|
||||
'generated with "--dispatch_benchmarks".',
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--enable_rgp",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for inserting debug frames between iterations for use with rgp.",
|
||||
help="Flag for inserting debug frames between iterations "
|
||||
"for use with rgp.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -474,38 +490,39 @@ p.add_argument(
|
||||
"--warmup_count",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Flag setting warmup count for clip and vae [>= 0].",
|
||||
help="Flag setting warmup count for CLIP and VAE [>= 0].",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--clear_all",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag to clear all mlir and vmfb from common locations. Recompiling will take "
|
||||
"several minutes.",
|
||||
help="Flag to clear all mlir and vmfb from common locations. "
|
||||
"Recompiling will take several minutes.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_metadata_to_json",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for whether or not to save a generation information json file with the image.",
|
||||
help="Flag for whether or not to save a generation information "
|
||||
"json file with the image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--write_metadata_to_png",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for whether or not to save generation information in PNG chunk text to "
|
||||
"generated images.",
|
||||
help="Flag for whether or not to save generation information in "
|
||||
"PNG chunk text to generated images.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--import_debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If import_mlir is True, saves mlir via the debug option in shark importer. Does "
|
||||
"nothing if import_mlir is false (the default).",
|
||||
help="If import_mlir is True, saves mlir via the debug option "
|
||||
"in shark importer. Does nothing if import_mlir is false (the default).",
|
||||
)
|
||||
##############################################################################
|
||||
# Web UI flags
|
||||
@@ -515,14 +532,16 @@ p.add_argument(
|
||||
"--progress_bar",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for removing the progress bar animation during image generation.",
|
||||
help="Flag for removing the progress bar animation during "
|
||||
"image generation.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--ckpt_dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to directory where all .ckpts are stored in order to populate them in the web UI.",
|
||||
help="Path to directory where all .ckpts are stored in order to populate "
|
||||
"them in the web UI.",
|
||||
)
|
||||
# TODO: replace API flag when these can be run together
|
||||
p.add_argument(
|
||||
@@ -557,16 +576,16 @@ p.add_argument(
|
||||
"--output_gallery",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for removing the output gallery tab, and avoid exposing images under "
|
||||
"--output_dir in the UI.",
|
||||
help="Flag for removing the output gallery tab, and avoid exposing "
|
||||
"images under --output_dir in the UI.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_gallery_followlinks",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for whether the output gallery tab in the UI should follow symlinks when "
|
||||
"listing subdirectories under --output_dir.",
|
||||
help="Flag for whether the output gallery tab in the UI should "
|
||||
"follow symlinks when listing subdirectories under --output_dir.",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -208,14 +208,15 @@ def get_device_mapping(driver, key_combination=3):
|
||||
specific devices for execution
|
||||
Args:
|
||||
driver (str): execution driver (vulkan, cuda, rocm, etc)
|
||||
key_combination (int, optional): choice for mapping value for device name.
|
||||
key_combination (int, optional): choice for mapping value for
|
||||
device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Returns:
|
||||
dict: map to possible device names user can input mapped to desired combination
|
||||
of name/path.
|
||||
dict: map to possible device names user can input mapped to desired
|
||||
combination of name/path.
|
||||
"""
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
|
||||
@@ -242,10 +243,12 @@ def get_device_mapping(driver, key_combination=3):
|
||||
|
||||
|
||||
def map_device_to_name_path(device, key_combination=3):
|
||||
"""Gives the appropriate device data (supported name/path) for user selected execution device
|
||||
"""Gives the appropriate device data (supported name/path) for user
|
||||
selected execution device
|
||||
Args:
|
||||
device (str): user
|
||||
key_combination (int, optional): choice for mapping value for device name.
|
||||
key_combination (int, optional): choice for mapping value for
|
||||
device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
@@ -253,8 +256,8 @@ def map_device_to_name_path(device, key_combination=3):
|
||||
Raises:
|
||||
ValueError:
|
||||
Returns:
|
||||
str / tuple: returns the mapping str or tuple of mapping str for the device
|
||||
depending on key_combination value
|
||||
str / tuple: returns the mapping str or tuple of mapping str for
|
||||
the device depending on key_combination value
|
||||
"""
|
||||
driver = device.split("://")[0]
|
||||
device_map = get_device_mapping(driver, key_combination)
|
||||
@@ -277,7 +280,8 @@ def set_init_device_flags():
|
||||
if triple is not None:
|
||||
args.iree_vulkan_target_triple = triple
|
||||
print(
|
||||
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
|
||||
f"Found device {device_name}. Using target triple "
|
||||
f"{args.iree_vulkan_target_triple}."
|
||||
)
|
||||
elif "cuda" in args.device:
|
||||
args.device = "cuda"
|
||||
@@ -288,7 +292,8 @@ def set_init_device_flags():
|
||||
if triple is not None:
|
||||
args.iree_metal_target_platform = triple
|
||||
print(
|
||||
f"Found device {device_name}. Using target triple {args.iree_metal_target_platform}."
|
||||
f"Found device {device_name}. Using target triple "
|
||||
f"{args.iree_metal_target_platform}."
|
||||
)
|
||||
elif "cpu" in args.device:
|
||||
args.device = "cpu"
|
||||
@@ -386,7 +391,8 @@ def set_init_device_flags():
|
||||
|
||||
if args.use_tuned:
|
||||
print(
|
||||
f"Using tuned models for {base_model_id}(fp16) on device {args.device}."
|
||||
f"Using tuned models for {base_model_id}(fp16) on "
|
||||
f"device {args.device}."
|
||||
)
|
||||
else:
|
||||
print("Tuned models are currently not supported for this setting.")
|
||||
@@ -537,10 +543,10 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
|
||||
from_safetensors = (
|
||||
True if custom_weights.lower().endswith(".safetensors") else False
|
||||
)
|
||||
# EMA weights usually yield higher quality images for inference but non-EMA weights have
|
||||
# been yielding better results in our case.
|
||||
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if they want to go for EMA
|
||||
# weight extraction or not.
|
||||
# EMA weights usually yield higher quality images for inference but
|
||||
# non-EMA weights have been yielding better results in our case.
|
||||
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if
|
||||
# they want to go for EMA weight extraction or not.
|
||||
extract_ema = False
|
||||
print(
|
||||
"Loading diffusers' pipeline from original stable diffusion checkpoint"
|
||||
@@ -562,8 +568,8 @@ def convert_original_vae(vae_checkpoint):
|
||||
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
|
||||
|
||||
config_url = (
|
||||
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/"
|
||||
"stable-diffusion/v1-inference.yaml"
|
||||
"https://raw.githubusercontent.com/CompVis/stable-diffusion/"
|
||||
"main/configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
@@ -702,13 +708,15 @@ def fetch_and_update_base_model_id(model_to_run, base_model=""):
|
||||
return base_model
|
||||
elif base_model == "":
|
||||
return base_model
|
||||
# Update JSON data to contain an entry mapping model_to_run with base_model.
|
||||
# Update JSON data to contain an entry mapping model_to_run with
|
||||
# base_model.
|
||||
json_data.update(data)
|
||||
with open(variants_path, "w", encoding="utf-8") as jsonFile:
|
||||
json.dump(json_data, jsonFile)
|
||||
|
||||
|
||||
# Generate and return a new seed if the provided one is not in the supported range (including -1)
|
||||
# Generate and return a new seed if the provided one is not in the
|
||||
# supported range (including -1)
|
||||
def sanitize_seed(seed):
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
@@ -727,7 +735,8 @@ def clear_all():
|
||||
for vmfb in vmfbs:
|
||||
if os.path.exists(vmfb):
|
||||
os.remove(vmfb)
|
||||
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
|
||||
# Temporary workaround of deleting yaml files to incorporate
|
||||
# diffusers' pipeline.
|
||||
# TODO: Remove this once we have better weight updation logic.
|
||||
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
|
||||
for yaml in inference_yaml:
|
||||
@@ -806,8 +815,9 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
|
||||
if args.output_img_format not in ["png", "jpg"]:
|
||||
print(
|
||||
f"[ERROR] Format {args.output_img_format} is not supported yet."
|
||||
"Image saved as png instead. Supported formats: png / jpg"
|
||||
f"[ERROR] Format {args.output_img_format} is not "
|
||||
f"supported yet. Image saved as png instead."
|
||||
f"Supported formats: png / jpg"
|
||||
)
|
||||
|
||||
# To be as low-impact as possible to the existing CSV format, we append
|
||||
|
||||
@@ -133,7 +133,8 @@ def img2img_inf(
|
||||
image, width, height = resize_stencil(image)
|
||||
elif "Shark" in args.scheduler:
|
||||
print(
|
||||
f"Shark schedulers are not supported. Switching to EulerDiscrete scheduler"
|
||||
f"Shark schedulers are not supported. Switching to EulerDiscrete "
|
||||
f"scheduler"
|
||||
)
|
||||
args.scheduler = "EulerDiscrete"
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
@@ -393,7 +394,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
"e.g: SG161222/Realistic_Vision_V1.3, "
|
||||
"https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model download URL",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL",
|
||||
lines=3,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
@@ -507,9 +509,9 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA weights dropdown "
|
||||
"on the left if you want to use a standalone "
|
||||
"HuggingFace model ID for LoRA here "
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
@@ -633,7 +635,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at {get_generated_imgs_path()}",
|
||||
value=f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
|
||||
@@ -346,7 +346,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
|
||||
"https://civitai.com/api/download/models/3433",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model download URL",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL",
|
||||
lines=3,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
@@ -401,9 +402,9 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA weights dropdown "
|
||||
"on the left if you want to use a standalone "
|
||||
"HuggingFace model ID for LoRA here "
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
@@ -534,7 +535,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at {get_generated_imgs_path()}",
|
||||
value=f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
|
||||
@@ -51,8 +51,8 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown "
|
||||
"on the left and enter model ID here "
|
||||
placeholder="Select 'None' in the Models "
|
||||
"dropdown on the left and enter model ID here "
|
||||
"e.g: SG161222/Realistic_Vision_V1.3",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
@@ -74,8 +74,8 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA weights "
|
||||
"dropdown on the left if you want to use a "
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use a "
|
||||
"standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
|
||||
@@ -19,7 +19,10 @@ def get_hf_list(num_of_models=20):
|
||||
|
||||
|
||||
def get_civit_list(num_of_models=50):
|
||||
path = f"https://civitai.com/api/v1/models?limit={num_of_models}&types=Checkpoint"
|
||||
path = (
|
||||
f"https://civitai.com/api/v1/models?limit="
|
||||
f"{num_of_models}&types=Checkpoint"
|
||||
)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
raw_json = requests.get(path, headers=headers).json()
|
||||
models = list(raw_json.items())[0][1]
|
||||
@@ -79,7 +82,7 @@ with gr.Blocks() as model_web:
|
||||
type="value",
|
||||
label="Model Source",
|
||||
)
|
||||
model_numebr = gr.Slider(
|
||||
model_number = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=10,
|
||||
@@ -111,9 +114,9 @@ with gr.Blocks() as model_web:
|
||||
modelmanager_sendto_outpaint = gr.Button(value="SendTo Outpaint")
|
||||
modelmanager_sendto_upscaler = gr.Button(value="SendTo Upscaler")
|
||||
|
||||
def get_model_list(model_source, model_numebr):
|
||||
def get_model_list(model_source, model_number):
|
||||
if model_source == "Hugging Face":
|
||||
hf_model_list = get_hf_list(model_numebr)
|
||||
hf_model_list = get_hf_list(model_number)
|
||||
models = []
|
||||
for model in hf_model_list:
|
||||
# TODO: add model info
|
||||
@@ -124,7 +127,7 @@ with gr.Blocks() as model_web:
|
||||
gr.Row.update(visible=True),
|
||||
)
|
||||
elif model_source == "Civitai":
|
||||
civit_model_list = get_civit_list(model_numebr)
|
||||
civit_model_list = get_civit_list(model_number)
|
||||
models = []
|
||||
for model in civit_model_list:
|
||||
image = get_image_from_model(model)
|
||||
@@ -148,7 +151,7 @@ with gr.Blocks() as model_web:
|
||||
|
||||
get_model_btn.click(
|
||||
fn=get_model_list,
|
||||
inputs=[model_source, model_numebr],
|
||||
inputs=[model_source, model_number],
|
||||
outputs=[
|
||||
hf_models,
|
||||
civit_models,
|
||||
|
||||
@@ -265,7 +265,9 @@ def outpaint_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["init_images"][0])
|
||||
res = outpaint_inf(
|
||||
@@ -352,7 +354,8 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
|
||||
"https://civitai.com/api/download/models/3433",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model download URL",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL",
|
||||
lines=3,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
@@ -404,9 +407,9 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA weights dropdown "
|
||||
"on the left if you want to use a "
|
||||
"standalone HuggingFace model ID for LoRA here "
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
@@ -559,7 +562,8 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at {get_generated_imgs_path()}",
|
||||
value=f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
|
||||
@@ -38,14 +38,14 @@ def output_subdirs() -> list[str]:
|
||||
)
|
||||
]
|
||||
|
||||
# It is less confusing to always including the subdir that will take any images generated
|
||||
# today even if it doesn't exist yet
|
||||
# It is less confusing to always including the subdir that will take any
|
||||
# images generated today even if it doesn't exist yet
|
||||
if get_generated_imgs_todays_subdir() not in relative_paths:
|
||||
relative_paths.append(get_generated_imgs_todays_subdir())
|
||||
|
||||
# sort subdirectories so that that the date named ones we probably created in this or
|
||||
# previous sessions come first, sorted with the most recent first. Other subdirs are listed
|
||||
# after.
|
||||
# sort subdirectories so that the date named ones we probably
|
||||
# created in this or previous sessions come first, sorted with the most
|
||||
# recent first. Other subdirs are listed after.
|
||||
generated_paths = sorted(
|
||||
[path for path in relative_paths if path.isnumeric()], reverse=True
|
||||
)
|
||||
@@ -66,7 +66,8 @@ with gr.Blocks() as outputgallery_web:
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
|
||||
with gr.Row(elem_id="outputgallery_gallery"):
|
||||
# needed to workaround gradio issue: https://github.com/gradio-app/gradio/issues/2907
|
||||
# needed to workaround gradio issue:
|
||||
# https://github.com/gradio-app/gradio/issues/2907
|
||||
dev_null = gr.Textbox("", visible=False)
|
||||
|
||||
gallery_files = gr.State(value=[])
|
||||
@@ -194,14 +195,18 @@ with gr.Blocks() as outputgallery_web:
|
||||
def on_refresh(current_subdir: str) -> list:
|
||||
# get an up-to-date subdirectory list
|
||||
refreshed_subdirs = output_subdirs()
|
||||
# get the images using either the current subdirectory or the most recent valid one
|
||||
# get the images using either the current subdirectory or the most
|
||||
# recent valid one
|
||||
new_subdir = (
|
||||
current_subdir
|
||||
if current_subdir in refreshed_subdirs
|
||||
else refreshed_subdirs[0]
|
||||
)
|
||||
new_images = outputgallery_filenames(new_subdir)
|
||||
new_label = f"{len(new_images)} images in {os.path.join(output_dir, new_subdir)}"
|
||||
new_label = (
|
||||
f"{len(new_images)} images in "
|
||||
f"{os.path.join(output_dir, new_subdir)}"
|
||||
)
|
||||
|
||||
return [
|
||||
gr.Dropdown.update(
|
||||
@@ -220,17 +225,22 @@ with gr.Blocks() as outputgallery_web:
|
||||
]
|
||||
|
||||
def on_new_image(subdir, subdir_paths, status) -> list:
|
||||
# prevent error triggered when an image generates before the tab has even been selected
|
||||
# prevent error triggered when an image generates before the tab
|
||||
# has even been selected
|
||||
subdir_paths = (
|
||||
subdir_paths
|
||||
if len(subdir_paths) > 0
|
||||
else [get_generated_imgs_todays_subdir()]
|
||||
)
|
||||
|
||||
# only update if the current subdir is the most recent one as new images only go there
|
||||
# only update if the current subdir is the most recent one as
|
||||
# new images only go there
|
||||
if subdir_paths[0] == subdir:
|
||||
new_images = outputgallery_filenames(subdir)
|
||||
new_label = f"{len(new_images)} images in {os.path.join(output_dir, subdir)} - {status}"
|
||||
new_label = (
|
||||
f"{len(new_images)} images in "
|
||||
f"{os.path.join(output_dir, subdir)} - {status}"
|
||||
)
|
||||
|
||||
return [
|
||||
new_images,
|
||||
@@ -245,11 +255,13 @@ with gr.Blocks() as outputgallery_web:
|
||||
),
|
||||
]
|
||||
else:
|
||||
# otherwise change nothing, (only untyped gradio gr.update() does this)
|
||||
# otherwise change nothing,
|
||||
# (only untyped gradio gr.update() does this)
|
||||
return [gr.update(), gr.update(), gr.update()]
|
||||
|
||||
def on_select_image(images: list[str], evt: gr.SelectData) -> list:
|
||||
# evt.index is an index into the full list of filenames for the current subdirectory
|
||||
# evt.index is an index into the full list of filenames for
|
||||
# the current subdirectory
|
||||
filename = images[evt.index]
|
||||
params = displayable_metadata(filename)
|
||||
|
||||
@@ -267,7 +279,8 @@ with gr.Blocks() as outputgallery_web:
|
||||
def on_outputgallery_filename_change(filename: str) -> list:
|
||||
exists = filename != "None" and os.path.exists(filename)
|
||||
return [
|
||||
# disable or enable each of the sendto button based on whether an image is selected
|
||||
# disable or enable each of the sendto button based on whether
|
||||
# an image is selected
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
@@ -276,14 +289,16 @@ with gr.Blocks() as outputgallery_web:
|
||||
gr.Button.update(interactive=exists),
|
||||
]
|
||||
|
||||
# The time first our tab is selected we need to do an initial refresh to populate
|
||||
# the subdirectory select box and the images from the most recent subdirectory.
|
||||
# The time first our tab is selected we need to do an initial refresh
|
||||
# to populate the subdirectory select box and the images from the most
|
||||
# recent subdirectory.
|
||||
#
|
||||
# We do it at this point rather than setting this up in the controls' definitions
|
||||
# as when you refresh the browser you always get what was *initially* set, which
|
||||
# won't include any new subdirectories or images that might have created since
|
||||
# the application was started. Doing it this way means a browser refresh/reload
|
||||
# always gets the most up to date data.
|
||||
# We do it at this point rather than setting this up in the controls'
|
||||
# definitions as when you refresh the browser you always get what was
|
||||
# *initially* set, which won't include any new subdirectories or images
|
||||
# that might have created since the application was started. Doing it
|
||||
# this way means a browser refresh/reload always gets the most
|
||||
# up-to-date data.
|
||||
def on_select_tab(subdir_paths):
|
||||
if len(subdir_paths) == 0:
|
||||
return on_refresh("")
|
||||
@@ -297,11 +312,11 @@ with gr.Blocks() as outputgallery_web:
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
# Unfortunately as of gradio 3.22.0 gr.update against Galleries doesn't support
|
||||
# things set with .style, nor the elem_classes kwarg so we have to directly set
|
||||
# things up via JavaScript if we want the client to take notice of any of our
|
||||
# changes to the number of columns after it decides to put them back to the
|
||||
# original number when we change something
|
||||
# Unfortunately as of gradio 3.22.0 gr.update against Galleries
|
||||
# doesn't support things set with .style, nor the elem_classes kwarg, so
|
||||
# we have to directly set things up via JavaScript if we want the client
|
||||
# to take notice of our changes to the number of columns after it
|
||||
# decides to put them back to the original number when we change something
|
||||
def js_set_columns_in_browser(timeout_length):
|
||||
return f"""
|
||||
(new_cols) => {{
|
||||
@@ -318,32 +333,36 @@ with gr.Blocks() as outputgallery_web:
|
||||
|
||||
# --- Wire handlers up to the actions
|
||||
|
||||
# - Many actions reset the number of columns shown in the gallery on the browser end,
|
||||
# so we have to set them back to what we think they should be after the initial
|
||||
# action.
|
||||
# - None of the actions on this tab trigger inference, and we want the user to be able
|
||||
# to do them whilst other tabs have ongoing inference running. Waiting in the queue
|
||||
# behind inference jobs would mean the UI can't fully respond until the inference tasks
|
||||
# complete, hence queue=False on all of these.
|
||||
# Many actions reset the number of columns shown in the gallery on the
|
||||
# browser end, so we have to set them back to what we think they should
|
||||
# be after the initial action.
|
||||
#
|
||||
# None of the actions on this tab trigger inference, and we want the
|
||||
# user to be able to do them whilst other tabs have ongoing inference
|
||||
# running. Waiting in the queue behind inference jobs would mean the UI
|
||||
# can't fully respond until the inference tasks complete,
|
||||
# hence queue=False on all of these.
|
||||
set_gallery_columns_immediate = dict(
|
||||
fn=None,
|
||||
inputs=[image_columns],
|
||||
# gradio blanks the UI on Chrome on Linux on gallery select if I don't put an output here
|
||||
# gradio blanks the UI on Chrome on Linux on gallery select if
|
||||
# I don't put an output here
|
||||
outputs=[dev_null],
|
||||
_js=js_set_columns_in_browser(0),
|
||||
queue=False,
|
||||
)
|
||||
|
||||
# setting columns after selecting a gallery item needs a real timeout length for the
|
||||
# number of columns to actually be applied. Not really sure why, maybe something has
|
||||
# to finish animating?
|
||||
# setting columns after selecting a gallery item needs a real
|
||||
# timeout length for the number of columns to actually be applied.
|
||||
# Not really sure why, maybe something has to finish animating?
|
||||
set_gallery_columns_delayed = dict(
|
||||
set_gallery_columns_immediate, _js=js_set_columns_in_browser(250)
|
||||
)
|
||||
|
||||
# clearing images when we need to completely change what's in the gallery avoids current
|
||||
# images being shown replacing piecemeal and prevents weirdness and errors if the user
|
||||
# selects an image during the replacement phase.
|
||||
# clearing images when we need to completely change what's in the
|
||||
# gallery avoids current images being shown replacing piecemeal and
|
||||
# prevents weirdness and errors if the user selects an image during the
|
||||
# replacement phase.
|
||||
clear_gallery = dict(
|
||||
fn=on_clear_gallery,
|
||||
inputs=None,
|
||||
|
||||
@@ -8,13 +8,15 @@ from transformers import (
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
|
||||
start_message = (
|
||||
"<|SYSTEM|># StableLM Tuned (Alpha version)- StableLM is a helpful and "
|
||||
"harmless open-source AI language model developed by StabilityAI."
|
||||
"\n- StableLM is excited to be able to help the user, but will refuse to do "
|
||||
"anything that could be considered harmful to the user."
|
||||
"\n- StableLM is more than just an information source, StableLM is also able to "
|
||||
"write poetry, short stories, and make jokes."
|
||||
"\n- StableLM will refuse to participate in anything that could harm a human."
|
||||
"<|SYSTEM|># StableLM Tuned (Alpha version)"
|
||||
"\n- StableLM is a helpful and harmless open-source AI language model "
|
||||
"developed by StabilityAI."
|
||||
"\n- StableLM is excited to be able to help the user, but will refuse "
|
||||
"to do anything that could be considered harmful to the user."
|
||||
"\n- StableLM is more than just an information source, StableLM is also "
|
||||
"able to write poetry, short stories, and make jokes."
|
||||
"\n- StableLM will refuse to participate in anything that "
|
||||
"could harm a human."
|
||||
)
|
||||
|
||||
|
||||
@@ -91,7 +93,8 @@ def chat(curr_system_message, history, model, device, precision):
|
||||
"StableLM"
|
||||
) # pass elements from UI as required
|
||||
|
||||
# Construct the input message string for the model by concatenating the current system message and conversation history
|
||||
# Construct the input message string for the model by concatenating the
|
||||
# current system message and conversation history
|
||||
if len(curr_system_message.split()) > 160:
|
||||
print("clearing context")
|
||||
curr_system_message = start_message
|
||||
@@ -111,7 +114,8 @@ def chat(curr_system_message, history, model, device, precision):
|
||||
# print(new_text)
|
||||
partial_text += new_text
|
||||
history[-1][1] = partial_text
|
||||
# Yield an empty string to clean up the message textbox and the updated conversation history
|
||||
# Yield an empty string to clean up the message textbox and the updated
|
||||
# conversation history
|
||||
yield history
|
||||
return words_list
|
||||
|
||||
|
||||
@@ -87,8 +87,8 @@ def txt2img_inf(
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both "
|
||||
"must not be empty",
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
@@ -316,10 +316,11 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
)
|
||||
txt2img_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the dropdown on the left and "
|
||||
"enter model ID here.",
|
||||
placeholder="Select 'None' in the dropdown "
|
||||
"on the left and enter model ID here.",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model download URL.",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL.",
|
||||
lines=3,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
@@ -375,10 +376,10 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA weights dropdown "
|
||||
"on the left if you want to use a standalone "
|
||||
"HuggingFace model ID for LoRA here e.g: "
|
||||
"sayakpaul/sd-model-finetuned-lora-t4",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
@@ -509,7 +510,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at {get_generated_imgs_path()}",
|
||||
value=f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
|
||||
@@ -88,8 +88,8 @@ def upscaler_inf(
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be "
|
||||
"empty.",
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty.",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
@@ -227,10 +227,22 @@ def upscaler_inf(
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={steps}, noise_level={noise_level}, guidance_scale={guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, " f"ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, " f"device={device}"
|
||||
text_output += (
|
||||
f"\nsteps={steps}, "
|
||||
f"noise_level={noise_level}, "
|
||||
f"guidance_scale={guidance_scale}, "
|
||||
f"seed={seeds}"
|
||||
)
|
||||
text_output += (
|
||||
f"\nsize={height}x{width}, "
|
||||
f"batch_count={batch_count}, "
|
||||
f"batch_size={batch_size}, "
|
||||
f"max_length={args.max_length}"
|
||||
)
|
||||
text_output += global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
@@ -271,7 +283,9 @@ def upscaler_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["init_images"][0])
|
||||
res = upscaler_inf(
|
||||
@@ -353,7 +367,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
"e.g: SG161222/Realistic_Vision_V1.3, "
|
||||
"https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model download URL",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL",
|
||||
lines=3,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
@@ -397,7 +412,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
).replace("\\", "\n\\")
|
||||
upscaler_lora_info = f"LoRA Path: {upscaler_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standalone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=upscaler_lora_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
@@ -405,9 +420,9 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA weights dropdown "
|
||||
"on the left if you want to use a standalone "
|
||||
"HuggingFace model ID for LoRA here "
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
@@ -539,7 +554,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at {get_generated_imgs_path()}",
|
||||
value=f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
|
||||
@@ -39,6 +39,9 @@ scheduler_list_cpu_only = [
|
||||
"LMSDiscrete",
|
||||
"KDPM2Discrete",
|
||||
"DPMSolverMultistep",
|
||||
"DPMSolverMultistep++",
|
||||
"DPMSolverMultistepKarras",
|
||||
"DPMSolverMultistepKarras++",
|
||||
"EulerDiscrete",
|
||||
"EulerAncestralDiscrete",
|
||||
"DEISMultistep",
|
||||
@@ -86,7 +89,8 @@ def create_custom_models_folders():
|
||||
else:
|
||||
if not os.path.isdir(args.ckpt_dir):
|
||||
sys.exit(
|
||||
f"Invalid --ckpt_dir argument, {args.ckpt_dir} folder does not exists."
|
||||
f"Invalid --ckpt_dir argument, "
|
||||
f"{args.ckpt_dir} folder does not exists."
|
||||
)
|
||||
for root in dir:
|
||||
get_custom_model_path(root).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
Reference in New Issue
Block a user