Modify import_with_fx to import with dtype=f16.

This commit is contained in:
Prashant Kumar
2023-01-11 14:53:43 +00:00
parent f0dd48ed2a
commit 6f80825814

View File

@@ -245,8 +245,45 @@ class SharkImporter:
)
def get_f16_inputs(inputs, is_f16, f16_input_mask):
if is_f16 == False:
return inputs
if f16_input_mask == None:
return tuple([x.half() for x in inputs])
f16_masked_inputs = []
for i in range(len(inputs)):
if f16_input_mask[i]:
f16_masked_inputs.append(inputs[i].half())
else:
f16_masked_inputs.append(inputs[i])
return tuple(f16_masked_inputs)
def transform_fx(fx_g):
import torch
kwargs_dict = {
"dtype": torch.float16,
"device": torch.device(type="cpu"),
"pin_memory": False,
}
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.arange.start,
]:
node.kwargs = kwargs_dict
fx_g.graph.lint()
# Applies fx conversion to the model and imports the mlir.
def import_with_fx(model, inputs, debug=False):
def import_with_fx(
model, inputs, is_f16=False, f16_input_mask=None, debug=False
):
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
@@ -270,6 +307,9 @@ def import_with_fx(model, inputs, debug=False):
),
)(*inputs)
if is_f16:
transform_fx(fx_g)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
@@ -286,15 +326,23 @@ def import_with_fx(model, inputs, debug=False):
strip_overloads(fx_g)
if is_f16:
fx_g = fx_g.half()
fx_g.recompile()
ts_graph = torch.jit.script(fx_g)
inputs = get_f16_inputs(inputs, is_f16, f16_input_mask)
mlir_importer = SharkImporter(
fx_g,
ts_graph,
inputs,
frontend="torch",
)
if debug:
(mlir_module, func_name), _, _ = mlir_importer.import_debug()
return mlir_module, func_name
# if debug:
# (mlir_module, func_name), _, _ = mlir_importer.import_debug()
# return mlir_module, func_name
mlir_module, func_name = mlir_importer.import_mlir()