Fix stencil pipline to use input image (#2027)

This commit is contained in:
gpetters94
2023-12-07 01:25:18 -05:00
committed by GitHub
parent bb5f133e1c
commit 7e12d1782a

View File

@@ -33,7 +33,14 @@ from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
from apps.stable_diffusion.src.utils import (
resamplers,
resampler_list,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class StencilPipeline(StableDiffusionPipeline):
@@ -66,6 +73,24 @@ class StencilPipeline(StableDiffusionPipeline):
self.controlnet_id = [str] * len(controlnet_names)
self.controlnet_512_id = [str] * len(controlnet_names)
self.controlnet_names = controlnet_names
self.vae_encode = None
def load_vae_encode(self):
if self.vae_encode is not None:
return
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def load_controlnet(self, index, model_name):
if model_name is None:
@@ -100,30 +125,57 @@ class StencilPipeline(StableDiffusionPipeline):
self.controlnet_512_id[index] = None
self.controlnet_512[index] = None
def prepare_latents(
def prepare_image_latents(
self,
image,
batch_size,
height,
width,
generator,
num_inference_steps,
strength,
dtype,
resample_type,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
# Pre process image -> get image encoded -> process latents
# TODO: process with variable HxW combos
# Pre-process image
resample_type = (
resamplers[resample_type]
if resample_type in resampler_list
# Fallback to Lanczos
else Image.Resampling.LANCZOS
)
image = image.resize((width, height), resample=resample_type)
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
image_arr = image_arr / 255.0
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
image_arr = 2 * (image_arr - 0.5)
# set scheduler steps
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
init_timestep = min(
int(num_inference_steps * strength), num_inference_steps
)
t_start = max(num_inference_steps - init_timestep, 0)
# timesteps reduced as per strength
timesteps = self.scheduler.timesteps[t_start:]
# new number of steps to be used as per strength will be
# num_inference_steps = num_inference_steps - t_start
# image encode
latents = self.encode_image((image_arr,))
latents = torch.from_numpy(latents).to(dtype)
# add noise to data
noise = torch.randn(latents.shape, generator=generator, dtype=dtype)
latents = self.scheduler.add_noise(
latents, noise, timesteps[0].repeat(1)
)
return latents, timesteps
def produce_stencil_latents(
self,
@@ -370,6 +422,17 @@ class StencilPipeline(StableDiffusionPipeline):
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def encode_image(self, input_image):
self.load_vae_encode()
vae_encode_start = time.time()
latents = self.vae_encode("forward", input_image)
vae_inf_time = (time.time() - vae_encode_start) * 1000
if self.ondemand:
self.unload_vae_encode()
self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}"
return latents
def generate_images(
self,
prompts,
@@ -456,16 +519,18 @@ class StencilPipeline(StableDiffusionPipeline):
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Prepare initial latent.
init_latents = self.prepare_latents(
# Prepare input image latent
init_latents, final_timesteps = self.prepare_image_latents(
image=image,
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
resample_type=resample_type,
)
final_timesteps = self.scheduler.timesteps
# Get Image latents
latents = self.produce_stencil_latents(