mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
Slightly less slow remote copyin (#10404)
bytes concat is slow, don't do it if data is already present in self._h also don't cast memoryview into bytes (copy, +100ms) before it's needed this mitigates shard copying before shrink master: ``` *** REMOTE 6 copy 1073.74M, REMOTE <- METAL arg 2 mem 2.15 GB tm 806.84ms/ 829.61ms ( 0.00 GFLOPS 1.3|1.3 GB/s) *** REMOTE: 7 copy 1073.74M, REMOTE: <- METAL arg 2 mem 3.22 GB tm 797.41ms/ 1627.02ms ( 0.00 GFLOPS 1.3|1.3 GB/s) *** REMOTE: 8 copy 1073.74M, REMOTE: <- METAL arg 2 mem 4.29 GB tm 677.89ms/ 2304.91ms ( 0.00 GFLOPS 1.6|1.6 GB/s) *** REMOTE: 9 copy 1073.74M, REMOTE: <- METAL arg 2 mem 5.37 GB tm 659.81ms/ 2964.72ms ( 0.00 GFLOPS 1.6|1.6 GB/s) *** REMOTE: 10 copy 1073.74M, REMOTE: <- METAL arg 2 mem 6.44 GB tm 679.21ms/ 3643.93ms ( 0.00 GFLOPS 1.6|1.6 GB/s) *** REMOTE: 11 copy 1073.74M, REMOTE: <- METAL arg 2 mem 7.52 GB tm 673.90ms/ 4317.83ms ``` this: ``` *** REMOTE 6 copy 1073.74M, REMOTE <- METAL arg 2 mem 2.15 GB tm 867.06ms/ 895.58ms ( 0.00 GFLOPS 1.2|1.2 GB/s) *** REMOTE: 7 copy 1073.74M, REMOTE: <- METAL arg 2 mem 3.22 GB tm 433.35ms/ 1328.93ms ( 0.00 GFLOPS 2.5|2.5 GB/s) *** REMOTE: 8 copy 1073.74M, REMOTE: <- METAL arg 2 mem 4.29 GB tm 433.19ms/ 1762.12ms ( 0.00 GFLOPS 2.5|2.5 GB/s) *** REMOTE: 9 copy 1073.74M, REMOTE: <- METAL arg 2 mem 5.37 GB tm 432.71ms/ 2194.83ms ( 0.00 GFLOPS 2.5|2.5 GB/s) *** REMOTE: 10 copy 1073.74M, REMOTE: <- METAL arg 2 mem 6.44 GB tm 433.68ms/ 2628.51ms ( 0.00 GFLOPS 2.5|2.5 GB/s) *** REMOTE: 11 copy 1073.74M, REMOTE: <- METAL arg 2 mem 7.52 GB tm 432.91ms/ 3061.42ms ``` The 430ms is basically all sha256 time.
This commit is contained in:
@@ -116,9 +116,10 @@ class BatchRequest:
|
||||
def __init__(self):
|
||||
self._q: list[RemoteRequest] = []
|
||||
self._h: dict[str, bytes] = {}
|
||||
def h(self, d:bytes) -> str:
|
||||
binhash = hashlib.sha256(d).digest()
|
||||
self._h[datahash:=binascii.hexlify(binhash).decode()] = binhash+struct.pack("<Q", len(d))+d
|
||||
def h(self, d:bytes|memoryview) -> str:
|
||||
datahash = hashlib.sha256(d).hexdigest() # NOTE: this is very slow, should use blake3 on gpu instead
|
||||
if datahash not in self._h:
|
||||
self._h[datahash] = bytes.fromhex(datahash)+struct.pack("<Q", len(d))+bytes(d)
|
||||
return datahash
|
||||
def q(self, x:RemoteRequest): self._q.append(x)
|
||||
def serialize(self) -> bytes:
|
||||
@@ -246,7 +247,7 @@ class RemoteAllocator(Allocator['RemoteDevice']):
|
||||
return buffer_num
|
||||
# TODO: options should not be here in any Allocator
|
||||
def _free(self, opaque:int, options): self.dev.q(BufferFree(opaque))
|
||||
def _copyin(self, dest:int, src:memoryview): self.dev.q(CopyIn(dest, self.dev.conn.req.h(bytes(src))))
|
||||
def _copyin(self, dest:int, src:memoryview): self.dev.q(CopyIn(dest, self.dev.conn.req.h(src)))
|
||||
def _copyout(self, dest:memoryview, src:int):
|
||||
resp = self.dev.q(CopyOut(src), wait=True)
|
||||
assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}"
|
||||
|
||||
Reference in New Issue
Block a user