fix pylint for onnx (#11673)

* fix pylint for onnx

* too long
This commit is contained in:
chenyu
2025-08-14 15:48:02 -07:00
committed by GitHub
parent e9d0027591
commit 48c4033ae1
2 changed files with 11 additions and 7 deletions

View File

@@ -332,7 +332,7 @@ jobs:
python3 -m ruff check extra/onnx.py
python3 -m ruff check examples/mlperf/ --ignore E501
- name: Lint tinygrad with pylint
run: python -m pylint tinygrad/
run: python -m pylint tinygrad/ extra/onnx.py
- name: Run mypy
run: |
python -m mypy --strict-equality --lineprecision-report .

View File

@@ -1,3 +1,4 @@
# pylint: disable=possibly-unused-variable
from typing import Any, Sequence, cast, Literal, NamedTuple, Generator
import dataclasses, functools, io, math, types, warnings, pathlib, sys, os, struct, enum
from io import BufferedReader
@@ -231,9 +232,9 @@ class OnnxPBParser:
if self.file_path is None:
if isinstance(self.tensor.device, str) and self.tensor.device.startswith("DISK:"):
self.file_path = pathlib.Path(self.tensor.device[5:])
else: raise Exception("onnx external_data needs the origin file path, try passing onnx file path to onnx_load")
else: raise ValueError("onnx external_data needs the origin file path, try passing onnx file path to onnx_load")
ext_path = self.file_path.parent.joinpath(location)
if not ext_path.exists(): raise Exception(f"external location not exists: {ext_path}")
if not ext_path.exists(): raise FileNotFoundError(f"external location not exists: {ext_path}")
ext_tensor = Tensor(ext_path)
obj["raw_data"] = ext_tensor[offset:offset+length] if length is not None else ext_tensor[offset:]
@@ -583,7 +584,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
# ***** Unary Ops (math) *****
def Not(x:Tensor): return x.logical_not()
def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None): return x if min is None and max is None else x.clip(min, max) # noqa: A002
def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None): return x if min is None and max is None else x.clip(min, max) # noqa: A002 # pylint: disable=redefined-builtin
def IsInf(x:Tensor, detect_negative:int=1, detect_positive:int=1): return x.isinf(bool(detect_positive), bool(detect_negative))
# ***** Unary Ops (activation) *****
@@ -731,7 +732,9 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
pads = _onnx_pads_to_tiny_pads(pads)
return X.conv_transpose2d(W, B, group, strides_, dilations_, pads, output_padding_)
def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shape:list[int]=[], pads:list[int]|int=0, strides:list[int]|int=1):
def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shape:list[int]|None=None, pads:list[int]|int=0,
strides:list[int]|int=1):
if kernel_shape is None: kernel_shape = []
pads_: int | tuple[int, ...] = tuple(pads) if isinstance(pads, list) else pads
return Tensor.max_unpool2d(xT, xI, tuple(kernel_shape), strides, 1, pads_, outshape if outshape is None else tuple(outshape))
@@ -860,7 +863,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
return X.permute(*argsort(perm)) if perm else X
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated
def TopK(X:Tensor, K:int|list[int], axis:int=-1, largest:int=1, sorted:int=1): # noqa: A002
def TopK(X:Tensor, K:int|list[int], axis:int=-1, largest:int=1, sorted:int=1): # noqa: A002 # pylint: disable=redefined-builtin
val, idx = X.topk(_resolve_const(K), axis, bool(largest), bool(sorted))
return val, idx.cast(dtypes.int64)
@@ -920,7 +923,8 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
if seg_embedding_res is not None: embedding_sum = embedding_sum + seg_embedding_res
out = embedding_sum.layernorm(eps=epsilon) * gamma + beta
return out, None, embedding_sum
def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]):
def MeanVarianceNormalization(x:Tensor, axis:list[int]|None=None):
if axis is None: axis = [0,2,3]
return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9)
def OneHot(indices:Tensor, depth:float|int|list[int|float], values:Tensor, axis:int=-1):