jit sampling functionn in test_randomness.test_multinomial (#5034)

* jit sampling functionn in test_randomness.test_multinomial

`THREEFRY=1 python3 -m pytest test/test_randomness.py::TestRandomness::test_multinomial --durations 1` 7 sec -> 1.2 sec

* skip that
This commit is contained in:
chenyu
2024-06-18 14:21:05 -04:00
committed by GitHub
parent f31ef11537
commit dc942bf1f6
2 changed files with 12 additions and 6 deletions

View File

@@ -236,7 +236,7 @@ jobs:
run: GPU=1 python -m pytest -n=auto test/external/external_test_datasets.py --durations=20
- if: ${{ matrix.task == 'onnx' }}
name: Test THREEFRY
run: PYTHONPATH=. THREEFRY=1 GPU=1 python3 -m pytest test/test_randomness.py test/test_jit.py
run: PYTHONPATH=. THREEFRY=1 GPU=1 python3 -m pytest test/test_randomness.py test/test_jit.py --durations=20
- name: Run process replay tests
if: env.RUN_PROCESS_REPLAY == '1'
run: cp test/external/replay_codegen.py ./replay_codegen.py && git fetch origin master && git checkout origin/master && PYTHONPATH=. python3 replay_codegen.py

View File

@@ -3,8 +3,8 @@ from functools import partial
import numpy as np
import torch
from tinygrad import nn, dtypes, Tensor, Device
from tinygrad.helpers import THREEFRY, getenv
from tinygrad import nn, dtypes, Tensor, Device, TinyJit
from tinygrad.helpers import THREEFRY, getenv, CI
from test.helpers import is_dtype_supported
from hypothesis import given, settings, strategies as strat
@@ -174,9 +174,15 @@ class TestRandomness(unittest.TestCase):
# no-replacement isn't supported, unless taking only one sample
w = [0.1, 0.9]
self.assertRaises(AssertionError, lambda: Tensor(w).multinomial(100, replacement=False))
tiny_samples = [Tensor(w).multinomial(1, replacement=False).numpy().item() for _ in range(1000)]
torch_samples = [torch.tensor(w).multinomial(1, replacement=False).item() for _ in range(1000)]
self.assertTrue(equal_distribution(lambda *_: Tensor(tiny_samples), lambda _: torch.tensor(torch_samples)))
@TinyJit
def sample_one(): return Tensor(w).multinomial(1, replacement=False).realize()
# TODO: fix mockgpu issue
if not (CI and Device.DEFAULT == "AMD"):
tiny_samples = [sample_one().item() for _ in range(1000)]
torch_samples = [torch.tensor(w).multinomial(1, replacement=False).item() for _ in range(1000)]
self.assertTrue(equal_distribution(lambda *_: Tensor(tiny_samples), lambda _: torch.tensor(torch_samples)))
def test_multinomial_counterexample(self):
tiny_res = Tensor([0.3, 0.6, 0.1]).multinomial(2000, replacement=True)