mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
hcq: fix usb<->cpu mappings (#14827)
* hcq: fix usb<->cpu mappings * non cpu * um
This commit is contained in:
47
test/unit/test_hcq_graph.py
Normal file
47
test/unit/test_hcq_graph.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import unittest
|
||||
from tinygrad import Device, Tensor
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled
|
||||
from tinygrad.runtime.support.usb import USBMMIOInterface
|
||||
from test.mockgpu.usb import MockUSB
|
||||
|
||||
@unittest.skipUnless(issubclass(type(Device[Device.DEFAULT]), HCQCompiled), "HCQ device required to run")
|
||||
class TestHCQUnit(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU", "requires non-CPU HCQ device")
|
||||
def test_supports_exec_item(self):
|
||||
d0, cpu_dev = Device[Device.DEFAULT], Device["CPU"]
|
||||
|
||||
@TinyJit
|
||||
def f(inp, inp_cpu):
|
||||
return (inp + 1.0).contiguous().realize(), (inp_cpu + 1.0).contiguous().realize()
|
||||
inp, inp_cpu = Tensor.randn(10, 10, device=Device.DEFAULT).realize(), Tensor.randn(10, 10, device="CPU").realize()
|
||||
for _ in range(5): f(inp, inp_cpu)
|
||||
|
||||
gpu_ei, cpu_ei, gpu_devs = None, None, []
|
||||
for ji in f.captured.jit_cache:
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
if ji.prg.dev._is_cpu(): cpu_ei = ji
|
||||
else:
|
||||
gpu_ei = ji
|
||||
if ji.prg.dev not in gpu_devs: gpu_devs.append(ji.prg.dev)
|
||||
assert gpu_ei is not None and cpu_ei is not None and len(gpu_devs) > 0
|
||||
|
||||
# local MMIO: GPU works alone and with CPU in batch (cpu_support=True)
|
||||
assert HCQGraph.supports_exec_item(gpu_devs, gpu_ei) is True
|
||||
assert HCQGraph.supports_exec_item(gpu_devs, cpu_ei) is True
|
||||
assert HCQGraph.supports_exec_item(gpu_devs + [cpu_dev], gpu_ei) is True
|
||||
|
||||
# USB MMIO: GPU-only still works, but CPU batching must be rejected (cpu_support=False)
|
||||
orig_view = d0.timeline_signal.base_buf.view
|
||||
try:
|
||||
d0.timeline_signal.base_buf.view = USBMMIOInterface(MockUSB(bytearray(256)), 0, 16, fmt='B')
|
||||
assert HCQGraph.supports_exec_item(gpu_devs, gpu_ei) is True
|
||||
assert HCQGraph.supports_exec_item(gpu_devs, cpu_ei) is False
|
||||
assert HCQGraph.supports_exec_item(gpu_devs + [cpu_dev], gpu_ei) is False
|
||||
finally:
|
||||
d0.timeline_signal.base_buf.view = orig_view
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -246,7 +246,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
if not all(issubclass(type(d), HCQCompiled) for d in all_devs): return False
|
||||
|
||||
# If all of devices are mapped into CPU address space, can use CPU inside the peer group.
|
||||
cpu_support = all(isinstance(d.timeline_signal.base_buf.view, MMIOInterface) for d in all_devs)
|
||||
cpu_support = all(type(d.timeline_signal.base_buf.view) is MMIOInterface for d in all_devs)
|
||||
|
||||
# Check if all devices are within the same peer group. If CPU is supported, don't count it as a separate peer group.
|
||||
if len(set(d.peer_group for d in all_devs if not (cpu_support and d._is_cpu()))) > 1: return False
|
||||
|
||||
Reference in New Issue
Block a user