Fix onnx Domain bug (#11650)

This commit is contained in:
geohotstan
2025-08-13 23:20:50 +08:00
committed by GitHub
parent 67df617fe1
commit 925555b62a
2 changed files with 1 additions and 2 deletions

View File

@@ -401,7 +401,7 @@ class OnnxRunner:
def __init__(self, model_path: Tensor | str | pathlib.Path):
model = OnnxPBParser(model_path, load_external_data=True).parse()
graph = model["graph"]
self.is_training = any(n['domain'] in {Domain.AI_ONNX_TRAINING, Domain.AI_ONNX_PREVIEW_TRAINING} for n in graph["node"])
self.is_training = any(n['parsed_node'].opset_id.domain in {Domain.AI_ONNX_TRAINING, Domain.AI_ONNX_PREVIEW_TRAINING} for n in graph["node"])
self.graph_values = {"": None, **{i["name"]: i["parsed_tensor"] for i in graph["initializer"]}}
self.graph_inputs = {i["name"]: i["parsed_type"] for i in graph["input"] if i["name"] not in self.graph_values}
self.graph_outputs = tuple(o["name"] for o in graph["output"])