From 5aaa8a0cc1b867822d9b85e302ca5965a715c341 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 31 Oct 2023 11:35:03 -0700 Subject: [PATCH] fix shape --- extra/onnx_ops.py | 3 ++- openpilot/compile2.py | 21 +++++++++------------ 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 33fbabf91c..8a0ebec007 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -3,6 +3,7 @@ from tinygrad.helpers import prod, dtypes from extra.onnx import safe_numpy from onnx.helper import tensor_dtype_to_np_dtype from onnx.onnx_pb import TensorProto +import os import numpy as np import functools from typing import Union, Tuple, Optional, List, Any @@ -103,7 +104,7 @@ def OptionalGetElement(x: Tensor=None): return x if x is not None else Tensor([] def Tile(input: Tensor, repeats): return input.repeat([int(x) for x in safe_numpy(repeats)]) def Range(start: Tensor, limit, delta): return Tensor.arange(start=int(safe_numpy(start)), stop=int(safe_numpy(limit)), step=int(safe_numpy(delta))).cast(dtype=start.dtype) -def Shape(data: Tensor, end=None, start=0): return Tensor(list(data.shape)[start:end], dtype=dtypes.int64) +def Shape(data: Tensor, end=None, start=0): return Tensor(list(data.shape)[start:end], dtype=dtypes.int32 if os.path.isfile("/TICI") else dtypes.int64) # TODO: really? def Size(data: Tensor): return prod(data if isinstance(data, list) else data.shape) def Flatten(input: Tensor, axis=1): return input.reshape(prod((1,) + input.shape[0:axis]), -1) def Reshape(data: Tensor, shape: Tensor, allowzero=None): return data.reshape([int(x) if x != 0 else data.shape[i] for i,x in enumerate(safe_numpy(shape))]) diff --git a/openpilot/compile2.py b/openpilot/compile2.py index 7aa2076d0c..1ee84f91eb 100644 --- a/openpilot/compile2.py +++ b/openpilot/compile2.py @@ -109,10 +109,6 @@ def thneed_test_onnx(onnx_data, output_fn): # non thneed run_onnx = get_run_onnx(onnx_model) new_tinygrad_out = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).numpy() - for i,(x,y) in enumerate(zip(new_torch_out.flatten().tolist(), new_tinygrad_out.flatten().tolist())): - if abs(x-y) > 100: - print(i, x, y) - np.testing.assert_allclose(new_torch_out, new_tinygrad_out, atol=1e-4, rtol=1e-2) print("classic self-test passed!") else: @@ -142,19 +138,20 @@ if __name__ == "__main__": schedule, schedule_independent, inputs = get_schedule(onnx_data) schedule, schedule_input = partition(schedule, lambda x: x.ast.op not in LoadOps) print(f"{len(schedule_input)} inputs") - schedule = fix_schedule_for_images(schedule) - image_count = sum(isinstance(si.out.dtype, ImageDType) for si in schedule) - print(f"**** running real kernels {image_count}/{len(schedule)} images ****") - - if GRAPH: - for si in schedule_input: log_schedule_item(si) - for si in schedule: log_schedule_item(si) run_schedule(schedule_independent, disable_logging=True) run_schedule(schedule_input) with Context(DEBUG=2, BEAM=getenv("LATEBEAM")): + schedule = fix_schedule_for_images(schedule) + image_count = sum(isinstance(si.out.dtype, ImageDType) for si in schedule) + print(f"**** running real kernels {image_count}/{len(schedule)} images ****") + + if GRAPH: + for si in schedule_input: log_schedule_item(si) + for si in schedule: log_schedule_item(si) + GlobalCounters.reset() - run_schedule(schedule) + run_schedule(schedule[:]) output_fn = sys.argv[2] if len(sys.argv) >= 3 else "/tmp/output.thneed" schedule_to_thneed(schedule, output_fn)