From 22feb3a2f1ba6ad6276ccb67498ef08365135a3b Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 7 Dec 2024 13:26:26 +0800 Subject: [PATCH] move copy into the JIT for openpilot compile3 (#7937) * move copy into the JIT, test fails * ahh, prune was the issue --- examples/openpilot/compile3.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/examples/openpilot/compile3.py b/examples/openpilot/compile3.py index 87776e777c..12f1b51470 100644 --- a/examples/openpilot/compile3.py +++ b/examples/openpilot/compile3.py @@ -5,7 +5,7 @@ if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2" if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1" if "JIT_BATCH_SIZE" not in os.environ: os.environ["JIT_BATCH_SIZE"] = "0" -from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters +from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters, Device from tinygrad.helpers import DEBUG, getenv from tinygrad.tensor import _from_np_dtype @@ -30,18 +30,22 @@ def compile(): if getenv("FLOAT16", 0) == 0: input_types = {k:(np.float32 if v==np.float16 else v) for k,v in input_types.items()} Tensor.manual_seed(100) new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())} + new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()} print("created tensors") - run_onnx_jit = TinyJit(lambda **kwargs: run_onnx(kwargs), prune=True) + run_onnx_jit = TinyJit(lambda **kwargs: + next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())).cast('float32'), prune=True) for i in range(3): GlobalCounters.reset() print(f"run {i}") + inputs = {**{k:v.clone() for k,v in new_inputs.items() if 'img' in k}, + **{k:Tensor(v, device="NPY").realize() for k,v in new_inputs_numpy.items() if 'img' not in k}} with Context(DEBUG=max(DEBUG.value, 2 if i == 2 else 1)): - ret = next(iter(run_onnx_jit(**new_inputs).values())).cast('float32').numpy() + ret = run_onnx_jit(**inputs).numpy() # copy i == 1 so use of JITBEAM is okay if i == 1: test_val = np.copy(ret) print(f"captured {len(run_onnx_jit.captured.jit_cache)} kernels") - np.testing.assert_equal(test_val, ret) + np.testing.assert_equal(test_val, ret, "JIT run failed") print("jit run validated") with open(OUTPUT, "wb") as f: @@ -64,10 +68,10 @@ def test(test_val=None): st = time.perf_counter() # Need to cast non-image inputs from numpy, this is only realistic way to run it inputs = {**{k:v for k,v in new_inputs.items() if 'img' in k}, - **{k:Tensor(v) for k,v in new_inputs_numpy.items() if 'img' not in k}} + **{k:Tensor(v, device="NPY").realize() for k,v in new_inputs_numpy.items() if 'img' not in k}} out = run(**inputs) mt = time.perf_counter() - val = out['outputs'].numpy() + val = out.numpy() et = time.perf_counter() print(f"enqueue {(mt-st)*1e3:6.2f} ms -- total run {(et-st)*1e3:6.2f} ms") print(out, val.shape, val.dtype)