mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
final dname to device [pr] (#7806)
* final dname to device [pr] * oops, fix nv
This commit is contained in:
2
test/external/external_test_hcq.py
vendored
2
test/external/external_test_hcq.py
vendored
@@ -22,7 +22,7 @@ class TestHCQ(unittest.TestCase):
|
||||
TestHCQ.a = Tensor([0.,1.], device=Device.DEFAULT).realize()
|
||||
TestHCQ.b = self.a + 1
|
||||
si = create_schedule([self.b.lazydata])[-1]
|
||||
TestHCQ.runner = get_runner(TestHCQ.d0.dname, si.ast)
|
||||
TestHCQ.runner = get_runner(TestHCQ.d0.device, si.ast)
|
||||
TestHCQ.b.lazydata.buffer.allocate()
|
||||
# wow that's a lot of abstraction layers
|
||||
TestHCQ.addr = struct.pack("QQ", TestHCQ.b.lazydata.buffer._buf.va_addr, TestHCQ.a.lazydata.buffer._buf.va_addr)
|
||||
|
||||
4
test/external/external_test_nv.py
vendored
4
test/external/external_test_nv.py
vendored
@@ -21,7 +21,7 @@ class TestNV(unittest.TestCase):
|
||||
TestNV.a = Tensor([0.,1.], device="NV").realize()
|
||||
TestNV.b = self.a + 1
|
||||
si = create_schedule([self.b.lazydata])[-1]
|
||||
TestNV.d0_runner = get_runner(TestNV.d0.dname, si.ast)
|
||||
TestNV.d0_runner = get_runner(TestNV.d0.device, si.ast)
|
||||
TestNV.b.lazydata.buffer.allocate()
|
||||
TestNV.addr = struct.pack("QQ", TestNV.b.lazydata.buffer._buf.va_addr, TestNV.a.lazydata.buffer._buf.va_addr)
|
||||
|
||||
@@ -44,7 +44,7 @@ class TestNV(unittest.TestCase):
|
||||
def test_buf4_usage(self):
|
||||
TestNV.along = Tensor([105615], device="NV").realize()
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.SIN, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.ulong, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
temp_runner = get_runner(TestNV.d0.dname, (ast,))
|
||||
temp_runner = get_runner(TestNV.d0.device, (ast,))
|
||||
temp_runner([TestNV.b.lazydata.buffer, TestNV.along.lazydata.buffer], var_vals={})
|
||||
val = TestNV.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||
assert abs(val - 0.80647) < 0.001, f"got val {val}"
|
||||
|
||||
2
test/external/fuzz_linearizer.py
vendored
2
test/external/fuzz_linearizer.py
vendored
@@ -228,7 +228,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2, opts_list=None):
|
||||
validate_lin = test_lin.copy()
|
||||
validate_lin.opts = validate_device.renderer
|
||||
if validate_rawbufs is None:
|
||||
validate_rawbufs = [get_fuzz_rawbuf_like(x, copy=True, force_device=validate_device.dname) for x in rawbufs]
|
||||
validate_rawbufs = [get_fuzz_rawbuf_like(x, copy=True, force_device=validate_device.device) for x in rawbufs]
|
||||
(_msg, _, _, _, state2) = compare_linearizer(validate_lin, validate_rawbufs, var_vals, ground_truth, rtol=rtol, atol=atol)
|
||||
|
||||
if _msg != "PASS": failures[f"VALIDATE_DEV_{_msg}"].append((validate_lin.ast, validate_lin.applied_opts))
|
||||
|
||||
Reference in New Issue
Block a user