Make loading custom inpainting models general (#1126)

This commit is contained in:
jinchen62
2023-03-01 22:14:04 -08:00
committed by GitHub
parent 7f3f92b9d5
commit 080350d311
5 changed files with 24 additions and 12 deletions

View File

@@ -188,14 +188,16 @@ if __name__ == "__main__":
if args.mask_path is None:
print("Flag --mask_path is required.")
exit()
if "inpaint" not in args.hf_model_id:
print("Please use inpainting model with --hf_model_id.")
exit()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
model_id = (
args.hf_model_id
if "inpaint" in args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
image = Image.open(args.img_path)

View File

@@ -200,14 +200,16 @@ if __name__ == "__main__":
if args.img_path is None:
print("Flag --img_path is required.")
exit()
if "inpaint" not in args.hf_model_id:
print("Please use inpainting model with --hf_model_id.")
exit()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
model_id = (
args.hf_model_id
if "inpaint" in args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
image = Image.open(args.img_path)

View File

@@ -80,7 +80,8 @@ class SharkifyStableDiffusionModel:
batch_size: int = 1,
use_base_vae: bool = False,
use_tuned: bool = False,
low_cpu_mem_usage: bool = False
low_cpu_mem_usage: bool = False,
is_inpaint: bool = False
):
self.check_params(max_len, width, height)
self.max_len = max_len
@@ -116,6 +117,7 @@ class SharkifyStableDiffusionModel:
self.model_name = self.model_name + "_tuned"
self.model_name = self.model_name + "_" + get_path_stem(self.model_id)
self.low_cpu_mem_usage = low_cpu_mem_usage
self.is_inpaint = is_inpaint
def get_extended_name_for_all_model(self, mask_to_fetch):
model_name = {}
@@ -484,7 +486,7 @@ class SharkifyStableDiffusionModel:
assert self.custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
preprocessCKPT(self.custom_weights)
preprocessCKPT(self.custom_weights, self.is_inpaint)
else:
model_to_run = args.hf_model_id
# For custom Vae user can provide either the repo-id or a checkpoint file,

View File

@@ -319,6 +319,10 @@ class StableDiffusionPipeline:
low_cpu_mem_usage: bool = False,
use_stencil: bool = False,
):
is_inpaint = cls.__name__ in [
"InpaintPipeline",
"OutpaintPipeline",
]
if import_mlir:
mlir_import = SharkifyStableDiffusionModel(
model_id,
@@ -332,6 +336,7 @@ class StableDiffusionPipeline:
use_base_vae=use_base_vae,
use_tuned=use_tuned,
low_cpu_mem_usage=low_cpu_mem_usage,
is_inpaint=is_inpaint,
)
if cls.__name__ in [
"Image2ImagePipeline",
@@ -386,6 +391,7 @@ class StableDiffusionPipeline:
use_base_vae=use_base_vae,
use_tuned=use_tuned,
low_cpu_mem_usage=low_cpu_mem_usage,
is_inpaint=is_inpaint,
)
if cls.__name__ in [
"Image2ImagePipeline",

View File

@@ -416,7 +416,7 @@ def get_path_to_diffusers_checkpoint(custom_weights):
return path_to_diffusers
def preprocessCKPT(custom_weights):
def preprocessCKPT(custom_weights, is_inpaint=False):
path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights)
if next(Path(path_to_diffusers).iterdir(), None):
print("Checkpoint already loaded at : ", path_to_diffusers)
@@ -437,7 +437,7 @@ def preprocessCKPT(custom_weights):
print(
"Loading diffusers' pipeline from original stable diffusion checkpoint"
)
num_in_channels = 9 if "inpainting" in custom_weights else 4
num_in_channels = 9 if is_inpaint else 4
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=custom_weights,
extract_ema=extract_ema,