remove annoying accelerate warning (#1056)

disables usage of low_cpu_mem_usage=True in from_pretrained() calls.
Can be re-enabled by using flag --low_cpu_mem_usage
defaults to False to avoid spam as we don't include accelerate in our
requirements.txt
This commit is contained in:
Daniel Garvey
2023-02-20 14:46:26 -06:00
committed by GitHub
parent 2ae047f1a8
commit 4e92304b89
5 changed files with 29 additions and 7 deletions

View File

@@ -136,6 +136,7 @@ def img2img_inf(
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
)
if not img2img_obj:
@@ -230,6 +231,7 @@ if __name__ == "__main__":
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
)
start_time = time.time()

View File

@@ -126,6 +126,7 @@ def txt2img_inf(
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
)
if not txt2img_obj:
@@ -199,6 +200,7 @@ if __name__ == "__main__":
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
)
for current_batch in range(args.batch_count):

View File

@@ -80,6 +80,7 @@ class SharkifyStableDiffusionModel:
batch_size: int = 1,
use_base_vae: bool = False,
use_tuned: bool = False,
low_cpu_mem_usage: bool = False
):
self.check_params(max_len, width, height)
self.max_len = max_len
@@ -114,6 +115,7 @@ class SharkifyStableDiffusionModel:
if use_tuned:
self.model_name = self.model_name + "_tuned"
self.model_name = self.model_name + "_" + get_path_stem(self.model_id)
self.low_cpu_mem_usage = low_cpu_mem_usage
def get_extended_name_for_all_model(self):
model_name = {}
@@ -139,11 +141,12 @@ class SharkifyStableDiffusionModel:
def get_vae_encode(self):
class VaeEncodeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
def forward(self, input):
@@ -165,23 +168,26 @@ class SharkifyStableDiffusionModel:
def get_vae(self):
class VaeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, base_vae=self.base_vae, custom_vae=self.custom_vae):
def __init__(self, model_id=self.model_id, base_vae=self.base_vae, custom_vae=self.custom_vae, low_cpu_mem_usage=False):
super().__init__()
self.vae = None
if custom_vae == "":
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif not isinstance(custom_vae, dict):
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.vae.load_state_dict(custom_vae)
self.base_vae = base_vae
@@ -196,7 +202,7 @@ class SharkifyStableDiffusionModel:
x = x * 255.0
return x.round()
vae = VaeModel()
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
is_f16 = True if self.precision == "fp16" else False
shark_vae = compile_through_fx(
@@ -211,11 +217,12 @@ class SharkifyStableDiffusionModel:
def get_unet(self):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.in_channels = self.unet.in_channels
self.train(False)
@@ -234,7 +241,7 @@ class SharkifyStableDiffusionModel:
)
return noise_pred
unet = UnetModel()
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False]
@@ -251,17 +258,18 @@ class SharkifyStableDiffusionModel:
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(self, model_id=self.model_id):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
low_cpu_mem_usage=low_cpu_mem_usage,
)
def forward(self, input):
return self.text_encoder(input)[0]
clip_model = CLIPText()
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
shark_clip = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),

View File

@@ -201,6 +201,7 @@ class StableDiffusionPipeline:
width: int,
use_base_vae: bool,
use_tuned: bool,
low_cpu_mem_usage: bool = False,
):
if import_mlir:
mlir_import = SharkifyStableDiffusionModel(
@@ -214,6 +215,7 @@ class StableDiffusionPipeline:
width=width,
use_base_vae=use_base_vae,
use_tuned=use_tuned,
low_cpu_mem_usage=low_cpu_mem_usage,
)
if cls.__name__ in ["Image2ImagePipeline", "InpaintPipeline"]:
clip, unet, vae, vae_encode = mlir_import()
@@ -248,6 +250,7 @@ class StableDiffusionPipeline:
width=width,
use_base_vae=use_base_vae,
use_tuned=use_tuned,
low_cpu_mem_usage=low_cpu_mem_usage,
)
if cls.__name__ in ["Image2ImagePipeline", "InpaintPipeline"]:
clip, unet, vae, vae_encode = mlir_import()

View File

@@ -193,6 +193,13 @@ p.add_argument(
help="The repo-id of hugging face.",
)
p.add_argument(
"--low_cpu_mem_usage",
default=False,
action=argparse.BooleanOptionalAction,
help="Use the accelerate package to reduce cpu memory consumption",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################