Commit Graph

4204 Commits

Author SHA1 Message Date
George Hotz
0e0be99b55 Merge branch 'master' into simpler_postrange 2025-08-28 07:22:39 -07:00
geohotstan
4e8370309c Support onnx If OP (#11648)
* start

* tiny clean up

* whoops, didn't mean to accidentally fix this

* fix .to(device), kinda hacky and this fix makes it slower?

* merge properly

* FINALLY figured out slowness, also hack pylint for now

* add DEBUGONNX print for subgraph

* oops

* WOOOOOOOO SHAPE CACHE 50% SPEED INCREASE

* small fix, but maybe all deterministic Tensor creation in fp should be cached

* cache condition

* sliiiightly cleaner

* better abstraction?

* remove sam from model_benchmark

* remove shape cache speed up for now

* less lines

* isinstance fix

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2025-08-28 10:17:35 -04:00
George Hotz
6d6f0dada7 support for tuple ranges (#11890)
* support for tuple ranges

* breaks it
2025-08-28 07:02:31 -07:00
chenyu
beb5982165 FUSE_ATTENTION (#11884) 2025-08-27 19:59:17 -04:00
nimlgen
44816218b5 memplan: fix large buffers planning (#11878)
* memplan: fix large buffers planning

* fix

* fix dsp
2025-08-27 23:54:27 +03:00
George Hotz
e9575c81e2 delete 2025-08-27 12:49:58 -07:00
George Hotz
ea1b853a60 delete 2025-08-27 12:49:58 -07:00
nimlgen
4006366752 Revert "memplan: fix large buffers planning (#11876)" (#11877)
This reverts commit 7f90497efc.
2025-08-27 22:36:14 +03:00
nimlgen
7f90497efc memplan: fix large buffers planning (#11876)
* memplan: fix large buffers planning

* fix
2025-08-27 22:04:15 +03:00
George Hotz
73f83e6fe6 Merge branch 'master' into simpler_postrange 2025-08-27 11:43:12 -07:00
Jordan Chalupka
e9789d8a70 Add mxfp4 support (#11873)
* bump ggml url

* map mxfp4 to tensor

* tests
2025-08-27 10:56:56 -07:00
George Hotz
99c8c37511 working double tc 2025-08-26 22:32:26 -07:00
George Hotz
195feb1b10 flash attention tc 2025-08-26 18:44:20 -07:00
Sieds Lykles
d39365809a add ctx to z3_renderer arg (#11867)
* add ctx to z3_renderer arg

* update symbolic fuzzer

* rewrite u1,u2,u3

* update fuzz_fast_idiv

* remove imports
2025-08-27 03:38:15 +02:00
George Hotz
68d7218f80 double gemm is failing 2025-08-26 17:27:47 -07:00
George Hotz
78e092d59d reorder 2025-08-26 17:10:06 -07:00
George Hotz
c94adb3594 Merge branch 'master' into simpler_postrange 2025-08-26 13:41:24 -07:00
chenyu
7028cb4167 clean up TestBitcastConstFolding (#11856) 2025-08-26 15:26:47 -04:00
George Hotz
f0f7437385 cleanups 2025-08-26 12:02:14 -07:00
George Hotz
b268755d51 small changes from postopt (#11854) 2025-08-26 11:56:16 -07:00
Sieds Lykles
a3aeef45cc associative variation of where branch-merging (#11851)
* add rule and test

* change comment
2025-08-26 19:27:05 +02:00
b1tg
1dd613cb89 test float_to_bf16 round-to-even behavior (#11849)
Co-authored-by: b1tg <b1tg@users.noreply.github.com>
2025-08-26 12:16:10 -04:00
b1tg
409399c609 fix nan in float_to_bf16 (#11843)
Co-authored-by: b1tg <b1tg@users.noreply.github.com>
2025-08-26 11:42:25 -04:00
chenyu
f28f613f85 improved float_to_bf16 (#11848)
round instead of truncate
2025-08-26 11:14:06 -04:00
chenyu
337e979a59 call dtypes.as_const in Tensor(list) (#11840) 2025-08-25 22:08:26 -04:00
chenyu
ac3449b0c8 truncate_fp16 cleanup (#11838)
native `@` is default
2025-08-25 19:03:41 -04:00
qazal
a1f6823060 viz: memory layout in client side (#11830)
* viz: memory layout in client side

* update test_viz
2025-08-25 14:49:33 +03:00
Sieds Lykles
a286a1a6f7 Fast idiv try removing factors of two before cast (#11824)
* try removing factors of two

* dont return if None

* add test
2025-08-24 20:04:25 +02:00
George Hotz
6540bb32a6 move into codegen late [pr] (#11823) 2025-08-24 10:23:25 -07:00
Sieds Lykles
dd69114573 Revert "Better div nesting (#11811)" (#11818)
This reverts commit 952f729b07.
2025-08-24 18:11:24 +02:00
Sieds Lykles
952f729b07 Better div nesting (#11811)
* remove check

* use fold_divmod_congruence instead of simplify

* adjust tests

* shorten line
2025-08-24 04:17:40 +02:00
Sieds Lykles
e652062f92 tweak divmod_folding condition (#11810) 2025-08-24 02:59:02 +02:00
Sieds Lykles
07d4ed7e4c one more symbolic add variation (#11807) 2025-08-24 01:15:04 +02:00
qazal
0d86288bd7 viz: calculate timeline fixed points in client side (#11805)
* viz: calculate timeline fixed points in client side

* 26 bytes / event

* math
2025-08-24 01:44:40 +03:00
George Hotz
a75da49951 use AxisType for UPCAST/UNROLL (#11800)
* use AxisType for UPCAST/UNROLL

* fixes

* fix the bug

* fix hack

* bad test

* flaky test
2025-08-23 14:44:48 -07:00
qazal
2407fecdae viz bytepack format (#11792)
* viz bytepack format

Training a 1B llama yields ~20M profiler events.

With JSON serialization, the browser tries to load 6GB to memory. This OOMs since each tab is limited to <3-4GB memory usage. Using a packed format, we only need ~600MB.

**Design decisions:**

- Timestamps are in microseconds relative to start time. They're stored in u32, which can express up to ~1 hr of trace events.
- Strings (kernel names, metadata, etc) are deduped.
- Buffer sizes are in u64 nbytes.

More optimization possible:

- The string lookup is a JSON dumped array, we can compress this.
- Can store less for memory by moving the layout to client.

**Results**

|  | Events | JSON | bytepack |
|----------------|---------|-------------|-------------|
| DP=8 llama 1B train (`command: [1]`) | 24M | 5.8GB | 640MB |
| examples/beautiful_mnist.py | 16K | 3.7MB | 745KB |
| examples/gpt2.py | 55K | 12.54MB | 1.40MB |

`[1]`: `VIZ=1 FAKEDATA=1 OFFLOAD_OPTIM=1 DP=8 BS=8 GRADIENT_ACC_STEPS=2 BLOCK_REORDER=0 LR=3e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=8192 PYTHONPATH=. AMD=1 AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py`

* python reference decoder

* 27 bytes / event, 1hr hard limit
2025-08-23 23:50:21 +03:00
qazal
b12d1d866c count bytes per kernel in test_viz (#11801)
Currently at ~100 bytes/kernel with JSON.
2025-08-23 23:35:27 +03:00
Sieds Lykles
6a50ab6b87 adjust idiv min_max (#11802)
* change div min_max

* add tests
2025-08-23 22:25:51 +02:00
chenyu
9d4cccd0f9 test_dtype_alu cleanups (#11799) 2025-08-23 15:11:17 -04:00
George Hotz
aefabaf774 add AxisType to range (#11798)
* add AxisType to range

* missed them

* fix that test

* fix that test
2025-08-23 11:15:00 -07:00
qazal
b975830424 add profile loader helper in test_viz (#11797) 2025-08-23 19:20:29 +03:00
chenyu
7123df3928 Use Tensor.logaddexp to implement Tensor.softplus (#11796)
instead of piecewise linear, numerical is handled by logaddexp. jax does this and i think it's more elegant than torch's approach
2025-08-23 11:52:29 -04:00
chenyu
fb8ee02424 Tensor.logaddexp (#11793) 2025-08-23 09:15:00 -04:00
Sieds Lykles
5a6817d5f8 Fix z3 rendering of floats in indexing (#11740)
* Fix floating point comparison in indexing

* wrap in noop

* update tests

* improve rules for loading and comparing floats

* add test cast to bool
2025-08-23 05:56:19 +02:00
chenyu
e39b25cd36 upcast float exp to at least float32 (#11758)
* upcast float exp to at least float32

* unlucky seed
2025-08-22 20:16:34 -04:00
qazal
9ff03680ba viz: store relative timestamps (#11787)
* viz: store relative timestamps

* err

* update test
2025-08-22 19:30:21 +03:00
geohotstan
1e679bd789 fix max_unpool2d inf (#11784)
* start

* add regression test for maxunpool2d
2025-08-22 08:31:24 -04:00
George Hotz
9832599c9e test_vmap + permute isn't a sint (#11783)
* test_vmap + permute isn't a sint

* order
2025-08-21 22:39:35 -07:00
George Hotz
bb8de51e5f remove unused early cleanups + contig w range [pr] (#11780)
* remove unused early cleanups [pr]

* contiguous with range

* woah, this works
2025-08-21 20:04:45 -07:00
chenyu
91a4de4ca7 fix getitem with inf in tensor (#11781) 2025-08-21 21:55:32 -04:00