mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
mlx: remove to_be, use helpers (#15655)
This commit is contained in:
@@ -6,7 +6,8 @@ from tinygrad.device import Device, BufferSpec
|
||||
from tinygrad.runtime.support.system import PCIDevice
|
||||
from tinygrad.runtime.support.memory import AddrSpace
|
||||
from tinygrad.runtime.ops_amd import AMDComputeQueue
|
||||
from extra.mlx_driver.mlxdev import MLXDev, MLXQP, to_be
|
||||
from tinygrad.helpers import to_be32, to_be64
|
||||
from extra.mlx_driver.mlxdev import MLXDev, MLXQP
|
||||
|
||||
BUF_SIZE = 0x1000
|
||||
MLX_PCI = getenv("MLX_PCI", "0000:41:00.0")
|
||||
@@ -49,7 +50,7 @@ rq_wqe = qp.qp_buf.view((qp.rq_head & rq_mask) * 16, 16)
|
||||
rq_wqe[:] = struct.pack('>IIQ', len(test_msg), dev.mkey, dst_paddr)
|
||||
qp.rq_head += 1
|
||||
# ring recv doorbell from CPU (DBR offset 0 = recv counter)
|
||||
dev.dbr[qp.qp_dbr // 4] = to_be('I', qp.rq_head)
|
||||
dev.dbr[qp.qp_dbr // 4] = to_be32(qp.rq_head)
|
||||
|
||||
# build send WQE in SQ from CPU (opcode 0x0a = SEND, ds_count=2)
|
||||
sq_head = qp.sq_head
|
||||
@@ -60,7 +61,7 @@ wqe[0:8] = struct.pack('>II', (sq_head << 8) | 0x0a, (qp.qp_info['qpn'] << 8) |
|
||||
wqe[11] = 0x08 # CE: signal completion
|
||||
wqe[16:32] = struct.pack('>IIQ', len(test_msg), dev.mkey, src_paddr)
|
||||
qp.sq_head += 1
|
||||
doorbell_val = to_be('Q', int.from_bytes(bytes(wqe[0:8]), 'big'))
|
||||
doorbell_val = to_be64(int.from_bytes(bytes(wqe[0:8]), 'big'))
|
||||
|
||||
# map MLX5 UAR and DBR into GPU VA
|
||||
uar_paddr = dev.pci_dev.bar_info(0)[0] + dev.uar * 0x1000
|
||||
@@ -72,7 +73,7 @@ print(f"UAR gpu_va=0x{uar_gpu_va:x} DBR gpu_va=0x{dbr_gpu_va:x}")
|
||||
q = AMDComputeQueue(gpu)
|
||||
q.wait(gpu.timeline_signal, gpu.timeline_value - 1)
|
||||
# write DBR (32-bit sq_head) - send doorbell at qp_dbr + 4
|
||||
q.release_mem(dbr_gpu_va + qp.qp_dbr + 4, to_be('I', qp.sq_head), q.pm4.data_sel__mec_release_mem__send_32_bit_low,
|
||||
q.release_mem(dbr_gpu_va + qp.qp_dbr + 4, to_be32(qp.sq_head), q.pm4.data_sel__mec_release_mem__send_32_bit_low,
|
||||
q.pm4.int_sel__mec_release_mem__none)
|
||||
# write UAR doorbell (64-bit)
|
||||
q.release_mem(uar_gpu_va + 0x800, doorbell_val, q.pm4.data_sel__mec_release_mem__send_64_bit_data,
|
||||
|
||||
@@ -5,8 +5,8 @@ from tinygrad.uop.ops import sint
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocatorBase, HCQAllocator, HWQueue, HCQBuffer, FileIOInterface
|
||||
from tinygrad.runtime.support.system import System, PCIIfaceBase, PCIAllocationMeta
|
||||
from tinygrad.runtime.support.memory import VirtMapping, AddrSpace
|
||||
from tinygrad.runtime.support.mlx.mlxdev import MLXDev, MLXQP, to_be
|
||||
from tinygrad.helpers import unwrap
|
||||
from tinygrad.runtime.support.mlx.mlxdev import MLXDev, MLXQP
|
||||
from tinygrad.helpers import unwrap, to_be32, to_be64
|
||||
|
||||
class RDMACopyQueue(HWQueue):
|
||||
def __init__(self, dev:RDMADevice):
|
||||
@@ -19,10 +19,10 @@ class RDMACopyQueue(HWQueue):
|
||||
|
||||
def encode_ring(self, hwq:HWQueue, dev:HCQCompiled, iface:MLXIface, qp:MLXQP, cq_buf:HCQBuffer, head:sint, ring_uar:bool=False):
|
||||
for buf in [iface.dbr_buf, cq_buf] + ([iface.uar_buf] if ring_uar else []): cast(HCQAllocator, dev.allocator).map(buf)
|
||||
hwq.write(iface.dbr_buf.offset(qp.qp_dbr + (4 if ring_uar else 0)), to_be('I', head + 1))
|
||||
if ring_uar: hwq.write(iface.uar_buf.offset(0x800), to_be('Q', ((head << 8) | 0x0a) << 32 | ((qp.qp_info['qpn'] << 8) | 2)), b64=True)
|
||||
hwq.write(iface.dbr_buf.offset(qp.qp_dbr + (4 if ring_uar else 0)), to_be32(head + 1))
|
||||
if ring_uar: hwq.write(iface.uar_buf.offset(0x800), to_be64(((head << 8) | 0x0a) << 32 | ((qp.qp_info['qpn'] << 8) | 2)), b64=True)
|
||||
hwq.poll_bit(cq_buf.offset((head & (qp.cq_size - 1)) * 64 + 60, 4), ((head >> (qp.cq_size.bit_length() - 1)) & 1) << 24, mask=0x01000000)
|
||||
hwq.write(iface.dbr_buf.offset(qp.cq_dbr), to_be('I', (head + 1) & 0xFFFFFF))
|
||||
hwq.write(iface.dbr_buf.offset(qp.cq_dbr), to_be32((head + 1) & 0xFFFFFF))
|
||||
return self
|
||||
|
||||
def copy(self, dest:HCQBuffer, src:HCQBuffer, sz:int):
|
||||
@@ -39,8 +39,8 @@ class RDMACopyQueue(HWQueue):
|
||||
def _submit(self, dev:RDMADevice):
|
||||
for remote_nic, sq_wqe, rq_wqe in zip(self._q[0::3], self._q[1::3], self._q[2::3]):
|
||||
src_qp, dest_qp, _, _ = dev.iface.connect(remote_nic)
|
||||
assert src_qp.head + 1 - to_be('I', src_qp.dev.dbr[src_qp.qp_dbr // 4 + 1]) <= (1 << src_qp.log_sq_size), "SQ ring full"
|
||||
assert src_qp.head + 1 - to_be('I', dest_qp.dev.dbr[dest_qp.qp_dbr // 4]) <= (1 << dest_qp.log_rq_size), "RQ ring full"
|
||||
assert src_qp.head + 1 - to_be32(src_qp.dev.dbr[src_qp.qp_dbr // 4 + 1]) <= (1 << src_qp.log_sq_size), "SQ ring full"
|
||||
assert src_qp.head + 1 - to_be32(dest_qp.dev.dbr[dest_qp.qp_dbr // 4]) <= (1 << dest_qp.log_rq_size), "RQ ring full"
|
||||
dest_qp.qp_buf.view((src_qp.head & ((1 << dest_qp.log_rq_size) - 1)) * 16, 16)[:] = rq_wqe
|
||||
sq_view = src_qp.qp_buf.view(src_qp.sq_offset + (src_qp.head & ((1 << src_qp.log_sq_size) - 1)) * 64, 64)
|
||||
sq_view[:] = struct.pack('>I', (src_qp.head << 8) | 0x0a) + sq_wqe[4:]
|
||||
|
||||
@@ -11,7 +11,6 @@ MLX5_CMD_STRUCTS = {v: (getattr(mlx5, f"struct_mlx5_ifc_{n[12:].lower()}_in_bits
|
||||
getattr(mlx5, f"struct_mlx5_ifc_{n[12:].lower()}_out_bits", None)) for n, v in mlx5.__dict__.items() if n.startswith("MLX5_CMD_OP_")}
|
||||
MLX5_CMD_STRUCTS[mlx5.MLX5_CMD_OP_ACCESS_REG] = (mlx5.struct_mlx5_ifc_access_register_in_bits, mlx5.struct_mlx5_ifc_access_register_out_bits)
|
||||
|
||||
def to_be(fmt, val): return to_be32(val) if fmt == 'I' else to_be64(val)
|
||||
def ipv4_to_gid(ip): return bytes(10) + b'\xff\xff' + socket.inet_aton(ip)
|
||||
|
||||
def udp_sport(lqpn, rqpn):
|
||||
@@ -64,7 +63,7 @@ class MLXCmdQueue:
|
||||
for i in range(n):
|
||||
off, _ = self.mboxes[base + i]
|
||||
blk = mlx5.struct_mlx5_cmd_prot_block(data=list(data[i*chunk_sz:(i+1)*chunk_sz].ljust(chunk_sz, b'\x00')),
|
||||
next=to_be('Q', self.mboxes[base+i+1][1]) if i < n-1 else 0, block_num=to_be('I', i), token=tok)
|
||||
next=to_be64(self.mboxes[base+i+1][1]) if i < n-1 else 0, block_num=to_be32(i), token=tok)
|
||||
self.queue[off:off + ctypes.sizeof(mlx5.struct_mlx5_cmd_prot_block)] = bytes(blk)
|
||||
return (self.mboxes[base][0], self.mboxes[base][1], n)
|
||||
|
||||
@@ -81,9 +80,9 @@ class MLXCmdQueue:
|
||||
# prepare mailboxes and build command layout
|
||||
_, in_ptr, n_in = self.create_mbox_chain(0, tok, inp[16:])
|
||||
_, out_ptr, n_out = self.create_mbox_chain(n_in, tok, bytes(out_sz))
|
||||
cmd = mlx5.struct_mlx5_cmd_layout(type=mlx5.MLX5_PCI_CMD_XPORT, inlen=to_be('I', len(inp)), in_ptr=to_be('Q', in_ptr),
|
||||
cmd = mlx5.struct_mlx5_cmd_layout(type=mlx5.MLX5_PCI_CMD_XPORT, inlen=to_be32(len(inp)), in_ptr=to_be64(in_ptr),
|
||||
_in=[int.from_bytes(inp[i:i+4], 'little') for i in range(0, 16, 4)],
|
||||
out_ptr=to_be('Q', out_ptr), outlen=to_be('I', 16 + out_sz), token=tok, status_own=mlx5.CMD_OWNER_HW)
|
||||
out_ptr=to_be64(out_ptr), outlen=to_be32(16 + out_sz), token=tok, status_own=mlx5.CMD_OWNER_HW)
|
||||
cmd_bytes = bytearray(bytes(cmd))
|
||||
cmd_bytes[mlx5.struct_mlx5_cmd_layout.sig.offset] = (~functools.reduce(lambda a, b: a ^ b, cmd_bytes)) & 0xFF # type: ignore[attr-defined]
|
||||
|
||||
@@ -115,8 +114,8 @@ class MLXDev:
|
||||
|
||||
self.init_hw(ip)
|
||||
|
||||
def rreg(self, off): return to_be('I',self.bar[off // 4])
|
||||
def wreg(self, off, val): self.bar[off // 4] = to_be('I',val)
|
||||
def rreg(self, off): return to_be32(self.bar[off // 4])
|
||||
def wreg(self, off, val): self.bar[off // 4] = to_be32(val)
|
||||
def iseg_r(self, field): return self.rreg(getattr(mlx5.struct_mlx5_init_seg, field).offset)
|
||||
def iseg_w(self, field, val): self.wreg(getattr(mlx5.struct_mlx5_init_seg, field).offset, val)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user