mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
move copy into the JIT for openpilot compile3 (#7937)
* move copy into the JIT, test fails * ahh, prune was the issue
This commit is contained in:
@@ -5,7 +5,7 @@ if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
|
||||
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
|
||||
if "JIT_BATCH_SIZE" not in os.environ: os.environ["JIT_BATCH_SIZE"] = "0"
|
||||
|
||||
from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters
|
||||
from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters, Device
|
||||
from tinygrad.helpers import DEBUG, getenv
|
||||
from tinygrad.tensor import _from_np_dtype
|
||||
|
||||
@@ -30,18 +30,22 @@ def compile():
|
||||
if getenv("FLOAT16", 0) == 0: input_types = {k:(np.float32 if v==np.float16 else v) for k,v in input_types.items()}
|
||||
Tensor.manual_seed(100)
|
||||
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
|
||||
new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()}
|
||||
print("created tensors")
|
||||
|
||||
run_onnx_jit = TinyJit(lambda **kwargs: run_onnx(kwargs), prune=True)
|
||||
run_onnx_jit = TinyJit(lambda **kwargs:
|
||||
next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())).cast('float32'), prune=True)
|
||||
for i in range(3):
|
||||
GlobalCounters.reset()
|
||||
print(f"run {i}")
|
||||
inputs = {**{k:v.clone() for k,v in new_inputs.items() if 'img' in k},
|
||||
**{k:Tensor(v, device="NPY").realize() for k,v in new_inputs_numpy.items() if 'img' not in k}}
|
||||
with Context(DEBUG=max(DEBUG.value, 2 if i == 2 else 1)):
|
||||
ret = next(iter(run_onnx_jit(**new_inputs).values())).cast('float32').numpy()
|
||||
ret = run_onnx_jit(**inputs).numpy()
|
||||
# copy i == 1 so use of JITBEAM is okay
|
||||
if i == 1: test_val = np.copy(ret)
|
||||
print(f"captured {len(run_onnx_jit.captured.jit_cache)} kernels")
|
||||
np.testing.assert_equal(test_val, ret)
|
||||
np.testing.assert_equal(test_val, ret, "JIT run failed")
|
||||
print("jit run validated")
|
||||
|
||||
with open(OUTPUT, "wb") as f:
|
||||
@@ -64,10 +68,10 @@ def test(test_val=None):
|
||||
st = time.perf_counter()
|
||||
# Need to cast non-image inputs from numpy, this is only realistic way to run it
|
||||
inputs = {**{k:v for k,v in new_inputs.items() if 'img' in k},
|
||||
**{k:Tensor(v) for k,v in new_inputs_numpy.items() if 'img' not in k}}
|
||||
**{k:Tensor(v, device="NPY").realize() for k,v in new_inputs_numpy.items() if 'img' not in k}}
|
||||
out = run(**inputs)
|
||||
mt = time.perf_counter()
|
||||
val = out['outputs'].numpy()
|
||||
val = out.numpy()
|
||||
et = time.perf_counter()
|
||||
print(f"enqueue {(mt-st)*1e3:6.2f} ms -- total run {(et-st)*1e3:6.2f} ms")
|
||||
print(out, val.shape, val.dtype)
|
||||
|
||||
Reference in New Issue
Block a user