mirror of
https://github.com/vacp2p/nim-quic.git
synced 2026-01-09 22:08:09 -05:00
fix: implement RFC 9000 compliant stream closure for large data trans… (#89)
This commit is contained in:
@@ -16,6 +16,7 @@ type OpenStream* = ref object of StreamState
|
||||
frameSorter*: FrameSorter
|
||||
closeFut*: Future[string]
|
||||
writeFinSent*: bool
|
||||
readClosed*: bool
|
||||
|
||||
proc newOpenStream*(connection: Ngtcp2Connection): OpenStream =
|
||||
let incomingQ = newAsyncQueue[seq[byte]]()
|
||||
@@ -25,6 +26,7 @@ proc newOpenStream*(connection: Ngtcp2Connection): OpenStream =
|
||||
closeFut: newFuture[string](),
|
||||
frameSorter: initFrameSorter(incomingQ),
|
||||
writeFinSent: false,
|
||||
readClosed: false,
|
||||
)
|
||||
|
||||
method enter*(state: OpenStream, stream: Stream) =
|
||||
@@ -38,32 +40,53 @@ method leave*(state: OpenStream) =
|
||||
state.stream = Opt.none(Stream)
|
||||
|
||||
method read*(state: OpenStream): Future[seq[byte]] {.async.} =
|
||||
# Check if we have EOF and no more data in queue
|
||||
# RFC 9000 compliant stream reading logic
|
||||
# Priority 1: Check for immediate EOF conditions
|
||||
if state.frameSorter.isEOF() and state.incoming.len == 0:
|
||||
return @[] # Always return EOF when we hit end of stream
|
||||
|
||||
let incomingFut = state.incoming.get()
|
||||
let raceFut = await race(state.closeFut, incomingFut)
|
||||
if raceFut == incomingFut:
|
||||
result = await incomingFut
|
||||
|
||||
# If we got empty data and isEOF, this means EOF
|
||||
if result.len == 0 and state.frameSorter.isEOF():
|
||||
return @[] # Return EOF (empty array)
|
||||
|
||||
# If we got real data, return it
|
||||
if result.len > 0:
|
||||
allowMoreIncomingBytes(state.stream, state.connection, result.len.uint64)
|
||||
else:
|
||||
incomingFut.cancelSoon()
|
||||
let stream = state.stream.valueOr:
|
||||
raise newException(StreamError, "stream is closed")
|
||||
|
||||
if state.frameSorter.isEOF():
|
||||
# Remote sent FIN and no more data - check if we should switch to ClosedStream
|
||||
if state.readClosed:
|
||||
# Both remote FIN received and local read closed - switch to ClosedStream
|
||||
let stream = state.stream.valueOr:
|
||||
return @[] # Already closed
|
||||
stream.switch(newClosedStream(state.incoming, state.frameSorter))
|
||||
return @[] # Return EOF immediately per RFC 9000 "Data Read" state
|
||||
|
||||
let closeReason = await state.closeFut
|
||||
raise newException(StreamError, closeReason)
|
||||
# Priority 2: Check if local read is closed but there's still buffered data
|
||||
if state.readClosed and state.incoming.len == 0:
|
||||
# Local read closed and no buffered data - switch to ClosedStream
|
||||
let stream = state.stream.valueOr:
|
||||
return @[] # Already closed
|
||||
stream.switch(newClosedStream(state.incoming, state.frameSorter))
|
||||
return @[] # Return EOF for locally closed read
|
||||
|
||||
# Priority 3: Get data from incoming queue
|
||||
let data = await state.incoming.get()
|
||||
|
||||
# If we got real data, return it with flow control update
|
||||
if data.len > 0:
|
||||
allowMoreIncomingBytes(state.stream, state.connection, data.len.uint64)
|
||||
return data
|
||||
|
||||
# If we got empty data (len == 0), check if this is EOF
|
||||
if data.len == 0 and state.frameSorter.isEOF():
|
||||
# This is EOF - stream has been closed with FIN bit from remote
|
||||
let stream = state.stream.valueOr:
|
||||
return @[] # Already closed
|
||||
# If local read is also closed, switch to ClosedStream
|
||||
if state.readClosed:
|
||||
stream.switch(newClosedStream(state.incoming, state.frameSorter))
|
||||
return @[] # Return EOF per RFC 9000
|
||||
|
||||
# If local read is closed but we got empty data (not EOF), return EOF
|
||||
if state.readClosed:
|
||||
let stream = state.stream.valueOr:
|
||||
return @[] # Already closed
|
||||
stream.switch(newClosedStream(state.incoming, state.frameSorter))
|
||||
return @[] # Return EOF for locally closed read
|
||||
|
||||
# Empty data but no EOF - this shouldn't happen in normal operation
|
||||
# Continue reading for more data
|
||||
return await state.read()
|
||||
|
||||
method write*(state: OpenStream, bytes: seq[byte]): Future[void] =
|
||||
if state.writeFinSent:
|
||||
@@ -76,17 +99,25 @@ method write*(state: OpenStream, bytes: seq[byte]): Future[void] =
|
||||
state.connection.send(state.stream.get.id, bytes)
|
||||
|
||||
method close*(state: OpenStream) {.async.} =
|
||||
## Close both write and read sides of the stream
|
||||
let stream = state.stream.valueOr:
|
||||
return
|
||||
discard state.connection.send(state.stream.get.id, @[], true)
|
||||
|
||||
# Close write side by sending FIN
|
||||
discard state.connection.send(state.stream.get.id, @[], true) # Send FIN
|
||||
state.writeFinSent = true
|
||||
stream.switch(newClosedStream(state.incoming, state.frameSorter))
|
||||
|
||||
# Close read side locally
|
||||
state.readClosed = true
|
||||
|
||||
# Don't switch to ClosedStream immediately - let read() handle the transition
|
||||
# when all buffered data is consumed
|
||||
|
||||
method closeWrite*(state: OpenStream) {.async.} =
|
||||
## Close write side by sending FIN, but keep read side open
|
||||
let stream = state.stream.valueOr:
|
||||
return
|
||||
discard state.connection.send(state.stream.get.id, @[], true) # Sending FIN
|
||||
discard state.connection.send(state.stream.get.id, @[], true) # Send FIN
|
||||
state.writeFinSent = true
|
||||
# Note: we don't switch to ClosedStream here - read side stays open for half-close
|
||||
|
||||
@@ -103,6 +134,15 @@ method reset*(state: OpenStream) =
|
||||
method onClose*(state: OpenStream) =
|
||||
let stream = state.stream.valueOr:
|
||||
return
|
||||
|
||||
# Wake up pending read() operations before switching states
|
||||
# This fixes race condition when ngtcp2 calls onClose() while read() is waiting
|
||||
try:
|
||||
state.incoming.putNoWait(@[]) # Send EOF marker to wake up pending reads
|
||||
except AsyncQueueFullError:
|
||||
# Queue is full, that's fine - there's already data to process
|
||||
discard
|
||||
|
||||
stream.switch(newClosedStream(state.incoming, state.frameSorter))
|
||||
|
||||
method isClosed*(state: OpenStream): bool =
|
||||
|
||||
@@ -428,3 +428,457 @@ suite "streams":
|
||||
# Should fail - client closed write side
|
||||
|
||||
await simulation.cancelAndWait()
|
||||
|
||||
# Bidirectional stream closure tests
|
||||
asyncTest "close() should fully close bidirectional stream in both directions":
|
||||
## RFC 9000: close() should fully close the stream (both read and write)
|
||||
let simulation = simulateNetwork(client, server)
|
||||
|
||||
let clientStream = await client.openStream()
|
||||
check not clientStream.isUnidirectional
|
||||
|
||||
# Activate stream
|
||||
await clientStream.write(@[])
|
||||
let serverStream = await server.incomingStream()
|
||||
|
||||
# Client sends data and fully closes stream
|
||||
let clientData = @[1'u8, 2, 3, 4, 5]
|
||||
await clientStream.write(clientData)
|
||||
await clientStream.close() # Full close
|
||||
|
||||
# After close() client should NOT be able to write or read
|
||||
expect QuicError:
|
||||
await clientStream.write(@[6'u8, 7, 8])
|
||||
|
||||
# Server should receive data and EOF
|
||||
let receivedData = await serverStream.read()
|
||||
check receivedData == clientData
|
||||
|
||||
let eof = await serverStream.read()
|
||||
check eof.len == 0 # EOF
|
||||
|
||||
# Server can still write back (until it receives indication that client closed read)
|
||||
# But in QUIC when close() is called, it closes ALL directions
|
||||
# TODO: this depends on specific RFC 9000 implementation
|
||||
|
||||
await simulation.cancelAndWait()
|
||||
|
||||
asyncTest "bidirectional closeWrite() - both sides close write independently":
|
||||
## Test RFC 9000 bidirectional half-close semantics
|
||||
let simulation = simulateNetwork(client, server)
|
||||
|
||||
let clientStream = await client.openStream()
|
||||
await clientStream.write(@[])
|
||||
let serverStream = await server.incomingStream()
|
||||
|
||||
# Both send data
|
||||
let clientData = @[1'u8, 2, 3]
|
||||
let serverData = @[4'u8, 5, 6]
|
||||
|
||||
await clientStream.write(clientData)
|
||||
await serverStream.write(serverData)
|
||||
|
||||
# Both close their write side
|
||||
await clientStream.closeWrite()
|
||||
await serverStream.closeWrite()
|
||||
|
||||
# Neither can write
|
||||
expect QuicError:
|
||||
await clientStream.write(@[7'u8])
|
||||
expect QuicError:
|
||||
await serverStream.write(@[8'u8])
|
||||
|
||||
# But both can read each other's data
|
||||
let receivedByClient = await clientStream.read()
|
||||
let receivedByServer = await serverStream.read()
|
||||
|
||||
check receivedByClient == serverData
|
||||
check receivedByServer == clientData
|
||||
|
||||
# And both receive EOF
|
||||
let eofClient = await clientStream.read()
|
||||
let eofServer = await serverStream.read()
|
||||
|
||||
check eofClient.len == 0
|
||||
check eofServer.len == 0
|
||||
|
||||
await simulation.cancelAndWait()
|
||||
|
||||
asyncTest "close() after closeWrite() should work correctly":
|
||||
## After closeWrite() calling close() should also close the read side
|
||||
let simulation = simulateNetwork(client, server)
|
||||
|
||||
let clientStream = await client.openStream()
|
||||
await clientStream.write(@[])
|
||||
let serverStream = await server.incomingStream()
|
||||
|
||||
let clientData = @[1'u8, 2, 3]
|
||||
await clientStream.write(clientData)
|
||||
await clientStream.closeWrite() # First half-close
|
||||
|
||||
# Server sends response
|
||||
let serverData = @[4'u8, 5, 6]
|
||||
await serverStream.write(serverData)
|
||||
|
||||
# Client reads response
|
||||
let response = await clientStream.read()
|
||||
check response == serverData
|
||||
|
||||
# Now client fully closes stream
|
||||
await clientStream.close()
|
||||
|
||||
await simulation.cancelAndWait()
|
||||
|
||||
asyncTest "mixed close() and closeWrite() semantics":
|
||||
## One uses close(), other uses closeWrite()
|
||||
let simulation = simulateNetwork(client, server)
|
||||
|
||||
let clientStream = await client.openStream()
|
||||
await clientStream.write(@[])
|
||||
let serverStream = await server.incomingStream()
|
||||
|
||||
let clientData = @[1'u8, 2, 3]
|
||||
let serverData = @[4'u8, 5, 6]
|
||||
|
||||
await clientStream.write(clientData)
|
||||
await serverStream.write(serverData)
|
||||
|
||||
# Client does half-close
|
||||
await clientStream.closeWrite()
|
||||
|
||||
# Server does full close
|
||||
await serverStream.close()
|
||||
|
||||
# Client should receive data from server
|
||||
let receivedByClient = await clientStream.read()
|
||||
check receivedByClient == serverData
|
||||
|
||||
# And EOF
|
||||
let eofClient = await clientStream.read()
|
||||
check eofClient.len == 0
|
||||
|
||||
# Server should also receive data from client (before its close())
|
||||
let receivedByServer = await serverStream.read()
|
||||
check receivedByServer == clientData
|
||||
|
||||
await simulation.cancelAndWait()
|
||||
|
||||
asyncTest "stream state tracking for bidirectional closure":
|
||||
## Check that stream state is properly tracked
|
||||
let simulation = simulateNetwork(client, server)
|
||||
|
||||
let clientStream = await client.openStream()
|
||||
await clientStream.write(@[])
|
||||
let serverStream = await server.incomingStream()
|
||||
|
||||
# Initially both streams are open
|
||||
check not clientStream.isClosed()
|
||||
check not serverStream.isClosed()
|
||||
|
||||
# After closeWrite() stream should not be considered fully closed
|
||||
await clientStream.closeWrite()
|
||||
check not clientStream.isClosed() # Half-close ≠ closed
|
||||
|
||||
# After peer also closes its side, stream may be considered closed
|
||||
await serverStream.closeWrite()
|
||||
|
||||
# Read all data to reach final state
|
||||
discard await clientStream.read() # May be empty or EOF
|
||||
discard await serverStream.read() # May be empty or EOF
|
||||
|
||||
# Now both streams should be closed
|
||||
|
||||
await simulation.cancelAndWait()
|
||||
|
||||
# Large data transfer tests
|
||||
asyncTest "simple 10MB write test":
|
||||
let simulation = simulateNetwork(client, server)
|
||||
let dataSize = 10 * 1024 * 1024 # 10 MB
|
||||
var testData = newSeq[uint8](dataSize)
|
||||
|
||||
# Fill with pattern
|
||||
for i in 0 ..< dataSize:
|
||||
testData[i] = uint8(i mod 256)
|
||||
|
||||
let clientStream = await client.openStream()
|
||||
let serverStreamFuture = server.incomingStream()
|
||||
|
||||
# Activate stream
|
||||
await clientStream.write(@[])
|
||||
|
||||
let serverStream = await serverStreamFuture
|
||||
|
||||
# Server starts reading IMMEDIATELY (parallel with client writing)
|
||||
proc serverReadData(): Future[seq[uint8]] {.async.} =
|
||||
var receivedData: seq[uint8]
|
||||
var chunkCount = 0
|
||||
while true:
|
||||
let chunk = await serverStream.read()
|
||||
if chunk.len == 0:
|
||||
break
|
||||
receivedData.add(chunk)
|
||||
chunkCount += 1
|
||||
return receivedData
|
||||
|
||||
let serverTask = serverReadData()
|
||||
|
||||
# Client writes data WHILE server is reading
|
||||
await clientStream.write(testData)
|
||||
|
||||
await clientStream.closeWrite()
|
||||
|
||||
# Wait for server to finish reading
|
||||
let receivedData = await serverTask
|
||||
|
||||
check receivedData.len == dataSize
|
||||
|
||||
await serverStream.close()
|
||||
await clientStream.close()
|
||||
await simulation.cancelAndWait()
|
||||
|
||||
asyncTest "bidirectional 10MB + 10MB closeWrite test":
|
||||
let simulation = simulateNetwork(client, server)
|
||||
let dataSize = 10 * 1024 * 1024 # 10 MB each direction
|
||||
|
||||
# Client data pattern
|
||||
var clientData = newSeq[uint8](dataSize)
|
||||
for i in 0 ..< dataSize:
|
||||
clientData[i] = uint8(0xAA) # Client pattern
|
||||
|
||||
# Server data pattern
|
||||
var serverData = newSeq[uint8](dataSize)
|
||||
for i in 0 ..< dataSize:
|
||||
serverData[i] = uint8(0xBB) # Server pattern
|
||||
|
||||
let clientStream = await client.openStream()
|
||||
let serverStreamFuture = server.incomingStream()
|
||||
|
||||
# Activate stream
|
||||
await clientStream.write(@[])
|
||||
|
||||
let serverStream = await serverStreamFuture
|
||||
|
||||
# Start parallel read operations for both directions
|
||||
proc clientReadData(): Future[seq[uint8]] {.async.} =
|
||||
var receivedData: seq[uint8]
|
||||
var chunkCount = 0
|
||||
while true:
|
||||
let chunk = await clientStream.read()
|
||||
if chunk.len == 0:
|
||||
break
|
||||
receivedData.add(chunk)
|
||||
chunkCount += 1
|
||||
return receivedData
|
||||
|
||||
proc serverReadData(): Future[seq[uint8]] {.async.} =
|
||||
var receivedData: seq[uint8]
|
||||
var chunkCount = 0
|
||||
while true:
|
||||
let chunk = await serverStream.read()
|
||||
if chunk.len == 0:
|
||||
break
|
||||
receivedData.add(chunk)
|
||||
chunkCount += 1
|
||||
return receivedData
|
||||
|
||||
# Start both read tasks
|
||||
let clientReadTask = clientReadData()
|
||||
let serverReadTask = serverReadData()
|
||||
|
||||
# Client writes 10MB and closes write side
|
||||
await clientStream.write(clientData)
|
||||
await clientStream.closeWrite()
|
||||
|
||||
# Server writes 10MB and closes write side
|
||||
await serverStream.write(serverData)
|
||||
await serverStream.closeWrite()
|
||||
|
||||
# Wait for both read operations to complete
|
||||
let clientReceivedData = await clientReadTask
|
||||
let serverReceivedData = await serverReadTask
|
||||
|
||||
# Verify data sizes
|
||||
check clientReceivedData.len == dataSize
|
||||
check serverReceivedData.len == dataSize
|
||||
|
||||
# Verify data patterns
|
||||
var clientDataValid = true
|
||||
var serverDataValid = true
|
||||
|
||||
for i in 0 ..< min(dataSize, clientReceivedData.len):
|
||||
if clientReceivedData[i] != 0xBB: # Client should receive server pattern
|
||||
clientDataValid = false
|
||||
break
|
||||
|
||||
for i in 0 ..< min(dataSize, serverReceivedData.len):
|
||||
if serverReceivedData[i] != 0xAA: # Server should receive client pattern
|
||||
serverDataValid = false
|
||||
break
|
||||
|
||||
check clientDataValid
|
||||
check serverDataValid
|
||||
|
||||
# Both sides should be able to detect EOF now
|
||||
let clientEOF = await clientStream.read()
|
||||
let serverEOF = await serverStream.read()
|
||||
check clientEOF.len == 0
|
||||
check serverEOF.len == 0
|
||||
|
||||
await serverStream.close()
|
||||
await clientStream.close()
|
||||
await simulation.cancelAndWait()
|
||||
|
||||
asyncTest "mixed semantics: client closeWrite + server close with 10MB":
|
||||
let simulation = simulateNetwork(client, server)
|
||||
let dataSize = 10 * 1024 * 1024 # 10 MB
|
||||
|
||||
# Client data pattern
|
||||
var clientData = newSeq[uint8](dataSize)
|
||||
for i in 0 ..< dataSize:
|
||||
clientData[i] = uint8(0xCC) # Client pattern
|
||||
|
||||
# Server data pattern
|
||||
var serverData = newSeq[uint8](dataSize)
|
||||
for i in 0 ..< dataSize:
|
||||
serverData[i] = uint8(0xDD) # Server pattern
|
||||
|
||||
let clientStream = await client.openStream()
|
||||
let serverStreamFuture = server.incomingStream()
|
||||
|
||||
# Activate stream
|
||||
await clientStream.write(@[])
|
||||
|
||||
let serverStream = await serverStreamFuture
|
||||
|
||||
# Start parallel read operations
|
||||
proc clientReadData(): Future[seq[uint8]] {.async.} =
|
||||
var receivedData: seq[uint8]
|
||||
var chunkCount = 0
|
||||
while true:
|
||||
let chunk = await clientStream.read()
|
||||
if chunk.len == 0:
|
||||
break
|
||||
receivedData.add(chunk)
|
||||
chunkCount += 1
|
||||
return receivedData
|
||||
|
||||
proc serverReadData(): Future[seq[uint8]] {.async.} =
|
||||
var receivedData: seq[uint8]
|
||||
var chunkCount = 0
|
||||
while true:
|
||||
let chunk = await serverStream.read()
|
||||
if chunk.len == 0:
|
||||
break
|
||||
receivedData.add(chunk)
|
||||
chunkCount += 1
|
||||
return receivedData
|
||||
|
||||
# Start both read tasks
|
||||
let clientReadTask = clientReadData()
|
||||
let serverReadTask = serverReadData()
|
||||
|
||||
# Client writes 10MB and does closeWrite() (half-close)
|
||||
await clientStream.write(clientData)
|
||||
await clientStream.closeWrite()
|
||||
|
||||
# Server writes 10MB and does close() (full-close)
|
||||
await serverStream.write(serverData)
|
||||
await serverStream.close()
|
||||
|
||||
# Wait for both read operations to complete
|
||||
let clientReceivedData = await clientReadTask
|
||||
let serverReceivedData = await serverReadTask
|
||||
|
||||
# Verify data sizes
|
||||
check clientReceivedData.len == dataSize
|
||||
check serverReceivedData.len == dataSize
|
||||
|
||||
# Verify data patterns
|
||||
var clientDataValid = true
|
||||
var serverDataValid = true
|
||||
|
||||
for i in 0 ..< min(dataSize, clientReceivedData.len):
|
||||
if clientReceivedData[i] != 0xDD: # Client should receive server pattern
|
||||
clientDataValid = false
|
||||
break
|
||||
|
||||
for i in 0 ..< min(dataSize, serverReceivedData.len):
|
||||
if serverReceivedData[i] != 0xCC: # Server should receive client pattern
|
||||
serverDataValid = false
|
||||
break
|
||||
|
||||
check clientDataValid
|
||||
check serverDataValid
|
||||
|
||||
# Client should get EOF when trying to read (server did full close)
|
||||
let clientEOF = await clientStream.read()
|
||||
check clientEOF.len == 0
|
||||
|
||||
# Client should still be able to close its read side
|
||||
await clientStream.close()
|
||||
|
||||
await simulation.cancelAndWait()
|
||||
|
||||
asyncTest "reverse order: client starts writing first, server reads parallel":
|
||||
let simulation = simulateNetwork(client, server)
|
||||
let dataSize = 10 * 1024 * 1024 # 10 MB
|
||||
|
||||
var testData = newSeq[uint8](dataSize)
|
||||
for i in 0 ..< dataSize:
|
||||
testData[i] = uint8(0xEE) # Pattern for this test
|
||||
|
||||
let clientStream = await client.openStream()
|
||||
let serverStreamFuture = server.incomingStream()
|
||||
|
||||
# Activate stream
|
||||
await clientStream.write(@[])
|
||||
|
||||
let serverStream = await serverStreamFuture
|
||||
|
||||
# CLIENT STARTS WRITING FIRST (non-blocking)
|
||||
let clientWriteTask = proc() {.async.} =
|
||||
await clientStream.write(testData)
|
||||
await clientStream.closeWrite()
|
||||
|
||||
let clientTask = clientWriteTask()
|
||||
|
||||
# Small delay to let client start writing first
|
||||
await sleepAsync(5.milliseconds)
|
||||
|
||||
# SERVER STARTS READING IN PARALLEL (after client already started)
|
||||
proc serverReadData(): Future[seq[uint8]] {.async.} =
|
||||
var receivedData: seq[uint8]
|
||||
var chunkCount = 0
|
||||
while true:
|
||||
let chunk = await serverStream.read()
|
||||
if chunk.len == 0:
|
||||
break
|
||||
receivedData.add(chunk)
|
||||
chunkCount += 1
|
||||
return receivedData
|
||||
|
||||
let serverTask = serverReadData()
|
||||
|
||||
# Wait for both operations to complete
|
||||
await clientTask
|
||||
let receivedData = await serverTask
|
||||
|
||||
# Verify data
|
||||
check receivedData.len == dataSize
|
||||
|
||||
# Verify data pattern
|
||||
var dataValid = true
|
||||
for i in 0 ..< min(dataSize, receivedData.len):
|
||||
if receivedData[i] != 0xEE:
|
||||
dataValid = false
|
||||
break
|
||||
|
||||
check dataValid
|
||||
|
||||
# EOF check
|
||||
let eofCheck = await serverStream.read()
|
||||
check eofCheck.len == 0
|
||||
|
||||
await serverStream.close()
|
||||
await clientStream.close()
|
||||
await simulation.cancelAndWait()
|
||||
|
||||
Reference in New Issue
Block a user