mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
onnx consts are const [pr] (#8548)
This commit is contained in:
@@ -2,7 +2,7 @@ from typing import Callable, Any, Sequence
|
||||
import importlib, functools
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.helpers import getenv, DEBUG, all_same
|
||||
from tinygrad.helpers import getenv, DEBUG, all_same, get_single_element
|
||||
from tinygrad.dtype import DType, ConstType
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from onnx import AttributeProto, ModelProto, TensorProto, ValueInfoProto, helper
|
||||
@@ -54,6 +54,7 @@ def buffer_parse(onnx_tensor: TensorProto) -> Tensor:
|
||||
dtype, shape = dtype_parse(onnx_tensor.data_type), tuple(onnx_tensor.dims)
|
||||
if data := list(onnx_tensor.float_data) or list(onnx_tensor.int32_data) or list(onnx_tensor.int64_data) or list(onnx_tensor.double_data) or \
|
||||
list(onnx_tensor.uint64_data):
|
||||
if shape == (): return Tensor(get_single_element(data), dtype=dtype).reshape(shape).realize()
|
||||
return Tensor(data, dtype=dtype).reshape(shape).realize()
|
||||
assert onnx_tensor.HasField("raw_data")
|
||||
return Tensor(np.frombuffer(onnx_tensor.raw_data, dtype=helper.tensor_dtype_to_np_dtype(onnx_tensor.data_type)).copy().reshape(shape), dtype=dtype)
|
||||
|
||||
@@ -34,7 +34,7 @@ class HWInterface:
|
||||
@staticmethod
|
||||
def readlink(path): return os.readlink(path)
|
||||
@staticmethod
|
||||
def eventfd(initval, flags=None): return HWInterface(fd=os.eventfd(initval, flags))
|
||||
def eventfd(initval, flags=None): return HWInterface(fd=os.eventfd(initval, flags)) # type: ignore[attr-defined]
|
||||
|
||||
if MOCKGPU:=getenv("MOCKGPU"): from test.mockgpu.mockgpu import MockHWInterface as HWInterface # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
|
||||
Reference in New Issue
Block a user