clean up training guide

This commit is contained in:
Alex O'Connell
2025-12-23 21:03:07 -05:00
parent 1a8c2e6152
commit 831ef9bfca
3 changed files with 190 additions and 371 deletions

View File

@@ -1,3 +1,190 @@
# Training Home LLM Models
# Training (Docker + Axolotl)
This directory contains resources and instructions for training Home LLM models. Currently, it is recommended to use axolotl via a Docker container for training. There are various examples of model configurations provided in the `config/` folder. Additionally, you can refer to the [Axolotl documentation](https://docs.axolotl.ai/) for more detailed guidance on setting up and running training sessions.
This repo recommends training with **Axolotl** inside a Docker container. The `train/` folder includes:
- Example Axolotl configs in `train/configs/`.
- Chat templates in `train/chat_templates/`.
- A Kubernetes job spec in `train/training-job.yml` that references the Axolotl image.
The instructions below are written to match the paths used by the configs in this repo.
## Hardware recommendations
- **Recommended minimum VRAM:** **24GB total** across all GPUs.
- Multi-GPU is fine: e.g. **2×12GB**, **1×24GB**, **2×16GB**, etc.
- More VRAM lets you increase `sequence_len`, batch size, and/or train larger base models.
## Fine-tuning approaches (full vs LoRA vs QLoRA)
All three approaches start from a **base model** (e.g. `google/gemma-3-270m-it`) and train on this repos synthetic dataset.
### Full fine-tuning
Full fine-tuning updates **all** model weights.
- **Pros:** highest quality ceiling.
- **Cons:** highest VRAM/compute; largest checkpoints.
- **When to use:** you have ample VRAM/compute and want maximum adaptation.
### LoRA (Low-Rank Adaptation)
LoRA keeps base weights frozen and trains small **adapter** matrices.
- **Pros:** much lower VRAM; fast iteration; adapters are small and easy to share.
- **Cons:** slightly lower ceiling than full fine-tuning.
- **When to use:** common default when youre VRAM-constrained.
### QLoRA
QLoRA is LoRA, but the frozen base model is loaded in **4-bit quantized** form.
- **Pros:** lowest VRAM footprint; enables training larger models on modest GPUs.
- **Cons:** can be slower (dequant overhead) and sometimes trickier to tune.
- **When to use:** you want LoRA but need the base model to fit in memory.
## Training Scripts
We now recommend that you use the [Axolotl](https://github.com/axolotl-ai-cloud/axolotl) training script suite to perform training runs.
The Docker images that are recommended are:
- `axolotlai/axolotl-cloud:main-py3.11-cu128-2.8.0` - CUDA 12.8 with PyTorch 2.8.0
- `axolotlai/axolotl-cloud:main-py3.11-cu126-2.8.0` - CUDA 12.6 with PyTorch 2.8.0
Both images provide the `axolotl` CLI used in the examples below. It is recommended to use the "cloud" versions since they have various custom folder mounts already set up that make data management easier.
## Dataset Generation
The synthetic dataset is generated by scripts under `data/`.
- **Generator:** `data/generate_data.py`
- **Outputs:** JSONL files in `data/output/` (for example `home_assistant_train.jsonl`, `home_assistant_test.jsonl`, and `sample.jsonl`).
The example training configs in `train/configs/` are written to read a dataset file mounted at `/workspace/data/datasets/sample.jsonl`. Depending on the variant of the dataset that you are using for training, you may need to edit the config to point to the correct dataset file.
For local Docker training youll typically:
1. Generate a dataset JSONL under `data/output/`.
2. Copy it into a host folder that you mount as `/workspace/data/datasets/`.
## Training Configs
This repo currently includes:
- `train/configs/gemma3-270m.yml`
- `train/configs/functiongemma-270m.yml`
Additional configs can be found in the [Axolotl repo](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples). They will need to be adapted to use this repos dataset paths and potentially have their chat template adjusted.
### Model selection
- `base_model`: Hugging Face model id.
- `model_type`: Transformers model class (example: `Gemma3ForCausalLM`).
### Chat formatting and tool calling
- `chat_template: jinja` and `chat_template_jinja: | ...`:
- Defines how multi-turn chats (and tools) are rendered into training text.
- For tool calling, this formatting matters a lot.
Related files:
- `train/chat_templates/` contains reusable templates you can adapt.
### Dataset configs
- `datasets:` list points at a JSONL file, example:
- `path: /workspace/data/datasets/sample.jsonl`
- `ds_type: json`
- `type: chat_template`
- `message_field_training: train_on_turn`
- `roles_to_train: []`
> NOTE: `roles_to_train` controls which message roles are used for loss calculation. Alternatively, you can simply enable training against all "assistant" role messages by setting `roles_to_train: [assistant]`.
### Sequence length + packing
- `sequence_len`: max context length.
- `sample_packing`:
- `true` packs multiple short samples into one sequence (better throughput).
- `false` keeps samples separate (often easier to debug).
### Batch sizing
Effective batch size is calculated as:
$$\text{effective\_batch} = \text{micro\_batch\_size} \times \text{gradient\_accumulation\_steps} \times \text{num\_gpus}$$
Having the correct effective batch size is important for training stability. Modify `micro_batch_size` until you can fit the model and optimizer states in VRAM, and then set `gradient_accumulation_steps` to keep the effective batch size in the correct range -- 16, 32, or 64 are values shown to work well for this dataset.
### Memory/perf helpers
- `bf16: true`: bfloat16 training on modern NVIDIA GPUs (30-series/Ada or newer)
- `gradient_checkpointing: true`: re-compute gradient activations during backward pass instead of storing in VRAM (lower VRAM, more compute.)
- `flash_attention: true`: faster attention when supported (almost always).
- `optimizer: adamw_bnb_8bit`: 8-bit optimizer states to save VRAM
### Outputs
- `output_dir: /workspace/data/training-runs/<run-name>` stores checkpoints and TensorBoard logs.
## Running training with Docker (local machine)
These commands assume:
- You are running on Linux with an NVIDIA GPU (or WSL2)
- You have Docker installed
- You have the NVIDIA Driver and the NVIDIA Container Toolkit set up and installed
### 1) Create host folders to mount
```bash
mkdir -p ./train-local/datasets
mkdir -p ./train-local/training-runs
mkdir -p ./train-local/huggingface-cache
```
### 2) Generate a dataset JSONL
```bash
python3 data/generate_data.py --sample --language english
cp data/output/sample.jsonl ./train-local/datasets/sample.jsonl
```
### 3) Run preprocess + train
```bash
docker pull axolotlai/axolotl-cloud:main-py3.11-cu128-2.8.0
```
```bash
docker run --rm -it \
--gpus all \
-e AXOLOTL_DO_NOT_TRACK=1 \
-e HF_HOME=/workspace/data/huggingface-cache \
-e HF_TOKEN="$HF_TOKEN" \
-v "$PWD/train-local/datasets:/workspace/data/datasets" \
-v "$PWD/train-local/training-runs:/workspace/data/training-runs" \
-v "$PWD/train-local/huggingface-cache:/workspace/data/huggingface-cache" \
-v "$PWD/train/configs:/workspace/configs" \
axolotlai/axolotl-cloud:main-py3.11-cu128-2.8.0 \
axolotl preprocess /workspace/configs/gemma3-270m.yml --debug
```
```bash
docker run --rm -it \
--gpus all \
-e AXOLOTL_DO_NOT_TRACK=1 \
-e HF_HOME=/workspace/data/huggingface-cache \
-e HF_TOKEN="$HF_TOKEN" \
-v "$PWD/train-local/datasets:/workspace/data/datasets" \
-v "$PWD/train-local/training-runs:/workspace/data/training-runs" \
-v "$PWD/train-local/huggingface-cache:/workspace/data/huggingface-cache" \
-v "$PWD/train/configs:/workspace/configs" \
axolotlai/axolotl-cloud:main-py3.11-cu128-2.8.0 \
axolotl train /workspace/configs/gemma3-270m.yml
```
Artifacts will appear under `./train-local/training-runs/`.
## Running on Kubernetes (e.g. cloud GPU host)
`train/training-job.yml` mounts:
- `/workspace/data/datasets` (dataset JSONL)
- `/workspace/data/training-runs` (outputs)
- `/workspace/configs` (Axolotl YAML)
- `/workspace/data/huggingface-cache` (HF cache)
It runs `axolotl preprocess ...` as an init container, then `axolotl train ...`.
The helper script `train/train.sh` copies `train/configs/<MODEL_NAME>.yml` to a remote server and starts the Kubernetes Job.