Commit Graph

1174 Commits

Author SHA1 Message Date
George Hotz
9b1c3cd9ca hlb_cifar: support EVAL_STEPS=1000, print when dataset is shuffled 2023-10-18 01:11:08 +00:00
George Hotz
c36d306606 KOPT is over, BEAM is upstream (#2071)
* create cache for q learning

* make linter happy

* global beam

* where it belongs

* bugfix

* ditch the kopt, use the beam

* faster lin and DEBUG=2 okay

* remove kopt, move search to features
2023-10-16 09:46:03 -07:00
Ahmed Harmouche
0d3410d93f Stable diffusion: Make guidance modifiable (#2077) 2023-10-15 14:36:43 -07:00
George Hotz
49bcfec383 0s in the action space (#2070)
* 0s in the action space

* simpler

* skip duplicate actions
2023-10-14 11:22:48 -07:00
mmmkkaaayy
91168a28c4 whisper: make file transcription work, add basic CI test (#2042) 2023-10-13 17:13:35 -07:00
George Hotz
6f1810af2d with unroll, the action space goes from 161 -> 127 (#2060)
* with unroll, the action space goes from 161 -> 127

* more reliable instrumentation

* beam search is so op

* beam bugfix
2023-10-12 20:52:23 -07:00
Yixiang Gao
3187962476 CIFAR HALF mode (#2041)
* load weights in fp16

* add dtype option in nn

* fix test

* no need for dtype in nn

* add option to load weights in FP16, but NaN

* change loss scaler

* cast to float32 for norm layer

* add a todo for the forward pass padding

* fix transform
2023-10-12 10:19:51 -07:00
George Hotz
c5edb3c374 train value net, improve API, add BCE (#2047)
* api cleanups, BCE losses

* valuenet

* fixup examples

* learning okay

* add valuenet runner

* net improvements

* net improvements

* 40% win rate
2023-10-12 07:56:38 -07:00
George Hotz
41bfeb2c1e start work on auto opt (#2034)
* start work on auto opt

* lin failure

* not beating hcopt

* greedy

* timing is fast

* codegen.search

* greedy search in handcode_opt

* track running gflops

* clean up those files

* no failure
2023-10-11 12:54:53 -07:00
chenyu
e2b83f1b42 Variable.bind newer (#2017)
* Variable.bind attempt 2

* ShapeTracker.unbind

* fix llama

* fix types

* test case

* View.vars cleanup

* include mask in symbolic source

* mask can be sint

* st.unbind in bufferops

* assert ast contain free Variable only

* cleanup

* conservative unbinding reduce op arg

* move reduceop unbind

* fix llama JIT arg behavior
2023-10-10 10:03:01 -07:00
Ahmed Harmouche
e27fedfc7b Fix stable diffusion output error on WebGPU (#2032)
* Fix stable diffusion on WebGPU

* Remove hack, numpy cast only on webgpu

* No-copy numpy cast
2023-10-10 06:40:51 -07:00
chenyu
25555c836f llama default to JIT only if device supports JIT (#2028) 2023-10-09 17:26:02 -07:00
George Hotz
16ca8410f8 op logger + replay (#2021)
* logops

* fix dtype printing

* needs inf

* ops dataset

* minor improvements

* 12k kernels

* opt can compile

* graph flops
2023-10-08 15:10:18 -07:00
mmmkkaaayy
af6e2f31ca whisper: cast model output token to int32 (#2013)
Co-authored-by: mmmkkaaayy <mmmkkaaayy@users.noreply.github.com>
2023-10-08 05:56:22 -07:00
George Hotz
44ed94ef5c use the device abstraction in handcode_resnet50_opt 2023-10-07 13:22:20 -07:00
George Hotz
121f7aa8c5 Schedule item (#2012)
* ScheduleItem

* put var_vals in the schedule

* fix tests, wow that proliferated quickly

* not ready to be in the schedule
2023-10-07 08:59:25 -07:00
George Hotz
f54959e5cd move print tree into graph (#2003)
* move print tree into graph

* add winograd profiling test

* change pre-commit to run ruff first
2023-10-07 04:39:21 -07:00
Ahmed Harmouche
2114dc13d1 Allow multi-input model export (#1995)
* Allow multi-input model export

* Add model export unit test

* Fix efficientnet compilation

* Only run model export test on JIT supported devices

* Skip export model test if not EXPORT_SUPPORTED_DEVICE
2023-10-07 04:13:34 -07:00
chenyu
05be57f57f Fix llama with empty prompt (#1997)
* fix llama with one token prompt

* llama is all_jitted
2023-10-06 06:48:07 -07:00
chenyu
da2b3e55f4 simpler llama - don't shrink twice (#1981) 2023-10-05 14:31:46 -07:00
chenyu
c99fa58dd2 simplify gpt2 example (#1973)
* simplify gpt2 example

* kernel_jitted_count and jit tests

* Revert "kernel_jitted_count and jit tests"

This reverts commit 31a3c26dd0.

* all_jitted test in test_real_world
2023-10-05 07:09:29 -07:00
nimlgen
2ea1dd3e87 no process() in Linearizer (#1966)
* no process() in Linearizer

* more process() clean up
2023-10-04 07:18:42 -07:00
Daniel Riege
579cabf668 Fix examples/train_efficientnet (#1947)
* added missing colon

* bug fixes for cifar10 dataset loading
needed a reshape to work with conv layers and resolve fetched tensor to numpy since further code expects numpy array
2023-10-02 02:23:38 -07:00
George Hotz
90326dbdc3 resnet50 hand coded optimization (#1945)
* resnet50 hand coded opt

* hand optimize one kernel

* opt in both places to fix test
2023-09-29 09:34:51 -07:00
George Hotz
4ff35e2b97 better resnet eval (#1943) 2023-09-29 05:40:25 -07:00
George Hotz
48c8d130ae simpler GPT2 (#1941)
* don't realize in gpt2

* simpler gpt2
2023-09-29 04:41:09 -07:00
Yixiang Gao
094d3d71be with Tensor.train() (#1935)
* add with.train

* remove the rest TODOs

* fix pyflake

* fix pyflake error

* fix mypy
2023-09-28 18:02:31 -07:00
George Hotz
adab724caa schedule2, keep the tests working with small changes (#1932)
* lazy cleanups

* ast functions take in LazyOps

* op instead of self.op

* _base for mops

* fix contiguous

* start schedule

* test_schedule

* fix openpilot

* more tests

* bugfix and test skip

* work

* make sure things get freed

* fix zerosized tensors

* fix failing test

* fix ceil and friends

* fix openpilot

* disable training

* disable test collectives
2023-09-28 09:14:43 -07:00
Dat D. Nguyen
ae9529e678 chore: remove redundant noise in stable diffusion example (#1910) 2023-09-24 21:33:45 +08:00
Gijs Koning
b8ff20ffe4 Gpt2 (#1896)
* small helps

* got something working

* faster?

* faster yes

* cleanup

* cleanup

* cleanup

* Fix non jit

* Fix fp16 and some cleanup

* Fix fp16 and some cleanup

* cleanup

* similar to master

* cleanup
2023-09-22 20:14:47 +08:00
Yixiang Gao
cb5d6576cb cifar step time 65ms while stay above 94% (#1888)
* change reduceop heruistics

* add model ema and jit hack

* add ema eval

* have to create a duplicate eval function for jit

* remove manual seed

* 94% achieveable with normal eval

* ema is outputting the same results as normal

* fix ema bug

* ema achieves 94% with fix seed

* multigpu tested

* constant fold decay, fix jit, adjust message for multigpu

* pull SpeedyResNet out of train_cifar()
2023-09-21 11:19:32 +08:00
nimlgen
4c31dfafb3 add seed to gpt-2 (#1869) 2023-09-15 17:34:14 -04:00
segf00lt
9e8c1dbf34 patch to remove hack from stable_diffusion.py (#1814)
* patch to remove hack from stable_diffusion.py

* sorry linter

* realize after assign?

* float16 broken in llvmlite use float64 for now

* int32

* idiot forgot to change test array dtype
2023-09-08 09:26:50 -07:00
chenyu
ebcda8a714 Move var_vals from ShapeTracker to LazyBuffer (#1819) 2023-09-08 09:25:10 -07:00
George Hotz
722823dee1 stable diffusion: force fp16 free 2023-09-06 15:11:05 -07:00
Yixiang Gao
22cf15e9d0 convert function into tinygrad (#1803) 2023-09-06 14:41:26 -07:00
Pavol Rusnak
52a92bf95d use class Foo: instead of class Foo(): (#1797)
* use class Foo: instead of class Foo():

* add ruff linter, copy settings from .flake8 to ruff.toml
2023-09-06 12:20:25 -07:00
badcc
fd25792c8b Ensure freqs as type float32 in freqs_cis (#1798) 2023-09-06 10:24:15 -07:00
George Hotz
f67638b27a delete broken DDPG example 2023-09-06 08:01:12 -07:00
Francis Lam
0379b64ac4 add seed option to stable_diffusion (#1784)
useful for testing correctness of model runs
2023-09-05 19:45:15 -07:00
George Hotz
fb1cc6bf4b llama jit is default, print tok/sec (#1774)
* llama jit is default, print tok/sec

* jit not default in CI
2023-09-05 10:12:16 -07:00
Yixiang Gao
66a6bbd029 codellama (#1702)
* add codellama with pre-downloaded weights

* add rope_theta, fix param

* fix test

* add 7B-Python

* add 7B-Instruct

* replace single quotes with doulbe

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2023-09-02 08:45:12 -07:00
chenyu
a2745819f6 faster gpt2 jit path and gpt2 in test_real_world (#1738) 2023-09-02 08:39:12 -07:00
geohotstan
94b1257f5e Changed DEVICE to Device.DEFAULT in deep_determinist_policy_gradient (#1715)
* added device in optim and deep

* oops forgot to del print code

* use Device.DEFAULT instead

* removed device
2023-08-31 07:08:51 -07:00
nimlgen
b5cf274da3 remove memory peak for quantized llama (#1720) 2023-08-30 16:32:30 -04:00
chenyu
e4eb5d55c7 critical realize for unjitted llama (#1718) 2023-08-30 14:52:32 -04:00
George Hotz
cd7ceed914 gpt2: print total instead of sync time 2023-08-30 10:59:42 -07:00
Karan Handa
a8aa13dc91 [ready] Replacing os with pathlib (#1708)
* replace os.path with pathlib

* safe convert dirnames to pathlib

* replace all os.path.join

* fix cuda error

* change main chunk

* Reviewer fixes

* fix vgg

* Fixed everything

* Final fixes

* ensure consistency

* Change all parent.parent... to parents
2023-08-30 10:41:08 -07:00
chenyu
ac183568be llama JIT python runtime speedup (#1633)
* no JIT call in TransformerBlock

* idea

* move 2 reshapes to jitted function

shrink inside jitted too, 6.3ms

remove back reshapes, 5.5ms

isinstance -> __class__ 4.99ms

* think

revert ops_gpu.py

revert symbolic.py too

PYOPENCL_COMPILER_OUTPUT=1

* cleanup

* fix cache shape for conversational model

only reshape if start_pos > 0

* small cleanup

* include var_vals.keys() to st.key

* add comments

* llama small update

* everything jitted again, similar structure to gpt2

* fix typing

* add TODO for in place update cache
2023-08-30 07:51:05 -07:00
Umut Zengin
1682e9a38a Fix: Stable Diffusion index (#1713) 2023-08-30 00:21:10 -04:00