mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
ONNX real float16 (#10694)
* squash commits * temp fix for const tensor * actually realizing float16 can only happen in raw_data * .float -> cast(float) to rerun CI --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import sys, onnx
|
||||
from tinygrad import Tensor, fetch, GlobalCounters
|
||||
from tinygrad import Tensor, fetch, GlobalCounters, dtypes
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from tinygrad.kernelize.kernelize import get_kernelize_map
|
||||
@@ -17,7 +17,7 @@ if __name__ == "__main__":
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
|
||||
inputs = run_onnx.get_empty_input_data("npy")
|
||||
inputs = run_onnx.get_empty_input_data("npy", dtypes.float32)
|
||||
out: Tensor = next(iter(run_onnx({k:v.to(None) for k,v in inputs.items()}).values())).to('cpu')
|
||||
root = out.uop
|
||||
targets = [x.uop for x in inputs.values()]
|
||||
|
||||
Reference in New Issue
Block a user