mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
check dims before execution on nv (#4756)
* check dims before execution on nv * fix linter
This commit is contained in:
14
test/external/external_test_nv.py
vendored
14
test/external/external_test_nv.py
vendored
@@ -5,7 +5,8 @@ from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.runtime.ops_nv import NVDevice, HWComputeQueue
|
||||
from tinygrad.engine.search import Opt, OptOps
|
||||
from test.test_linearizer_failures import helper_test_lin
|
||||
from tinygrad.engine.realize import get_runner
|
||||
from tinygrad.engine.realize import get_runner, CompiledRunner
|
||||
from test.external.fuzz_linearizer import get_fuzz_rawbufs
|
||||
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer
|
||||
@@ -29,6 +30,17 @@ class TestNV(unittest.TestCase):
|
||||
opts = [Opt(op=OptOps.TC, axis=6, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=2)] # noqa: E501
|
||||
helper_test_lin(Linearizer(ast), opts=opts, failed_platforms=["NV"])
|
||||
|
||||
def test_error_on_huge_dims(self):
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 683, 1), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=dtypes.float),), arg=(3,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=2)] # noqa: E501
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
lin = Linearizer(ast)
|
||||
for opt in opts: lin.apply_opt(opt)
|
||||
rawbufs = get_fuzz_rawbufs(lin)
|
||||
prg = CompiledRunner(lin.to_program())
|
||||
prg(rawbufs, {}, wait=True)
|
||||
self.assertEqual(str(cm.exception), "This is a runtime error message")
|
||||
|
||||
def test_buf4_usage(self):
|
||||
TestNV.along = Tensor([105615], device="NV").realize()
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SIN, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.ulong, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
|
||||
Reference in New Issue
Block a user