Slightly less slow remote copyin (#10404)

bytes concat is slow, don't do it if data is already present in self._h

also don't cast memoryview into bytes (copy, +100ms) before it's needed

this mitigates shard copying before shrink

master:
```
*** REMOTE     6 copy 1073.74M,  REMOTE <- METAL           arg  2 mem  2.15 GB tm    806.84ms/   829.61ms (     0.00 GFLOPS    1.3|1.3     GB/s)
*** REMOTE:    7 copy 1073.74M, REMOTE: <- METAL           arg  2 mem  3.22 GB tm    797.41ms/  1627.02ms (     0.00 GFLOPS    1.3|1.3     GB/s)
*** REMOTE:    8 copy 1073.74M, REMOTE: <- METAL           arg  2 mem  4.29 GB tm    677.89ms/  2304.91ms (     0.00 GFLOPS    1.6|1.6     GB/s)
*** REMOTE:    9 copy 1073.74M, REMOTE: <- METAL           arg  2 mem  5.37 GB tm    659.81ms/  2964.72ms (     0.00 GFLOPS    1.6|1.6     GB/s)
*** REMOTE:   10 copy 1073.74M, REMOTE: <- METAL           arg  2 mem  6.44 GB tm    679.21ms/  3643.93ms (     0.00 GFLOPS    1.6|1.6     GB/s)
*** REMOTE:   11 copy 1073.74M, REMOTE: <- METAL           arg  2 mem  7.52 GB tm    673.90ms/  4317.83ms
```

this:
```
*** REMOTE     6 copy 1073.74M,  REMOTE <- METAL           arg  2 mem  2.15 GB tm    867.06ms/   895.58ms (     0.00 GFLOPS    1.2|1.2     GB/s)
*** REMOTE:    7 copy 1073.74M, REMOTE: <- METAL           arg  2 mem  3.22 GB tm    433.35ms/  1328.93ms (     0.00 GFLOPS    2.5|2.5     GB/s)
*** REMOTE:    8 copy 1073.74M, REMOTE: <- METAL           arg  2 mem  4.29 GB tm    433.19ms/  1762.12ms (     0.00 GFLOPS    2.5|2.5     GB/s)
*** REMOTE:    9 copy 1073.74M, REMOTE: <- METAL           arg  2 mem  5.37 GB tm    432.71ms/  2194.83ms (     0.00 GFLOPS    2.5|2.5     GB/s)
*** REMOTE:   10 copy 1073.74M, REMOTE: <- METAL           arg  2 mem  6.44 GB tm    433.68ms/  2628.51ms (     0.00 GFLOPS    2.5|2.5     GB/s)
*** REMOTE:   11 copy 1073.74M, REMOTE: <- METAL           arg  2 mem  7.52 GB tm    432.91ms/  3061.42ms
```

The 430ms is basically all sha256 time.
This commit is contained in:
uuuvn
2025-05-19 04:20:43 +05:00
committed by GitHub
parent e55ee28b29
commit 33cf33902a

View File

@@ -116,9 +116,10 @@ class BatchRequest:
def __init__(self):
self._q: list[RemoteRequest] = []
self._h: dict[str, bytes] = {}
def h(self, d:bytes) -> str:
binhash = hashlib.sha256(d).digest()
self._h[datahash:=binascii.hexlify(binhash).decode()] = binhash+struct.pack("<Q", len(d))+d
def h(self, d:bytes|memoryview) -> str:
datahash = hashlib.sha256(d).hexdigest() # NOTE: this is very slow, should use blake3 on gpu instead
if datahash not in self._h:
self._h[datahash] = bytes.fromhex(datahash)+struct.pack("<Q", len(d))+bytes(d)
return datahash
def q(self, x:RemoteRequest): self._q.append(x)
def serialize(self) -> bytes:
@@ -246,7 +247,7 @@ class RemoteAllocator(Allocator['RemoteDevice']):
return buffer_num
# TODO: options should not be here in any Allocator
def _free(self, opaque:int, options): self.dev.q(BufferFree(opaque))
def _copyin(self, dest:int, src:memoryview): self.dev.q(CopyIn(dest, self.dev.conn.req.h(bytes(src))))
def _copyin(self, dest:int, src:memoryview): self.dev.q(CopyIn(dest, self.dev.conn.req.h(src)))
def _copyout(self, dest:memoryview, src:int):
resp = self.dev.q(CopyOut(src), wait=True)
assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}"