From 3d13c23bfaa09f8408dbd670d893f0ba864b146b Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Wed, 12 Jun 2024 05:59:59 +0000 Subject: [PATCH] llama3 `--download_model` (#4922) --- examples/llama3.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/examples/llama3.py b/examples/llama3.py index 08d596bc64..3d13c171a1 100644 --- a/examples/llama3.py +++ b/examples/llama3.py @@ -7,7 +7,7 @@ from tqdm import tqdm from extra.models.llama import Transformer, convert_from_huggingface, fix_bf16 from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters -from tinygrad.helpers import Profiling, Timing, DEBUG +from tinygrad.helpers import Profiling, Timing, DEBUG, fetch class Tokenizer: pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" @@ -146,7 +146,9 @@ def build_transformer(model_path: Path, model_size="8B", quantize=None, device=N # load weights if model_path.is_dir(): - weights = concat_weights([load(str(model_path / f"consolidated.{i:02d}.pth")) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device) + if (model_path / "model.safetensors.index.json").exists(): weights = load(str(model_path / "model.safetensors.index.json")) + elif (model_path / "model.safetensors").exists(): weights = load(str(model_path / "model.safetensors")) + else: weights = concat_weights([load(str(model_path / f"consolidated.{i:02d}.pth")) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device) else: weights = load(str(model_path)) if "model.embed_tokens.weight" in weights: @@ -191,6 +193,7 @@ if __name__ == "__main__": Tensor.no_grad = True parser = argparse.ArgumentParser() + parser.add_argument("--download_model", action="store_true") parser.add_argument("--model", type=Path, required=True) parser.add_argument("--size", choices=["8B", "70B"], default="8B") parser.add_argument("--shard", type=int, default=1) @@ -204,6 +207,16 @@ if __name__ == "__main__": parser.add_argument("--profile", action="store_true", help="Output profile data") args = parser.parse_args() + if args.download_model: + if not args.model.is_dir(): raise ValueError("for --download_model, --model must be a directory") + if not args.model.exists(): args.model.mkdir(parents=True) + fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", args.model / "tokenizer.model") + fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00001-of-00004.safetensors", args.model / "model-00001-of-00004.safetensors") + fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00002-of-00004.safetensors", args.model / "model-00002-of-00004.safetensors") + fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00003-of-00004.safetensors", args.model / "model-00003-of-00004.safetensors") + fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", args.model / "model-00004-of-00004.safetensors") + fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", args.model / "model.safetensors.index.json") + if args.seed is not None: Tensor.manual_seed(args.seed) print(f"seed = {Tensor._seed}")