mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add frontend error checks.
This commit is contained in:
@@ -14,7 +14,11 @@ import os
|
||||
from shark.parser import shark_args
|
||||
from shark.shark_runner import SharkRunner, SharkBenchmarkRunner
|
||||
import time
|
||||
import sys
|
||||
|
||||
# Prints to stderr.
|
||||
def print_err(*a):
|
||||
print(*a, file=sys.stderr)
|
||||
|
||||
class SharkInference:
|
||||
"""Inference API targeting pytorch, tensorflow, linalg, mhlo and tosa frontend."""
|
||||
@@ -40,9 +44,14 @@ class SharkInference:
|
||||
|
||||
self.shark_runner = None
|
||||
|
||||
# Sets the frontend i.e `pytorch` `tensorflow`, `linalg`, `mhlo`, `tosa`.
|
||||
# Sets the frontend i.e `pytorch` or `tensorflow`.
|
||||
def set_frontend(self, frontend: str):
|
||||
self.frontend = frontend
|
||||
if frontend not in [
|
||||
"pytorch", "torch", "tensorflow", "tf", "mhlo", "linalg", "tosa"
|
||||
]:
|
||||
print_err("frontend not supported.")
|
||||
else:
|
||||
self.frontend = frontend
|
||||
|
||||
def compile(self):
|
||||
# Inference do not use AOT.
|
||||
|
||||
@@ -21,6 +21,12 @@ from shark.backward_makefx import MakeFxModule
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
import sys
|
||||
|
||||
|
||||
# Prints to stderr.
|
||||
def print_err(*a):
|
||||
print(*a, file=sys.stderr)
|
||||
|
||||
|
||||
class SharkTrainer:
|
||||
@@ -51,7 +57,12 @@ class SharkTrainer:
|
||||
|
||||
# Sets the frontend i.e `pytorch` or `tensorflow`.
|
||||
def set_frontend(self, frontend: str):
|
||||
self.frontend = frontend
|
||||
if frontend not in [
|
||||
"pytorch", "torch", "tensorflow", "tf", "mhlo", "linalg", "tosa"
|
||||
]:
|
||||
print_err("frontend not supported.")
|
||||
else:
|
||||
self.frontend = frontend
|
||||
|
||||
# Training function is needed in the case of torch_fn.
|
||||
def compile(self, training_fn=None):
|
||||
@@ -73,7 +84,7 @@ class SharkTrainer:
|
||||
self.jit_trace, self.from_aot,
|
||||
self.frontend)
|
||||
else:
|
||||
print("Unknown frontend")
|
||||
print_err("Unknown frontend")
|
||||
return
|
||||
|
||||
# The inputs to the mlir-graph are weights, buffers and inputs respectively.
|
||||
@@ -121,5 +132,5 @@ class SharkTrainer:
|
||||
elif self.frontend in ["tf", "tensorflow"]:
|
||||
return self._train_tf(num_iters)
|
||||
else:
|
||||
print("Unknown frontend")
|
||||
print_err("Unknown frontend")
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user