mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
* feat: initial xor * feat: initial threefly * feat: remove custom random * fix: really need to install precommit * feat: lmao forgot that this is rotate not a shift * clean: put that there * feat: numpy xor * feat: quick test for xor * feat: llvm xor * feat: slightly working xor in torch * feat: rand works in jit * clean: save a line * feat: match jax * feat: maybe test against jax * feat: requires_grad * fix: fix test_symbolic_ops * feat: lower alpha * feat: just pad * fix: maybe fix training tests? * fix: fix some llvm stuff * feat: cursed realize on the way out * feat: testing jax * fix: why is the jax install process not simple * fix: maybe passing test * fix: symbolic workarounds * clean: still need that precommit * fix: aaaa * fix: more test fixes * fix: quick fix for wgsl * feat: need to set requires_grad on the final tensor * feat: one more tensor * feat: don't take forever * feat: seeing y ci is brok * feat: can't allocate 64GiB lmao * fix: fix this * feat: hope this doesn't break smth before i go to bed * feat: don't destroy ram * feat: int * feat: remove jax * feat: properish workaround? * feat: skip slow webgpu tests * feat: no longer fails * feat: use dtypes * feat: real number * fix: torch * fix: don't test against reference for torch * feat: to device * feat: fix advanced indexing * feat: correct casting * feat: even rng_counter * feat: match master * feat: this was actually bad * fix: maybe? * feat: store * feat: remove realizes * feat: somehow this is important * feat: somehow this is also important * feat: save a line * fix: don't need that anymore * feat: restore this * fix: linter * feat: remove realizes * fix: realized is in base now * fix: add back cast * fix: bump deadline * fix: bump deadline * fix: bump deadline * fix: bump deadline * fix: bump deadline * fix: :( * fix: :( * fix: not being dumb * feat: try changing less tests * feat: shouldn't have to change that * feat: contiguous bumps it by one * fix: hmm * fix: numpy memory moment * fix: cl_khr_fp16 * fix: torch has different tensor count * fix: missing contiguous * hmm: hmm * fix: some fixes * fix: typing * feat: dont do that * feat: typing fixes * feat: why is this realize required? * feat: ngl kinda odd typing * feat: oh * feat: remove realizes * feat: why is this realize required? * fix: hacky patch for cudacpu * fix: without this realize pytest crashes????? * fix: shorter line * fix: cudacpu fixes * fix: cudacpu fixes * feat: real buffer * feat: don't search when searching lmao * fix: can't use contiguous things * fix: no more 100GB arrays * fix: revert * fix: skip 7 and 10 * feat: working ish beam * feat: minimize changes * feat: seed 0 stable diffusion example changed * fix: different on ci * fix: no beam * feat: make threefry optional * fix: check value * fix: unused import * feat: threefry default * fix: 5d * feat: allow non upcast div * fix: 5d better * fix: 5d better * fix: save all dtype * feat: proper error * feat: lazyop key * fix: check float * feat: try removing this realize now * feat: disable threefry for uops hip tensor cores * feat: don't need that * feat: only check upcast * fix: disable threefry for some metal tests * feat: disable for metal tensor uops as well * feat: disable for most uops * fix: disable threefry for new uops tests * feat: multitensor * fix: typing * feat: threefry default off * feat: skip threefry half rand * feat: restore old * fix: bad git * clean: ruff * feat: bfloat16 fix * fix: :| * feat: restore old --------- Co-authored-by: chenyu <chenyu@fastmail.com>
39 lines
1.7 KiB
Python
39 lines
1.7 KiB
Python
import sys
|
|
from tinygrad import Tensor, Device, dtypes
|
|
from tinygrad.device import JITRunner
|
|
from tinygrad.dtype import DType
|
|
from tinygrad.nn.state import get_parameters
|
|
from tinygrad.helpers import Context, CI, OSX
|
|
|
|
def derandomize_model(model):
|
|
with Context(GRAPH=0):
|
|
for p in get_parameters(model):
|
|
p.lazydata = Tensor.empty(p.shape, device=p.device, dtype=p.dtype).lazydata
|
|
p.realize()
|
|
|
|
def assert_jit_cache_len(fxn, expected_len):
|
|
assert len(fxn.jit_cache) > 0
|
|
# until we have a better way of typing the prg in JitItem
|
|
if issubclass(type(fxn.jit_cache[0].prg), JITRunner) and not type(fxn.jit_cache[0].prg).__name__.endswith('Graph'):
|
|
assert len(fxn.jit_cache) == expected_len
|
|
else:
|
|
assert len(fxn.jit_cache) == 1
|
|
# until we have a better way of typing the prg in JitItem
|
|
assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph')
|
|
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len
|
|
|
|
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
|
|
if dtype == dtypes.bfloat16:
|
|
# NOTE: this requires bf16 buffer support
|
|
return device in {"RHIP", "HSA"}
|
|
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
|
|
# for CI GPU, cl_khr_fp16 isn't supported
|
|
# for CI LLVM, it segfaults because it can't link to the casting function
|
|
# CUDACPU architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
|
|
# PYTHON supports half memoryview in 3.12+ https://github.com/python/cpython/issues/90751
|
|
if dtype == dtypes.half:
|
|
if device in ["GPU", "LLVM", "CUDA"]: return not CI
|
|
if device == "PYTHON": return sys.version_info >= (3, 12)
|
|
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
|
|
return True
|