mlx: remove to_be, use helpers (#15655)

This commit is contained in:
nimlgen
2026-04-08 20:07:28 +03:00
committed by GitHub
parent 1b44cb2ac6
commit 28b14b0e38
3 changed files with 17 additions and 17 deletions

View File

@@ -6,7 +6,8 @@ from tinygrad.device import Device, BufferSpec
from tinygrad.runtime.support.system import PCIDevice
from tinygrad.runtime.support.memory import AddrSpace
from tinygrad.runtime.ops_amd import AMDComputeQueue
from extra.mlx_driver.mlxdev import MLXDev, MLXQP, to_be
from tinygrad.helpers import to_be32, to_be64
from extra.mlx_driver.mlxdev import MLXDev, MLXQP
BUF_SIZE = 0x1000
MLX_PCI = getenv("MLX_PCI", "0000:41:00.0")
@@ -49,7 +50,7 @@ rq_wqe = qp.qp_buf.view((qp.rq_head & rq_mask) * 16, 16)
rq_wqe[:] = struct.pack('>IIQ', len(test_msg), dev.mkey, dst_paddr)
qp.rq_head += 1
# ring recv doorbell from CPU (DBR offset 0 = recv counter)
dev.dbr[qp.qp_dbr // 4] = to_be('I', qp.rq_head)
dev.dbr[qp.qp_dbr // 4] = to_be32(qp.rq_head)
# build send WQE in SQ from CPU (opcode 0x0a = SEND, ds_count=2)
sq_head = qp.sq_head
@@ -60,7 +61,7 @@ wqe[0:8] = struct.pack('>II', (sq_head << 8) | 0x0a, (qp.qp_info['qpn'] << 8) |
wqe[11] = 0x08 # CE: signal completion
wqe[16:32] = struct.pack('>IIQ', len(test_msg), dev.mkey, src_paddr)
qp.sq_head += 1
doorbell_val = to_be('Q', int.from_bytes(bytes(wqe[0:8]), 'big'))
doorbell_val = to_be64(int.from_bytes(bytes(wqe[0:8]), 'big'))
# map MLX5 UAR and DBR into GPU VA
uar_paddr = dev.pci_dev.bar_info(0)[0] + dev.uar * 0x1000
@@ -72,7 +73,7 @@ print(f"UAR gpu_va=0x{uar_gpu_va:x} DBR gpu_va=0x{dbr_gpu_va:x}")
q = AMDComputeQueue(gpu)
q.wait(gpu.timeline_signal, gpu.timeline_value - 1)
# write DBR (32-bit sq_head) - send doorbell at qp_dbr + 4
q.release_mem(dbr_gpu_va + qp.qp_dbr + 4, to_be('I', qp.sq_head), q.pm4.data_sel__mec_release_mem__send_32_bit_low,
q.release_mem(dbr_gpu_va + qp.qp_dbr + 4, to_be32(qp.sq_head), q.pm4.data_sel__mec_release_mem__send_32_bit_low,
q.pm4.int_sel__mec_release_mem__none)
# write UAR doorbell (64-bit)
q.release_mem(uar_gpu_va + 0x800, doorbell_val, q.pm4.data_sel__mec_release_mem__send_64_bit_data,

View File

@@ -5,8 +5,8 @@ from tinygrad.uop.ops import sint
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocatorBase, HCQAllocator, HWQueue, HCQBuffer, FileIOInterface
from tinygrad.runtime.support.system import System, PCIIfaceBase, PCIAllocationMeta
from tinygrad.runtime.support.memory import VirtMapping, AddrSpace
from tinygrad.runtime.support.mlx.mlxdev import MLXDev, MLXQP, to_be
from tinygrad.helpers import unwrap
from tinygrad.runtime.support.mlx.mlxdev import MLXDev, MLXQP
from tinygrad.helpers import unwrap, to_be32, to_be64
class RDMACopyQueue(HWQueue):
def __init__(self, dev:RDMADevice):
@@ -19,10 +19,10 @@ class RDMACopyQueue(HWQueue):
def encode_ring(self, hwq:HWQueue, dev:HCQCompiled, iface:MLXIface, qp:MLXQP, cq_buf:HCQBuffer, head:sint, ring_uar:bool=False):
for buf in [iface.dbr_buf, cq_buf] + ([iface.uar_buf] if ring_uar else []): cast(HCQAllocator, dev.allocator).map(buf)
hwq.write(iface.dbr_buf.offset(qp.qp_dbr + (4 if ring_uar else 0)), to_be('I', head + 1))
if ring_uar: hwq.write(iface.uar_buf.offset(0x800), to_be('Q', ((head << 8) | 0x0a) << 32 | ((qp.qp_info['qpn'] << 8) | 2)), b64=True)
hwq.write(iface.dbr_buf.offset(qp.qp_dbr + (4 if ring_uar else 0)), to_be32(head + 1))
if ring_uar: hwq.write(iface.uar_buf.offset(0x800), to_be64(((head << 8) | 0x0a) << 32 | ((qp.qp_info['qpn'] << 8) | 2)), b64=True)
hwq.poll_bit(cq_buf.offset((head & (qp.cq_size - 1)) * 64 + 60, 4), ((head >> (qp.cq_size.bit_length() - 1)) & 1) << 24, mask=0x01000000)
hwq.write(iface.dbr_buf.offset(qp.cq_dbr), to_be('I', (head + 1) & 0xFFFFFF))
hwq.write(iface.dbr_buf.offset(qp.cq_dbr), to_be32((head + 1) & 0xFFFFFF))
return self
def copy(self, dest:HCQBuffer, src:HCQBuffer, sz:int):
@@ -39,8 +39,8 @@ class RDMACopyQueue(HWQueue):
def _submit(self, dev:RDMADevice):
for remote_nic, sq_wqe, rq_wqe in zip(self._q[0::3], self._q[1::3], self._q[2::3]):
src_qp, dest_qp, _, _ = dev.iface.connect(remote_nic)
assert src_qp.head + 1 - to_be('I', src_qp.dev.dbr[src_qp.qp_dbr // 4 + 1]) <= (1 << src_qp.log_sq_size), "SQ ring full"
assert src_qp.head + 1 - to_be('I', dest_qp.dev.dbr[dest_qp.qp_dbr // 4]) <= (1 << dest_qp.log_rq_size), "RQ ring full"
assert src_qp.head + 1 - to_be32(src_qp.dev.dbr[src_qp.qp_dbr // 4 + 1]) <= (1 << src_qp.log_sq_size), "SQ ring full"
assert src_qp.head + 1 - to_be32(dest_qp.dev.dbr[dest_qp.qp_dbr // 4]) <= (1 << dest_qp.log_rq_size), "RQ ring full"
dest_qp.qp_buf.view((src_qp.head & ((1 << dest_qp.log_rq_size) - 1)) * 16, 16)[:] = rq_wqe
sq_view = src_qp.qp_buf.view(src_qp.sq_offset + (src_qp.head & ((1 << src_qp.log_sq_size) - 1)) * 64, 64)
sq_view[:] = struct.pack('>I', (src_qp.head << 8) | 0x0a) + sq_wqe[4:]

View File

@@ -11,7 +11,6 @@ MLX5_CMD_STRUCTS = {v: (getattr(mlx5, f"struct_mlx5_ifc_{n[12:].lower()}_in_bits
getattr(mlx5, f"struct_mlx5_ifc_{n[12:].lower()}_out_bits", None)) for n, v in mlx5.__dict__.items() if n.startswith("MLX5_CMD_OP_")}
MLX5_CMD_STRUCTS[mlx5.MLX5_CMD_OP_ACCESS_REG] = (mlx5.struct_mlx5_ifc_access_register_in_bits, mlx5.struct_mlx5_ifc_access_register_out_bits)
def to_be(fmt, val): return to_be32(val) if fmt == 'I' else to_be64(val)
def ipv4_to_gid(ip): return bytes(10) + b'\xff\xff' + socket.inet_aton(ip)
def udp_sport(lqpn, rqpn):
@@ -64,7 +63,7 @@ class MLXCmdQueue:
for i in range(n):
off, _ = self.mboxes[base + i]
blk = mlx5.struct_mlx5_cmd_prot_block(data=list(data[i*chunk_sz:(i+1)*chunk_sz].ljust(chunk_sz, b'\x00')),
next=to_be('Q', self.mboxes[base+i+1][1]) if i < n-1 else 0, block_num=to_be('I', i), token=tok)
next=to_be64(self.mboxes[base+i+1][1]) if i < n-1 else 0, block_num=to_be32(i), token=tok)
self.queue[off:off + ctypes.sizeof(mlx5.struct_mlx5_cmd_prot_block)] = bytes(blk)
return (self.mboxes[base][0], self.mboxes[base][1], n)
@@ -81,9 +80,9 @@ class MLXCmdQueue:
# prepare mailboxes and build command layout
_, in_ptr, n_in = self.create_mbox_chain(0, tok, inp[16:])
_, out_ptr, n_out = self.create_mbox_chain(n_in, tok, bytes(out_sz))
cmd = mlx5.struct_mlx5_cmd_layout(type=mlx5.MLX5_PCI_CMD_XPORT, inlen=to_be('I', len(inp)), in_ptr=to_be('Q', in_ptr),
cmd = mlx5.struct_mlx5_cmd_layout(type=mlx5.MLX5_PCI_CMD_XPORT, inlen=to_be32(len(inp)), in_ptr=to_be64(in_ptr),
_in=[int.from_bytes(inp[i:i+4], 'little') for i in range(0, 16, 4)],
out_ptr=to_be('Q', out_ptr), outlen=to_be('I', 16 + out_sz), token=tok, status_own=mlx5.CMD_OWNER_HW)
out_ptr=to_be64(out_ptr), outlen=to_be32(16 + out_sz), token=tok, status_own=mlx5.CMD_OWNER_HW)
cmd_bytes = bytearray(bytes(cmd))
cmd_bytes[mlx5.struct_mlx5_cmd_layout.sig.offset] = (~functools.reduce(lambda a, b: a ^ b, cmd_bytes)) & 0xFF # type: ignore[attr-defined]
@@ -115,8 +114,8 @@ class MLXDev:
self.init_hw(ip)
def rreg(self, off): return to_be('I',self.bar[off // 4])
def wreg(self, off, val): self.bar[off // 4] = to_be('I',val)
def rreg(self, off): return to_be32(self.bar[off // 4])
def wreg(self, off, val): self.bar[off // 4] = to_be32(val)
def iseg_r(self, field): return self.rreg(getattr(mlx5.struct_mlx5_init_seg, field).offset)
def iseg_w(self, field, val): self.wreg(getattr(mlx5.struct_mlx5_init_seg, field).offset, val)