move copy into the JIT for openpilot compile3 (#7937)

* move copy into the JIT, test fails

* ahh, prune was the issue
This commit is contained in:
George Hotz
2024-12-07 13:26:26 +08:00
committed by GitHub
parent 0ed731b5ea
commit 22feb3a2f1

View File

@@ -5,7 +5,7 @@ if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
if "JIT_BATCH_SIZE" not in os.environ: os.environ["JIT_BATCH_SIZE"] = "0"
from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters
from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters, Device
from tinygrad.helpers import DEBUG, getenv
from tinygrad.tensor import _from_np_dtype
@@ -30,18 +30,22 @@ def compile():
if getenv("FLOAT16", 0) == 0: input_types = {k:(np.float32 if v==np.float16 else v) for k,v in input_types.items()}
Tensor.manual_seed(100)
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()}
print("created tensors")
run_onnx_jit = TinyJit(lambda **kwargs: run_onnx(kwargs), prune=True)
run_onnx_jit = TinyJit(lambda **kwargs:
next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())).cast('float32'), prune=True)
for i in range(3):
GlobalCounters.reset()
print(f"run {i}")
inputs = {**{k:v.clone() for k,v in new_inputs.items() if 'img' in k},
**{k:Tensor(v, device="NPY").realize() for k,v in new_inputs_numpy.items() if 'img' not in k}}
with Context(DEBUG=max(DEBUG.value, 2 if i == 2 else 1)):
ret = next(iter(run_onnx_jit(**new_inputs).values())).cast('float32').numpy()
ret = run_onnx_jit(**inputs).numpy()
# copy i == 1 so use of JITBEAM is okay
if i == 1: test_val = np.copy(ret)
print(f"captured {len(run_onnx_jit.captured.jit_cache)} kernels")
np.testing.assert_equal(test_val, ret)
np.testing.assert_equal(test_val, ret, "JIT run failed")
print("jit run validated")
with open(OUTPUT, "wb") as f:
@@ -64,10 +68,10 @@ def test(test_val=None):
st = time.perf_counter()
# Need to cast non-image inputs from numpy, this is only realistic way to run it
inputs = {**{k:v for k,v in new_inputs.items() if 'img' in k},
**{k:Tensor(v) for k,v in new_inputs_numpy.items() if 'img' not in k}}
**{k:Tensor(v, device="NPY").realize() for k,v in new_inputs_numpy.items() if 'img' not in k}}
out = run(**inputs)
mt = time.perf_counter()
val = out['outputs'].numpy()
val = out.numpy()
et = time.perf_counter()
print(f"enqueue {(mt-st)*1e3:6.2f} ms -- total run {(et-st)*1e3:6.2f} ms")
print(out, val.shape, val.dtype)