Centralize the parser location. Also add the --device flag.

Centralized the shark_args parser. Also added the --device flag that
specifies the device on which the inference or training is to be done.
This commit is contained in:
Prashant Kumar
2022-05-02 09:48:32 +00:00
parent 2475b05ada
commit db5be15310
4 changed files with 71 additions and 46 deletions

View File

@@ -19,7 +19,7 @@ git clone https://github.com/NodLabs/dSHARK.git
### Run a demo script
```shell
python -m shark.examples.resnet50_script
python -m shark.examples.resnet50_script --device="cpu/gpu/vulkan"
```
### Shark Inference API

View File

@@ -1,15 +1,11 @@
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from shark.shark_runner import SharkInference
import timeit
torch.manual_seed(0)
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
def _prepare_sentence_tokens(sentence: str):
return torch.tensor([tokenizer.encode(sentence)])
class MiniLMSequenceClassification(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -25,9 +21,11 @@ class MiniLMSequenceClassification(torch.nn.Module):
return self.model.forward(tokens)[0]
test_input = _prepare_sentence_tokens("this project is very interesting")
test_input = torch.randint(2, (1,128))
shark_module = SharkInference(
MiniLMSequenceClassification(), (test_input,), device="cpu", jit_trace=True
MiniLMSequenceClassification(), (test_input,), jit_trace=True
)
results = shark_module.forward((test_input,))
print(timeit.timeit(lambda: shark_module.forward((test_input,)), number=1))

49
shark/parser.py Normal file
View File

@@ -0,0 +1,49 @@
# Copyright 2020 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
def dir_path(path):
if os.path.isdir(path):
return path
else:
raise argparse.ArgumentTypeError(
f"readable_dir:{path} is not a valid path")
parser = argparse.ArgumentParser(description='SHARK runner.')
parser.add_argument(
"--device",
type=str,
default="cpu",
help="Device on which shark_runner runs. options are cpu, gpu, and vulkan")
parser.add_argument(
"--repro_dir",
help=
"Directory to which module files will be saved for reproduction or debugging.",
type=dir_path,
default="/tmp/")
parser.add_argument(
"--save_mlir",
default=False,
action="store_true",
help="Saves input MLIR module to /tmp/ directory.")
parser.add_argument(
"--save_vmfb",
default=False,
action="store_true",
help="Saves iree .vmfb module to /tmp/ directory.")
shark_args = parser.parse_args()

View File

@@ -14,17 +14,9 @@
from shark.torch_mlir_utils import get_torch_mlir_module, export_module_to_mlir_file
from shark.iree_utils import get_results, get_iree_compiled_module, export_iree_module_to_vmfb
import argparse
import os
from shark.functorch_utils import AOTModule
def dir_path(path):
if os.path.isdir(path):
return path
else:
raise argparse.ArgumentTypeError(
f"readable_dir:{path} is not a valid path")
from shark.parser import shark_args
class SharkRunner:
@@ -39,34 +31,17 @@ class SharkRunner:
tracing_required: bool,
from_aot: bool,
):
self.parser = argparse.ArgumentParser(description='SHARK runner.')
self.parser.add_argument(
"--repro_dir",
help=
"Directory to which module files will be saved for reproduction or debugging.",
type=dir_path,
default="/tmp/")
self.parser.add_argument(
"--save_mlir",
default=False,
action="store_true",
help="Saves input MLIR module to /tmp/ directory.")
self.parser.add_argument(
"--save_vmfb",
default=False,
action="store_true",
help="Saves iree .vmfb module to /tmp/ directory.")
self.parser.parse_args(namespace=self)
self.torch_module = model
self.input = input
self.torch_mlir_module = get_torch_mlir_module(model, input, dynamic,
tracing_required,
from_aot)
if self.save_mlir:
export_module_to_mlir_file(self.torch_mlir_module, self.repro_dir)
if self.save_vmfb:
if shark_args.save_mlir:
export_module_to_mlir_file(self.torch_mlir_module, shark_args.repro_dir)
if shark_args.save_vmfb:
export_iree_module_to_vmfb(self.torch_mlir_module, device,
self.repro_dir)
shark_args.repro_dir)
(
self.iree_compilation_module,
self.iree_config,
@@ -85,7 +60,7 @@ class SharkInference:
model,
input: tuple,
dynamic: bool = False,
device: str = "cpu",
device: str = None,
jit_trace: bool = False,
from_aot: bool = False,
custom_inference_fn=None,
@@ -94,6 +69,8 @@ class SharkInference:
self.input = input
self.from_aot = from_aot
self.device = device if device is not None else shark_args.device
if from_aot:
aot_module = AOTModule(model,
input,
@@ -102,8 +79,8 @@ class SharkInference:
self.model = aot_module.forward_graph
self.input = aot_module.forward_inputs
self.shark_runner = SharkRunner(self.model, self.input, dynamic, device,
jit_trace, from_aot)
self.shark_runner = SharkRunner(self.model, self.input, dynamic,
self.device, jit_trace, from_aot)
def forward(self, inputs):
# TODO Capture weights and inputs in case of AOT, Also rework the
@@ -122,7 +99,7 @@ class SharkTrainer:
input: tuple,
label: tuple,
dynamic: bool = False,
device: str = "cpu",
device: str = None,
jit_trace: bool = False,
from_aot: bool = True,
):
@@ -130,6 +107,7 @@ class SharkTrainer:
self.model = model
self.input = input
self.label = label
self.device = device if device is not None else shark_args.device
aot_module = AOTModule(model, input, label)
aot_module.generate_training_graph()
self.forward_graph = aot_module.forward_graph
@@ -141,7 +119,7 @@ class SharkTrainer:
self.forward_graph,
self.forward_inputs,
dynamic,
device,
self.device,
jit_trace,
from_aot,
)
@@ -149,7 +127,7 @@ class SharkTrainer:
self.backward_graph,
self.backward_inputs,
dynamic,
device,
self.device,
jit_trace,
from_aot,
)