onnx consts are const [pr] (#8548)

This commit is contained in:
George Hotz
2025-01-09 16:09:22 -08:00
committed by GitHub
parent 88661cd96f
commit 5720871903
2 changed files with 3 additions and 2 deletions

View File

@@ -2,7 +2,7 @@ from typing import Callable, Any, Sequence
import importlib, functools
import numpy as np
from tinygrad import Tensor, dtypes
from tinygrad.helpers import getenv, DEBUG, all_same
from tinygrad.helpers import getenv, DEBUG, all_same, get_single_element
from tinygrad.dtype import DType, ConstType
from tinygrad.device import is_dtype_supported
from onnx import AttributeProto, ModelProto, TensorProto, ValueInfoProto, helper
@@ -54,6 +54,7 @@ def buffer_parse(onnx_tensor: TensorProto) -> Tensor:
dtype, shape = dtype_parse(onnx_tensor.data_type), tuple(onnx_tensor.dims)
if data := list(onnx_tensor.float_data) or list(onnx_tensor.int32_data) or list(onnx_tensor.int64_data) or list(onnx_tensor.double_data) or \
list(onnx_tensor.uint64_data):
if shape == (): return Tensor(get_single_element(data), dtype=dtype).reshape(shape).realize()
return Tensor(data, dtype=dtype).reshape(shape).realize()
assert onnx_tensor.HasField("raw_data")
return Tensor(np.frombuffer(onnx_tensor.raw_data, dtype=helper.tensor_dtype_to_np_dtype(onnx_tensor.data_type)).copy().reshape(shape), dtype=dtype)

View File

@@ -34,7 +34,7 @@ class HWInterface:
@staticmethod
def readlink(path): return os.readlink(path)
@staticmethod
def eventfd(initval, flags=None): return HWInterface(fd=os.eventfd(initval, flags))
def eventfd(initval, flags=None): return HWInterface(fd=os.eventfd(initval, flags)) # type: ignore[attr-defined]
if MOCKGPU:=getenv("MOCKGPU"): from test.mockgpu.mockgpu import MockHWInterface as HWInterface # noqa: F401 # pylint: disable=unused-import