Commit Graph

181 Commits

Author SHA1 Message Date
George Hotz
bd9c015b09 tests from grad uop path [pr] (#8313) 2024-12-18 09:25:05 -08:00
qazal
d05e21cb69 replace lazy srcs with the new uop api [pr] (#8255)
* buf_uop_view function

* srcs shouldn't exist

* fix TestTensorMetadata

---------

Co-authored-by: George Hotz <geohot@gmail.com>
2024-12-15 17:09:54 +08:00
George Hotz
734f2c5344 compute gradient [pr] (#8237)
* compute gradient [pr]

* schedule_step_with_grads

* second deriv works
2024-12-13 20:46:01 -08:00
chenyu
40a4c603b9 remove more test skip for webgpu [pr] (#8192) 2024-12-12 14:06:35 -05:00
chenyu
72ff631f8d remove unreachable tensor dtype assert (#8190)
it would have failed in `to_dtype`. added some tests for it too
2024-12-12 13:04:49 -05:00
George Hotz
c8e7707a7e hotfix: disable flaky move tensor test 2024-12-10 17:11:21 -08:00
Ahmed Harmouche
a8cfdc70ed Run more webgpu tests (#8142) 2024-12-10 23:20:04 +01:00
chenyu
a77ee72d11 clean up reshape size check [pr] (#8067)
removed a resolve, and remove special case for 0 size assert since it's covered by generic size check
2024-12-06 07:51:19 -05:00
chenyu
3d82f8e340 simpler rand_like (#7680) 2024-11-13 12:28:41 -05:00
George Hotz
205befa788 move is_dtype_supported to device [pr] (#7575) 2024-11-07 20:38:03 +08:00
Tobias Fischer
1a9e145388 Tensor Clone Function (#7154)
* implemented clone function

* cleanup linting, single func

* added tests, cleaned up grad cloning

* fixed whitespace
2024-11-01 12:24:43 +08:00
qazal
7149eabb34 assert set equality in TestTensorMetadata [pr] (#7364) 2024-10-29 19:29:29 +08:00
qazal
0ebdb136e8 revert metadata with graph_rewrite (#7353) (#7362)
This reverts commit 540e4179e7.
2024-10-29 19:16:31 +08:00
qazal
540e4179e7 global UOp to Metadata mapping + inverse DEBUG=2 metadata order [pr] (#7353)
* add ctx.buf_metadata [pr]

* revert metadata insertion order

* lint rename
2024-10-29 17:12:00 +08:00
qazal
e46edc22aa use unittest helpers in TestTensorMetadata [pr] (#7329)
* use unittest helpers in TestTensorMetadata [pr]

* fix that

* 5 args
2024-10-28 18:38:30 +08:00
Bhavya Gada
534597e753 fix all test warnings (#7024)
* fix pytorch warning in nn.conv2d for same padding

* fix future warning in torch load

* fix overflow warning in tensor list test: https://github.com/numpy/numpy/issues/23606#issuecomment-1512752172

* fix floating point warnings in dtype tests using docs https://numpy.org/doc/stable/reference/generated/numpy.errstate.html and a neat solution https://stackoverflow.com/questions/53634965/change-np-seterr-behavior-inside-a-function-only

* put err state in one place; comment taken care of by function hover

* enter np errstate context manager on test setup

* put decorator on class
2024-10-18 08:56:40 +08:00
nimlgen
3c56aeee70 add Tensor.from_blob (#6765)
* draft tensor from pointer init

* some docs and types

* comment

* cleaner

* test

* malloc

* qcom cl interop

* jit example

* cleaner

* dealoc

* wording

* docs
2024-09-26 18:33:19 +08:00
David González Martínez
724e408736 add support for retain_graph in backward (#6145)
* add support for retain_graph in backward

* fix: dont accumulate grad on non-leaf tensors

* fix order

* fix: do not delete grad on leafs

* fix linter

* fix: can't exactly match torch behaviour internally

* allow numerical room for test

* refactor
2024-08-18 16:08:31 -07:00
George Hotz
17a043edad tensor inference (#6156)
* tensor inference

* test is even better name
2024-08-18 00:19:28 -07:00
Jun Zhang
54e176fb4f Ignore non-computational backends when overwriting the default (#5770) 2024-08-10 09:23:29 -07:00
qazal
e6d41b0ce7 hotfix: adjust test_backward_pass_diamond_model thresholds (#5981) 2024-08-09 00:20:53 +08:00
David González Martínez
0f09b94c43 add failing test for second order derivatives (#5772)
* add failing test

* fix lint

* fix bad merge

* fix again

* fix test

* more minimal
2024-08-01 02:34:47 -07:00
David González Martínez
d0fd84e617 feat: allow passing gradient to .backward() to compute vjp (#5771)
* feat: allow passing gradient to .backward() to compute vjp

* fix

* refactor

* fix trailing whitespace
2024-07-28 11:13:18 -07:00
chenyu
e41ab66653 use is to compare types (#5476)
new rule in latest ruff
2024-07-14 14:26:41 -04:00
wozeparrot
9150a6be7a tensor metadata (#5271) 2024-07-08 17:45:40 -07:00
chenyu
cc2be9064f fix out of bound python list into numpy array (#5043)
numpy 2.0 does not allow oob python const and recommends writing as `np.array(value).astype(dtype)`
2024-06-18 18:05:21 -04:00
chenyu
2b2488f2e2 revert creating Tensor from a list without numpy (#5041)
the change was incomplete and broke creating Tensor from a list of np array
2024-06-18 17:31:22 -04:00
chenyu
acaf9a490d RECIP(-0.0) should be -inf (#5024)
* RECIP(-0.0) should be -inf

added test_dtype_alu for PYTHON backend

* catcht that

* fix those two
2024-06-17 22:26:58 -04:00
chenyu
03b367c014 handle float16 overflow in PYTHON (#5022)
* handle float16 overflow in PYTHON

use `truncate` when constructing tensor from list to make sure all values are packable (might be slow, but should be correct). add truncate_fp16 to cast overflowed values to inf/-inf.

* all valid fmt supports truncate
2024-06-17 21:12:52 -04:00
chenyu
64cda3c481 raise TypeError calling len() on a 0-d tensor (#4970)
matched numpy and torch
2024-06-14 16:34:27 -04:00
chenyu
67e8df4969 remove numpy from dtype (#4969)
replaced all dtype.np with _to_np_dtype defined in tensor.py.

after this, the only numpy usages are (1) Tensor(np.ndarray), (2) construct .numpy() output, (3) numpy random buffer
2024-06-14 15:38:45 -04:00
chenyu
dae1c8abe2 create Tensor from bytes without numpy (#4964) 2024-06-14 13:37:27 -04:00
chenyu
287d3c3b84 support list, tuple input in dtypes.from_py (#4945)
* support list, tuple input in dtypes.from_py

and used it to infer dtype from python list and tuple in Tensor constructor.

* fix tests
2024-06-13 13:38:06 -04:00
chenyu
7aecea4f56 support creating Tensor from python tuple (#4944)
added a small fuzzer to test data with mixed tuple and list of numbers matched with numpy
2024-06-13 12:18:37 -04:00
chenyu
45083ccb43 canonicalize 0 in shape in View.create (#4815)
set strides to 0, offset to 0, mask to None, and contiguous to True with size 0 view.
2024-06-03 13:37:37 -04:00
chenyu
8942230b1f minor cleanups of test_tensor and extend some cases (#4794) 2024-05-31 10:43:22 -04:00
qazal
637f482588 configure derandomizing CI tests (#4793) 2024-05-31 17:06:58 +03:00
nimlgen
eb9689336e nv mockgpu (#4600)
* mockgpu nv

* works

* comment that out

* fix merge

* setup gpuocelot

* install packages

* not run all of them

* passes

* fix ci

* almost

* should pass

* linter

* linter 2

* try this?

* ugn, not supported

* ci

* remove ticket from description

* better descs
2024-05-15 23:46:08 +03:00
geohotstan
269a58d5fa tolist to return multidimensional list (#4192)
* lol does this work

* some more changes

* a tiny note

* rename a variable

* add test for data const and add TODO comment

* make type correct

make type correct
2024-04-18 07:43:10 +04:00
George Hotz
55ae73e951 Replicate llm.c in tinygrad (#4179)
* write llm.c and add a few new methods to tensor

* training works

* add jit

* tests for new functions

* test tolist

* simple fix for onnx test failures (#4186)

* write llm.c and add a few new methods to tensor

* training works

* add jit

* tests for new functions

* bump line count to 7500

* simplest fix

* safenumpy tolist for now

---------

Co-authored-by: George Hotz <geohot@gmail.com>
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>

---------

Co-authored-by: geohotstan <135171913+geohotstan@users.noreply.github.com>
2024-04-16 15:40:48 +04:00
geohotstan
183708b3fd broadcast expand to match torch (#4085)
* initial version

* heh gimme grrrreen

* version 2

* clean ups

* some test confusion

* fix onnx

* rename to _broadcast_tensors

* improved errors and test

* fixed?

* some test fixup

* version 3 lol

* comments

* cleaner

* add failure test for expand to 0 test

* 1 more assertRaises test

* make err msg better

* also rewrite the expand onnx op? :s
2024-04-07 16:23:13 -04:00
wozeparrot
a0ab755317 threefry again (#3785)
* feat: initial xor

* feat: initial threefly

* feat: remove custom random

* fix: really need to install precommit

* feat: lmao forgot that this is rotate not a shift

* clean: put that there

* feat: numpy xor

* feat: quick test for xor

* feat: llvm xor

* feat: slightly working xor in torch

* feat: rand works in jit

* clean: save a line

* feat: match jax

* feat: maybe test against jax

* feat: requires_grad

* fix: fix test_symbolic_ops

* feat: lower alpha

* feat: just pad

* fix: maybe fix training tests?

* fix: fix some llvm stuff

* feat: cursed realize on the way out

* feat: testing jax

* fix: why is the jax install process not simple

* fix: maybe passing test

* fix: symbolic workarounds

* clean: still need that precommit

* fix: aaaa

* fix: more test fixes

* fix: quick fix for wgsl

* feat: need to set requires_grad on the final tensor

* feat: one more tensor

* feat: don't take forever

* feat: seeing y ci is brok

* feat: can't allocate 64GiB lmao

* fix: fix this

* feat: hope this doesn't break smth before i go to bed

* feat: don't destroy ram

* feat: int

* feat: remove jax

* feat: properish workaround?

* feat: skip slow webgpu tests

* feat: no longer fails

* feat: use dtypes

* feat: real number

* fix: torch

* fix: don't test against reference for torch

* feat: to device

* feat: fix advanced indexing

* feat: correct casting

* feat: even rng_counter

* feat: match master

* feat: this was actually bad

* fix: maybe?

* feat: store

* feat: remove realizes

* feat: somehow this is important

* feat: somehow this is also important

* feat: save a line

* fix: don't need that anymore

* feat: restore this

* fix: linter

* feat: remove realizes

* fix: realized is in base now

* fix: add back cast

* fix: bump deadline

* fix: bump deadline

* fix: bump deadline

* fix: bump deadline

* fix: bump deadline

* fix: :(

* fix: :(

* fix: not being dumb

* feat: try changing less tests

* feat: shouldn't have to change that

* feat: contiguous bumps it by one

* fix: hmm

* fix: numpy memory moment

* fix: cl_khr_fp16

* fix: torch has different tensor count

* fix: missing contiguous

* hmm: hmm

* fix: some fixes

* fix: typing

* feat: dont do that

* feat: typing fixes

* feat: why is this realize required?

* feat: ngl kinda odd typing

* feat: oh

* feat: remove realizes

* feat: why is this realize required?

* fix: hacky patch for cudacpu

* fix: without this realize pytest crashes?????

* fix: shorter line

* fix: cudacpu fixes

* fix: cudacpu fixes

* feat: real buffer

* feat: don't search when searching lmao

* fix: can't use contiguous things

* fix: no more 100GB arrays

* fix: revert

* fix: skip 7 and 10

* feat: working ish beam

* feat: minimize changes

* feat: seed 0 stable diffusion example changed

* fix: different on ci

* fix: no beam

* feat: make threefry optional

* fix: check value

* fix: unused import

* feat: threefry default

* fix: 5d

* feat: allow non upcast div

* fix: 5d better

* fix: 5d better

* fix: save all dtype

* feat: proper error

* feat: lazyop key

* fix: check float

* feat: try removing this realize now

* feat: disable threefry for uops hip tensor cores

* feat: don't need that

* feat: only check upcast

* fix: disable threefry for some metal tests

* feat: disable for metal tensor uops as well

* feat: disable for most uops

* fix: disable threefry for new uops tests

* feat: multitensor

* fix: typing

* feat: threefry default off

* feat: skip threefry half rand

* feat: restore old

* fix: bad git

* clean: ruff

* feat: bfloat16 fix

* fix: :|

* feat: restore old

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2024-03-18 16:47:07 -04:00
George Hotz
311cf2b7d3 Revert "threefry_2x32 (#2601)" (#3784)
This reverts commit db3de54bc4.
2024-03-17 10:27:20 -07:00
wozeparrot
db3de54bc4 threefry_2x32 (#2601)
* feat: initial xor

* feat: initial threefly

* feat: remove custom random

* fix: really need to install precommit

* feat: lmao forgot that this is rotate not a shift

* clean: put that there

* feat: numpy xor

* feat: quick test for xor

* feat: llvm xor

* feat: slightly working xor in torch

* feat: rand works in jit

* clean: save a line

* feat: match jax

* feat: maybe test against jax

* feat: requires_grad

* fix: fix test_symbolic_ops

* feat: lower alpha

* feat: just pad

* fix: maybe fix training tests?

* fix: fix some llvm stuff

* feat: cursed realize on the way out

* feat: testing jax

* fix: why is the jax install process not simple

* fix: maybe passing test

* fix: symbolic workarounds

* clean: still need that precommit

* fix: aaaa

* fix: more test fixes

* fix: quick fix for wgsl

* feat: need to set requires_grad on the final tensor

* feat: one more tensor

* feat: don't take forever

* feat: seeing y ci is brok

* feat: can't allocate 64GiB lmao

* fix: fix this

* feat: hope this doesn't break smth before i go to bed

* feat: don't destroy ram

* feat: int

* feat: remove jax

* feat: properish workaround?

* feat: skip slow webgpu tests

* feat: no longer fails

* feat: use dtypes

* feat: real number

* fix: torch

* fix: don't test against reference for torch

* feat: to device

* feat: fix advanced indexing

* feat: correct casting

* feat: even rng_counter

* feat: match master

* feat: this was actually bad

* fix: maybe?

* feat: store

* feat: remove realizes

* feat: somehow this is important

* feat: somehow this is also important

* feat: save a line

* fix: don't need that anymore

* feat: restore this

* fix: linter

* feat: remove realizes

* fix: realized is in base now

* fix: add back cast

* fix: bump deadline

* fix: bump deadline

* fix: bump deadline

* fix: bump deadline

* fix: bump deadline

* fix: :(

* fix: :(

* fix: not being dumb

* feat: try changing less tests

* feat: shouldn't have to change that

* feat: contiguous bumps it by one

* fix: hmm

* fix: numpy memory moment

* fix: cl_khr_fp16

* fix: torch has different tensor count

* fix: missing contiguous

* hmm: hmm

* fix: some fixes

* fix: typing

* feat: dont do that

* feat: typing fixes

* feat: why is this realize required?

* feat: ngl kinda odd typing

* feat: oh

* feat: remove realizes

* feat: why is this realize required?

* fix: hacky patch for cudacpu

* fix: without this realize pytest crashes?????

* fix: shorter line

* fix: cudacpu fixes

* fix: cudacpu fixes

* feat: real buffer

* feat: don't search when searching lmao

* fix: can't use contiguous things

* fix: no more 100GB arrays

* fix: revert

* fix: skip 7 and 10

* feat: working ish beam

* feat: minimize changes

* feat: seed 0 stable diffusion example changed

* fix: different on ci

* fix: no beam

* feat: make threefry optional

* fix: check value

* fix: unused import

* feat: threefry default

* fix: 5d

* feat: allow non upcast div

* fix: 5d better

* fix: 5d better

* fix: save all dtype

* feat: proper error

* feat: lazyop key

* fix: check float

* feat: try removing this realize now

* feat: disable threefry for uops hip tensor cores

* feat: don't need that

* feat: only check upcast

* fix: disable threefry for some metal tests

* feat: disable for metal tensor uops as well

* feat: disable for most uops

* fix: disable threefry for new uops tests

* feat: multitensor

* fix: typing

* feat: threefry default off

* feat: skip threefry half rand

* feat: restore old

* fix: bad git

* clean: ruff

* feat: bfloat16 fix

* fix: :|

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-17 10:19:33 -07:00
Francis Lata
957ae9b594 Fix Tensor's __repr__ for printing out grad (#3673)
* update check for Tensor's __repr__ with grad

* add test for repr with grad bugfix
2024-03-10 17:04:29 -04:00
Maximilian Wolf
8ae85b2cf5 add inference_mode context manager with decorator support (#3621)
* add inference_mode context manager with decorator support

* change val to mode for train and inference_mode

* fix wrong rename
2024-03-09 08:38:26 -08:00
chenyu
4552248c84 fix Tensor.to preserves grad.data (#3636) 2024-03-06 21:44:49 -05:00
chenyu
8f10bfa2ff ban __bool__ on Tensor (#3632)
* ban __bool__ on Tensor

avoid misuse

* test case

* fix tests

* fix more tests
2024-03-06 17:12:35 -05:00
chenyu
282bbd5acb check the input length into argfix (#3610)
* check the input length into argfix

it's possible to overlook setting keyword for kwargs and argfix silently truncates input

* add test
2024-03-04 19:50:17 -05:00
Marcin Słowik
56d21d77b3 Fix two bugs concerning Tensor.to. (#3593)
1. Tensor.to should return self if device == self.device. This was not the case if provided with non-canonical name of self.device.
2. Tensor.to result was missing graph, even though requires_grad and grad were propagated .

Add corresponding tests.
2024-03-03 08:48:56 -08:00