mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Modify import_with_fx to import with dtype=f16.
This commit is contained in:
@@ -245,8 +245,45 @@ class SharkImporter:
|
||||
)
|
||||
|
||||
|
||||
def get_f16_inputs(inputs, is_f16, f16_input_mask):
|
||||
|
||||
if is_f16 == False:
|
||||
return inputs
|
||||
if f16_input_mask == None:
|
||||
return tuple([x.half() for x in inputs])
|
||||
|
||||
f16_masked_inputs = []
|
||||
for i in range(len(inputs)):
|
||||
if f16_input_mask[i]:
|
||||
f16_masked_inputs.append(inputs[i].half())
|
||||
else:
|
||||
f16_masked_inputs.append(inputs[i])
|
||||
|
||||
return tuple(f16_masked_inputs)
|
||||
|
||||
|
||||
def transform_fx(fx_g):
|
||||
import torch
|
||||
|
||||
kwargs_dict = {
|
||||
"dtype": torch.float16,
|
||||
"device": torch.device(type="cpu"),
|
||||
"pin_memory": False,
|
||||
}
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.arange.start,
|
||||
]:
|
||||
node.kwargs = kwargs_dict
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
# Applies fx conversion to the model and imports the mlir.
|
||||
def import_with_fx(model, inputs, debug=False):
|
||||
def import_with_fx(
|
||||
model, inputs, is_f16=False, f16_input_mask=None, debug=False
|
||||
):
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
@@ -270,6 +307,9 @@ def import_with_fx(model, inputs, debug=False):
|
||||
),
|
||||
)(*inputs)
|
||||
|
||||
if is_f16:
|
||||
transform_fx(fx_g)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
@@ -286,15 +326,23 @@ def import_with_fx(model, inputs, debug=False):
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
if is_f16:
|
||||
fx_g = fx_g.half()
|
||||
fx_g.recompile()
|
||||
|
||||
ts_graph = torch.jit.script(fx_g)
|
||||
|
||||
inputs = get_f16_inputs(inputs, is_f16, f16_input_mask)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
fx_g,
|
||||
ts_graph,
|
||||
inputs,
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
if debug:
|
||||
(mlir_module, func_name), _, _ = mlir_importer.import_debug()
|
||||
return mlir_module, func_name
|
||||
# if debug:
|
||||
# (mlir_module, func_name), _, _ = mlir_importer.import_debug()
|
||||
# return mlir_module, func_name
|
||||
|
||||
mlir_module, func_name = mlir_importer.import_mlir()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user