mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Capture input information from mlir_graph and generate random inputs.
This commit is contained in:
@@ -16,19 +16,22 @@ print("Running shark on cpu backend")
|
||||
shark_module = SharkInference(
|
||||
mhlo_ir, function_name="forward", device="cpu", mlir_dialect="mhlo"
|
||||
)
|
||||
|
||||
# Generate the random inputs and feed into the graph.
|
||||
x = shark_module.generate_random_inputs()
|
||||
shark_module.compile()
|
||||
print(shark_module.forward((arg0, arg1)))
|
||||
print(shark_module.forward(x))
|
||||
|
||||
print("Running shark on cuda backend")
|
||||
shark_module = SharkInference(
|
||||
mhlo_ir, function_name="forward", device="cuda", mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.compile()
|
||||
print(shark_module.forward((arg0, arg1)))
|
||||
print(shark_module.forward(x))
|
||||
|
||||
print("Running shark on vulkan backend")
|
||||
shark_module = SharkInference(
|
||||
mhlo_ir, function_name="forward", device="vulkan", mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.compile()
|
||||
print(shark_module.forward((arg0, arg1)))
|
||||
print(shark_module.forward(x))
|
||||
|
||||
@@ -10,6 +10,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from shark.shark_runner import SharkRunner
|
||||
import numpy as np
|
||||
|
||||
|
||||
dtype_to_np_dtype = {
|
||||
"f32": np.float32,
|
||||
"f64": np.float64,
|
||||
"i32": np.int32,
|
||||
"i64": np.int64,
|
||||
"i1": np.bool_,
|
||||
}
|
||||
|
||||
|
||||
class SharkInference:
|
||||
@@ -70,3 +80,44 @@ class SharkInference:
|
||||
# inputs are considered to be tuple of np.array.
|
||||
def forward(self, inputs: tuple):
|
||||
return self.shark_runner.run(inputs)
|
||||
|
||||
# Captures the static input information from the mlir_module.
|
||||
# TODO(pashu123): Generate the input information for dynamic shapes.
|
||||
def _input_info(self):
|
||||
# func_key to get the line which contains the function.
|
||||
func_key = "func.func @" + self.function_name
|
||||
func_header = None
|
||||
for line in str(self.mlir_module).splitlines():
|
||||
if func_key in line:
|
||||
func_header = line
|
||||
break
|
||||
if func_header is None:
|
||||
print(f"Function: {self.function_name} not found")
|
||||
|
||||
import re
|
||||
|
||||
inputs = re.findall("\(.*?\)", func_header)[0].split(",")
|
||||
shapes = []
|
||||
dtype = []
|
||||
for inp in inputs:
|
||||
shape_dtype = re.findall(r"<[^>]*>", inp)[0].split("x")
|
||||
shape_dtype[0], shape_dtype[-1] = (
|
||||
shape_dtype[0][1:],
|
||||
shape_dtype[-1][:-1],
|
||||
)
|
||||
shapes.append(tuple([int(x) for x in shape_dtype[:-1]]))
|
||||
dtype.append(shape_dtype[-1])
|
||||
|
||||
return shapes, dtype
|
||||
|
||||
# Generates random input to be feed into the graph.
|
||||
def generate_random_inputs(self, low=0, high=1):
|
||||
shapes, dtype = self._input_info()
|
||||
inputs = []
|
||||
for i, j in zip(shapes, dtype):
|
||||
inputs.append(
|
||||
np.random.uniform(low, high, size=i).astype(
|
||||
dtype_to_np_dtype[j]
|
||||
)
|
||||
)
|
||||
return tuple(inputs)
|
||||
|
||||
@@ -98,7 +98,3 @@ class SharkRunner:
|
||||
return export_iree_module_to_vmfb(
|
||||
self.model, self.device, dir, self.mlir_dialect
|
||||
)
|
||||
|
||||
# TODO: Get the input information from the mlir_module.
|
||||
def input_info(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user