Files
tinygrad/test/helpers.py
George Hotz 2c363b5f0b new style device (#2530)
* cpu tests pass

* torch works

* works

* metal works

* fix ops_disk

* metal jit works

* fix openpilot

* llvm and clang work

* fix webgpu

* docs are rly broken

* LRU works on metal

* delete comment

* revert name to ._buf. LRU only on Compiled

* changes

* allocator

* allocator, getting closer

* lru alloc

* LRUAllocator

* all pass

* metal

* cuda

* test examples

* linearizer

* test fixes

* fix custom + clean realize

* fix hip

* skip tests

* fix tests

* fix size=0

* fix MOCKHIP

* fix thneed

* copy better

* simple

* old style metal copy

* fix thneed

* np reshape

* give cuda a device
2023-11-30 17:07:16 -08:00

26 lines
925 B
Python

from tinygrad.device import JITRunner
from tinygrad.ops import LazyOp, LoadOps
from tinygrad.nn.state import get_parameters
# for speed
def derandomize(x):
if isinstance(x, LazyOp):
new_op = LoadOps.EMPTY if x.op == LoadOps.RAND else x.op
return LazyOp(new_op, tuple([derandomize(s) for s in x.src]), None if x.op == LoadOps.RAND else x.arg)
x.op = derandomize(x.op)
return x
def derandomize_model(model):
for p in get_parameters(model):
p.lazydata = derandomize(p.lazydata)
p.realize()
def assert_jit_cache_len(fxn, expected_len):
assert len(fxn.jit_cache) > 0
if issubclass(type(fxn.jit_cache[0].prg), JITRunner):
assert len(fxn.jit_cache) == expected_len
else:
assert len(fxn.jit_cache) == 1
# until we have a better way of typing the prg in JitItem
assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph')
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len