Remove dependency of asizeof

This commit is contained in:
Brandon Rising
2024-09-05 14:44:42 -04:00
committed by Brandon
parent 5219ac12a6
commit 6667c39c73

View File

@@ -7,7 +7,6 @@ from typing import Optional
import accelerate
import torch
from pympler import asizeof
from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
@@ -199,15 +198,15 @@ class FluxCheckpointModel(ModelLoader):
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
futures: list[torch.jit.Future[tuple[str, torch.Tensor]]] = []
# For the first iteration we are just requesting the current size of the state dict
# This is due to an expected doubling of the tensor sizes in memory after converting float8 -> float16
# This should be refined in the future if not removed entirely when we support more data types
sd_size = asizeof.asizeof(sd)
cache_updated = False
for k in sd.keys():
v = sd[k]
if v.dtype != torch.bfloat16:
if not cache_updated:
# For the first iteration we are just requesting the current size of the state dict
# This is due to an expected doubling of the tensor sizes in memory after converting float8 -> float16
# This should be refined in the future if not removed entirely when we support more data types
sd_size = sum([ten.nelement() * ten.element_size() for ten in sd.values()])
self._ram_cache.make_room(sd_size)
cache_updated = True
futures.append(torch.jit.fork(convert_sd_entry_to_bfloat16, k, v))