mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Guard quantization imports
This commit is contained in:
@@ -509,22 +509,6 @@ def import_with_fx(
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
# from brevitas_examples.llm.llm_quant.export import (
|
||||
# block_quant_layer_level_manager,
|
||||
# )
|
||||
# from brevitas_examples.llm.llm_quant.export import (
|
||||
# brevitas_layer_export_mode,
|
||||
# )
|
||||
# from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
|
||||
# LinearWeightBlockQuantHandlerFwd,
|
||||
# )
|
||||
# from brevitas_examples.llm.llm_quant.export import replace_call_fn_target
|
||||
# from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
|
||||
# matmul_rhs_group_quant_placeholder,
|
||||
# )
|
||||
# from brevitas.backport.fx.experimental.proxy_tensor import (
|
||||
# make_fx as brevitas_make_fx,
|
||||
# )
|
||||
|
||||
golden_values = None
|
||||
if debug:
|
||||
@@ -597,32 +581,48 @@ def import_with_fx(
|
||||
torch.ops.aten.masked_fill.Tensor,
|
||||
torch.ops.aten.masked_fill.Scalar,
|
||||
]
|
||||
# if precision in ["int4", "int8"]:
|
||||
# export_context_manager = brevitas_layer_export_mode
|
||||
# export_class = block_quant_layer_level_manager(
|
||||
# export_handlers=[LinearWeightBlockQuantHandlerFwd]
|
||||
# )
|
||||
# with export_context_manager(model, export_class):
|
||||
# fx_g = brevitas_make_fx(
|
||||
# model,
|
||||
# decomposition_table=get_decompositions(decomps_list),
|
||||
# )(*inputs)
|
||||
#
|
||||
# transform_fx(fx_g, quantized=True)
|
||||
# replace_call_fn_target(
|
||||
# fx_g,
|
||||
# src=matmul_rhs_group_quant_placeholder,
|
||||
# target=torch.ops.brevitas.matmul_rhs_group_quant,
|
||||
# )
|
||||
#
|
||||
# fx_g.recompile()
|
||||
# removed_none_indexes = _remove_nones(fx_g)
|
||||
# was_unwrapped = _unwrap_single_tuple_return(fx_g)
|
||||
# else:
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(decomps_list),
|
||||
)(*inputs)
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
block_quant_layer_level_manager,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
brevitas_layer_export_mode,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
|
||||
LinearWeightBlockQuantHandlerFwd,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.export import replace_call_fn_target
|
||||
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
|
||||
matmul_rhs_group_quant_placeholder,
|
||||
)
|
||||
from brevitas.backport.fx.experimental.proxy_tensor import (
|
||||
make_fx as brevitas_make_fx,
|
||||
)
|
||||
export_context_manager = brevitas_layer_export_mode
|
||||
export_class = block_quant_layer_level_manager(
|
||||
export_handlers=[LinearWeightBlockQuantHandlerFwd]
|
||||
)
|
||||
with export_context_manager(model, export_class):
|
||||
fx_g = brevitas_make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(decomps_list),
|
||||
)(*inputs)
|
||||
|
||||
transform_fx(fx_g, quantized=True)
|
||||
replace_call_fn_target(
|
||||
fx_g,
|
||||
src=matmul_rhs_group_quant_placeholder,
|
||||
target=torch.ops.brevitas.matmul_rhs_group_quant,
|
||||
)
|
||||
|
||||
fx_g.recompile()
|
||||
removed_none_indexes = _remove_nones(fx_g)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_g)
|
||||
else:
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(decomps_list),
|
||||
)(*inputs)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
Reference in New Issue
Block a user