mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[SD] Use dynamic stencil HF repo id
-- This commit removes the hardcoded HF ID for Stencil and instead utilizes a dynamic instantiation of HF model. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
committed by
Abhishek Varma
parent
b23d3aa584
commit
b8f4b18951
@@ -16,6 +16,7 @@ from apps.stable_diffusion.src.utils import (
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
get_stencil_model_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -81,7 +82,8 @@ class SharkifyStableDiffusionModel:
|
||||
use_base_vae: bool = False,
|
||||
use_tuned: bool = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
is_inpaint: bool = False
|
||||
is_inpaint: bool = False,
|
||||
use_stencil: str = None
|
||||
):
|
||||
self.check_params(max_len, width, height)
|
||||
self.max_len = max_len
|
||||
@@ -118,6 +120,7 @@ class SharkifyStableDiffusionModel:
|
||||
self.model_name = self.model_name + "_" + get_path_stem(self.model_id)
|
||||
self.low_cpu_mem_usage = low_cpu_mem_usage
|
||||
self.is_inpaint = is_inpaint
|
||||
self.use_stencil = get_stencil_model_id(use_stencil)
|
||||
|
||||
def get_extended_name_for_all_model(self, mask_to_fetch):
|
||||
model_name = {}
|
||||
@@ -229,7 +232,7 @@ class SharkifyStableDiffusionModel:
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
"takuma104/control_sd15_canny", # TODO: ADD with model ID
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
@@ -277,12 +280,11 @@ class SharkifyStableDiffusionModel:
|
||||
def get_control_net(self):
|
||||
class StencilControlNetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, model_id=self.model_id, low_cpu_mem_usage=False
|
||||
self, model_id=self.use_stencil, low_cpu_mem_usage=False
|
||||
):
|
||||
super().__init__()
|
||||
self.cnet = ControlNetModel.from_pretrained(
|
||||
"takuma104/control_sd15_canny", # TODO: ADD with model ID
|
||||
subfolder="controlnet",
|
||||
model_id,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.in_channels = self.cnet.in_channels
|
||||
@@ -454,7 +456,7 @@ class SharkifyStableDiffusionModel:
|
||||
# -- Fetch all vmfbs for the model, if present, else delete the lot.
|
||||
need_vae_encode, need_stencil = False, False
|
||||
if args.img_path is not None:
|
||||
if args.use_stencil is not None:
|
||||
if self.use_stencil is not None:
|
||||
need_stencil = True
|
||||
else:
|
||||
need_vae_encode = True
|
||||
|
||||
@@ -317,7 +317,7 @@ class StableDiffusionPipeline:
|
||||
use_base_vae: bool,
|
||||
use_tuned: bool,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
use_stencil: bool = False,
|
||||
use_stencil: str = None,
|
||||
):
|
||||
is_inpaint = cls.__name__ in [
|
||||
"InpaintPipeline",
|
||||
@@ -337,6 +337,7 @@ class StableDiffusionPipeline:
|
||||
use_tuned=use_tuned,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
is_inpaint=is_inpaint,
|
||||
use_stencil=use_stencil,
|
||||
)
|
||||
if cls.__name__ in [
|
||||
"Image2ImagePipeline",
|
||||
|
||||
@@ -13,6 +13,7 @@ from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
from apps.stable_diffusion.src.utils.stencils.stencil_utils import (
|
||||
controlnet_hint_conversion,
|
||||
get_stencil_model_id,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.utils import (
|
||||
get_shark_model,
|
||||
|
||||
@@ -134,6 +134,24 @@ def controlnet_hint_conversion(
|
||||
return controlnet_hint
|
||||
|
||||
|
||||
stencil_to_model_id_map = {
|
||||
"canny": "lllyasviel/sd-controlnet-canny",
|
||||
"depth": "lllyasviel/sd-controlnet-depth",
|
||||
"hed": "lllyasviel/sd-controlnet-hed",
|
||||
"mlsd": "lllyasviel/sd-controlnet-mlsd",
|
||||
"normal": "lllyasviel/sd-controlnet-normal",
|
||||
"openpose": "lllyasviel/sd-controlnet-openpose",
|
||||
"scribble": "lllyasviel/sd-controlnet-scribble",
|
||||
"seg": "lllyasviel/sd-controlnet-seg",
|
||||
}
|
||||
|
||||
|
||||
def get_stencil_model_id(use_stencil):
|
||||
if use_stencil in stencil_to_model_id_map:
|
||||
return stencil_to_model_id_map[use_stencil]
|
||||
return None
|
||||
|
||||
|
||||
# Stencil 1. Canny
|
||||
def hint_canny(
|
||||
image: Image.Image,
|
||||
|
||||
Reference in New Issue
Block a user