Guard quantization imports

This commit is contained in:
Ean Garvey
2023-08-16 13:26:51 -05:00
committed by GitHub
parent 7d77d6cfb2
commit c22416cbb5

View File

@@ -509,22 +509,6 @@ def import_with_fx(
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
# from brevitas_examples.llm.llm_quant.export import (
# block_quant_layer_level_manager,
# )
# from brevitas_examples.llm.llm_quant.export import (
# brevitas_layer_export_mode,
# )
# from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
# LinearWeightBlockQuantHandlerFwd,
# )
# from brevitas_examples.llm.llm_quant.export import replace_call_fn_target
# from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
# matmul_rhs_group_quant_placeholder,
# )
# from brevitas.backport.fx.experimental.proxy_tensor import (
# make_fx as brevitas_make_fx,
# )
golden_values = None
if debug:
@@ -597,32 +581,48 @@ def import_with_fx(
torch.ops.aten.masked_fill.Tensor,
torch.ops.aten.masked_fill.Scalar,
]
# if precision in ["int4", "int8"]:
# export_context_manager = brevitas_layer_export_mode
# export_class = block_quant_layer_level_manager(
# export_handlers=[LinearWeightBlockQuantHandlerFwd]
# )
# with export_context_manager(model, export_class):
# fx_g = brevitas_make_fx(
# model,
# decomposition_table=get_decompositions(decomps_list),
# )(*inputs)
#
# transform_fx(fx_g, quantized=True)
# replace_call_fn_target(
# fx_g,
# src=matmul_rhs_group_quant_placeholder,
# target=torch.ops.brevitas.matmul_rhs_group_quant,
# )
#
# fx_g.recompile()
# removed_none_indexes = _remove_nones(fx_g)
# was_unwrapped = _unwrap_single_tuple_return(fx_g)
# else:
fx_g = make_fx(
model,
decomposition_table=get_decompositions(decomps_list),
)(*inputs)
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.export import (
block_quant_layer_level_manager,
)
from brevitas_examples.llm.llm_quant.export import (
brevitas_layer_export_mode,
)
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
LinearWeightBlockQuantHandlerFwd,
)
from brevitas_examples.llm.llm_quant.export import replace_call_fn_target
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
matmul_rhs_group_quant_placeholder,
)
from brevitas.backport.fx.experimental.proxy_tensor import (
make_fx as brevitas_make_fx,
)
export_context_manager = brevitas_layer_export_mode
export_class = block_quant_layer_level_manager(
export_handlers=[LinearWeightBlockQuantHandlerFwd]
)
with export_context_manager(model, export_class):
fx_g = brevitas_make_fx(
model,
decomposition_table=get_decompositions(decomps_list),
)(*inputs)
transform_fx(fx_g, quantized=True)
replace_call_fn_target(
fx_g,
src=matmul_rhs_group_quant_placeholder,
target=torch.ops.brevitas.matmul_rhs_group_quant,
)
fx_g.recompile()
removed_none_indexes = _remove_nones(fx_g)
was_unwrapped = _unwrap_single_tuple_return(fx_g)
else:
fx_g = make_fx(
model,
decomposition_table=get_decompositions(decomps_list),
)(*inputs)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()