Commit Graph

1021 Commits

Author SHA1 Message Date
nimlgen
052722a7bc torch hook: address comments (#9295)
* torch hook: address comments

* failed test
2025-02-28 11:51:52 +03:00
Priyank Patel
8ae215dd3d torch backend fix manual seed warning (#9292) 2025-02-28 13:45:32 +08:00
George Hotz
ac40316692 hotfix: group cpu functions in torch backend 2025-02-28 10:39:00 +08:00
George Hotz
b32595dbbc torch examples (#9290)
* torch, fix examples/mnist

* fix vae torch example

* where out
2025-02-28 10:16:06 +08:00
hooved
3b9950241e tinychat in browser, Part 1: llama (#9273)
* load llama3-1B to WEBGPU device

* include compile script for loading llama3 to WEBGPU

* parametrize max_context in build_transformer fxn

* jit_model with two different args sets

* compile for webgpu, split weights

* load model weight parts in browser

* export all tensors from initialized transformer

* run transformer inference in browser

* enable tiktoken with llama bpe in browser

* count total tokens on client with tiktoken.js

* full client-side chat streaming, eliminate server

* revert change that enabled jitting with 2 argsets

* llama without Variable or cache_kv, for webgpu

* have client use mask tokens / whole context

* cleanup staged weights

* add tiktoken.js build script, README

* export CLANG for Q6_k to float32 decompression

* fix and test exported CLANG code for Q6_k to fp32

* revert changes to jit and export_model

* isolate clang export

* test Q6_K to float32 decompression in browser

* gguf_load now also returns t_infos and data_start

* prepare llama-1B Q6_K gguf chunks for browser

* cache and decompress quantized llama in browser

* enable separate deployment of large files

* fix kv cache and symbolic with llama wgpu

* eliminate browser lag during decompression

* hash metadata and weight chunks

* delete obsolete indexeddb cache to free disk

* add progress bar, track model download/decompress

* refactor progress callback

* skip buffer hash verification for speed

* Display progress for entire loading scope

* Report page load errors to user

* actually display errors

* skip prompt tokens already seen by model

* skip prefilling with last assistant message tokens

* on page load tell user if webgpu not enabled

* push deployed URL root to window.history

* make note of bug sources with TODO items

* isolate bug in CLANG with BEAM=2

* remove clang_bug.py from diff

* decompress q6k to f32 on webgpu instead of clang

* remove unused code

* inter-weight decomp with larger wgpu kernels

* parallelize decompression submissions

* refactor dequantize scheduling

* add progress bar back

* fix bug

* temp fix for loading GGUF Q6_K to fp16 not fp32

* fix rendering of exported CLANG

* remove weight casts, sketch js functions for clang

* get symbolic vars from jit_cache for model export

* include symbolic vars in exported CLANG

* render js for clang transformer

* toggle clang/webgpu deployment; refactor decomp

* compile and render clang Q6_K->fp16 and int8 quant

* fix rendered clang for abs(fp16), to work in wasm

* simplify clang js wrapping

* run compiled clang in worker

* prepare llama weights in workers, q6k to int8/fp16

* tinychat on clang in browser, f32/int8 weights

* move wasm inference to (now flexible) worker

* don't load redundant embeddings

* modest wasm perf gain with compile flags

* set default backend, enable backend choice/backup

* render symbolic vars in exported WEBGPU

* quantize webgpu llama to int8/f32

* improve UX arising from rendered WEBGPU

* clean up webgpu launch

* new weights split: smaller chunks, tinygrad quant.

* switch webgpu inference to int8 quant

* remove unneeded clang decompression

* eliminate unneeded kv cache transfer to wasm

* use 1 worker for simplified clang decompression

* display launch errors

* refactor: stream load weight chunks to WebGPU

* show loading chunk completion

* quantize embeddings to int8

* test float16 as input for quantization

* webgpu: use f16 source, int8 embed, eliminate q6k

* simplify split weights prep: all from state_dict

* revert change to nn.state.gguf_load

* remove unneeded decompression from webgpu client

* remove unneeded code

* decrease dl chunks from 47 to 16 MiB

* improve stability of webgpu loading on mobile

* autodetect mobile, improve load stability

* refactor: progress closure

* refactor: one unified progress bar

* remove unneeded code

* revert changes to tinygrad core library

* enforce ios18.3 nerfed max buf size

* BEAM=3 webgpu

* cache integrity, mobile save throttling

* improve mobile UX - no autozoom on prompt box

* clang: int8 from f16, remove q6k

* reduce concurrent dls on mobile to 2 for stability

* refactor: wasm backend with stream loading

* prevent race between wasm load and indexedb save

* split wasm kernels into separate modules

* js wrapper for multiple wasm module inference

* revert multi-module wasm to single module

* make mobile wasm load more stable/fast

* refactor: copy weights into wasm without crashes

* fix bug in download queue; increase mobile dls

* refactor exported clang wrapper, split weights

* remove unnecessary code

* greatly improve int8 quant quality with rounding

* eliminate mobile throttling

* increase webgpu context to 4096 tokens

* export webgpu js functions

* enable separate hosted weights for mobile/pc

* enable prompt-thread switching during generation

* stop generation when max_context is reached

* show progress bar for prefill

* tell user if webgpu fails, while wasm loads

* make loading messages more concise

* update font

* revert changes to tinychat python app launch

* cleanup quantization, add scale_dtype param

* cleanup kv cache code

* cleanup compile code

* link tok_embeddings with output in webgpu export

* refactor: export_model webgpu: symbolic vars

* refactor: export_model weight loading

* forgot to commit export_model.py

* change CLANG to CPU

* deal with pylint incorrectly failing tests

* simplify f-strings for older CI python version

* fix pre-python3.12 parser errors

* [Int32Array] not Int32Array

* cleanup webgpu compile after refactor export_model

* refactor WASM export into export_model

* merge WebGPU/WASM compile scripts

* simplify max_contexts for local deployment

* fix parser issues and whitespace

* deduplicate variable defs for non-wasm clang export

* cleanup code

* cleanup compile scripts

* simplify wasm inference wrapping

* simplify webgpu symbolic vars export

* refactor: unify export of symbolic variables

* simplify WASM export

* simplify clang/wasm export

* update README and build scripts

* separate files for browser/python apps

* restore original python tinychat app files

* browser and python tinychats share assets

* minor cleanup

* isolate diffs to llama files

* minor cleanup

* set default scale_dtype

* set default scale_dtype for NF4 quantization

* make quantization of tok_embeds optional

* match output with tok_embeds if not quantizing

* minor change
2025-02-27 15:57:37 -05:00
chenyu
184030168d fix aten.reflection_pad2d (#9289)
tested the torch doc example
2025-02-27 15:53:46 -05:00
chenyu
0de6585df0 fix aten.normal_ arg (#9288)
should be mean and std.
2025-02-27 15:36:25 -05:00
chenyu
8ee2b460ee Tensor.var_mean (#9287) 2025-02-27 15:15:31 -05:00
nimlgen
43e60914f3 init torch hooking (#9284)
* smth

* mv

* prof wk

* revert and move

* fix

* nvprof

* fix and no print much
2025-02-27 19:36:55 +03:00
George Hotz
387ea41e99 increase speed of torch mnist: use gradient api (#9282) 2025-02-27 11:57:41 +08:00
Priyank Patel
a0764f0dc0 (bounty) Make mnist training run with torch backend (#9233)
* yml changes

* torch backend remove meta decomps and add test

* torch backend bump timeout for tests

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2025-02-27 11:32:25 +08:00
George Hotz
9088125a6a a lil more torch (#9280) 2025-02-27 11:12:20 +08:00
George Hotz
b6a14911c8 start torch.compile support (#9279) 2025-02-27 10:29:51 +08:00
Francis Lata
86b737a120 leakyrelu to leaky_relu (#9270) 2025-02-26 13:22:08 -05:00
chenyu
aaf0a8069f xor -> bitwise_xor (#9264) 2025-02-26 10:21:14 -05:00
George Hotz
2158dc4849 full fix for as_strided in torch backend (#9257)
* fixes from chargpt for torch backend

* shrink support

* add stride support

* comment cleanup

* a few more

* work

* import the stream hack

* llvm multi auto
2025-02-26 22:34:05 +08:00
George Hotz
7780393460 rig up torch's testing framework [pr] (#9254)
* rig up torch's testing framework [pr]

* support more movement ops

* dec on expand

* fix tests

* work

* fix tests

* a few more

* decomps + opt hook

* installed pytest
2025-02-26 18:46:22 +08:00
George Hotz
b603af373e run some tests from torch [pr] (#9252)
* run some tests from torch [pr]

* yml

* wrap_out

* clean up for the new people

* a lil more
2025-02-26 15:42:22 +08:00
geohotstan
f0b24d230c add test_onnx_ops.py (#8569)
* boom

* fix webgpu

* use exact variable names in test so that AI can read easier

* add tag for specific test name like test a specific dtype

* fix ruff

* astype everything

* dtype in array creation

* just arange

* is 67% considered fixed?

* move test up

* small cleanups

* share function

* add qgemm as well

* add qgemm too

* make sure qgemm comes out as int

* take out qgemm for now

* fixed test

* add correct qgemm

* addressing feedback here too, early naive fix for now

* simplify bias and c to be minimalistic enough to test correctness

* refactored qlinearops

* maybe these asserts aren't the best..

* fix test

* updated tests to cover new ops

* try to add to CI

* move test_onnx_ops into testextra/

* more attention tests

* qlinear_add atol=1

* attention still not fullllllly correct

* it is what it is

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2025-02-24 16:15:22 -05:00
George Hotz
fc32ff80d6 torch and numpy dtype interop [pr] (#9224)
* torch and numpy dtype interop [pr]

* less lines

* order
2025-02-24 18:26:49 +08:00
George Hotz
fd731e740a hotfix: add note on backend2.py 2025-02-24 11:23:03 +08:00
albanD
f2dd9c1562 simplify c++ code (#9221) 2025-02-24 11:04:41 +08:00
George Hotz
97bc723538 torch backend works for ResNet-18 (#9200)
* torch backend progress, a few more functions

* resnet works

* pillow

* tv
2025-02-22 22:16:23 +08:00
George Hotz
4e6665bda5 different way to write torch backend (#9197)
* different way to write torch backend

* both backends

* more work

* simpler code

* more work

* test both

* imply unwrap/wrap

* FORWARD_ONLY=1 TINY_BACKEND=1 python3 test/test_ops.py TestOps.test_add works

* ready to start making test_ops work in torch backend

* backward pass, TINY_BACKEND=1 python3 test/test_ops.py TestOps.test_add works

* FORWARD_ONLY=1 TINY_BACKEND=1 python3 test/test_ops.py TestOps.test_simple_conv2d works

* matmul backward is broken with as_strided
2025-02-22 14:42:26 +08:00
geohotstan
6587c7879b simple fixes to onnx (#9195)
* uncontroversial changes

* cleaner _prepare_quantize
2025-02-21 13:10:06 -05:00
George Hotz
e87be0131e torch backend start (#9191)
* start torch backend

* progress

* ugh, you need cpp crap

* 1+1 works

* 1+1 works

* becoming a real backend

* ready to merge?
2025-02-21 16:57:28 +08:00
chenyu
2e7c2780a9 CLANG -> CPU (#9189) 2025-02-20 18:03:09 -05:00
chenyu
1692087db5 _one_hot_along_dim input needs to be int (#9179)
* _one_hot_along_dim input needs to be int

indexing and onehot compare with arange, and non-int dtype is likely a bug
2025-02-20 09:00:43 -05:00
George Hotz
bf36967883 cuda hooking (#9180)
* cuda hooking

* progress

* more hook cuda

* fix params

* compile + cuMemHostAlloc hook

* work

* revert that
2025-02-20 19:20:01 +08:00
chenyu
975c318dbc bert use int32 for input ids (#9173)
original data was int32 for these. float might have caused precision issues
2025-02-19 08:17:27 -05:00
George Hotz
7eea9b639d hotfix: add replay_pkl debugging env 2025-02-17 17:34:58 +08:00
chenyu
1fda98d14f fix import time_linearizer [pr] (#9118)
only test that used it was skipped in CI due to being slow
2025-02-15 21:33:28 -05:00
Josh Moore
1f9d2442b9 Add Tensor.scatter_reduce (#8947)
* pytorch scatter -> scatter_reduce

* WIP scatter_reduce implementation

* _pre_scatter return type hint

* split out src, mask to satisfy linter

* Add src cast back in

* dict of lambdas instead of ifs

* sum and prod reduction ops with include_self

* add reduce arg error message

* add amax and amin reduction ops

* Fix include_self for higher dims

* Simplify

* Simplify amax and amin too

* Pull include_self logic out into _inv_mask function

* reduce arg cannot be None for scatter_reduce

* Fix self-mask issue

* Add mean reduce op

* Add tests

* any() not needed here

* remove comment

* End support for Tensor src with reduce arg in tinygrad scatter

* Process index, dim inside actual functions

* Add scatter_reduce to onnx

* Add excluded onnx ScatterElements reduction tests back in

* Save 2 lines on the mask helpers

* Update docs

* Add include_self=False tests

* cleanup

* Remove unneeded helper function

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2025-02-13 09:08:54 -05:00
George Hotz
74742c018f hotfix: setup_mock_nv_osx 2025-02-13 12:26:15 +08:00
chenyu
f4f56d7c15 move time_linearizer to extra.optimization.helpers [pr] (#9048)
no longer used in tinygrad
2025-02-12 15:49:58 -05:00
divinity76
bec4f59ce8 workaround f16 cast ambiguity (#8935)
for unknown reasons, without this, when trying to execute "Llama 3.2 1B", I get the error below. Fwiw I do not know the performance impact for this change. I can't even get exo running, but this change allows me to /get further/ (before running into a separate issue with vram allocation? story for another day i suppose)

error: 
```
Failed to fetch completions: Error processing prompt (see logs with DEBUG>=2): Nvrtc Error 6, NVRTC_ERROR_COMPILATION <null>(18): error: more than one user-defined conversion from "nv_bfloat16" to "half" applies:
            function "__half::__half(float)" (declared at line 214 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(short)" (declared at line 227 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned short)" (declared at line 228 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(int)" (declared at line 229 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned int)" (declared at line 230 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(long long)" (declared at line 231 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned long long)" (declared at line 232 of /usr/include/cuda_fp16.hpp)
    *((half4*)((data0+(alu0+(gidx1<<14)+(lidx0<<11)+alu1)))) = make_half4(((half)(val0)),((half)(val1)),((half)(val2)),((half)(val3)));
                                                                                 ^

<null>(18): error: more than one user-defined conversion from "nv_bfloat16" to "half" applies:
            function "__half::__half(float)" (declared at line 214 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(short)" (declared at line 227 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned short)" (declared at line 228 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(int)" (declared at line 229 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned int)" (declared at line 230 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(long long)" (declared at line 231 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned long long)" (declared at line 232 of /usr/include/cuda_fp16.hpp)
    *((half4*)((data0+(alu0+(gidx1<<14)+(lidx0<<11)+alu1)))) = make_half4(((half)(val0)),((half)(val1)),((half)(val2)),((half)(val3)));
                                                                                                ^

<null>(18): error: more than one user-defined conversion from "nv_bfloat16" to "half" applies:
            function "__half::__half(float)" (declared at line 214 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(short)" (declared at line 227 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned short)" (declared at line 228 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(int)" (declared at line 229 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned int)" (declared at line 230 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(long long)" (declared at line 231 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned long long)" (declared at line 232 of /usr/include/cuda_fp16.hpp)
    *((half4*)((data0+(alu0+(gidx1<<14)+(lidx0<<11)+alu1)))) = make_half4(((half)(val0)),((half)(val1)),((half)(val2)),((half)(val3)));
                                                                                                               ^

<null>(18): error: more than one user-defined conversion from "nv_bfloat16" to "half" applies:
            function "__half::__half(float)" (declared at line 214 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(short)" (declared at line 227 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned short)" (declared at line 228 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(int)" (declared at line 229 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned int)" (declared at line 230 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(long long)" (declared at line 231 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned long long)" (declared at line 232 of /usr/include/cuda_fp16.hpp)
    *((half4*)((data0+(alu0+(gidx1<<14)+(lidx0<<11)+alu1)))) = make_half4(((half)(val0)),((half)(val1)),((half)(val2)),((half)(val3)));
                                                                                                                              ^

4 errors detected in the compilation of "<null>".
```
2025-02-11 09:38:56 +08:00
nimlgen
dfc9d6827f am_smi: print power state (#9013) 2025-02-10 23:07:39 +03:00
nimlgen
f91409f038 am: fix proclogs (#9004) 2025-02-10 16:38:58 +03:00
nimlgen
c6c2373bc0 replace libpciaccess autogen with just pci regs (#8983)
* replace libpciaccess autogen with just pci regs

* add pci.py
2025-02-09 18:40:45 +03:00
George Hotz
6ffee2fca9 reduce speed example [pr] (#8978)
* reduce speed example

* fast like a nascar
2025-02-09 14:13:59 +08:00
George Hotz
a3c78d47b3 speed docs + upgrades [pr] (#8964)
* add some docs about speed [pr]

* better torch gemm

* enable locals on llvm/clang

* disable locals for beam speed on LLVM/CLANG

* 0x20 alignment in llvm allows ymm use
2025-02-08 17:28:52 +08:00
nimlgen
11d50324d8 am: tiny cleanups (#8958)
* am: start cleanups

* am
2025-02-07 23:44:43 +03:00
Ahmed Harmouche
133cacadde Autogen webgpu dawn, removing wgpu-py dependency (f16 support part 1) (#8646)
* Switch to dawn, all tests passing locally

* Use dawn-python

* Skip failing test

* Skip midcast and fix timestamp on metal ci

* Autogen webgpu

* Try fetch dawn lib again

* /usr/lib

* Without lib prefix

* Test autogen diff

* Delete webgpu support, move everything to ops_webgpu

* mypy fix

* Simplify, refactor

* Line savings

* No ResultContainer

* Type annotation for result

* Some more simplifications

* Why was this explicit sync used at all?

* Refactor: delete functions that are only used once

* Create shader module inline

* Clear unit tests cache, maybe that solves it

* That wasn't it

* Try deleting cache to pass failing weight compare

* weights_only=False for pytorch 2.6

* Simplify ctype array creation

* Remove nanosecond precision timestamps

* Simplify error handling

* Refactor, add back type annotations

* Deleted custom submit function, refactor

* read_buffer simplify

* Fix use after free, refactor

* Simplify supported_features

* Runtime docs

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2025-02-07 15:16:59 +08:00
nimlgen
ee1a0fb8ec am_smi: print device name (#8939) 2025-02-07 03:01:25 +03:00
chenyu
a092b6395d Tuple -> tuple, List -> list [pr] (#8936) 2025-02-06 14:21:19 -05:00
nimlgen
86feb98dcd am: add support for 7600 (#8910)
* am: start to add support for 7600

* test_tiny passes

* mmhub 3 0 2

* cleaner
2025-02-06 14:04:07 +03:00
geohotstan
057c70b05f add onnx_helpers to extra and add ort validate to benchmark_onnx (#8890)
* start

* log severity

* only change this

* change abstraction so it's more usable for huggingface

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2025-02-04 16:36:01 -05:00
George Hotz
56fa5c1191 dsp simulator (#8869)
* dsp simulator

* progress

* fix

* close on test tiny

* working

* less waste

* line savings

* Device DSP compiler

* mock DSP at the bottom

* DSP tests

* docker caching

* test update

* need load

* skip that test for CI DSP

* last touch

* ugh
2025-02-04 09:45:04 +08:00
geohotstan
d1aa9f30bc copy onnx_ops into onnx (#8876)
* just copy it over

* make OnnxOps a global var

* some small style stuff

* rerun CI but also some small clean up

* some comments
2025-02-03 12:15:07 -05:00
George Hotz
f484db0e63 dsp cleanups [pr] (#8866) 2025-02-03 15:18:53 +08:00