mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-15 17:15:48 -05:00
* feat: initial xor * feat: initial threefly * feat: remove custom random * fix: really need to install precommit * feat: lmao forgot that this is rotate not a shift * clean: put that there * feat: numpy xor * feat: quick test for xor * feat: llvm xor * feat: slightly working xor in torch * feat: rand works in jit * clean: save a line * feat: match jax * feat: maybe test against jax * feat: requires_grad * fix: fix test_symbolic_ops * feat: lower alpha * feat: just pad * fix: maybe fix training tests? * fix: fix some llvm stuff * feat: cursed realize on the way out * feat: testing jax * fix: why is the jax install process not simple * fix: maybe passing test * fix: symbolic workarounds * clean: still need that precommit * fix: aaaa * fix: more test fixes * fix: quick fix for wgsl * feat: need to set requires_grad on the final tensor * feat: one more tensor * feat: don't take forever * feat: seeing y ci is brok * feat: can't allocate 64GiB lmao * fix: fix this * feat: hope this doesn't break smth before i go to bed * feat: don't destroy ram * feat: int * feat: remove jax * feat: properish workaround? * feat: skip slow webgpu tests * feat: no longer fails * feat: use dtypes * feat: real number * fix: torch * fix: don't test against reference for torch * feat: to device * feat: fix advanced indexing * feat: correct casting * feat: even rng_counter * feat: match master * feat: this was actually bad * fix: maybe? * feat: store * feat: remove realizes * feat: somehow this is important * feat: somehow this is also important * feat: save a line * fix: don't need that anymore * feat: restore this * fix: linter * feat: remove realizes * fix: realized is in base now * fix: add back cast * fix: bump deadline * fix: bump deadline * fix: bump deadline * fix: bump deadline * fix: bump deadline * fix: :( * fix: :( * fix: not being dumb * feat: try changing less tests * feat: shouldn't have to change that * feat: contiguous bumps it by one * fix: hmm * fix: numpy memory moment * fix: cl_khr_fp16 * fix: torch has different tensor count * fix: missing contiguous * hmm: hmm * fix: some fixes * fix: typing * feat: dont do that * feat: typing fixes * feat: why is this realize required? * feat: ngl kinda odd typing * feat: oh * feat: remove realizes * feat: why is this realize required? * fix: hacky patch for cudacpu * fix: without this realize pytest crashes????? * fix: shorter line * fix: cudacpu fixes * fix: cudacpu fixes * feat: real buffer * feat: don't search when searching lmao * fix: can't use contiguous things * fix: no more 100GB arrays * fix: revert * fix: skip 7 and 10 * feat: working ish beam * feat: minimize changes * feat: seed 0 stable diffusion example changed * fix: different on ci * fix: no beam * feat: make threefry optional * fix: check value * fix: unused import * feat: threefry default * fix: 5d * feat: allow non upcast div * fix: 5d better * fix: 5d better * fix: save all dtype * feat: proper error * feat: lazyop key * fix: check float * feat: try removing this realize now * feat: disable threefry for uops hip tensor cores * feat: don't need that * feat: only check upcast * fix: disable threefry for some metal tests * feat: disable for metal tensor uops as well * feat: disable for most uops * fix: disable threefry for new uops tests * feat: multitensor * fix: typing * feat: threefry default off * feat: skip threefry half rand * feat: restore old * fix: bad git * clean: ruff * feat: bfloat16 fix * fix: :| * feat: restore old --------- Co-authored-by: chenyu <chenyu@fastmail.com>
80 lines
2.3 KiB
Python
80 lines
2.3 KiB
Python
import unittest
|
|
import numpy as np
|
|
|
|
from tinygrad.helpers import BEAM, Timing, CI
|
|
from tinygrad.shape.symbolic import Variable
|
|
from tinygrad.tensor import Tensor
|
|
from tinygrad.nn import Conv2d
|
|
|
|
def rand(*shape):
|
|
if CI: return Tensor(np.random.rand(*shape))
|
|
return Tensor.rand(*shape)
|
|
|
|
class TestBeamSearch(unittest.TestCase):
|
|
def setUp(self):
|
|
self.old_beam = BEAM.value
|
|
BEAM.value = 2
|
|
def tearDown(self):
|
|
BEAM.value = self.old_beam
|
|
|
|
def test_variable_ast_beam(self):
|
|
a = rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3))
|
|
a = (a+1).realize()
|
|
|
|
def test_big_prime_number(self):
|
|
a = rand(367, 367)
|
|
b = rand(367, 367)
|
|
c = (a@b).realize()
|
|
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4)
|
|
|
|
def test_big_prime_number_max(self):
|
|
a = -rand(367, 367)
|
|
b = rand(367, 367)
|
|
# if incorrectly padded 0, the max would be 0 instead of a negative number
|
|
c = (a*b).max(1)
|
|
np.testing.assert_allclose(c.numpy(), (a.numpy() * b.numpy()).max(1), atol=1e-4, rtol=1e-4)
|
|
|
|
def test_big_prime_number_sum(self):
|
|
a = rand(367, 367)
|
|
b = rand(367, 367)
|
|
# if incorrectly padded 0, the sum would be inf
|
|
c = (a/b).sum(1).realize()
|
|
np.testing.assert_allclose(c.numpy(), (a.numpy() / b.numpy()).sum(1), atol=1e-4, rtol=1e-4)
|
|
|
|
def test_variable_big_prime_number(self):
|
|
v = Variable("v", 1, 400).bind(367)
|
|
a = rand(367, 367)
|
|
b = rand(367, 367)
|
|
c = (a.reshape(367, v) @ b.reshape(v, 367)).realize()
|
|
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4)
|
|
|
|
def test_variable_shrink_prime_number(self):
|
|
v = Variable("v", 1, 400).bind(367)
|
|
a = rand(400, 367)
|
|
b = (a.shrink(((0,v), None))+1).reshape(367,367).realize()
|
|
np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4)
|
|
|
|
def test_no_mutate_rawbuffers(self):
|
|
a = rand(3, 3).realize()
|
|
desired = a.numpy() + 1
|
|
a.assign(a+1)
|
|
actual = a.numpy()
|
|
np.testing.assert_allclose(actual, desired)
|
|
|
|
@unittest.skipIf(CI, "flaky. CL_OUT_OF_RESOURCES")
|
|
def test_conv_beam(self):
|
|
c = Conv2d(3, 16, (3,3))
|
|
x = rand(1,3,32,32)
|
|
with Timing():
|
|
c(x).realize()
|
|
|
|
def test_large_ast(self):
|
|
a = Tensor.rand(3, 3)
|
|
for _ in range(5):
|
|
for _ in range(4):
|
|
a = (a + a) * a
|
|
a.realize()
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|