mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Make loading custom inpainting models general (#1126)
This commit is contained in:
@@ -188,14 +188,16 @@ if __name__ == "__main__":
|
||||
if args.mask_path is None:
|
||||
print("Flag --mask_path is required.")
|
||||
exit()
|
||||
if "inpaint" not in args.hf_model_id:
|
||||
print("Please use inpainting model with --hf_model_id.")
|
||||
exit()
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if "inpaint" in args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-inpainting"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
image = Image.open(args.img_path)
|
||||
|
||||
@@ -200,14 +200,16 @@ if __name__ == "__main__":
|
||||
if args.img_path is None:
|
||||
print("Flag --img_path is required.")
|
||||
exit()
|
||||
if "inpaint" not in args.hf_model_id:
|
||||
print("Please use inpainting model with --hf_model_id.")
|
||||
exit()
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if "inpaint" in args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-inpainting"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
image = Image.open(args.img_path)
|
||||
|
||||
@@ -80,7 +80,8 @@ class SharkifyStableDiffusionModel:
|
||||
batch_size: int = 1,
|
||||
use_base_vae: bool = False,
|
||||
use_tuned: bool = False,
|
||||
low_cpu_mem_usage: bool = False
|
||||
low_cpu_mem_usage: bool = False,
|
||||
is_inpaint: bool = False
|
||||
):
|
||||
self.check_params(max_len, width, height)
|
||||
self.max_len = max_len
|
||||
@@ -116,6 +117,7 @@ class SharkifyStableDiffusionModel:
|
||||
self.model_name = self.model_name + "_tuned"
|
||||
self.model_name = self.model_name + "_" + get_path_stem(self.model_id)
|
||||
self.low_cpu_mem_usage = low_cpu_mem_usage
|
||||
self.is_inpaint = is_inpaint
|
||||
|
||||
def get_extended_name_for_all_model(self, mask_to_fetch):
|
||||
model_name = {}
|
||||
@@ -484,7 +486,7 @@ class SharkifyStableDiffusionModel:
|
||||
assert self.custom_weights.lower().endswith(
|
||||
(".ckpt", ".safetensors")
|
||||
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
|
||||
preprocessCKPT(self.custom_weights)
|
||||
preprocessCKPT(self.custom_weights, self.is_inpaint)
|
||||
else:
|
||||
model_to_run = args.hf_model_id
|
||||
# For custom Vae user can provide either the repo-id or a checkpoint file,
|
||||
|
||||
@@ -319,6 +319,10 @@ class StableDiffusionPipeline:
|
||||
low_cpu_mem_usage: bool = False,
|
||||
use_stencil: bool = False,
|
||||
):
|
||||
is_inpaint = cls.__name__ in [
|
||||
"InpaintPipeline",
|
||||
"OutpaintPipeline",
|
||||
]
|
||||
if import_mlir:
|
||||
mlir_import = SharkifyStableDiffusionModel(
|
||||
model_id,
|
||||
@@ -332,6 +336,7 @@ class StableDiffusionPipeline:
|
||||
use_base_vae=use_base_vae,
|
||||
use_tuned=use_tuned,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
is_inpaint=is_inpaint,
|
||||
)
|
||||
if cls.__name__ in [
|
||||
"Image2ImagePipeline",
|
||||
@@ -386,6 +391,7 @@ class StableDiffusionPipeline:
|
||||
use_base_vae=use_base_vae,
|
||||
use_tuned=use_tuned,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
is_inpaint=is_inpaint,
|
||||
)
|
||||
if cls.__name__ in [
|
||||
"Image2ImagePipeline",
|
||||
|
||||
@@ -416,7 +416,7 @@ def get_path_to_diffusers_checkpoint(custom_weights):
|
||||
return path_to_diffusers
|
||||
|
||||
|
||||
def preprocessCKPT(custom_weights):
|
||||
def preprocessCKPT(custom_weights, is_inpaint=False):
|
||||
path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights)
|
||||
if next(Path(path_to_diffusers).iterdir(), None):
|
||||
print("Checkpoint already loaded at : ", path_to_diffusers)
|
||||
@@ -437,7 +437,7 @@ def preprocessCKPT(custom_weights):
|
||||
print(
|
||||
"Loading diffusers' pipeline from original stable diffusion checkpoint"
|
||||
)
|
||||
num_in_channels = 9 if "inpainting" in custom_weights else 4
|
||||
num_in_channels = 9 if is_inpaint else 4
|
||||
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=custom_weights,
|
||||
extract_ema=extract_ema,
|
||||
|
||||
Reference in New Issue
Block a user