mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Openpilot compile: fix for openpilot use (#8338)
* compile3 changes * merge conflict * merge conflict * give dm npy for now * Revert "give dm npy for now" This reverts commit bfd980da7d2c2bab5b073127442c361922032ba1. * updates * Always float32 floats * Update compile3.py * Update compile3.py --------- Co-authored-by: ZwX1616 <zwx1616@gmail.com>
This commit is contained in:
@@ -15,7 +15,7 @@ from onnx.helper import tensor_dtype_to_np_dtype
|
||||
from extra.onnx import get_run_onnx # TODO: port to main tinygrad
|
||||
|
||||
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx"
|
||||
OUTPUT = "/tmp/openpilot.pkl"
|
||||
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/openpilot.pkl"
|
||||
|
||||
def compile(onnx_file):
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
@@ -27,7 +27,8 @@ def compile(onnx_file):
|
||||
|
||||
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
|
||||
input_types = {inp.name: tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input}
|
||||
if getenv("FLOAT16", 0) == 0: input_types = {k:(np.float32 if v==np.float16 else v) for k,v in input_types.items()}
|
||||
# Float inputs and outputs to tinyjits for openpilot are always float32
|
||||
input_types = {k:(np.float32 if v==np.float16 else v) for k,v in input_types.items()}
|
||||
Tensor.manual_seed(100)
|
||||
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
|
||||
new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()}
|
||||
|
||||
Reference in New Issue
Block a user