mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-10 14:18:00 -05:00
fixes for multi-gpu training
This commit is contained in:
26
train.py
26
train.py
@@ -85,11 +85,10 @@ accelerate launch train.py \
|
||||
--base_model stabilityai/stablelm-zephyr-3b \
|
||||
--bf16 \
|
||||
--train_dataset data/home_assistant_train.jsonl \
|
||||
--test_dataset data/home_assistant_test.jsonl \
|
||||
--learning_rate 1e-5 --batch_size 64 --epochs 1 \
|
||||
--micro_batch_size 8 --gradient_checkpointing \
|
||||
--micro_batch_size 2 --gradient_checkpointing --group_by_length \
|
||||
--ctx_size 2048 \
|
||||
--save_steps 50 --save_total_limit 20 --eval_steps 100 --logging_steps 2
|
||||
--save_steps 50 --save_total_limit 5 --eval_steps 100 --logging_steps 2
|
||||
"""
|
||||
|
||||
"""
|
||||
@@ -186,19 +185,21 @@ else:
|
||||
model_kwargs["use_cache"] = False
|
||||
|
||||
def find_max_vram(min_buffer_mib=800):
|
||||
total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
|
||||
suggestion = round((total_mem - 1000) / 1000) * 1000
|
||||
suggestion = min(suggestion, total_mem - min_buffer_mib)
|
||||
max_memory = {}
|
||||
for i in range(torch.cuda.device_count()):
|
||||
total_mem = (torch.cuda.get_device_properties(i).total_memory / (1024 * 1024))
|
||||
suggestion = round((total_mem - 1000) / 1000) * 1000
|
||||
suggestion = min(suggestion, total_mem - min_buffer_mib)
|
||||
|
||||
print(f"Model will target using {suggestion}MiB of VRAM")
|
||||
max_memory = {0: f'{suggestion}MiB'}
|
||||
print(f"Model will target using {suggestion}MiB of VRAM on GPU {i}")
|
||||
max_memory[i] = f'{suggestion}MiB'
|
||||
|
||||
return max_memory if len(max_memory) > 0 else None
|
||||
return max_memory
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
training_run_args.base_model,
|
||||
trust_remote_code=True,
|
||||
device_map="auto",
|
||||
# device_map="auto",
|
||||
max_memory=find_max_vram(),
|
||||
**model_kwargs
|
||||
)
|
||||
@@ -535,7 +536,10 @@ try:
|
||||
if tensorboard_process:
|
||||
input("Training is finished. Press enter to quit tensorboard after the viewing results.")
|
||||
tensorboard_process.kill()
|
||||
except Exception as e:
|
||||
except Exception as ex:
|
||||
if len(torch.cuda.device_count()) > 1:
|
||||
raise ex
|
||||
|
||||
print("Something bad happened! Try and save it?")
|
||||
import code, traceback
|
||||
traceback.print_exc()
|
||||
|
||||
Reference in New Issue
Block a user