mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[Shark][Training] Refresh SharkTrainer to latest APIs.
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from torch.nn.utils import _stateless
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from shark.shark_runner import SharkTrainer
|
||||
from shark.shark_trainer import SharkTrainer
|
||||
|
||||
|
||||
class MiniLMSequenceClassification(torch.nn.Module):
|
||||
@@ -42,6 +42,7 @@ def forward(params, buffers, args):
|
||||
return params, buffers
|
||||
|
||||
|
||||
shark_module = SharkTrainer(mod, inp, custom_inference_fn=forward)
|
||||
shark_module = SharkTrainer(mod, inp)
|
||||
shark_module.compile(forward)
|
||||
|
||||
print(shark_module.forward())
|
||||
print(shark_module.train())
|
||||
|
||||
@@ -312,9 +312,51 @@ def transform_fx(fx_g):
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
# Doesn't replace the None type.
|
||||
def change_fx_graph_return_to_tuple(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
# output nodes always have one argument
|
||||
node_arg = node.args[0]
|
||||
out_nodes = []
|
||||
if isinstance(node_arg, list):
|
||||
# Don't return NoneType elements.
|
||||
for out_node in node_arg:
|
||||
if not isinstance(out_node, type(None)):
|
||||
out_nodes.append(out_node)
|
||||
# If there is a single tensor/element to be returned don't
|
||||
# a tuple for it.
|
||||
if len(out_nodes) == 1:
|
||||
node.args = out_nodes
|
||||
else:
|
||||
node.args = (tuple(out_nodes),)
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return fx_g
|
||||
|
||||
|
||||
def flatten_training_input(inputs):
|
||||
flattened_input = []
|
||||
for i in inputs:
|
||||
if isinstance(i, dict):
|
||||
for value in i.values():
|
||||
flattened_input.append(value.detach())
|
||||
elif isinstance(i, tuple):
|
||||
for value in i:
|
||||
flattened_input.append(value)
|
||||
else:
|
||||
flattened_input.append(i)
|
||||
return tuple(flattened_input)
|
||||
|
||||
|
||||
# Applies fx conversion to the model and imports the mlir.
|
||||
def import_with_fx(
|
||||
model, inputs, is_f16=False, f16_input_mask=None, debug=False
|
||||
model,
|
||||
inputs,
|
||||
is_f16=False,
|
||||
f16_input_mask=None,
|
||||
debug=False,
|
||||
training=False,
|
||||
):
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
@@ -360,9 +402,12 @@ def import_with_fx(
|
||||
transform_fx(fx_g)
|
||||
fx_g.recompile()
|
||||
|
||||
if training:
|
||||
change_fx_graph_return_to_tuple(fx_g)
|
||||
inputs = flatten_training_input(inputs)
|
||||
|
||||
ts_graph = torch.jit.script(fx_g)
|
||||
inputs = get_f16_inputs(inputs, is_f16, f16_input_mask)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
ts_graph,
|
||||
inputs,
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
from shark.parser import shark_args
|
||||
from shark.shark_runner import SharkRunner
|
||||
from shark.backward_makefx import MakeFxModule
|
||||
from shark.shark_importer import import_with_fx
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import sys
|
||||
@@ -67,23 +68,21 @@ class SharkTrainer:
|
||||
self.frontend = frontend
|
||||
|
||||
# Training function is needed in the case of torch_fn.
|
||||
def compile(self, training_fn=None):
|
||||
def compile(self, training_fn=None, extra_args=[]):
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
aot_module = MakeFxModule(
|
||||
self.model, tuple(self.input), custom_inference_fn=training_fn
|
||||
packed_inputs = (
|
||||
dict(self.model.named_parameters()),
|
||||
dict(self.model.named_buffers()),
|
||||
tuple(self.input),
|
||||
)
|
||||
mlir_module, func_name = import_with_fx(
|
||||
training_fn, packed_inputs, False, [], training=True
|
||||
)
|
||||
aot_module.generate_graph()
|
||||
# Returns the backward graph.
|
||||
training_graph = aot_module.training_graph
|
||||
weights = self.get_torch_params()
|
||||
self.shark_runner = SharkRunner(
|
||||
training_graph,
|
||||
weights + self.input,
|
||||
self.dynamic,
|
||||
mlir_module,
|
||||
self.device,
|
||||
self.jit_trace,
|
||||
self.from_aot,
|
||||
self.frontend,
|
||||
"tm_tensor",
|
||||
extra_args=extra_args,
|
||||
)
|
||||
elif self.frontend in ["tensorflow", "tf", "mhlo"]:
|
||||
self.shark_runner = SharkRunner(
|
||||
@@ -112,8 +111,8 @@ class SharkTrainer:
|
||||
params = [x.numpy() for x in params]
|
||||
print(f"Training started for {num_iters} iterations:")
|
||||
for i in tqdm(range(num_iters)):
|
||||
params = self.shark_runner.forward(
|
||||
params + self.input, self.frontend
|
||||
params = self.shark_runner.run(
|
||||
"forward", params + self.input, self.frontend
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
Reference in New Issue
Block a user