add test for compile3 [pr] (#7783)

Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
This commit is contained in:
George Hotz
2024-11-19 19:26:51 +08:00
committed by GitHub
parent 4f6071d919
commit fbb4099b3c
2 changed files with 4 additions and 0 deletions

View File

@@ -27,6 +27,7 @@ def compile():
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
input_types = {inp.name: tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input}
if getenv("FLOAT16", 0) == 0: input_types = {k:(np.float32 if v==np.float16 else v) for k,v in input_types.items()}
Tensor.manual_seed(100)
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
print("created tensors")