(Studio) Fix controlnet switching. (#2026)

* Fix controlnet switching.

* Fix txt2img + control adapters
This commit is contained in:
Ean Garvey
2023-12-07 00:52:36 -06:00
committed by GitHub
parent 7e12d1782a
commit 7159698496
3 changed files with 75 additions and 22 deletions

View File

@@ -219,7 +219,6 @@ class SharkifyStableDiffusionModel:
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
self.use_lora = use_lora
print(self.model_name)
self.model_name = self.get_extended_name_for_all_model()
self.debug = debug
self.sharktank_dir = sharktank_dir
@@ -241,7 +240,7 @@ class SharkifyStableDiffusionModel:
args.hf_model_id = self.base_model_id
self.return_mlir = return_mlir
def get_extended_name_for_all_model(self):
def get_extended_name_for_all_model(self, model_list=None):
model_name = {}
sub_model_list = [
"clip",
@@ -255,6 +254,8 @@ class SharkifyStableDiffusionModel:
"stencil_adapter",
"stencil_adapter_512",
]
if model_list is not None:
sub_model_list = model_list
index = 0
for model in sub_model_list:
sub_model = model
@@ -272,7 +273,7 @@ class SharkifyStableDiffusionModel:
if stencil is not None:
cnet_config = (
self.model_namedata
+ "_v1-5"
+ "_sd15_"
+ stencil.split("_")[-1]
)
stencil_names.append(
@@ -283,6 +284,7 @@ class SharkifyStableDiffusionModel:
else:
model_name[model] = get_extended_name(sub_model + model_config)
index += 1
return model_name
def check_params(self, max_len, width, height):
@@ -765,7 +767,8 @@ class SharkifyStableDiffusionModel:
inputs = tuple(self.inputs["stencil_adapter"])
model_name = "stencil_adapter_512" if use_large else "stencil_adapter"
ext_model_name = self.model_name[model_name]
stencil_names = self.get_extended_name_for_all_model([model_name])
ext_model_name = stencil_names[model_name]
if isinstance(ext_model_name, list):
desired_name = None
print(ext_model_name)

View File

@@ -125,6 +125,31 @@ class StencilPipeline(StableDiffusionPipeline):
self.controlnet_512_id[index] = None
self.controlnet_512[index] = None
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def prepare_image_latents(
self,
image,
@@ -203,10 +228,16 @@ class StencilPipeline(StableDiffusionPipeline):
self.load_unet_512()
for i, name in enumerate(self.controlnet_names):
use_names = []
if name is not None:
use_names.append(name)
else:
continue
if text_embeddings.shape[1] <= self.model_max_length:
self.load_controlnet(i, name)
else:
self.load_controlnet_512(i, name)
self.controlnet_names = use_names
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
@@ -461,6 +492,7 @@ class StencilPipeline(StableDiffusionPipeline):
# image, use_stencil, height, width, dtype, num_images_per_prompt=1
# )
stencil_hints = []
self.sd_model.stencils = stencils
for i, hint in enumerate(preprocessed_hints):
if hint is not None:
hint = controlnet_hint_reshaping(
@@ -475,7 +507,7 @@ class StencilPipeline(StableDiffusionPipeline):
for i, stencil in enumerate(stencils):
if stencil == None:
continue
if len(stencil_hints) >= i:
if len(stencil_hints) > i:
if stencil_hints[i] is not None:
print(f"Using preprocessed controlnet hint for {stencil}")
continue
@@ -518,19 +550,30 @@ class StencilPipeline(StableDiffusionPipeline):
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Prepare input image latent
init_latents, final_timesteps = self.prepare_image_latents(
image=image,
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
resample_type=resample_type,
)
if image is not None:
# Prepare input image latent
init_latents, final_timesteps = self.prepare_image_latents(
image=image,
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
resample_type=resample_type,
)
else:
# Prepare initial latent.
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
final_timesteps = self.scheduler.timesteps
# Get Image latents
latents = self.produce_stencil_latents(

View File

@@ -105,7 +105,7 @@ def img2img_inf(
for i, stencil in enumerate(stencils):
if images[i] is None and stencil is not None:
return
continue
if images[i] is not None:
if isinstance(images[i], dict):
images[i] = images[i]["composite"]
@@ -120,6 +120,8 @@ def img2img_inf(
else:
# TODO: enable t2i + controlnets
image = None
if image:
image, _, _ = resize_stencil(image, width, height)
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
@@ -152,7 +154,6 @@ def img2img_inf(
stencil_count += 1
if stencil_count > 0:
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, _, _ = resize_stencil(image, width, height)
elif "Shark" in args.scheduler:
print(
f"Shark schedulers are not supported. Switching to EulerDiscrete "
@@ -162,6 +163,7 @@ def img2img_inf(
cpu_scheduling = not args.scheduler.startswith("Shark")
args.precision = precision
dtype = torch.float32 if precision == "fp32" else torch.half
print(stencils)
new_config_obj = Config(
"img2img",
args.hf_model_id,
@@ -180,7 +182,12 @@ def img2img_inf(
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
or any(
global_obj.get_cfg_obj().stencils[idx] != stencil
for idx, stencil in enumerate(stencils)
)
):
print("clearing config because you changed something important")
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_count = batch_count
@@ -632,7 +639,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
[cnet_1_image],
)
cnet_1_model.input(
cnet_1_model.change(
fn=(
lambda m, w, h, s, i, p: update_cn_input(
m, w, h, s, i, p, 0
@@ -739,7 +746,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="Preprocessed Hint",
interactive=True,
)
cnet_2_model.select(
cnet_2_model.change(
fn=(
lambda m, w, h, s, i, p: update_cn_input(
m, w, h, s, i, p, 0