mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 12:08:02 -05:00
20 lines
699 B
Python
20 lines
699 B
Python
import torch
|
|
|
|
try:
|
|
from bitsandbytes.nn.modules import Params4bit
|
|
|
|
bnb_available: bool = True
|
|
except ImportError:
|
|
bnb_available: bool = False
|
|
|
|
|
|
def get_param_shape(param: torch.Tensor) -> torch.Size:
|
|
"""A helper function to get the shape of a parameter that handles `bitsandbytes.nn.Params4Bit` correctly."""
|
|
# Accessing the `.shape` attribute of `bitsandbytes.nn.Params4Bit` will return an incorrect result. Instead, we must
|
|
# access the `.quant_state.shape` attribute.
|
|
if bnb_available and type(param) is Params4bit: # type: ignore
|
|
quant_state = param.quant_state
|
|
if quant_state is not None:
|
|
return quant_state.shape
|
|
return param.shape
|