Commit Graph

3444 Commits

Author SHA1 Message Date
chenyu
b9d27636aa cleanup test_ops.py (#3192)
- removed exact duplicated tests
- only kept one function if torch_fxn is the same as tinygrad_fxn
- used tensor method instead of class method style
- replaced unneeded `lamdba f: f(x)` with just `f`
- re-enabled commented tests that work now
- removed some forward_only now 0 shape tensor can backward
2024-01-20 20:08:56 -05:00
chenyu
3f56d1a5e8 add operator.lt and operator.eq to test_dtype_alu (#3191)
* add operator.lt and operator.eq to test_dtype_alu

those should pass now as we have broadcasted before passing to lt and eq.
also updated the test skipping criteria to reuse test_dtype.is_dtype_supported

* llvm lt nan is incorrect

* enable truediv too

* Revert "enable truediv too"

This reverts commit df703235fb.

* just that
2024-01-20 14:54:02 -05:00
chenyu
c4b5661146 fuzz length for multitensor reduce test case (#3190)
so that the uneven case is not just with 0 length and can have other positve values
2024-01-20 00:44:38 -05:00
chenyu
fdb1c2b1d9 move reduce over 0 len axis logic to lazy.py (#3188)
* move reduce over 0 len axis logic to lazy.py

this fixed uneven shard reduce case if the uneven one has length 0

* fix interpreted backends

* fix backwards for 0 shape tensors too
2024-01-20 00:13:03 -05:00
chenyu
485332935e ring copy example (#3185)
* ring copy example

* use ones for init
2024-01-19 23:34:30 -05:00
George Hotz
254a7372fe buffer copy refactor (#3187) 2024-01-19 20:21:24 -08:00
chenyu
fb4bd2a57d reenable padto to search action (#3183) 2024-01-19 14:17:53 -05:00
chenyu
cb4cfc078a parameterize multitensor tests for reduce (#3181)
uneven shards reduce is incorrect now
2024-01-19 14:03:01 -05:00
nimlgen
5097d5b808 fix padto when with late reduce (#3180)
* fix padto test

* no long comment
2024-01-19 14:01:44 -05:00
George Hotz
729a01bf3e complex PRs will not be merged 2024-01-19 10:58:47 -08:00
nimlgen
f87ecbb0f3 fuzzer validates outputs + (partially) oob accesses (#3178)
* fuzzer validates outputs + (partially) oob accesses

* +random

* oob check only for compiled

* type cmp fixes

* fix zeroing

* no prints

* add seed
2024-01-19 13:34:51 -05:00
chenyu
b2571d586c hypothesis.st -> hypothesis.strat (#3179)
leave `st` for shapetracker
2024-01-19 11:55:26 -05:00
chenyu
c4faedebf3 add test cases for negative entry max allreduce (#3177) 2024-01-18 22:26:51 -05:00
chenyu
ab1b7c4d09 fix allreduce for max (#3175)
* test cases to show allreduce for max is incorrect

* oh fixed
2024-01-18 20:25:35 -05:00
George Hotz
c51c90bcd4 more sync in transfer (#3174) 2024-01-18 17:17:03 -08:00
chenyu
28dcbf0e00 test case sharded batchnorm has different ast on devices (#3172) 2024-01-18 18:12:15 -05:00
chenyu
a60d50487d disable padto, seems to have bug in gpt2 (#3173) 2024-01-18 18:09:30 -05:00
George Hotz
c80884884e event driven hip (#3160)
* event driven hip

* simpler, src makes copy

* pass mypy
2024-01-18 14:35:18 -08:00
George Hotz
d2aab65958 remove unused expr node (#3170)
* remove unused expr node

* still works

* simple expr_idxs

* fixup typing
2024-01-18 14:18:43 -08:00
chenyu
097b1390ec touchup test_indexing (#3169) 2024-01-18 14:32:43 -05:00
George Hotz
a04e4d0442 inline clang renderer (#3168) 2024-01-18 11:17:34 -08:00
geohotstan
efbe4788d1 indexing: Final cleanup (#3156)
* init

* feat: add _to_const_val to getitem

* doc: changed docs

* docs: updated more docs

* merge: improved/fancy

* better error msg, minor cleanups

* feat: added index_put to test_indexing

* clean: test_indexing

* revert: gather changes lol

* refactor: use dict for tracking tensor indexing, also asserts for type

* oooooooooops

* ugh

* will revert this commit xD

* fix: removed asserts

* improvement: made in-line if statement clearer

* improved err message and improved slice_int tests

* fix: recover accidentally deleted line

* finishing touches

* reword some docs and del torch device tests in test_indexing

* del some redundant tests

* revert: gather asserts, do it in seperate pr

* fix some data_ptr stuff

* done

* done done
2024-01-18 14:08:03 -05:00
chenyu
e139ae550d smaller limit_dims_to_max (#3167)
same questionable logic, but less lines now
2024-01-18 13:02:20 -05:00
nimlgen
992067399e clean up exceptions in __del__ everywhere (#3165) 2024-01-18 08:34:09 -08:00
Max-We
0338903429 Update kits19.py (#3166) 2024-01-18 08:33:50 -08:00
George Hotz
67bc2ddfd8 JIT cleanups (#3164)
* move GraphException

* factor out apply_graph_to_jit

* that return was wrong
2024-01-17 23:39:57 -08:00
George Hotz
f0c178b7e9 move get_contraction to helpers (#3162)
* move get_contraction to helpers

* move simplify

* lines

* to_movement_ops is not generic
2024-01-17 19:13:11 -08:00
chenyu
e52a609240 make WINO a context var, and LATEWINO in hlb_cifar (#3161) 2024-01-17 20:21:26 -05:00
George Hotz
ee83505fcc fix test extra issue (#3159) 2024-01-17 11:58:08 -08:00
George Hotz
9cc2577a08 use hip events (#3157)
* use hip events

* cleanup
2024-01-17 10:39:57 -08:00
chenyu
1b508e0f71 fix fuzz_linearizer toCPU to as_buffer (#3158) 2024-01-17 13:18:46 -05:00
George Hotz
743b36f0ce hotfix: copy size is in bytes 2024-01-17 16:44:15 +00:00
George Hotz
2e6162b281 graph cleanup (#3155)
* simpler graph

* unused functions
2024-01-16 20:57:31 -08:00
George Hotz
a72b1b6d65 sharding for llama (#3151)
* shard llama

* sharding works

* simpler

* simpler

* consume option

* disable that test

* save a line

---------

Co-authored-by: George Hotz <george@tinygrad.org>
2024-01-16 19:28:00 -08:00
chenyu
14c010958b support for non-uniform sharding (#3154)
* support for non-uniform sharding

* bugfix and more tests

---------

Co-authored-by: George Hotz <geohot@gmail.com>
2024-01-16 20:33:32 -05:00
nimlgen
81ae4ea179 compile cache for several devices (#3148)
* compile cache for several devices

* ops_gpu uses hash to not care about sql

* hip rdna test with device

* linter happy

* no device passed where possible

* arch is optional to compile_{hip|cuda}
2024-01-16 11:45:26 -08:00
chenyu
589c16756f hlb_cifar multi gpu training (#3150)
* cifar train with multi gpu

* GPUS=1 is noop
2024-01-16 14:38:45 -05:00
George Hotz
cc0de99751 hotfix: multilazybuffer can have only one lazybuffer 2024-01-16 10:06:45 -08:00
George Hotz
228f30b96a multitensor jit (#3149)
* initial multitensor jit support and tests

* Added graphs to multitensor jit and updated tests

* update unbind api

* fix set device, add TinyJit to resnet

* update_stats includes device

---------

Co-authored-by: ramenguy99 <ramenguy99@gmail.com>
2024-01-16 09:09:15 -08:00
chenyu
b9d470577c gelu -> quick_gelu in hlb_cifar (#3147)
89 -> 86 seconds, same eval acc
2024-01-16 02:03:37 -05:00
chenyu
ec5a212b0a modernize hlb_cifar (#3146)
* modernize hlb_cifar

do more things in Tensor space instead of numpy, clean up dtypes and use more Tensor methods.

* eigens are float64
2024-01-16 01:35:11 -05:00
chenyu
2088937206 run full hlb_cifar training in tinybox ci (#3145)
* run full hlb_cifar training in tinybox ci

single gpu ~89 seconds

* time that
2024-01-15 23:59:20 -05:00
chenyu
22920a7e55 add LATEBEAM to hlb_cifar (#3142)
still too slow to search on tinybox though
2024-01-15 23:26:03 -05:00
chenyu
766bd0bbe8 make _deepwalk a generator and not passing nodes around (#3141) 2024-01-15 20:26:00 -05:00
George Hotz
120c8b1841 update llvm api + add cache key (#3140)
* update llvm api + add cache key

* use_xcode is a different function

* types
2024-01-15 17:25:32 -08:00
George Hotz
cec0a7bc37 use shard api to eval resnet fast (#3136)
* use shard api to eval resnet fast

* to supports shard

* test to in multitensor
2024-01-15 16:49:38 -08:00
George Hotz
ca0beeef38 Christopherm99 ptx (#3139)
* get basic ptx impl working

* test ops passing

* mypy

* dont hardcode target

* more walrus

* ptx in ci

* bool cast and f16 load/store

* weird numpy bug and f16 cast tolerance

* cast half to bool

* fix 1 byte load/store

* disable half for ptx

* fix args and enable xid

* fix non-ptr args

* allow bitcast

* mypy

* cleanups

* midcast use allclose

* add xor

* Revert "disable half for ptx"

This reverts commit 73391c05fd.

* enable float16

* mypy

* no more crashing in ci

* fix ci

* minor cleanups

* use new fn for ptx compiler

* no diskcache in ptx compile

* use rn instead of rz

* save some lines

* new DEFINE_GLOBAL syntax

* line length

* new llvm

* cmpeq

* minor fix

* cast in mulacc

* update test_recursive_add to check line count

* mypy

* remove llvmir.py

* fix bool const

* wip

* cleanups

* working

* llvm in separate pr

* cleanups

* more cleanups

* fix ci

* use in_features directly in nn.Linear.__init__ bound check (#3050)

* use in_features directly in nn.Linear.__init__ bound check

get rid of the unnecessary check of isinstance int

* that is always int

* long lines

* Device._buffers -> Device._devices (#3052)

backend devices used to be called buffers

* make Embedding device aware for multigpu (#3051)

* make Embedding device aware for multigpu

* split line instead of igore because that's cheating

* add test incomplete

* add test complete

* remove comment

* fix white space

* remove nn.Embedding

* remove unused reciprocal (#3053)

* remove unused reciprocal

* comment

* unit tests for Device.canonicalize (#3055)

* add multigpu test for RMSNorm (#3056)

* need all gather

* add two multigpu test scenarios for RMSNorm

* No extra vars call (#3054)

* remove unused reciprocal

* comment

* remove unneeded call to vars

* free speedup

* explicit lazybuffer caching (#3058)

* hotfix: remove useless slow assert from ShapeTracker

* Speed tweaks (#3059)

* base doesn't have to be a function

* no double fetch

* pop, don't check

* make the gc happy

* avoid hasattr

* cache canonicalize

* remove assert, faster base

* don't redefine that every time

* fix gpt2 attention with start_pos = 0 (#3061)

* fix gpt2 attention with start_pos size 1

test cases taken from ll_transformer branch

* fix interpreted

* Tensor.cat with 0 shape tensors (#3062)

* Tensor.cat with 0 shape tensors

supported both 0 in cat axis (for a subset of input), or 0 in non-cat axis (all needs to be 0)

* no shp

* test scaled dot product attention (#3063)

* add test

* add initial test for scaled dot product attention

* test pass for scaled dot product attention

* cached size (#3060)

* cached size

* simplify simplify

* 0 doesn't have base

* fix test

* cleaner cache

* hmm, metal is flaky on this...might be real(ish) but useless as test

* short circuit reshape/expand properly

* better reshape bypass

* hotfix: use is for enum compare

* hotfix: use is for enum compare, a few more

* speedtweaks3: apply shouldn't use the tensor constructor (#3065)

* speedtweaks3: apply shouldn't use the tensor constructor

* replace 0 size with CONST, not 0 in shape

* update gh actions (#3033)

* update checkout actions

* update upload artifact

* update setup python

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>

* unbind view or shapetracker also returns var_val (#3067)

* unbind view or shapetracker also returns var_val

4% faster for llama compile time

* one line less

* unbound_views

* hotfix: examples/transformer.py

* jit autorealizes output (#3069)

* early gate the graph (#3070)

* simpler idxs_to_idx (#3071)

* filter_strides -> canonicalize_strides (#3072)

* fix onehot and jit in examples/transformer (#3073)

trained to 0.999 in < 6 seconds on M1 Max consistently

* better test demonstration (#3077)

* a better test demonstration

* fix white space

* Tensor.expand resolves the new_shape before shortcut return (#3078)

similar to how reshape is done. also updated shrink shortcut criteria to read similar to pad

* minor cleanups of lazy.py (#3080)

* wmma: clean up device specific tensor core code (#3081)

* mem_estimate is always int, not symbolic (#3083)

* mem_estimate is always int, not symbolic

op_estimate can be symbolic, but mem_estimate is always int, thus we don't need to sym_infer it.
fixed some long lines too. update_stats is a very big function

* operator does not need underscores

* cat works (#3086)

* hotfix disable flaky mac runner wino cifar (#3087)

* remove the third merging state in view._merge_dims (#3085)

no logic depends on state == 0 or state == 2

* minor cleanup of View.reshape (#3088)

* minor cleanup of View.reshape

removed some redundant logic

* new_strides

* revert that

* use BEAM=2 instead of BEAM=4 in cuda ci gpt2 (#3089)

BEAM=2 is faster and less search time. investigating why BEAM2+BEAM4 is slower than BEAM2 alone

* use device from LinearizerOptions in kernel search (#3090)

* use device from LinearizerOptions in kernel search

removed all Device.DEFAULT in search.py

* pass device string for parallel pickle

* device for interpreted backends in LinearizerOptions

* update jit type annotation post lazy rewrite (#3091)

* add mutigpu support for llama attention (#3064)

* add llama attention test for multigpu

* test fails

* kv cache trying to shrink on sharded axis

* mask None works for scale dot product

* kv cache seems to be working but scale dot product breaks

* scaled dot product works, but the last linear layer failed

* running into the reshape case where it could be wrong for multigpu

* making sure it was the reshape

* adding contiguous doesn't solve

* need to shard more properly

* remove reshape test

* minor adjustment to scale dot product attention test

* weights are sharded wrong

* continue fix new weight sharding

* clean up

* fix attention when start_pos is 0

* remove print

* add TODOs for the best mutigpu interface

* bugfix do not reset shapetracker of 0 size lazybuffer (#3096)

it might be coming from an expand, and resetting results incorrect stride. caught by interpreted backend

* One hot in tensor.py (#3093)

* onehot in Tensor.py

* one_hot tests

* works for all shapes, not just 1

* pylint

* not a static method

* moved around, num_classes mandatory

* pylint

* pylint

* space & moving

* formatting

* moved tests

* fix broadcasted logic if there's 0 in shapes (#3097)

* fix broadcasted logic if there's 0 in shapes

should always expand into 0, not the other way around. fixed matmul with 0 in input shapes.
for forwards for now though, backward is more involved and would need to change 0 size shortcuts

* fix tests

* replace with tensor op (#3099)

* fix gpt2 with empty prompt (#3100)

logits would be empty so need to replace that with ones before sampling, also cannot reshape with -1 when there's 0 in other axes

* Revert "fix gpt2 with empty prompt" (#3101)

* fix gpt2 with empty prompt take 2 (#3102)

logits would be empty so need to replace that with ones before sampling, also cannot reshape with -1 when there's 0 in other axes

* wmma: enable METAL half tensor cores and clean up cstyle (#3095)

* wmma: enable METAL half tensor cores and clean up cstyle

* revert simple_matmul rand changes and break line in tensor

* added metal fp16->fp32 tensor core

* add half @ half to mac benchmark (#3103)

* flag to profile mixtral - 1.7 tok/s now (#3104)

* update NumNode.__hash__ to be hash(self.b) (#3105)

with this, `a:=NumNode(x) == b` implies `hash(a) == hash(b)`

* catch runtime error in search._time_program (#3106)

return inf if search encountered runtime errors.

* no exceptions in __del__ when module creation is failed in hip/cuda (#3107)

* failed test case due to cast resets shapetracker (#3109)

cast implicitly resets shapetracker and makes it contiguous (for disk tensor), which fails for Interpreted backend if inputs contain non-contiguous st.

* cleanup ops_disk type annotation and redundant str cast (#3110)

* minor cleanup of test_disk_tensor (#3112)

* add Tensor.var (#3114)

also updated MeanVarianceNormalization and made test_ops test tensors of var and std smaller

* move sample inside jit for beautiful_mnist (#3115)

also removed .realize() for jit functions since jit does it automatically now. a little more beautiful

* minor cleanups of onnx_ops (#3116)

* fix conversation: llama generates token not prob now (#3120)

* add device options for tests in multigpu (#3121)

* make DType a dataclass (#3111)

* remove np from DType

* convert to dataclass

* remove dunder hash, eq, ne overrides from ImageDType

* is dataclass required for PtrDType?

* fix GPU tests

* reduce lines

* revert changes to np

* minor cleanup

* hotfix: ptrdtype compare was broken

* move fromcpu out of lazy.py (#3122)

* move fromcpu out of lazy.py

* fix abstractions2

* remove numpy from device (#3123)

* remove numpy from device

* fix tests

* np item

* cleanups

* simplify with as_buffer

* no toCPU

* tinygradic

* cast to scalar

* remove numpy from ops_torch (#3124)

updated mnist test to cast label to int8 and avoid hacking cast issue of torch uint8

* Fix backward fn for `<` and `==` (#3037)

* fix no grad fn for < and ==

* remove 2 line breaks

* Remove deprecated autograd variable

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>

* separate try except blocks in onnx2torch in model benchmark (#3126)

exceptions can be raised from either model conversion or individual backend failed. openpilot on torch mps works, but does not work with torch cpu.
seperate the expcetion block so that the benchmark can inlcude torch mps for openpilot.

* update env_vars.md (#3127)

mostly removed deprecated ones. not clear how to maintain this especially for extra/examples

* update test_ptr_ne (#3130)

* remove np from metal graph (#3129)

* dtype fmt (#3132)

* dtype fmt

* three ways to access

* fix off-by-one error in st_equal (#3131)

* fix off by one error

* whitespace

* no numpy (#3134)

* fast resnet eval (#3135)

* fast resnet eval

* fix HIP multidevice graph

* neater expression for devices

* lines

* add decorator test

* remove LLVMOPT

* move ptx

* Update ops_cuda.py

---------

Co-authored-by: Christopher Milan <chrismilan@ucla.edu>
Co-authored-by: chenyu <chenyu@fastmail.com>
Co-authored-by: Yixiang Gao <yixiangg310573@gmail.com>
Co-authored-by: jxdv <virgoj@protonmail.com>
Co-authored-by: Francis Lam <flam@alum.mit.edu>
Co-authored-by: SnakeOnex <sheeproman@gmail.com>
Co-authored-by: nimlgen <138685161+nimlgen@users.noreply.github.com>
Co-authored-by: Jyotirmaya Mahanta <jyotirmaya.mahanta@gmail.com>
Co-authored-by: Guy Leroy <g.m.leroy@outlook.com>
Co-authored-by: Paul Gustafson <paul.gustafson@theambrusgroup.com>
2024-01-15 16:44:20 -08:00
chenyu
1ee11411f1 s/lazydata/lazyop/ in print_tree (#3138)
lazyop only now
2024-01-15 19:38:27 -05:00
George Hotz
a5d634a541 simplify dtype (#3137) 2024-01-15 16:27:43 -08:00
George Hotz
e4528543fa remove LLVMOPT 2024-01-15 16:01:09 -08:00