Capture input information from mlir_graph and generate random inputs.

This commit is contained in:
Prashant Kumar
2022-06-29 20:42:59 +05:30
parent 2adea76b8c
commit 83855e7b08
3 changed files with 57 additions and 7 deletions

View File

@@ -16,19 +16,22 @@ print("Running shark on cpu backend")
shark_module = SharkInference(
mhlo_ir, function_name="forward", device="cpu", mlir_dialect="mhlo"
)
# Generate the random inputs and feed into the graph.
x = shark_module.generate_random_inputs()
shark_module.compile()
print(shark_module.forward((arg0, arg1)))
print(shark_module.forward(x))
print("Running shark on cuda backend")
shark_module = SharkInference(
mhlo_ir, function_name="forward", device="cuda", mlir_dialect="mhlo"
)
shark_module.compile()
print(shark_module.forward((arg0, arg1)))
print(shark_module.forward(x))
print("Running shark on vulkan backend")
shark_module = SharkInference(
mhlo_ir, function_name="forward", device="vulkan", mlir_dialect="mhlo"
)
shark_module.compile()
print(shark_module.forward((arg0, arg1)))
print(shark_module.forward(x))

View File

@@ -10,6 +10,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from shark.shark_runner import SharkRunner
import numpy as np
dtype_to_np_dtype = {
"f32": np.float32,
"f64": np.float64,
"i32": np.int32,
"i64": np.int64,
"i1": np.bool_,
}
class SharkInference:
@@ -70,3 +80,44 @@ class SharkInference:
# inputs are considered to be tuple of np.array.
def forward(self, inputs: tuple):
return self.shark_runner.run(inputs)
# Captures the static input information from the mlir_module.
# TODO(pashu123): Generate the input information for dynamic shapes.
def _input_info(self):
# func_key to get the line which contains the function.
func_key = "func.func @" + self.function_name
func_header = None
for line in str(self.mlir_module).splitlines():
if func_key in line:
func_header = line
break
if func_header is None:
print(f"Function: {self.function_name} not found")
import re
inputs = re.findall("\(.*?\)", func_header)[0].split(",")
shapes = []
dtype = []
for inp in inputs:
shape_dtype = re.findall(r"<[^>]*>", inp)[0].split("x")
shape_dtype[0], shape_dtype[-1] = (
shape_dtype[0][1:],
shape_dtype[-1][:-1],
)
shapes.append(tuple([int(x) for x in shape_dtype[:-1]]))
dtype.append(shape_dtype[-1])
return shapes, dtype
# Generates random input to be feed into the graph.
def generate_random_inputs(self, low=0, high=1):
shapes, dtype = self._input_info()
inputs = []
for i, j in zip(shapes, dtype):
inputs.append(
np.random.uniform(low, high, size=i).astype(
dtype_to_np_dtype[j]
)
)
return tuple(inputs)

View File

@@ -98,7 +98,3 @@ class SharkRunner:
return export_iree_module_to_vmfb(
self.model, self.device, dir, self.mlir_dialect
)
# TODO: Get the input information from the mlir_module.
def input_info(self):
pass