mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
(Studio) Fix controlnet switching. (#2026)
* Fix controlnet switching. * Fix txt2img + control adapters
This commit is contained in:
@@ -219,7 +219,6 @@ class SharkifyStableDiffusionModel:
|
||||
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
|
||||
self.use_lora = use_lora
|
||||
|
||||
print(self.model_name)
|
||||
self.model_name = self.get_extended_name_for_all_model()
|
||||
self.debug = debug
|
||||
self.sharktank_dir = sharktank_dir
|
||||
@@ -241,7 +240,7 @@ class SharkifyStableDiffusionModel:
|
||||
args.hf_model_id = self.base_model_id
|
||||
self.return_mlir = return_mlir
|
||||
|
||||
def get_extended_name_for_all_model(self):
|
||||
def get_extended_name_for_all_model(self, model_list=None):
|
||||
model_name = {}
|
||||
sub_model_list = [
|
||||
"clip",
|
||||
@@ -255,6 +254,8 @@ class SharkifyStableDiffusionModel:
|
||||
"stencil_adapter",
|
||||
"stencil_adapter_512",
|
||||
]
|
||||
if model_list is not None:
|
||||
sub_model_list = model_list
|
||||
index = 0
|
||||
for model in sub_model_list:
|
||||
sub_model = model
|
||||
@@ -272,7 +273,7 @@ class SharkifyStableDiffusionModel:
|
||||
if stencil is not None:
|
||||
cnet_config = (
|
||||
self.model_namedata
|
||||
+ "_v1-5"
|
||||
+ "_sd15_"
|
||||
+ stencil.split("_")[-1]
|
||||
)
|
||||
stencil_names.append(
|
||||
@@ -283,6 +284,7 @@ class SharkifyStableDiffusionModel:
|
||||
else:
|
||||
model_name[model] = get_extended_name(sub_model + model_config)
|
||||
index += 1
|
||||
|
||||
return model_name
|
||||
|
||||
def check_params(self, max_len, width, height):
|
||||
@@ -765,7 +767,8 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
inputs = tuple(self.inputs["stencil_adapter"])
|
||||
model_name = "stencil_adapter_512" if use_large else "stencil_adapter"
|
||||
ext_model_name = self.model_name[model_name]
|
||||
stencil_names = self.get_extended_name_for_all_model([model_name])
|
||||
ext_model_name = stencil_names[model_name]
|
||||
if isinstance(ext_model_name, list):
|
||||
desired_name = None
|
||||
print(ext_model_name)
|
||||
|
||||
@@ -125,6 +125,31 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
self.controlnet_512_id[index] = None
|
||||
self.controlnet_512[index] = None
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height // 8,
|
||||
width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.is_scale_input_called = True
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def prepare_image_latents(
|
||||
self,
|
||||
image,
|
||||
@@ -203,10 +228,16 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
self.load_unet_512()
|
||||
|
||||
for i, name in enumerate(self.controlnet_names):
|
||||
use_names = []
|
||||
if name is not None:
|
||||
use_names.append(name)
|
||||
else:
|
||||
continue
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_controlnet(i, name)
|
||||
else:
|
||||
self.load_controlnet_512(i, name)
|
||||
self.controlnet_names = use_names
|
||||
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
@@ -461,6 +492,7 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
# image, use_stencil, height, width, dtype, num_images_per_prompt=1
|
||||
# )
|
||||
stencil_hints = []
|
||||
self.sd_model.stencils = stencils
|
||||
for i, hint in enumerate(preprocessed_hints):
|
||||
if hint is not None:
|
||||
hint = controlnet_hint_reshaping(
|
||||
@@ -475,7 +507,7 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
for i, stencil in enumerate(stencils):
|
||||
if stencil == None:
|
||||
continue
|
||||
if len(stencil_hints) >= i:
|
||||
if len(stencil_hints) > i:
|
||||
if stencil_hints[i] is not None:
|
||||
print(f"Using preprocessed controlnet hint for {stencil}")
|
||||
continue
|
||||
@@ -518,19 +550,30 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
# Prepare input image latent
|
||||
init_latents, final_timesteps = self.prepare_image_latents(
|
||||
image=image,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
strength=strength,
|
||||
dtype=dtype,
|
||||
resample_type=resample_type,
|
||||
)
|
||||
if image is not None:
|
||||
# Prepare input image latent
|
||||
init_latents, final_timesteps = self.prepare_image_latents(
|
||||
image=image,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
strength=strength,
|
||||
dtype=dtype,
|
||||
resample_type=resample_type,
|
||||
)
|
||||
else:
|
||||
# Prepare initial latent.
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
final_timesteps = self.scheduler.timesteps
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_stencil_latents(
|
||||
|
||||
@@ -105,7 +105,7 @@ def img2img_inf(
|
||||
|
||||
for i, stencil in enumerate(stencils):
|
||||
if images[i] is None and stencil is not None:
|
||||
return
|
||||
continue
|
||||
if images[i] is not None:
|
||||
if isinstance(images[i], dict):
|
||||
images[i] = images[i]["composite"]
|
||||
@@ -120,6 +120,8 @@ def img2img_inf(
|
||||
else:
|
||||
# TODO: enable t2i + controlnets
|
||||
image = None
|
||||
if image:
|
||||
image, _, _ = resize_stencil(image, width, height)
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
args.ckpt_loc = ""
|
||||
@@ -152,7 +154,6 @@ def img2img_inf(
|
||||
stencil_count += 1
|
||||
if stencil_count > 0:
|
||||
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
|
||||
image, _, _ = resize_stencil(image, width, height)
|
||||
elif "Shark" in args.scheduler:
|
||||
print(
|
||||
f"Shark schedulers are not supported. Switching to EulerDiscrete "
|
||||
@@ -162,6 +163,7 @@ def img2img_inf(
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
args.precision = precision
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
print(stencils)
|
||||
new_config_obj = Config(
|
||||
"img2img",
|
||||
args.hf_model_id,
|
||||
@@ -180,7 +182,12 @@ def img2img_inf(
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_cfg_obj() != new_config_obj
|
||||
or any(
|
||||
global_obj.get_cfg_obj().stencils[idx] != stencil
|
||||
for idx, stencil in enumerate(stencils)
|
||||
)
|
||||
):
|
||||
print("clearing config because you changed something important")
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.batch_count = batch_count
|
||||
@@ -632,7 +639,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
[cnet_1_image],
|
||||
)
|
||||
|
||||
cnet_1_model.input(
|
||||
cnet_1_model.change(
|
||||
fn=(
|
||||
lambda m, w, h, s, i, p: update_cn_input(
|
||||
m, w, h, s, i, p, 0
|
||||
@@ -739,7 +746,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
label="Preprocessed Hint",
|
||||
interactive=True,
|
||||
)
|
||||
cnet_2_model.select(
|
||||
cnet_2_model.change(
|
||||
fn=(
|
||||
lambda m, w, h, s, i, p: update_cn_input(
|
||||
m, w, h, s, i, p, 0
|
||||
|
||||
Reference in New Issue
Block a user