diff --git a/shark/examples/shark_inference/sharded_bloom_large_models.py b/shark/examples/shark_inference/sharded_bloom_large_models.py index 1635ac13..fa204682 100644 --- a/shark/examples/shark_inference/sharded_bloom_large_models.py +++ b/shark/examples/shark_inference/sharded_bloom_large_models.py @@ -324,7 +324,7 @@ if __name__ == "__main__": mlir_str = bytes(mlir_str, "utf-8") - if config["n_embed"] == 14336: + if "n_embed" in config.keys() and config["n_embed"] == 14336: def get_state_dict(): d = torch.load(