George Hotz
0e0be99b55
Merge branch 'master' into simpler_postrange
2025-08-28 07:22:39 -07:00
geohotstan
4e8370309c
Support onnx If OP ( #11648 )
...
* start
* tiny clean up
* whoops, didn't mean to accidentally fix this
* fix .to(device), kinda hacky and this fix makes it slower?
* merge properly
* FINALLY figured out slowness, also hack pylint for now
* add DEBUGONNX print for subgraph
* oops
* WOOOOOOOO SHAPE CACHE 50% SPEED INCREASE
* small fix, but maybe all deterministic Tensor creation in fp should be cached
* cache condition
* sliiiightly cleaner
* better abstraction?
* remove sam from model_benchmark
* remove shape cache speed up for now
* less lines
* isinstance fix
---------
Co-authored-by: chenyu <chenyu@fastmail.com >
2025-08-28 10:17:35 -04:00
George Hotz
6d6f0dada7
support for tuple ranges ( #11890 )
...
* support for tuple ranges
* breaks it
2025-08-28 07:02:31 -07:00
chenyu
beb5982165
FUSE_ATTENTION ( #11884 )
2025-08-27 19:59:17 -04:00
nimlgen
44816218b5
memplan: fix large buffers planning ( #11878 )
...
* memplan: fix large buffers planning
* fix
* fix dsp
2025-08-27 23:54:27 +03:00
George Hotz
e9575c81e2
delete
2025-08-27 12:49:58 -07:00
George Hotz
ea1b853a60
delete
2025-08-27 12:49:58 -07:00
nimlgen
4006366752
Revert "memplan: fix large buffers planning ( #11876 )" ( #11877 )
...
This reverts commit 7f90497efc .
2025-08-27 22:36:14 +03:00
nimlgen
7f90497efc
memplan: fix large buffers planning ( #11876 )
...
* memplan: fix large buffers planning
* fix
2025-08-27 22:04:15 +03:00
George Hotz
73f83e6fe6
Merge branch 'master' into simpler_postrange
2025-08-27 11:43:12 -07:00
Jordan Chalupka
e9789d8a70
Add mxfp4 support ( #11873 )
...
* bump ggml url
* map mxfp4 to tensor
* tests
2025-08-27 10:56:56 -07:00
George Hotz
99c8c37511
working double tc
2025-08-26 22:32:26 -07:00
George Hotz
195feb1b10
flash attention tc
2025-08-26 18:44:20 -07:00
Sieds Lykles
d39365809a
add ctx to z3_renderer arg ( #11867 )
...
* add ctx to z3_renderer arg
* update symbolic fuzzer
* rewrite u1,u2,u3
* update fuzz_fast_idiv
* remove imports
2025-08-27 03:38:15 +02:00
George Hotz
68d7218f80
double gemm is failing
2025-08-26 17:27:47 -07:00
George Hotz
78e092d59d
reorder
2025-08-26 17:10:06 -07:00
George Hotz
c94adb3594
Merge branch 'master' into simpler_postrange
2025-08-26 13:41:24 -07:00
chenyu
7028cb4167
clean up TestBitcastConstFolding ( #11856 )
2025-08-26 15:26:47 -04:00
George Hotz
f0f7437385
cleanups
2025-08-26 12:02:14 -07:00
George Hotz
b268755d51
small changes from postopt ( #11854 )
2025-08-26 11:56:16 -07:00
Sieds Lykles
a3aeef45cc
associative variation of where branch-merging ( #11851 )
...
* add rule and test
* change comment
2025-08-26 19:27:05 +02:00
b1tg
1dd613cb89
test float_to_bf16 round-to-even behavior ( #11849 )
...
Co-authored-by: b1tg <b1tg@users.noreply.github.com >
2025-08-26 12:16:10 -04:00
b1tg
409399c609
fix nan in float_to_bf16 ( #11843 )
...
Co-authored-by: b1tg <b1tg@users.noreply.github.com >
2025-08-26 11:42:25 -04:00
chenyu
f28f613f85
improved float_to_bf16 ( #11848 )
...
round instead of truncate
2025-08-26 11:14:06 -04:00
chenyu
337e979a59
call dtypes.as_const in Tensor(list) ( #11840 )
2025-08-25 22:08:26 -04:00
chenyu
ac3449b0c8
truncate_fp16 cleanup ( #11838 )
...
native `@` is default
2025-08-25 19:03:41 -04:00
qazal
a1f6823060
viz: memory layout in client side ( #11830 )
...
* viz: memory layout in client side
* update test_viz
2025-08-25 14:49:33 +03:00
Sieds Lykles
a286a1a6f7
Fast idiv try removing factors of two before cast ( #11824 )
...
* try removing factors of two
* dont return if None
* add test
2025-08-24 20:04:25 +02:00
George Hotz
6540bb32a6
move into codegen late [pr] ( #11823 )
2025-08-24 10:23:25 -07:00
Sieds Lykles
dd69114573
Revert "Better div nesting ( #11811 )" ( #11818 )
...
This reverts commit 952f729b07 .
2025-08-24 18:11:24 +02:00
Sieds Lykles
952f729b07
Better div nesting ( #11811 )
...
* remove check
* use fold_divmod_congruence instead of simplify
* adjust tests
* shorten line
2025-08-24 04:17:40 +02:00
Sieds Lykles
e652062f92
tweak divmod_folding condition ( #11810 )
2025-08-24 02:59:02 +02:00
Sieds Lykles
07d4ed7e4c
one more symbolic add variation ( #11807 )
2025-08-24 01:15:04 +02:00
qazal
0d86288bd7
viz: calculate timeline fixed points in client side ( #11805 )
...
* viz: calculate timeline fixed points in client side
* 26 bytes / event
* math
2025-08-24 01:44:40 +03:00
George Hotz
a75da49951
use AxisType for UPCAST/UNROLL ( #11800 )
...
* use AxisType for UPCAST/UNROLL
* fixes
* fix the bug
* fix hack
* bad test
* flaky test
2025-08-23 14:44:48 -07:00
qazal
2407fecdae
viz bytepack format ( #11792 )
...
* viz bytepack format
Training a 1B llama yields ~20M profiler events.
With JSON serialization, the browser tries to load 6GB to memory. This OOMs since each tab is limited to <3-4GB memory usage. Using a packed format, we only need ~600MB.
**Design decisions:**
- Timestamps are in microseconds relative to start time. They're stored in u32, which can express up to ~1 hr of trace events.
- Strings (kernel names, metadata, etc) are deduped.
- Buffer sizes are in u64 nbytes.
More optimization possible:
- The string lookup is a JSON dumped array, we can compress this.
- Can store less for memory by moving the layout to client.
**Results**
| | Events | JSON | bytepack |
|----------------|---------|-------------|-------------|
| DP=8 llama 1B train (`command: [1]`) | 24M | 5.8GB | 640MB |
| examples/beautiful_mnist.py | 16K | 3.7MB | 745KB |
| examples/gpt2.py | 55K | 12.54MB | 1.40MB |
`[1]`: `VIZ=1 FAKEDATA=1 OFFLOAD_OPTIM=1 DP=8 BS=8 GRADIENT_ACC_STEPS=2 BLOCK_REORDER=0 LR=3e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=8192 PYTHONPATH=. AMD=1 AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py`
* python reference decoder
* 27 bytes / event, 1hr hard limit
2025-08-23 23:50:21 +03:00
qazal
b12d1d866c
count bytes per kernel in test_viz ( #11801 )
...
Currently at ~100 bytes/kernel with JSON.
2025-08-23 23:35:27 +03:00
Sieds Lykles
6a50ab6b87
adjust idiv min_max ( #11802 )
...
* change div min_max
* add tests
2025-08-23 22:25:51 +02:00
chenyu
9d4cccd0f9
test_dtype_alu cleanups ( #11799 )
2025-08-23 15:11:17 -04:00
George Hotz
aefabaf774
add AxisType to range ( #11798 )
...
* add AxisType to range
* missed them
* fix that test
* fix that test
2025-08-23 11:15:00 -07:00
qazal
b975830424
add profile loader helper in test_viz ( #11797 )
2025-08-23 19:20:29 +03:00
chenyu
7123df3928
Use Tensor.logaddexp to implement Tensor.softplus ( #11796 )
...
instead of piecewise linear, numerical is handled by logaddexp. jax does this and i think it's more elegant than torch's approach
2025-08-23 11:52:29 -04:00
chenyu
fb8ee02424
Tensor.logaddexp ( #11793 )
2025-08-23 09:15:00 -04:00
Sieds Lykles
5a6817d5f8
Fix z3 rendering of floats in indexing ( #11740 )
...
* Fix floating point comparison in indexing
* wrap in noop
* update tests
* improve rules for loading and comparing floats
* add test cast to bool
2025-08-23 05:56:19 +02:00
chenyu
e39b25cd36
upcast float exp to at least float32 ( #11758 )
...
* upcast float exp to at least float32
* unlucky seed
2025-08-22 20:16:34 -04:00
qazal
9ff03680ba
viz: store relative timestamps ( #11787 )
...
* viz: store relative timestamps
* err
* update test
2025-08-22 19:30:21 +03:00
geohotstan
1e679bd789
fix max_unpool2d inf ( #11784 )
...
* start
* add regression test for maxunpool2d
2025-08-22 08:31:24 -04:00
George Hotz
9832599c9e
test_vmap + permute isn't a sint ( #11783 )
...
* test_vmap + permute isn't a sint
* order
2025-08-21 22:39:35 -07:00
George Hotz
bb8de51e5f
remove unused early cleanups + contig w range [pr] ( #11780 )
...
* remove unused early cleanups [pr]
* contiguous with range
* woah, this works
2025-08-21 20:04:45 -07:00
chenyu
91a4de4ca7
fix getitem with inf in tensor ( #11781 )
2025-08-21 21:55:32 -04:00