diff --git a/backend/server.py b/backend/server.py index 02fc0b487f..3d24656af0 100644 --- a/backend/server.py +++ b/backend/server.py @@ -147,15 +147,12 @@ def handle_request_all_images(): paths.sort(key=lambda x: os.path.getmtime(x)) image_array = [] for path in paths: - image = Image.open(path) - metadata = {} - if 'Dream' in image.info: - try: - metadata = vars(parser.parse_args(shlex.split(image.info['Dream']))) - except SystemExit: - # TODO: Unable to parse metadata, ignore it for now, - # this can happen when metadata is missing a prompt - pass + # image = Image.open(path) + all_metadata = retrieve_metadata(path) + if 'Dream' in all_metadata and not all_metadata['sd-metadata']: + metadata = vars(parser.parse_args(shlex.split(all_metadata['Dream']))) + else: + metadata = all_metadata['sd-metadata'] image_array.append({'path': path, 'metadata': metadata}) return make_response("OK", data=image_array) @@ -308,7 +305,7 @@ def save_image(image, parameters, output_dir, step_index=None, postprocessing=Fa command = parameters_to_command(parameters) - path = pngwriter.save_image_and_prompt_to_png(image, command, filename) + path = pngwriter.save_image_and_prompt_to_png(image, command, parameters, filename) return path diff --git a/ldm/dream/pngwriter.py b/ldm/dream/pngwriter.py index a8d2425b91..ecbc3c0e15 100644 --- a/ldm/dream/pngwriter.py +++ b/ldm/dream/pngwriter.py @@ -47,7 +47,8 @@ class PngWriter: metadata stored there, as a dict ''' path = os.path.join(self.outdir,img_basename) - return retrieve_metadata(path) + all_metadata = retrieve_metadata(path) + return all_metadata['sd-metadata'] def retrieve_metadata(img_path): ''' @@ -55,6 +56,7 @@ def retrieve_metadata(img_path): metadata stored there, as a dict ''' im = Image.open(img_path) - md = im.text.get('sd-metadata',{}) - return json.loads(md) + md = im.text.get('sd-metadata', '{}') + dream_prompt = im.text.get('Dream', '') + return {'sd-metadata': json.loads(md), 'Dream': dream_prompt} diff --git a/scripts/sd-metadata.py b/scripts/sd-metadata.py index a3438fa078..02d5002d60 100644 --- a/scripts/sd-metadata.py +++ b/scripts/sd-metadata.py @@ -13,7 +13,7 @@ filenames = sys.argv[1:] for f in filenames: try: metadata = retrieve_metadata(f) - print(f'{f}:\n',json.dumps(metadata, indent=4)) + print(f'{f}:\n',json.dumps(metadata['sd-metadata'], indent=4)) except FileNotFoundError: sys.stderr.write(f'{f} not found\n') continue