[SD] Use dynamic stencil HF repo id

-- This commit removes the hardcoded HF ID for Stencil and instead
   utilizes a dynamic instantiation of HF model.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
Abhishek Varma
2023-03-10 11:58:56 +00:00
committed by Abhishek Varma
parent b23d3aa584
commit b8f4b18951
4 changed files with 29 additions and 7 deletions

View File

@@ -16,6 +16,7 @@ from apps.stable_diffusion.src.utils import (
fetch_and_update_base_model_id,
get_path_stem,
get_extended_name,
get_stencil_model_id,
)
@@ -81,7 +82,8 @@ class SharkifyStableDiffusionModel:
use_base_vae: bool = False,
use_tuned: bool = False,
low_cpu_mem_usage: bool = False,
is_inpaint: bool = False
is_inpaint: bool = False,
use_stencil: str = None
):
self.check_params(max_len, width, height)
self.max_len = max_len
@@ -118,6 +120,7 @@ class SharkifyStableDiffusionModel:
self.model_name = self.model_name + "_" + get_path_stem(self.model_id)
self.low_cpu_mem_usage = low_cpu_mem_usage
self.is_inpaint = is_inpaint
self.use_stencil = get_stencil_model_id(use_stencil)
def get_extended_name_for_all_model(self, mask_to_fetch):
model_name = {}
@@ -229,7 +232,7 @@ class SharkifyStableDiffusionModel:
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
"takuma104/control_sd15_canny", # TODO: ADD with model ID
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
@@ -277,12 +280,11 @@ class SharkifyStableDiffusionModel:
def get_control_net(self):
class StencilControlNetModel(torch.nn.Module):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False
self, model_id=self.use_stencil, low_cpu_mem_usage=False
):
super().__init__()
self.cnet = ControlNetModel.from_pretrained(
"takuma104/control_sd15_canny", # TODO: ADD with model ID
subfolder="controlnet",
model_id,
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.in_channels = self.cnet.in_channels
@@ -454,7 +456,7 @@ class SharkifyStableDiffusionModel:
# -- Fetch all vmfbs for the model, if present, else delete the lot.
need_vae_encode, need_stencil = False, False
if args.img_path is not None:
if args.use_stencil is not None:
if self.use_stencil is not None:
need_stencil = True
else:
need_vae_encode = True

View File

@@ -317,7 +317,7 @@ class StableDiffusionPipeline:
use_base_vae: bool,
use_tuned: bool,
low_cpu_mem_usage: bool = False,
use_stencil: bool = False,
use_stencil: str = None,
):
is_inpaint = cls.__name__ in [
"InpaintPipeline",
@@ -337,6 +337,7 @@ class StableDiffusionPipeline:
use_tuned=use_tuned,
low_cpu_mem_usage=low_cpu_mem_usage,
is_inpaint=is_inpaint,
use_stencil=use_stencil,
)
if cls.__name__ in [
"Image2ImagePipeline",

View File

@@ -13,6 +13,7 @@ from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.stencils.stencil_utils import (
controlnet_hint_conversion,
get_stencil_model_id,
)
from apps.stable_diffusion.src.utils.utils import (
get_shark_model,

View File

@@ -134,6 +134,24 @@ def controlnet_hint_conversion(
return controlnet_hint
stencil_to_model_id_map = {
"canny": "lllyasviel/sd-controlnet-canny",
"depth": "lllyasviel/sd-controlnet-depth",
"hed": "lllyasviel/sd-controlnet-hed",
"mlsd": "lllyasviel/sd-controlnet-mlsd",
"normal": "lllyasviel/sd-controlnet-normal",
"openpose": "lllyasviel/sd-controlnet-openpose",
"scribble": "lllyasviel/sd-controlnet-scribble",
"seg": "lllyasviel/sd-controlnet-seg",
}
def get_stencil_model_id(use_stencil):
if use_stencil in stencil_to_model_id_map:
return stencil_to_model_id_map[use_stencil]
return None
# Stencil 1. Canny
def hint_canny(
image: Image.Image,