mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
all tests pass on strix halo (#13728)
This commit is contained in:
@@ -339,7 +339,7 @@ class TestRangeify(unittest.TestCase):
|
||||
def test_transformer_ffn(self):
|
||||
from tinygrad.apps.llm import TransformerBlock
|
||||
from tinygrad import nn
|
||||
blk = TransformerBlock(1024, 4096, 1, 1, 1e-5)
|
||||
blk = TransformerBlock(1024, 4096, 1, 1, 1e-5, head_dim=1024, rope_theta=10000.0)
|
||||
for p in nn.state.get_parameters(blk): p.replace(Tensor.empty(p.shape))
|
||||
|
||||
x = Tensor.empty(128, 1024)
|
||||
|
||||
@@ -232,7 +232,7 @@ import gc
|
||||
|
||||
def bufs_allocated() -> int:
|
||||
gc.collect()
|
||||
return sum([isinstance(x, Buffer) for x in gc.get_objects()])
|
||||
return sum([type(x).__name__ == "Buffer" and type(x).__module__ == "tinygrad.device" for x in gc.get_objects()])
|
||||
|
||||
class TestVizGC(BaseTestViz):
|
||||
def test_gc(self):
|
||||
|
||||
Reference in New Issue
Block a user