From 666d151f8a737bfe7a290597ed0092a864563e60 Mon Sep 17 00:00:00 2001 From: Diogo Date: Wed, 7 Jun 2023 22:44:30 -0400 Subject: [PATCH] Onnx slice fixups (#952) * resolved some slice test errors and added some more debugging logs * use same device in cumsum * increased float priority * onnx debug ouput match input --- extra/onnx.py | 47 +++++++++++++-------- extra/onnx_ops.py | 6 +-- test/external/external_test_onnx_backend.py | 9 ---- tinygrad/helpers.py | 2 +- tinygrad/tensor.py | 4 +- 5 files changed, 36 insertions(+), 32 deletions(-) diff --git a/extra/onnx.py b/extra/onnx.py index 86fe64d0fb..f34cc1d118 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -3,8 +3,8 @@ from google.protobuf.internal.containers import RepeatedCompositeFieldContainer import importlib import numpy as np from tinygrad.tensor import Tensor -from tinygrad.helpers import prod -from tinygrad.helpers import getenv, DEBUG +from tinygrad.helpers import prod, getenv, DEBUG, dtypes +from typing import List from onnx.onnx_pb import AttributeProto, ModelProto, TensorProto try: from onnx.helper import tensor_dtype_to_np_dtype @@ -93,22 +93,32 @@ def get_run_onnx(onnx_model: ModelProto): shape = shape_to_tuple(inp.type.tensor_type.shape) if len(shape) >= 1 and shape[0] == 0: shape = tuple([1]+list(shape[1:])) # 1 batch size if inp.name in inputs: - input_shape = inputs[inp.name].shape - if input_shape == (0,): raise NotImplementedError("empty tensors aren't supported in tinygrad") - assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}" if isinstance(inputs[inp.name], Tensor): input_tensors[inp.name] = inputs[inp.name] else: input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=False) + input_shape = input_tensors[inp.name].shape + if input_shape == (0,): raise NotImplementedError("empty tensors aren't supported in tinygrad") + assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}" for _,v in input_tensors.items(): v.realize() else: raise Exception(f"no data for {inp.name} with shape {shape}") + def fetch_tensor(x: str): + if x in tensors: return tensors[x] + if x in intermediate_tensors: return intermediate_tensors[x] + if x != str(): return input_tensors[x] + return None + for num,n in enumerate(onnx_model.graph.node): - inp = [tensors[x] if x in tensors else (intermediate_tensors[x] if x in intermediate_tensors else (input_tensors[x] if x != str() else None)) for x in n.input] + inp: List[Tensor] = [] + if debug: print("inputs:") + for x in n.input: + t = fetch_tensor(x) + if debug: print(f"\t{x} - {t}") + inp.append(t) opt = attribute_dict[num] if debug: print(f"{num}: op {n.op_type} shape {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}") - # free ones if n.op_type == "Relu": ret = inp[0].relu() elif n.op_type == "Sigmoid": ret = inp[0].sigmoid() @@ -128,7 +138,7 @@ def get_run_onnx(onnx_model: ModelProto): elif 'value_int' in opt: ret = Tensor(np.array(opt['value_int'], dtype=np.int64), requires_grad=False) elif 'value_floats' in opt: ret = Tensor(np.array(opt['value_floats'], dtype=np.float32), requires_grad=False) elif 'value_ints' in opt: ret = Tensor(np.array(opt['value_ints'], dtype=np.int64), requires_grad=False) - else: raise NotImplementedError(f'Constant not implemented') + else: raise NotImplementedError(f'Constant not implemented for {opt}') elif n.op_type == "Reshape": ret = inp[0].reshape([int(x) if x != 0 else inp[0].shape[i] for i,x in enumerate(safe_numpy(inp[1]))]) elif n.op_type == "Resize": # TODO: this is handcoded for YOLOv8 @@ -139,14 +149,14 @@ def get_run_onnx(onnx_model: ModelProto): ret = ret.reshape([x*y for x,y in zip(inp[0].shape, [int(x) for x in scales])]) elif n.op_type == "Gather": # TODO: is this correct? seems to work for simple gather ops - axis = opt['axis'] + axis = opt['axis'] if 'axis' in opt else 0 shape = list(inp[0].shape) indices = [shape[axis]+int(x) if x<0 else int(x) for x in safe_numpy(inp[1])] args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(shape)] for i in indices] ret = inp[0].slice(arg=args[0]).cat(*[inp[0].slice(arg=arg) for arg in args[1:]], dim=axis) ret = ret.reshape([s for i,s in enumerate(shape) if i != axis]) if len(indices) == 1 else ret # squeeze if needed elif n.op_type in ["Add", "Sub", "Mul", "Pow"]: - if (len(inp[0].shape) != len(inp[1].shape)) and (prod(inp[0].shape) == prod(inp[1].shape)): + if all([isinstance(x, Tensor) for x in inp]) and (len(inp[0].shape) != len(inp[1].shape)) and (prod(inp[0].shape) == prod(inp[1].shape)): inp[1] = inp[1].reshape(inp[0].shape) # TODO: is this right? if 'broadcast' in opt: inp[1] = inp[1].reshape([-1 if i == opt['broadcast'] else 1 for i in range(len(inp[0].shape))]) @@ -167,12 +177,12 @@ def get_run_onnx(onnx_model: ModelProto): elif n.op_type == "Slice": assert onnx_model_version >= 10, f'only onnx version >= 10 supported for slice' arg = [(0,x) for x in inp[0].shape] - starts, ends, axes = inp[1:4] - assert axes.shape == (1,) - axis, starts, ends = int(safe_numpy(axes)[0]), int(safe_numpy(starts)[0]), int(safe_numpy(ends)[0]) - ends = min(ends, inp[0].shape[axis]) - starts = starts + inp[0].shape[axis] if starts < 0 else starts - arg[axis] = (starts, ends) + starts, ends = inp[1:3] + axes = safe_numpy(Tensor.arange(inp[0].ndim, dtype=dtypes.int32) if len(inp) <= 3 else inp[3]) + steps = safe_numpy(inp[4])[0] if len(inp) > 4 else 1 + starts, ends = safe_numpy(starts.cast(dtypes.int32)).tolist(), safe_numpy(ends.cast(dtypes.int32)).tolist() # TODO: when indexing is added use that + for i,axis in enumerate(axes.tolist()): + arg[axis] = (starts[i], ends[i]) ret = inp[0].slice(arg=arg) elif n.op_type == "Shrink": bias = opt['bias'] if 'bias' in opt else 0 @@ -192,7 +202,10 @@ def get_run_onnx(onnx_model: ModelProto): if not isinstance(ret, tuple): ret = (ret, ) assert len(n.output) <= len(ret), f"expected output size must be less than {len(ret)}, it's {n.output}" if debug: print([x.shape if isinstance(x, Tensor) else None for x in ret]) - for i in range(len(n.output)): intermediate_tensors[n.output[i]] = ret[i] + if debug: print("outputs:") + for i in range(len(n.output)): + if debug: print(f"\t{n.output[i]} - {ret[i]}") + intermediate_tensors[n.output[i]] = ret[i] #print(ret[0].numpy().mean()) if num == ONNXLIMIT: output_tensor_names = n.output diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index ee0d617e7e..2491abe3a7 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -75,7 +75,7 @@ def _padding(X, pads=None, auto_pad="NOTSET", axes=None, constant_value=0.): return zero_padded + constant_padder def Pad(x: Tensor, pads: Union[Tensor, Tuple[int, ...]], constant_value: Tensor=None, axes: Tensor=None, mode="constant", value: float=0.): - assert mode == "constant" + assert mode == "constant", f"WARNING: Pad mode {mode} not implemented" constant_value = value if constant_value is None else constant_value.numpy() seq_pads = list(pads) if isinstance(pads, tuple) else pads.numpy().astype(np.int32).tolist() seq_axes = axes.numpy().astype(np.int32).tolist() if axes is not None else None @@ -110,8 +110,8 @@ def Dropout(data, ratio=0.5, training_mode=False, seed=None): mask = Tensor((rng.random(data.shape) >= ratio), requires_grad=False, device=data.device) return data * mask * (1/(1.0 - ratio)), mask -def Shape(data, end=None, start=0): return list(data.shape)[start:end] -def Size(data): return prod(data.shape) +def Shape(data, end=None, start=0): return Tensor(list(data.shape)[start:end], dtype=dtypes.int64) +def Size(data): return prod(data if isinstance(data, list) else data.shape) # TODO: this doesn't match Tensor.flatten behavior def Flatten(input, axis=1): diff --git a/test/external/external_test_onnx_backend.py b/test/external/external_test_onnx_backend.py index 8fc42cc619..9164ab51b3 100644 --- a/test/external/external_test_onnx_backend.py +++ b/test/external/external_test_onnx_backend.py @@ -92,18 +92,9 @@ backend_test.exclude('test_asin_*') backend_test.exclude('test_asinh_*') backend_test.exclude('test_atan_*') backend_test.exclude('test_atanh_*') -# backend_test.include('test_cos_*') -# backend_test.include('test_cosh_*') -# backend_test.exclude('test_sin_*') -# backend_test.include('test_sinh_*') -# backend_test.include('test_tanh_*') # no boolean ops (2d, 3d, 4d) -# backend_test.exclude('test_and*') -# backend_test.exclude('test_xor*') -# backend_test.exclude('test_or*') backend_test.exclude('test_bitshift_*') -# backend_test.include('test_not_*') # no scatter gather backend_test.exclude('test_gather_*') diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 450282b3ab..6201b82d4d 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -76,7 +76,7 @@ class dtypes: def from_np(x) -> DType: return asdict(dtypes())[np.dtype(x).name] bool: Final[DType] = DType(0, 1, "bool", bool) float16: Final[DType] = DType(0, 2, "half", np.float16) - float32: Final[DType] = DType(1, 4, "float", np.float32) + float32: Final[DType] = DType(4, 4, "float", np.float32) int8: Final[DType] = DType(0, 1, "char", np.int8) int32: Final[DType] = DType(1, 4, "int", np.int32) int64: Final[DType] = DType(2, 8, "int64", np.int64) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 62c6dd3ef8..d1cf6f87ef 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -143,7 +143,7 @@ class Tensor: def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs) @staticmethod - def arange(stop, start=0, step=1, **kwargs): return Tensor.full(((stop-start)//step,), step).cumsum() + (start - step) + def arange(stop, start=0, step=1, **kwargs): return Tensor.full(((stop-start)//step,), step, **kwargs).cumsum() + (start - step) @staticmethod def full_like(tensor, fill_value, dtype:Optional[DType]=None, **kwargs): @@ -470,7 +470,7 @@ class Tensor: def cumsum(self, axis=0): x = self.permute(*(i for i in range(self.ndim) if i != axis), axis) - return x.reshape(1, 1, -1, self.shape[axis]).conv2d(Tensor.ones(1, 1, 1, self.shape[axis]), padding=(self.shape[axis]-1, 0, 0, 0)).reshape(*x.shape).permute(*range(axis), self.ndim - 1, *range(axis, self.ndim-1)) + return x.reshape(1, 1, -1, self.shape[axis]).conv2d(Tensor.ones(1, 1, 1, self.shape[axis], dtype=self.dtype, device=self.device), padding=(self.shape[axis]-1, 0, 0, 0)).reshape(*x.shape).permute(*range(axis), self.ndim - 1, *range(axis, self.ndim-1)) # ***** mlops (unary) *****