mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
move checks into compile3, delete compile2 [pr] (#8127)
* move checks into compile3 [pr] * test_vs_onnx * test v torch works * float16 won't compile on compile3 * actually delete compile2
This commit is contained in:
@@ -1,211 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import os, sys, io, pathlib, json, struct
|
||||
import numpy as np
|
||||
sys.path.insert(0, str(pathlib.Path(__file__).parents[1]))
|
||||
|
||||
if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1"
|
||||
if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
|
||||
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
|
||||
|
||||
OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx"
|
||||
|
||||
import onnx
|
||||
from typing import Tuple, List, Optional, Dict, cast
|
||||
from extra.onnx import get_run_onnx
|
||||
from tinygrad import Tensor, Device, GlobalCounters, dtypes
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG, tqdm
|
||||
from tinygrad.engine.realize import run_schedule, lower_schedule, ExecItem, CompiledRunner
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
Device.DEFAULT = "GPU"
|
||||
|
||||
def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
|
||||
Tensor.no_grad = True
|
||||
Tensor.training = False
|
||||
|
||||
# load the model
|
||||
onnx_model = onnx.load(io.BytesIO(onnx_data))
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
|
||||
|
||||
# run the model
|
||||
inputs = {k:Tensor.empty(*shp) for k,shp in input_shapes.items()}
|
||||
ret: Tensor = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).contiguous()
|
||||
schedule = create_schedule([ret.lazydata])
|
||||
|
||||
# filter schedule that don't depend on the inputs
|
||||
input_lb = [x.lazydata.base.buffer for x in inputs.values()]
|
||||
depends = set(input_lb)
|
||||
for si in schedule:
|
||||
if any(b in depends for b in si.inputs):
|
||||
for out in si.outputs: depends.add(out)
|
||||
|
||||
# run all kernels that don't depend on the inputs
|
||||
# NOTE: there's two extra kernels due to fusions that now happen since the weights aren't realized
|
||||
schedule, schedule_independent = partition(schedule, lambda si: any(out in depends for out in si.outputs))
|
||||
print(f"{len(schedule)} schedule items depend on the input, {len(schedule_independent)} don't")
|
||||
|
||||
# confirm no non-sink metaop in the (non independent) schedule except for the ones that load the input buffers
|
||||
assert all(si.ast.op is Ops.SINK or out in input_lb for si in schedule for out in si.outputs), "has non SINK ops, can't compile to Thneed"
|
||||
return schedule, schedule_independent, inputs
|
||||
|
||||
def test_vs_onnx(onnx_data, eis:Optional[List[ExecItem]], inputs:Dict[str, Tensor]):
|
||||
import onnx
|
||||
#import pyopencl as cl
|
||||
#from extra.thneed import Thneed
|
||||
import numpy as np
|
||||
onnx_model = onnx.load(io.BytesIO(onnx_data))
|
||||
|
||||
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
|
||||
Tensor.manual_seed(1337)
|
||||
new_inputs = {k:Tensor.randn(*shp, requires_grad=False)*8 for k,shp in input_shapes.items()}
|
||||
new_np_inputs = {k:v.realize().numpy() for k,v in new_inputs.items()}
|
||||
|
||||
if getenv("ORT"):
|
||||
# test with onnxruntime
|
||||
import onnxruntime as ort
|
||||
onnx_session = ort.InferenceSession(onnx_data)
|
||||
onnx_output = onnx_session.run([onnx_model.graph.output[0].name], {k:v.astype(np.float16) for k,v in new_np_inputs.items()})
|
||||
new_torch_out = onnx_output[0]
|
||||
print("got ort outputs")
|
||||
else:
|
||||
# test with torch
|
||||
from test.models.test_onnx import run_onnx_torch
|
||||
new_torch_out = run_onnx_torch(onnx_model, new_np_inputs).numpy()
|
||||
print("got torch outputs")
|
||||
|
||||
# if you don't have a schedule
|
||||
if eis is None:
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
new_tinygrad_out = next(iter(run_onnx(new_inputs).values())).cast(dtypes.float32).numpy()
|
||||
np.testing.assert_allclose(new_torch_out, new_tinygrad_out, atol=1e-4, rtol=1e-2)
|
||||
print("classic self-test passed!")
|
||||
return
|
||||
|
||||
# set inputs
|
||||
for k,v in inputs.items(): v.lazydata.base.realized.copyin(new_np_inputs[k].data)
|
||||
|
||||
# run code (all buffers have been allocated)
|
||||
GlobalCounters.reset()
|
||||
output = eis[-1].bufs[0]
|
||||
for ei in eis: ei.run()
|
||||
|
||||
new_tinygrad_out = np.frombuffer(output.as_buffer(), dtype=_to_np_dtype(output.dtype))
|
||||
np.testing.assert_allclose(new_torch_out.reshape(new_tinygrad_out.shape), new_tinygrad_out, atol=1e-4, rtol=1e-2)
|
||||
print("semi-thneed self-test passed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
onnx_data = fetch(sys.argv[1] if len(sys.argv) > 1 else OPENPILOT_MODEL).read_bytes()
|
||||
|
||||
# quick test for ONNX issues
|
||||
#thneed_test_onnx(onnx_data, None)
|
||||
#exit(0)
|
||||
|
||||
schedule, schedule_independent, inputs = get_schedule(onnx_data)
|
||||
schedule, schedule_input = partition(schedule, lambda x: x.ast.op is Ops.SINK)
|
||||
print(f"{len(schedule_input)} inputs")
|
||||
|
||||
run_schedule(schedule_independent)
|
||||
run_schedule(schedule_input)
|
||||
with Context(DEBUG=max(DEBUG.value, 2), BEAM=getenv("LATEBEAM")):
|
||||
schedule = memory_planner(schedule)
|
||||
for si in schedule:
|
||||
for b in si.outputs:
|
||||
assert not b.is_allocated(), "output should not be allocated"
|
||||
image_count = sum(isinstance(out.dtype, ImageDType) for si in schedule for out in si.outputs)
|
||||
print(f"**** compiling real kernels {image_count}/{len(schedule)} images ****")
|
||||
eis = list(tqdm(lower_schedule(schedule), total=len(schedule)))
|
||||
|
||||
print("kernel count:", len(eis))
|
||||
assert len(eis) <= getenv("ALLOWED_KERNEL_COUNT", 0) or getenv("ALLOWED_KERNEL_COUNT", 0) == 0, "too many kernels!"
|
||||
|
||||
# new simple thneed
|
||||
def to_ref(b:Buffer): return struct.pack("Q", id(b)).decode("latin_1")
|
||||
|
||||
seen_buffers = set()
|
||||
input_buffers = [x.lazydata.buffer for x in inputs.values()]
|
||||
jdat = {"binaries": [], "programs": {}, "kernels": [], "objects": []}
|
||||
jdat["inputs"] = {k:to_ref(v.lazydata.buffer) for k,v in inputs.items()}
|
||||
jdat["outputs"] = [to_ref(eis[-1].bufs[0])]
|
||||
weights = []
|
||||
for i,ei in enumerate(eis):
|
||||
#print("***", i)
|
||||
for b in ei.bufs:
|
||||
needs_load = b.is_allocated() and b not in input_buffers
|
||||
#print(b, needs_load)
|
||||
if b in seen_buffers: continue
|
||||
seen_buffers.add(b)
|
||||
if isinstance(b.dtype, ImageDType):
|
||||
base_dtype = dtypes.float16 if b.dtype.fmt == 'e' else dtypes.float32
|
||||
row_pitch = (b.dtype.shape[0]*4*base_dtype.itemsize + 63)//64 * 64
|
||||
size = row_pitch * b.dtype.shape[1]
|
||||
jdat['objects'].append({
|
||||
"id": to_ref(b), "needs_load": needs_load, "size": size, "arg_type": "image2d_t",
|
||||
"width": b.dtype.shape[0], "height": b.dtype.shape[1], "row_pitch": row_pitch, "float32": b.dtype.base == dtypes.float32,
|
||||
})
|
||||
if needs_load:
|
||||
t = Tensor.empty(b.dtype.shape, dtype=b.dtype)
|
||||
t.lazydata.buffer = b
|
||||
data = t.cast(dtypes.float32).pad(((0, row_pitch//(4*base_dtype.itemsize)-b.dtype.shape[0]), (0,0), (0,0))).contiguous().numpy()
|
||||
# NOTE: this cast must be done in numpy for platforms that don't support half
|
||||
if base_dtype == dtypes.float16: data = data.astype(np.float16)
|
||||
weights.append(data.tobytes())
|
||||
assert len(weights[-1]) == size, "wrong size buffer"
|
||||
else:
|
||||
jdat['objects'].append({
|
||||
"id": to_ref(b), "arg_type": b.dtype.name + "*", "needs_load": needs_load, "size": b.nbytes,
|
||||
})
|
||||
if needs_load:
|
||||
weights.append(b.as_buffer())
|
||||
assert len(weights[-1]) == b.nbytes, "wrong size buffer"
|
||||
|
||||
saved_binaries = set()
|
||||
binaries = []
|
||||
gated_read_image_count = 0
|
||||
GlobalCounters.reset()
|
||||
with Context(DEBUG=max(DEBUG.value, 2)):
|
||||
for ei in eis:
|
||||
prg = cast(CompiledRunner, ei.prg)
|
||||
assert len(prg.p.vars) == 0
|
||||
if prg.p.function_name not in saved_binaries:
|
||||
jdat['binaries'].append({"name":prg.p.function_name, "length":len(prg.lib)})
|
||||
binaries.append(prg.lib)
|
||||
saved_binaries.add(prg.p.function_name)
|
||||
gated_read_image_count += prg.p.src.count("?read_image")
|
||||
ei.run()
|
||||
jdat['kernels'].append({
|
||||
"name": prg.p.function_name,
|
||||
"work_dim": len(prg.p.global_size),
|
||||
"global_work_size": prg.p.global_size,
|
||||
"local_work_size": prg.p.local_size,
|
||||
"num_args": len(ei.bufs),
|
||||
"args": [to_ref(b) for b in ei.bufs],
|
||||
"arg_size": [8]*len(ei.bufs),
|
||||
})
|
||||
|
||||
if (allowed_gated_read_image:=getenv("ALLOWED_GATED_READ_IMAGE", -1)) != -1:
|
||||
assert gated_read_image_count <= allowed_gated_read_image, \
|
||||
f"too many gated read_image! {gated_read_image_count=}, {allowed_gated_read_image=}"
|
||||
|
||||
output_fn = sys.argv[2] if len(sys.argv) >= 3 else "/tmp/output.thneed"
|
||||
print(f"saving thneed to {output_fn} with {len(weights)} buffers and {len(binaries)} binaries")
|
||||
with open(output_fn, "wb") as f:
|
||||
j = json.dumps(jdat, ensure_ascii=False).encode('latin_1')
|
||||
f.write(struct.pack("I", len(j)))
|
||||
f.write(j)
|
||||
for w in weights: f.write(w)
|
||||
for b in binaries: f.write(b)
|
||||
print("saved", f.tell(), "bytes")
|
||||
|
||||
FLOAT16 = getenv("FLOAT16", 0)
|
||||
if FLOAT16 == 0:
|
||||
try:
|
||||
test_vs_onnx(onnx_data, eis, inputs)
|
||||
except ModuleNotFoundError as e:
|
||||
print(f"TEST NOT HAPPENING {e}")
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ if "JIT_BATCH_SIZE" not in os.environ: os.environ["JIT_BATCH_SIZE"] = "0"
|
||||
from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters, Device
|
||||
from tinygrad.helpers import DEBUG, getenv
|
||||
from tinygrad.tensor import _from_np_dtype
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
|
||||
import onnx
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
@@ -16,12 +17,11 @@ from extra.onnx import get_run_onnx # TODO: port to main tinygrad
|
||||
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx"
|
||||
OUTPUT = "/tmp/openpilot.pkl"
|
||||
|
||||
def compile():
|
||||
def compile(onnx_file):
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
Tensor.no_grad = True
|
||||
Tensor.training = False
|
||||
|
||||
onnx_bytes = fetch(OPENPILOT_MODEL)
|
||||
onnx_model = onnx.load(onnx_bytes)
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
print("loaded model")
|
||||
|
||||
@@ -48,23 +48,29 @@ def compile():
|
||||
np.testing.assert_equal(test_val, ret, "JIT run failed")
|
||||
print("jit run validated")
|
||||
|
||||
# checks from compile2
|
||||
kernel_count = 0
|
||||
gated_read_image_count = 0
|
||||
for ei in run_onnx_jit.captured.jit_cache:
|
||||
if isinstance(ei.prg, CompiledRunner):
|
||||
kernel_count += 1
|
||||
gated_read_image_count += ei.prg.p.src.count("?read_image")
|
||||
print(f"kernel_count: {kernel_count} gated_read_image_count: {gated_read_image_count}")
|
||||
assert kernel_count <= getenv("ALLOWED_KERNEL_COUNT", 0) or getenv("ALLOWED_KERNEL_COUNT", 0) == 0, "too many kernels!"
|
||||
if (allowed_gated_read_image:=getenv("ALLOWED_GATED_READ_IMAGE", -1)) != -1:
|
||||
assert gated_read_image_count <= allowed_gated_read_image, \
|
||||
f"too many gated read_image! {gated_read_image_count=}, {allowed_gated_read_image=}"
|
||||
|
||||
with open(OUTPUT, "wb") as f:
|
||||
pickle.dump(run_onnx_jit, f)
|
||||
mdl_sz = os.path.getsize(onnx_bytes)
|
||||
mdl_sz = os.path.getsize(onnx_file)
|
||||
pkl_sz = os.path.getsize(OUTPUT)
|
||||
print(f"mdl size is {mdl_sz/1e6:.2f}M")
|
||||
print(f"pkl size is {pkl_sz/1e6:.2f}M")
|
||||
print("**** compile done ****")
|
||||
return test_val
|
||||
|
||||
def test(test_val=None):
|
||||
with open(OUTPUT, "rb") as f:
|
||||
run = pickle.load(f)
|
||||
|
||||
# same randomness as above
|
||||
Tensor.manual_seed(100)
|
||||
new_inputs = {nm:Tensor.randn(*st.shape, dtype=dtype).mul(8).realize() for nm, (st, _, dtype, _) in
|
||||
sorted(zip(run.captured.expected_names, run.captured.expected_st_vars_dtype_device))}
|
||||
def test_vs_compile(run, new_inputs, test_val=None):
|
||||
new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()}
|
||||
|
||||
# create fake "from_blob" tensors for the inputs, and wrapped NPY tensors for the numpy inputs (these have the same underlying memory)
|
||||
@@ -88,8 +94,39 @@ def test(test_val=None):
|
||||
out = run(**inputs)
|
||||
changed_val = out.numpy()
|
||||
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, val, changed_val)
|
||||
return val
|
||||
|
||||
def test_vs_onnx(new_inputs, test_val, onnx_file):
|
||||
new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()}
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
|
||||
if getenv("ORT"):
|
||||
# test with onnxruntime
|
||||
import onnxruntime as ort
|
||||
onnx_session = ort.InferenceSession(onnx_file)
|
||||
onnx_output = onnx_session.run([onnx_model.graph.output[0].name], {k:v.astype(np.float16) for k,v in new_inputs_numpy.items()})
|
||||
new_torch_out = onnx_output[0]
|
||||
print("got ort outputs")
|
||||
else:
|
||||
# test with torch
|
||||
from test.models.test_onnx import run_onnx_torch
|
||||
# NOTE: we have to correct the order here
|
||||
new_torch_out = run_onnx_torch(onnx_model, {k.name:new_inputs_numpy[k.name] for k in onnx_model.graph.input}).numpy()
|
||||
print("got torch outputs")
|
||||
|
||||
np.testing.assert_allclose(new_torch_out.reshape(test_val.shape), test_val, atol=1e-4, rtol=1e-2)
|
||||
print("test vs onnx passed")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_val = compile() if not getenv("RUN") else None
|
||||
test(test_val)
|
||||
onnx_file = fetch(OPENPILOT_MODEL)
|
||||
test_val = compile(onnx_file) if not getenv("RUN") else None
|
||||
|
||||
with open(OUTPUT, "rb") as f: pickle_loaded = pickle.load(f)
|
||||
|
||||
# same randomness as compile
|
||||
Tensor.manual_seed(100)
|
||||
new_inputs = {nm:Tensor.randn(*st.shape, dtype=dtype).mul(8).realize() for nm, (st, _, dtype, _) in
|
||||
sorted(zip(pickle_loaded.captured.expected_names, pickle_loaded.captured.expected_st_vars_dtype_device))}
|
||||
|
||||
test_val = test_vs_compile(pickle_loaded, new_inputs, test_val)
|
||||
if not getenv("FLOAT16"): test_vs_onnx(new_inputs, test_val, onnx_file)
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
#!/bin/bash
|
||||
NOLOCALS=1 FLOAT16=1 DEBUGCL=1 IMAGE=2 GPU=1 python3 examples/openpilot/compile2.py
|
||||
Reference in New Issue
Block a user