Add frontend error checks.

This commit is contained in:
Prashant Kumar
2022-05-27 10:13:03 +00:00
parent c69baa3b1e
commit cee02f6a61
2 changed files with 25 additions and 5 deletions

View File

@@ -14,7 +14,11 @@ import os
from shark.parser import shark_args
from shark.shark_runner import SharkRunner, SharkBenchmarkRunner
import time
import sys
# Prints to stderr.
def print_err(*a):
print(*a, file=sys.stderr)
class SharkInference:
"""Inference API targeting pytorch, tensorflow, linalg, mhlo and tosa frontend."""
@@ -40,9 +44,14 @@ class SharkInference:
self.shark_runner = None
# Sets the frontend i.e `pytorch` `tensorflow`, `linalg`, `mhlo`, `tosa`.
# Sets the frontend i.e `pytorch` or `tensorflow`.
def set_frontend(self, frontend: str):
self.frontend = frontend
if frontend not in [
"pytorch", "torch", "tensorflow", "tf", "mhlo", "linalg", "tosa"
]:
print_err("frontend not supported.")
else:
self.frontend = frontend
def compile(self):
# Inference do not use AOT.

View File

@@ -21,6 +21,12 @@ from shark.backward_makefx import MakeFxModule
import numpy as np
from tqdm import tqdm
import time
import sys
# Prints to stderr.
def print_err(*a):
print(*a, file=sys.stderr)
class SharkTrainer:
@@ -51,7 +57,12 @@ class SharkTrainer:
# Sets the frontend i.e `pytorch` or `tensorflow`.
def set_frontend(self, frontend: str):
self.frontend = frontend
if frontend not in [
"pytorch", "torch", "tensorflow", "tf", "mhlo", "linalg", "tosa"
]:
print_err("frontend not supported.")
else:
self.frontend = frontend
# Training function is needed in the case of torch_fn.
def compile(self, training_fn=None):
@@ -73,7 +84,7 @@ class SharkTrainer:
self.jit_trace, self.from_aot,
self.frontend)
else:
print("Unknown frontend")
print_err("Unknown frontend")
return
# The inputs to the mlir-graph are weights, buffers and inputs respectively.
@@ -121,5 +132,5 @@ class SharkTrainer:
elif self.frontend in ["tf", "tensorflow"]:
return self._train_tf(num_iters)
else:
print("Unknown frontend")
print_err("Unknown frontend")
return