[Shark][Training] Refresh SharkTrainer to latest APIs.

This commit is contained in:
stanley
2023-01-17 23:55:09 +00:00
parent 9d3af37104
commit c4a9365aa1
3 changed files with 65 additions and 20 deletions

View File

@@ -1,7 +1,7 @@
import torch
from torch.nn.utils import _stateless
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from shark.shark_runner import SharkTrainer
from shark.shark_trainer import SharkTrainer
class MiniLMSequenceClassification(torch.nn.Module):
@@ -42,6 +42,7 @@ def forward(params, buffers, args):
return params, buffers
shark_module = SharkTrainer(mod, inp, custom_inference_fn=forward)
shark_module = SharkTrainer(mod, inp)
shark_module.compile(forward)
print(shark_module.forward())
print(shark_module.train())

View File

@@ -312,9 +312,51 @@ def transform_fx(fx_g):
fx_g.graph.lint()
# Doesn't replace the None type.
def change_fx_graph_return_to_tuple(fx_g):
for node in fx_g.graph.nodes:
if node.op == "output":
# output nodes always have one argument
node_arg = node.args[0]
out_nodes = []
if isinstance(node_arg, list):
# Don't return NoneType elements.
for out_node in node_arg:
if not isinstance(out_node, type(None)):
out_nodes.append(out_node)
# If there is a single tensor/element to be returned don't
# a tuple for it.
if len(out_nodes) == 1:
node.args = out_nodes
else:
node.args = (tuple(out_nodes),)
fx_g.graph.lint()
fx_g.recompile()
return fx_g
def flatten_training_input(inputs):
flattened_input = []
for i in inputs:
if isinstance(i, dict):
for value in i.values():
flattened_input.append(value.detach())
elif isinstance(i, tuple):
for value in i:
flattened_input.append(value)
else:
flattened_input.append(i)
return tuple(flattened_input)
# Applies fx conversion to the model and imports the mlir.
def import_with_fx(
model, inputs, is_f16=False, f16_input_mask=None, debug=False
model,
inputs,
is_f16=False,
f16_input_mask=None,
debug=False,
training=False,
):
import torch
from torch.fx.experimental.proxy_tensor import make_fx
@@ -360,9 +402,12 @@ def import_with_fx(
transform_fx(fx_g)
fx_g.recompile()
if training:
change_fx_graph_return_to_tuple(fx_g)
inputs = flatten_training_input(inputs)
ts_graph = torch.jit.script(fx_g)
inputs = get_f16_inputs(inputs, is_f16, f16_input_mask)
mlir_importer = SharkImporter(
ts_graph,
inputs,

View File

@@ -15,6 +15,7 @@
from shark.parser import shark_args
from shark.shark_runner import SharkRunner
from shark.backward_makefx import MakeFxModule
from shark.shark_importer import import_with_fx
import numpy as np
from tqdm import tqdm
import sys
@@ -67,23 +68,21 @@ class SharkTrainer:
self.frontend = frontend
# Training function is needed in the case of torch_fn.
def compile(self, training_fn=None):
def compile(self, training_fn=None, extra_args=[]):
if self.frontend in ["torch", "pytorch"]:
aot_module = MakeFxModule(
self.model, tuple(self.input), custom_inference_fn=training_fn
packed_inputs = (
dict(self.model.named_parameters()),
dict(self.model.named_buffers()),
tuple(self.input),
)
mlir_module, func_name = import_with_fx(
training_fn, packed_inputs, False, [], training=True
)
aot_module.generate_graph()
# Returns the backward graph.
training_graph = aot_module.training_graph
weights = self.get_torch_params()
self.shark_runner = SharkRunner(
training_graph,
weights + self.input,
self.dynamic,
mlir_module,
self.device,
self.jit_trace,
self.from_aot,
self.frontend,
"tm_tensor",
extra_args=extra_args,
)
elif self.frontend in ["tensorflow", "tf", "mhlo"]:
self.shark_runner = SharkRunner(
@@ -112,8 +111,8 @@ class SharkTrainer:
params = [x.numpy() for x in params]
print(f"Training started for {num_iters} iterations:")
for i in tqdm(range(num_iters)):
params = self.shark_runner.forward(
params + self.input, self.frontend
params = self.shark_runner.run(
"forward", params + self.input, self.frontend
)
return params