mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix compile4 (#10797)
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -450,6 +450,8 @@ jobs:
|
||||
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
|
||||
- name: Test openpilot LLVM compile
|
||||
run: PYTHONPATH="." LLVM=1 LLVMOPT=1 JIT=2 BEAM=0 IMAGE=0 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
|
||||
- name: Test openpilot compile4
|
||||
run: PYTHONPATH="." NOLOCALS=1 GPU=1 IMAGE=2 FLOAT16=1 DEBUG=2 python3 examples/openpilot/compile4.py
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import sys, onnx
|
||||
from tinygrad import Tensor, fetch, GlobalCounters
|
||||
from tinygrad.uop import UOp
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from tinygrad.engine.grouper import get_kernelize_map
|
||||
from tinygrad.engine.schedule import create_schedule_with_vars
|
||||
@@ -37,12 +37,12 @@ if __name__ == "__main__":
|
||||
independent = UOp.sink(*independent_set.keys())
|
||||
kernelized = get_kernelize_map(independent)
|
||||
independent = independent.substitute(kernelized)
|
||||
schedule, var_vals, becomes_map = create_schedule_with_vars(independent)
|
||||
schedule, var_vals = create_schedule_with_vars(independent)
|
||||
run_schedule(schedule)
|
||||
|
||||
print("**** real ****")
|
||||
GlobalCounters.reset()
|
||||
out.uop = root.substitute(kernelized).substitute(becomes_map)
|
||||
out.uop = root.substitute(kernelized)
|
||||
out.kernelize()
|
||||
|
||||
# realize
|
||||
|
||||
@@ -63,12 +63,14 @@ def buffer_parse(onnx_tensor: TensorProto) -> Tensor:
|
||||
if len(data) == 1: return Tensor(data.tolist()[0], dtype=dtype).reshape(shape)
|
||||
return data.cast(dtype).reshape(shape).to(Device.DEFAULT)
|
||||
if has_field(onnx_tensor, "raw_data"):
|
||||
raw_data = onnx_tensor.raw_data
|
||||
if not isinstance(raw_data, Tensor): raw_data = Tensor(raw_data)
|
||||
if onnx_tensor.data_type == TensorProto.FLOAT16:
|
||||
np_buffer = np.frombuffer(onnx_tensor.raw_data.data().tobytes(),
|
||||
np_buffer = np.frombuffer(raw_data.data().tobytes(),
|
||||
dtype=helper.tensor_dtype_to_np_dtype(onnx_tensor.data_type)).copy().reshape(shape)
|
||||
if np_buffer.size == 1: return Tensor(np_buffer.item(), dtype=dtype).reshape(shape)
|
||||
return Tensor(np_buffer, dtype=dtype)
|
||||
ret = onnx_tensor.raw_data.bitcast(dtype).reshape(shape).to(Device.DEFAULT)
|
||||
ret = raw_data.bitcast(dtype).reshape(shape).to(Device.DEFAULT)
|
||||
if shape == (): ret = Tensor(ret.item(), dtype=dtype).reshape(shape)
|
||||
return ret
|
||||
return Tensor(None)
|
||||
|
||||
Reference in New Issue
Block a user