[SD] Update unet in_channels API and add PIL metadata to spec. (#1560)

* Fix deprecation warning for unet config.

* Include PIL metadata instead of hidden imports in SD spec.
This commit is contained in:
Ean Garvey
2023-06-20 12:26:36 -05:00
committed by GitHub
parent 3fb72e192e
commit 0def74f520
2 changed files with 2 additions and 2 deletions

View File

@@ -19,6 +19,7 @@ datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += copy_metadata('Pillow')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('pytorch_lightning')
@@ -48,7 +49,6 @@ block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("PIL") if "tests" not in x]
a = Analysis(
['web/index.py'],

View File

@@ -426,7 +426,7 @@ class SharkifyStableDiffusionModel:
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.in_channels = self.unet.config.in_channels
self.train(False)
if(args.attention_slicing is not None and args.attention_slicing != "none"):
if(args.attention_slicing.isdigit()):