mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
hotfix: improve self_tokenize
This commit is contained in:
@@ -2,7 +2,8 @@ import os, pathlib, argparse
|
||||
from examples.llama3 import Tokenizer
|
||||
from tabulate import tabulate
|
||||
from tinygrad import fetch
|
||||
from tinygrad.helpers import flatten
|
||||
from tinygrad.helpers import flatten, getenv
|
||||
from sz import NONCORE_DIRS
|
||||
|
||||
# llama 3 tokenizer
|
||||
tokenizer = Tokenizer(fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model").as_posix())
|
||||
@@ -10,19 +11,15 @@ tokenizer = Tokenizer(fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/
|
||||
def read_code(base_path):
|
||||
ret = []
|
||||
for path, _, files in os.walk(os.path.join(base_path, "tinygrad")):
|
||||
if not getenv("CORE") and any(path.split("./")[1].startswith(x) for x in NONCORE_DIRS): continue
|
||||
for name in files:
|
||||
if not name.endswith(".py"): continue
|
||||
if 'tinygrad/runtime/autogen' in path.replace('\\', '/'): continue
|
||||
fullpath = os.path.join(path, name)
|
||||
code = pathlib.Path(fullpath).read_text()
|
||||
ret.append(("### " + fullpath.split("tinygrad/", 1)[1], code))
|
||||
ret.append((fullpath.split("tinygrad/", 1)[1], code))
|
||||
return ret
|
||||
|
||||
def write_code_to_file(filename, code_list):
|
||||
"""Writes the combined code to a specified file."""
|
||||
with open(filename, 'w') as f:
|
||||
f.write('\n'.join(flatten(code_list)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Analyze and optionally save tinygrad code.")
|
||||
parser.add_argument("--output", help="Output file to write the combined code to.")
|
||||
@@ -32,10 +29,11 @@ if __name__ == "__main__":
|
||||
|
||||
table = []
|
||||
for name,code in ret:
|
||||
table.append([name, len(tokenizer.encode(name+"\x00"+code))])
|
||||
table.append([name, len(tokenizer.encode(code))])
|
||||
print(tabulate([["name", "llm tokens"]]+sorted(table, key=lambda x: -x[1]), headers="firstrow"))
|
||||
|
||||
code_str = '\x00'.join(flatten(ret))
|
||||
banner = "#"*40
|
||||
code_str = ''.join([f"{banner}\n# {name}\n{banner}\n\n{code}\n" for name,code in ret])
|
||||
print(f"code has {len(code_str)} chars")
|
||||
newline_count = code_str.count('\n')
|
||||
print(f"code has {newline_count} newlines")
|
||||
@@ -44,5 +42,5 @@ if __name__ == "__main__":
|
||||
print(f"code has {len(encoded)} tokens")
|
||||
|
||||
if args.output:
|
||||
write_code_to_file(args.output, ret)
|
||||
print(f"Combined code written to {args.output}")
|
||||
with open(args.output, 'w') as f: f.write(code_str)
|
||||
print(f"Combined code written to {args.output}")
|
||||
Reference in New Issue
Block a user