Files
AMD-SHARK-Studio/shark/backward_makefx.py
2022-05-04 20:58:59 +05:30

68 lines
2.4 KiB
Python

# Copyright 2020 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from functorch.compile import (
aot_module,
memory_efficient_fusion,
)
from functorch import make_fx
from torch.nn.utils import _stateless
from torch import fx
import copy
class MakeFxModule:
def __init__(self, model, inputs, labels=None, custom_inference_fn=None):
self.model = model
self.inputs = inputs
self.custom_inference_fn = custom_inference_fn
self.backward_graph = None
# Doesn't replace the None type.
def change_fx_graph_return_to_tuple(self, fx_g: fx.GraphModule):
for node in fx_g.graph.nodes:
if node.op == "output":
# output nodes always have one argument
node_arg = node.args[0]
out_nodes = []
if isinstance(node_arg, list):
# Don't return NoneType elements.
for out_node in node_arg:
if not isinstance(out_node, type(None)):
out_nodes.append(out_node)
# If there is a single tensor/element to be returned don't
# a tuple for it.
if len(out_nodes) == 1:
node.args = out_nodes
else:
node.args = (tuple(out_nodes),)
fx_g.graph.lint()
fx_g.recompile()
return fx_g
def generate_graph(self):
fx_g = make_fx(self.custom_inference_fn)(dict(
self.model.named_parameters()), dict(self.model.named_buffers()),
self.inputs)
fx_g = self.change_fx_graph_return_to_tuple(fx_g)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
ts_g = torch.jit.script(fx_g)
print(ts_g.graph)
self.backward_graph = ts_g
return ts_g