mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Centralize the parser location. Also add the --device flag.
Centralized the shark_args parser. Also added the --device flag that specifies the device on which the inference or training is to be done.
This commit is contained in:
@@ -19,7 +19,7 @@ git clone https://github.com/NodLabs/dSHARK.git
|
||||
|
||||
### Run a demo script
|
||||
```shell
|
||||
python -m shark.examples.resnet50_script
|
||||
python -m shark.examples.resnet50_script --device="cpu/gpu/vulkan"
|
||||
```
|
||||
|
||||
### Shark Inference API
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from shark.shark_runner import SharkInference
|
||||
import timeit
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
|
||||
|
||||
|
||||
def _prepare_sentence_tokens(sentence: str):
|
||||
return torch.tensor([tokenizer.encode(sentence)])
|
||||
|
||||
|
||||
class MiniLMSequenceClassification(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -25,9 +21,11 @@ class MiniLMSequenceClassification(torch.nn.Module):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
test_input = _prepare_sentence_tokens("this project is very interesting")
|
||||
test_input = torch.randint(2, (1,128))
|
||||
|
||||
shark_module = SharkInference(
|
||||
MiniLMSequenceClassification(), (test_input,), device="cpu", jit_trace=True
|
||||
MiniLMSequenceClassification(), (test_input,), jit_trace=True
|
||||
)
|
||||
results = shark_module.forward((test_input,))
|
||||
|
||||
print(timeit.timeit(lambda: shark_module.forward((test_input,)), number=1))
|
||||
|
||||
|
||||
49
shark/parser.py
Normal file
49
shark/parser.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
def dir_path(path):
|
||||
if os.path.isdir(path):
|
||||
return path
|
||||
else:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"readable_dir:{path} is not a valid path")
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='SHARK runner.')
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="Device on which shark_runner runs. options are cpu, gpu, and vulkan")
|
||||
parser.add_argument(
|
||||
"--repro_dir",
|
||||
help=
|
||||
"Directory to which module files will be saved for reproduction or debugging.",
|
||||
type=dir_path,
|
||||
default="/tmp/")
|
||||
parser.add_argument(
|
||||
"--save_mlir",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Saves input MLIR module to /tmp/ directory.")
|
||||
parser.add_argument(
|
||||
"--save_vmfb",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Saves iree .vmfb module to /tmp/ directory.")
|
||||
|
||||
shark_args = parser.parse_args()
|
||||
@@ -14,17 +14,9 @@
|
||||
|
||||
from shark.torch_mlir_utils import get_torch_mlir_module, export_module_to_mlir_file
|
||||
from shark.iree_utils import get_results, get_iree_compiled_module, export_iree_module_to_vmfb
|
||||
import argparse
|
||||
import os
|
||||
from shark.functorch_utils import AOTModule
|
||||
|
||||
|
||||
def dir_path(path):
|
||||
if os.path.isdir(path):
|
||||
return path
|
||||
else:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"readable_dir:{path} is not a valid path")
|
||||
from shark.parser import shark_args
|
||||
|
||||
|
||||
class SharkRunner:
|
||||
@@ -39,34 +31,17 @@ class SharkRunner:
|
||||
tracing_required: bool,
|
||||
from_aot: bool,
|
||||
):
|
||||
self.parser = argparse.ArgumentParser(description='SHARK runner.')
|
||||
self.parser.add_argument(
|
||||
"--repro_dir",
|
||||
help=
|
||||
"Directory to which module files will be saved for reproduction or debugging.",
|
||||
type=dir_path,
|
||||
default="/tmp/")
|
||||
self.parser.add_argument(
|
||||
"--save_mlir",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Saves input MLIR module to /tmp/ directory.")
|
||||
self.parser.add_argument(
|
||||
"--save_vmfb",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Saves iree .vmfb module to /tmp/ directory.")
|
||||
self.parser.parse_args(namespace=self)
|
||||
self.torch_module = model
|
||||
self.input = input
|
||||
self.torch_mlir_module = get_torch_mlir_module(model, input, dynamic,
|
||||
tracing_required,
|
||||
from_aot)
|
||||
if self.save_mlir:
|
||||
export_module_to_mlir_file(self.torch_mlir_module, self.repro_dir)
|
||||
if self.save_vmfb:
|
||||
|
||||
if shark_args.save_mlir:
|
||||
export_module_to_mlir_file(self.torch_mlir_module, shark_args.repro_dir)
|
||||
if shark_args.save_vmfb:
|
||||
export_iree_module_to_vmfb(self.torch_mlir_module, device,
|
||||
self.repro_dir)
|
||||
shark_args.repro_dir)
|
||||
(
|
||||
self.iree_compilation_module,
|
||||
self.iree_config,
|
||||
@@ -85,7 +60,7 @@ class SharkInference:
|
||||
model,
|
||||
input: tuple,
|
||||
dynamic: bool = False,
|
||||
device: str = "cpu",
|
||||
device: str = None,
|
||||
jit_trace: bool = False,
|
||||
from_aot: bool = False,
|
||||
custom_inference_fn=None,
|
||||
@@ -94,6 +69,8 @@ class SharkInference:
|
||||
self.input = input
|
||||
self.from_aot = from_aot
|
||||
|
||||
self.device = device if device is not None else shark_args.device
|
||||
|
||||
if from_aot:
|
||||
aot_module = AOTModule(model,
|
||||
input,
|
||||
@@ -102,8 +79,8 @@ class SharkInference:
|
||||
self.model = aot_module.forward_graph
|
||||
self.input = aot_module.forward_inputs
|
||||
|
||||
self.shark_runner = SharkRunner(self.model, self.input, dynamic, device,
|
||||
jit_trace, from_aot)
|
||||
self.shark_runner = SharkRunner(self.model, self.input, dynamic,
|
||||
self.device, jit_trace, from_aot)
|
||||
|
||||
def forward(self, inputs):
|
||||
# TODO Capture weights and inputs in case of AOT, Also rework the
|
||||
@@ -122,7 +99,7 @@ class SharkTrainer:
|
||||
input: tuple,
|
||||
label: tuple,
|
||||
dynamic: bool = False,
|
||||
device: str = "cpu",
|
||||
device: str = None,
|
||||
jit_trace: bool = False,
|
||||
from_aot: bool = True,
|
||||
):
|
||||
@@ -130,6 +107,7 @@ class SharkTrainer:
|
||||
self.model = model
|
||||
self.input = input
|
||||
self.label = label
|
||||
self.device = device if device is not None else shark_args.device
|
||||
aot_module = AOTModule(model, input, label)
|
||||
aot_module.generate_training_graph()
|
||||
self.forward_graph = aot_module.forward_graph
|
||||
@@ -141,7 +119,7 @@ class SharkTrainer:
|
||||
self.forward_graph,
|
||||
self.forward_inputs,
|
||||
dynamic,
|
||||
device,
|
||||
self.device,
|
||||
jit_trace,
|
||||
from_aot,
|
||||
)
|
||||
@@ -149,7 +127,7 @@ class SharkTrainer:
|
||||
self.backward_graph,
|
||||
self.backward_inputs,
|
||||
dynamic,
|
||||
device,
|
||||
self.device,
|
||||
jit_trace,
|
||||
from_aot,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user