mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
Add controlmode (#1957)
This commit is contained in:
@@ -105,6 +105,7 @@ def main():
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
use_stencil=use_stencil,
|
||||
control_mode=args.control_mode,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
|
||||
@@ -380,25 +380,38 @@ class SharkifyStableDiffusionModel:
|
||||
control11,
|
||||
control12,
|
||||
control13,
|
||||
scale1,
|
||||
scale2,
|
||||
scale3,
|
||||
scale4,
|
||||
scale5,
|
||||
scale6,
|
||||
scale7,
|
||||
scale8,
|
||||
scale9,
|
||||
scale10,
|
||||
scale11,
|
||||
scale12,
|
||||
scale13,
|
||||
):
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
db_res_samples = tuple(
|
||||
[
|
||||
control1,
|
||||
control2,
|
||||
control3,
|
||||
control4,
|
||||
control5,
|
||||
control6,
|
||||
control7,
|
||||
control8,
|
||||
control9,
|
||||
control10,
|
||||
control11,
|
||||
control12,
|
||||
control1 * scale1,
|
||||
control2 * scale2,
|
||||
control3 * scale3,
|
||||
control4 * scale4,
|
||||
control5 * scale5,
|
||||
control6 * scale6,
|
||||
control7 * scale7,
|
||||
control8 * scale8,
|
||||
control9 * scale9,
|
||||
control10 * scale10,
|
||||
control11 * scale11,
|
||||
control12 * scale12,
|
||||
]
|
||||
)
|
||||
mb_res_samples = control13
|
||||
mb_res_samples = control13 * scale13
|
||||
latents = torch.cat([latent] * 2)
|
||||
unet_out = self.unet.forward(
|
||||
latents,
|
||||
@@ -446,6 +459,19 @@ class SharkifyStableDiffusionModel:
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
]
|
||||
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
|
||||
unet,
|
||||
|
||||
@@ -113,6 +113,7 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
cpu_scheduling,
|
||||
controlnet_hint=None,
|
||||
controlnet_conditioning_scale: float = 1.0,
|
||||
control_mode="Balanced", # Prompt, Balanced, or Controlnet
|
||||
mask=None,
|
||||
masked_image_latents=None,
|
||||
return_all_latents=False,
|
||||
@@ -121,6 +122,7 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
assert control_mode in ["Prompt", "Balanced", "Controlnet"]
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_unet()
|
||||
self.load_controlnet()
|
||||
@@ -176,6 +178,22 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
|
||||
|
||||
dtype = latents.dtype
|
||||
if control_mode == "Balanced":
|
||||
control_scale = [
|
||||
torch.tensor(1.0, dtype=dtype) for _ in range(len(control))
|
||||
]
|
||||
elif control_mode == "Prompt":
|
||||
control_scale = [
|
||||
torch.tensor(0.825**x, dtype=dtype)
|
||||
for x in range(len(control))
|
||||
]
|
||||
elif control_mode == "Controlnet":
|
||||
control_scale = [
|
||||
torch.tensor(float(guidance_scale), dtype=dtype)
|
||||
for _ in range(len(control))
|
||||
]
|
||||
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
@@ -197,6 +215,19 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
control_scale[0],
|
||||
control_scale[1],
|
||||
control_scale[2],
|
||||
control_scale[3],
|
||||
control_scale[4],
|
||||
control_scale[5],
|
||||
control_scale[6],
|
||||
control_scale[7],
|
||||
control_scale[8],
|
||||
control_scale[9],
|
||||
control_scale[10],
|
||||
control_scale[11],
|
||||
control_scale[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
@@ -222,6 +253,19 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
control_scale[0],
|
||||
control_scale[1],
|
||||
control_scale[2],
|
||||
control_scale[3],
|
||||
control_scale[4],
|
||||
control_scale[5],
|
||||
control_scale[6],
|
||||
control_scale[7],
|
||||
control_scale[8],
|
||||
control_scale[9],
|
||||
control_scale[10],
|
||||
control_scale[11],
|
||||
control_scale[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
@@ -274,6 +318,7 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
max_embeddings_multiples,
|
||||
use_stencil,
|
||||
resample_type,
|
||||
control_mode,
|
||||
):
|
||||
# Control Embedding check & conversion
|
||||
# TODO: 1. Change `num_images_per_prompt`.
|
||||
@@ -328,6 +373,7 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
controlnet_hint=controlnet_hint,
|
||||
control_mode=control_mode,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
|
||||
@@ -290,6 +290,58 @@
|
||||
"control13": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale1": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale2": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale3": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale4": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale5": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale6": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale7": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale8": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale9": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale10": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale11": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale12": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale13": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -420,6 +420,13 @@ p.add_argument(
|
||||
help="Enable the stencil feature.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--control_mode",
|
||||
choices=["Prompt", "Balanced", "Controlnet"],
|
||||
default="Balanced",
|
||||
help="How Controlnet injection should be prioritized.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_lora",
|
||||
type=str,
|
||||
|
||||
@@ -68,6 +68,7 @@ def img2img_inf(
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
resample_type: str,
|
||||
control_mode: str,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
@@ -253,6 +254,7 @@ def img2img_inf(
|
||||
args.max_embeddings_multiples,
|
||||
use_stencil=use_stencil,
|
||||
resample_type=resample_type,
|
||||
control_mode=control_mode,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(
|
||||
@@ -412,6 +414,11 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
inputs=use_stencil,
|
||||
outputs=[canvas_width, canvas_height, create_button],
|
||||
)
|
||||
control_mode = gr.Radio(
|
||||
choices=["Prompt", "Balanced", "Controlnet"],
|
||||
value="Balanced",
|
||||
label="Control Mode",
|
||||
)
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
@@ -625,6 +632,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
control_mode,
|
||||
],
|
||||
outputs=[img2img_gallery, std_output, img2img_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
|
||||
Reference in New Issue
Block a user