final dname to device [pr] (#7806)

* final dname to device [pr]

* oops, fix nv
This commit is contained in:
George Hotz
2024-11-20 20:20:28 +08:00
committed by GitHub
parent bc977fec53
commit eb0bb7dc0b
12 changed files with 43 additions and 42 deletions

View File

@@ -22,7 +22,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.a = Tensor([0.,1.], device=Device.DEFAULT).realize()
TestHCQ.b = self.a + 1
si = create_schedule([self.b.lazydata])[-1]
TestHCQ.runner = get_runner(TestHCQ.d0.dname, si.ast)
TestHCQ.runner = get_runner(TestHCQ.d0.device, si.ast)
TestHCQ.b.lazydata.buffer.allocate()
# wow that's a lot of abstraction layers
TestHCQ.addr = struct.pack("QQ", TestHCQ.b.lazydata.buffer._buf.va_addr, TestHCQ.a.lazydata.buffer._buf.va_addr)

View File

@@ -21,7 +21,7 @@ class TestNV(unittest.TestCase):
TestNV.a = Tensor([0.,1.], device="NV").realize()
TestNV.b = self.a + 1
si = create_schedule([self.b.lazydata])[-1]
TestNV.d0_runner = get_runner(TestNV.d0.dname, si.ast)
TestNV.d0_runner = get_runner(TestNV.d0.device, si.ast)
TestNV.b.lazydata.buffer.allocate()
TestNV.addr = struct.pack("QQ", TestNV.b.lazydata.buffer._buf.va_addr, TestNV.a.lazydata.buffer._buf.va_addr)
@@ -44,7 +44,7 @@ class TestNV(unittest.TestCase):
def test_buf4_usage(self):
TestNV.along = Tensor([105615], device="NV").realize()
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.SIN, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.ulong, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501
temp_runner = get_runner(TestNV.d0.dname, (ast,))
temp_runner = get_runner(TestNV.d0.device, (ast,))
temp_runner([TestNV.b.lazydata.buffer, TestNV.along.lazydata.buffer], var_vals={})
val = TestNV.b.lazydata.buffer.as_buffer().cast("f")[0]
assert abs(val - 0.80647) < 0.001, f"got val {val}"

View File

@@ -228,7 +228,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2, opts_list=None):
validate_lin = test_lin.copy()
validate_lin.opts = validate_device.renderer
if validate_rawbufs is None:
validate_rawbufs = [get_fuzz_rawbuf_like(x, copy=True, force_device=validate_device.dname) for x in rawbufs]
validate_rawbufs = [get_fuzz_rawbuf_like(x, copy=True, force_device=validate_device.device) for x in rawbufs]
(_msg, _, _, _, state2) = compare_linearizer(validate_lin, validate_rawbufs, var_vals, ground_truth, rtol=rtol, atol=atol)
if _msg != "PASS": failures[f"VALIDATE_DEV_{_msg}"].append((validate_lin.ast, validate_lin.applied_opts))