From 4df581811ea77ac6081ebe55fc4b59fd494a140c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 7 Aug 2023 21:01:48 -0400 Subject: [PATCH] add template verification script --- scripts/verify_checkpoint_template.py | 51 +++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100755 scripts/verify_checkpoint_template.py diff --git a/scripts/verify_checkpoint_template.py b/scripts/verify_checkpoint_template.py new file mode 100755 index 0000000000..42c7acca3a --- /dev/null +++ b/scripts/verify_checkpoint_template.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +""" +Read a checkpoint/safetensors file and compare it to a template .json. +Returns True if their metadata match. +""" + +import sys +import argparse +import json + +from pathlib import Path + +from invokeai.backend.model_management.models.base import read_checkpoint_meta + +parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model") +parser.add_argument( + "--checkpoint", + "--in", + type=Path, + help="Path to the input checkpoint/safetensors file" +) +parser.add_argument( + "--template", + "--out", + type=Path, + help="Path to the template .json file to match against" +) + +opt = parser.parse_args() +ckpt = read_checkpoint_meta(opt.checkpoint) +while "state_dict" in ckpt: + ckpt = ckpt["state_dict"] + +checkpoint_metadata = {} + +for key, tensor in ckpt.items(): + checkpoint_metadata[key] = list(tensor.shape) + +with open(opt.template,'r') as f: + template = json.load(f) + +if checkpoint_metadata == template: + print('True') + sys.exit(0) +else: + print('False') + sys.exit(-1) + + + +