Files
tinygrad/extra/lib_test_ast.py
George Hotz 2844482a60 Mypy fun (#541)
* mypy fun

* things are just faster

* running fast

* mypy is fast

* compile.sh

* no gpu hack

* refactor ops_cpu and ops_torch to not subclass

* make weak buffer work

* tensor works

* fix test failing

* cpu/torch cleanups

* no or operator on dict in python 3.8

* that was junk

* fix warnings

* comment and touchup
2023-02-08 09:56:51 -06:00

26 lines
956 B
Python

import sys
import numpy as np
from typing import Dict, Type
from tinygrad.ast import ASTKernel
from tinygrad.llops.ops_cpu import CPUBuffer
from tinygrad.ops import DeviceBuffer, map_buffers
in_test = False
test_cnt = 0
def test_ast(k:ASTKernel, device:Type[DeviceBuffer]=CPUBuffer):
global in_test, test_cnt
if in_test: return
in_test = True
print("testing AST", test_cnt)
test_cnt += 1
# TODO: this should only copy the base buffer and retain the shapetracker (requires CPU shapetracker implementation)
cpubufs : Dict[DeviceBuffer, DeviceBuffer] = {x:device.fromCPU(x.toCPU()) for x in k.bufs}
real_out = cpubufs[k.bufs[0]].toCPU()
assert hasattr(device, 'exec_ast')
test_out = device.exec_ast(map_buffers(cpubufs, k.ast)).toCPU()
if not np.allclose(real_out, test_out, atol=1e-4, rtol=1e-4):
print("MISMATCH")
print(k.print())
sys.tracebacklimit = 0
np.testing.assert_allclose(real_out, test_out)
in_test = False