mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
add a test for 1B llm (#11124)
* add a test for 1B llm * fix mbs * add apps to release
This commit is contained in:
14
.github/workflows/test.yml
vendored
14
.github/workflows/test.yml
vendored
@@ -547,6 +547,20 @@ jobs:
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
testllm:
|
||||
name: Test LLM
|
||||
runs-on: ubuntu-24.04
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: apps_llm
|
||||
- name: Test 1B LLM
|
||||
run: echo "What's a male chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.apps.llm | grep -i rooster
|
||||
|
||||
testmodels:
|
||||
name: Models (llvm+cpu+gpu)
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
2
setup.py
2
setup.py
@@ -27,7 +27,7 @@ setup(name='tinygrad',
|
||||
packages = ['tinygrad', 'tinygrad.runtime.autogen', 'tinygrad.runtime.autogen.am', 'tinygrad.codegen', 'tinygrad.nn',
|
||||
'tinygrad.renderer', 'tinygrad.engine', 'tinygrad.viz', 'tinygrad.runtime', 'tinygrad.runtime.support', 'tinygrad.kernelize',
|
||||
'tinygrad.runtime.support.am', 'tinygrad.runtime.graph', 'tinygrad.shape', 'tinygrad.uop', 'tinygrad.opt',
|
||||
'tinygrad.runtime.support.nv'],
|
||||
'tinygrad.runtime.support.nv', 'tinygrad.apps'],
|
||||
package_data = {'tinygrad': ['py.typed'], 'tinygrad.viz': ['index.html', 'assets/**/*', 'js/*']},
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
|
||||
@@ -172,7 +172,10 @@ if __name__ == "__main__":
|
||||
ids: list[int] = [bos_id]
|
||||
while 1:
|
||||
start_pos = len(ids) - 1
|
||||
ids += tok.role("user") + tok.encode(input('>>> ')) + [eos_id] + tok.role("assistant")
|
||||
try:
|
||||
ids += tok.role("user") + tok.encode(input('>>> ')) + [eos_id] + tok.role("assistant")
|
||||
except EOFError:
|
||||
break
|
||||
for next_id in model.generate(ids, start_pos):
|
||||
sys.stdout.write(tok.decode([next_id]) if next_id != eos_id else "\n\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
Reference in New Issue
Block a user