* use full_shape to determine if index can potentially overflow
* update comment
* use shapetracker to check max index value
* wip
* lint
* handle mask
* upcast to int64 by st is noop on WGSL
* fix comments
* Handle negative overflow, intermediaries overflow, int64 support
handle negative overflow
handle symbolic
wip
handle intermediate values
wip
check if typemap support int64
lint
comment
* add invalid_dtype
lint
* Fix bug on checking mask overflow
wip
wip
* Add more tests, need to resolve partial upcast
test Valid_view_dup
test valid op overflow
refine test cases
clean up
cleanup
wip
refine tests
lint
* Upcast is handled by lower_load_store
upcast as graph_rewrite to backtrack
update test
wip
cleanup
wip
cleanup
do upcast in lower_load_store
lint
* cleanup
* do upcast within lower_load_store and mutate ctx
* do upcast in get_idx and view
revert
lint
* cleanup
* Upcast in vec, const
upcast to const
test case 3
upcast on vector
lint
* simplify idx with symbolic in case of fake overflow
test case4
test case 4
update test
* test case4 is only for metal
* try: upcast inside graph_rewrite instead of shapetracker
wip
* checking overflow can just be done directly on all views, with idxs
* cleanup
* REMOVE hard coded uop test for idx upcast
* refactor
cleanup
refactor
* do actual casting when necessary, instead of rewriting all idx
hard code uop test
new upcast
* check dtype for int64 in webgpu
* cleanup
cleanup
* cleanup
* update tests
cleanup
comment
cleanup
cleanup
* comment
* comment
* update comment
update comment
* refactor
* typo
* keep the scope to only upcasting
* white space
* Revert "white space"
This reverts commit 314d7eb184.
* Revert "keep the scope to only upcasting"
This reverts commit 1ef701dd85.
* sym folding is not necessary
lint1
* fold symbolic
lint
* use symbolic simple when folding shapetracker idx
* full sym folding is required after all...
* Ops.CAST should retain the src min max
* put rewrite to lowerer
wip
* start testing on higher level
wip
test higher level in test_tensor
* find Ops.STORE in list instead of recursively
* check dtype support when upcasting
* remove invalid_dtype
* lint
* fix int64 support checks in upcast
lint
* skipif skipunless
* revert fold to find test case
* Revert "revert fold to find test case"
This reverts commit 225bb6e801.
* test sym folding
* handle ptx
* wip
* wip
* delete hard coded uop test
* lint fixes
* wip
* fix checking for None
* lint
* handle ptx
* comment
* dtype for overflow()
* update skipIf skipUnless
* assert in wgsl renderer for int64
wip
* do folded_upcast in to_indexed_op, real_size uses views_to_indexed_ops
* assert in lowerer for dtype support
lint
* Revert "assert in lowerer for dtype support"
This reverts commit 8e9b1b79bf.
* assert dtype in kernel.py
* Revert "assert dtype in kernel.py"
This reverts commit e29b9a9893.
* wip
* assert in render
* remove old assert
* check dtype from rendere, assert in upcast
wip
* smaller arange for sym fold case
* linearize directly
* use expand directly
* lint
* lint
* rename
* no need to check dtype in device.py
* trigger pr
* remove dtype assert in upcast, make wgpu fail in render
* use DType for type hint instead of dtypes
* assert on KeyError in tests for webgpu backend int64
* use a tuple for src
* test real kernel run
wip
* lint error
* restore
* fix real_size
* update test example
* resolve merge stuff
---------
Co-authored-by: Mesozoic Egg <mesozoic.egg@proton.mail>
* assert to prepare for grad uop [pr]
* fix test_nn
* fix most of test_tensor
* few more tests
* fix multi
* uniform gradient
* acc_dtype
* any for multi
* fix typing
* fix assert, CAST_BEFORE_VIEW is still the issue
* explict test for CAST_BEFORE_VIEW
---------
Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
* add support for retain_graph in backward
* fix: dont accumulate grad on non-leaf tensors
* fix order
* fix: do not delete grad on leafs
* fix linter
* fix: can't exactly match torch behaviour internally
* allow numerical room for test
* refactor
* handle float16 overflow in PYTHON
use `truncate` when constructing tensor from list to make sure all values are packable (might be slow, but should be correct). add truncate_fp16 to cast overflowed values to inf/-inf.
* all valid fmt supports truncate
replaced all dtype.np with _to_np_dtype defined in tensor.py.
after this, the only numpy usages are (1) Tensor(np.ndarray), (2) construct .numpy() output, (3) numpy random buffer
* mockgpu nv
* works
* comment that out
* fix merge
* setup gpuocelot
* install packages
* not run all of them
* passes
* fix ci
* almost
* should pass
* linter
* linter 2
* try this?
* ugn, not supported
* ci
* remove ticket from description
* better descs
* lol does this work
* some more changes
* a tiny note
* rename a variable
* add test for data const and add TODO comment
* make type correct
make type correct
* write llm.c and add a few new methods to tensor
* training works
* add jit
* tests for new functions
* test tolist
* simple fix for onnx test failures (#4186)
* write llm.c and add a few new methods to tensor
* training works
* add jit
* tests for new functions
* bump line count to 7500
* simplest fix
* safenumpy tolist for now
---------
Co-authored-by: George Hotz <geohot@gmail.com>
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
---------
Co-authored-by: geohotstan <135171913+geohotstan@users.noreply.github.com>
* initial version
* heh gimme grrrreen
* version 2
* clean ups
* some test confusion
* fix onnx
* rename to _broadcast_tensors
* improved errors and test
* fixed?
* some test fixup
* version 3 lol
* comments
* cleaner
* add failure test for expand to 0 test
* 1 more assertRaises test
* make err msg better
* also rewrite the expand onnx op? :s
* feat: initial xor
* feat: initial threefly
* feat: remove custom random
* fix: really need to install precommit
* feat: lmao forgot that this is rotate not a shift
* clean: put that there
* feat: numpy xor
* feat: quick test for xor
* feat: llvm xor
* feat: slightly working xor in torch
* feat: rand works in jit
* clean: save a line
* feat: match jax
* feat: maybe test against jax
* feat: requires_grad
* fix: fix test_symbolic_ops
* feat: lower alpha
* feat: just pad
* fix: maybe fix training tests?
* fix: fix some llvm stuff
* feat: cursed realize on the way out
* feat: testing jax
* fix: why is the jax install process not simple
* fix: maybe passing test
* fix: symbolic workarounds
* clean: still need that precommit
* fix: aaaa
* fix: more test fixes
* fix: quick fix for wgsl
* feat: need to set requires_grad on the final tensor
* feat: one more tensor
* feat: don't take forever
* feat: seeing y ci is brok
* feat: can't allocate 64GiB lmao
* fix: fix this
* feat: hope this doesn't break smth before i go to bed
* feat: don't destroy ram
* feat: int
* feat: remove jax
* feat: properish workaround?
* feat: skip slow webgpu tests
* feat: no longer fails
* feat: use dtypes
* feat: real number
* fix: torch
* fix: don't test against reference for torch
* feat: to device
* feat: fix advanced indexing
* feat: correct casting
* feat: even rng_counter
* feat: match master
* feat: this was actually bad
* fix: maybe?
* feat: store
* feat: remove realizes
* feat: somehow this is important
* feat: somehow this is also important
* feat: save a line
* fix: don't need that anymore
* feat: restore this
* fix: linter
* feat: remove realizes
* fix: realized is in base now
* fix: add back cast
* fix: bump deadline
* fix: bump deadline
* fix: bump deadline
* fix: bump deadline
* fix: bump deadline
* fix: :(
* fix: :(
* fix: not being dumb
* feat: try changing less tests
* feat: shouldn't have to change that
* feat: contiguous bumps it by one
* fix: hmm
* fix: numpy memory moment
* fix: cl_khr_fp16
* fix: torch has different tensor count
* fix: missing contiguous
* hmm: hmm
* fix: some fixes
* fix: typing
* feat: dont do that
* feat: typing fixes
* feat: why is this realize required?
* feat: ngl kinda odd typing
* feat: oh
* feat: remove realizes
* feat: why is this realize required?
* fix: hacky patch for cudacpu
* fix: without this realize pytest crashes?????
* fix: shorter line
* fix: cudacpu fixes
* fix: cudacpu fixes
* feat: real buffer
* feat: don't search when searching lmao
* fix: can't use contiguous things
* fix: no more 100GB arrays
* fix: revert
* fix: skip 7 and 10
* feat: working ish beam
* feat: minimize changes
* feat: seed 0 stable diffusion example changed
* fix: different on ci
* fix: no beam
* feat: make threefry optional
* fix: check value
* fix: unused import
* feat: threefry default
* fix: 5d
* feat: allow non upcast div
* fix: 5d better
* fix: 5d better
* fix: save all dtype
* feat: proper error
* feat: lazyop key
* fix: check float
* feat: try removing this realize now
* feat: disable threefry for uops hip tensor cores
* feat: don't need that
* feat: only check upcast
* fix: disable threefry for some metal tests
* feat: disable for metal tensor uops as well
* feat: disable for most uops
* fix: disable threefry for new uops tests
* feat: multitensor
* fix: typing
* feat: threefry default off
* feat: skip threefry half rand
* feat: restore old
* fix: bad git
* clean: ruff
* feat: bfloat16 fix
* fix: :|
* feat: restore old
---------
Co-authored-by: chenyu <chenyu@fastmail.com>
* feat: initial xor
* feat: initial threefly
* feat: remove custom random
* fix: really need to install precommit
* feat: lmao forgot that this is rotate not a shift
* clean: put that there
* feat: numpy xor
* feat: quick test for xor
* feat: llvm xor
* feat: slightly working xor in torch
* feat: rand works in jit
* clean: save a line
* feat: match jax
* feat: maybe test against jax
* feat: requires_grad
* fix: fix test_symbolic_ops
* feat: lower alpha
* feat: just pad
* fix: maybe fix training tests?
* fix: fix some llvm stuff
* feat: cursed realize on the way out
* feat: testing jax
* fix: why is the jax install process not simple
* fix: maybe passing test
* fix: symbolic workarounds
* clean: still need that precommit
* fix: aaaa
* fix: more test fixes
* fix: quick fix for wgsl
* feat: need to set requires_grad on the final tensor
* feat: one more tensor
* feat: don't take forever
* feat: seeing y ci is brok
* feat: can't allocate 64GiB lmao
* fix: fix this
* feat: hope this doesn't break smth before i go to bed
* feat: don't destroy ram
* feat: int
* feat: remove jax
* feat: properish workaround?
* feat: skip slow webgpu tests
* feat: no longer fails
* feat: use dtypes
* feat: real number
* fix: torch
* fix: don't test against reference for torch
* feat: to device
* feat: fix advanced indexing
* feat: correct casting
* feat: even rng_counter
* feat: match master
* feat: this was actually bad
* fix: maybe?
* feat: store
* feat: remove realizes
* feat: somehow this is important
* feat: somehow this is also important
* feat: save a line
* fix: don't need that anymore
* feat: restore this
* fix: linter
* feat: remove realizes
* fix: realized is in base now
* fix: add back cast
* fix: bump deadline
* fix: bump deadline
* fix: bump deadline
* fix: bump deadline
* fix: bump deadline
* fix: :(
* fix: :(
* fix: not being dumb
* feat: try changing less tests
* feat: shouldn't have to change that
* feat: contiguous bumps it by one
* fix: hmm
* fix: numpy memory moment
* fix: cl_khr_fp16
* fix: torch has different tensor count
* fix: missing contiguous
* hmm: hmm
* fix: some fixes
* fix: typing
* feat: dont do that
* feat: typing fixes
* feat: why is this realize required?
* feat: ngl kinda odd typing
* feat: oh
* feat: remove realizes
* feat: why is this realize required?
* fix: hacky patch for cudacpu
* fix: without this realize pytest crashes?????
* fix: shorter line
* fix: cudacpu fixes
* fix: cudacpu fixes
* feat: real buffer
* feat: don't search when searching lmao
* fix: can't use contiguous things
* fix: no more 100GB arrays
* fix: revert
* fix: skip 7 and 10
* feat: working ish beam
* feat: minimize changes
* feat: seed 0 stable diffusion example changed
* fix: different on ci
* fix: no beam
* feat: make threefry optional
* fix: check value
* fix: unused import
* feat: threefry default
* fix: 5d
* feat: allow non upcast div
* fix: 5d better
* fix: 5d better
* fix: save all dtype
* feat: proper error
* feat: lazyop key
* fix: check float
* feat: try removing this realize now
* feat: disable threefry for uops hip tensor cores
* feat: don't need that
* feat: only check upcast
* fix: disable threefry for some metal tests
* feat: disable for metal tensor uops as well
* feat: disable for most uops
* fix: disable threefry for new uops tests
* feat: multitensor
* fix: typing
* feat: threefry default off
* feat: skip threefry half rand
* feat: restore old
* fix: bad git
* clean: ruff
* feat: bfloat16 fix
* fix: :|
---------
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>