mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
add test for compile3 [pr] (#7783)
Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
This commit is contained in:
@@ -27,6 +27,7 @@ def compile():
|
||||
|
||||
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
|
||||
input_types = {inp.name: tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input}
|
||||
if getenv("FLOAT16", 0) == 0: input_types = {k:(np.float32 if v==np.float16 else v) for k,v in input_types.items()}
|
||||
Tensor.manual_seed(100)
|
||||
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
|
||||
print("created tensors")
|
||||
|
||||
Reference in New Issue
Block a user