mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Fix onnx Domain bug (#11650)
This commit is contained in:
@@ -401,7 +401,7 @@ class OnnxRunner:
|
||||
def __init__(self, model_path: Tensor | str | pathlib.Path):
|
||||
model = OnnxPBParser(model_path, load_external_data=True).parse()
|
||||
graph = model["graph"]
|
||||
self.is_training = any(n['domain'] in {Domain.AI_ONNX_TRAINING, Domain.AI_ONNX_PREVIEW_TRAINING} for n in graph["node"])
|
||||
self.is_training = any(n['parsed_node'].opset_id.domain in {Domain.AI_ONNX_TRAINING, Domain.AI_ONNX_PREVIEW_TRAINING} for n in graph["node"])
|
||||
self.graph_values = {"": None, **{i["name"]: i["parsed_tensor"] for i in graph["initializer"]}}
|
||||
self.graph_inputs = {i["name"]: i["parsed_type"] for i in graph["input"] if i["name"] not in self.graph_values}
|
||||
self.graph_outputs = tuple(o["name"] for o in graph["output"])
|
||||
|
||||
Reference in New Issue
Block a user