refactor(streamstate): more consistent actions when entering states (#110)

This commit is contained in:
vladopajic
2025-08-27 15:50:31 +02:00
committed by GitHub
parent cae13c2d22
commit 14de00a704
6 changed files with 35 additions and 25 deletions

View File

@@ -113,6 +113,7 @@ proc reset*(fs: var FrameSorter) =
fs.buffer.clear()
fs.incoming.clear()
fs.emitPos = 0
# resetting FS should leave fs.closed (if it was set)
proc isComplete*(fs: FrameSorter): bool =
if fs.closed:

View File

@@ -8,6 +8,7 @@ type BaseStreamState* = ref object of StreamState
incoming*: AsyncQueue[seq[byte]]
connection*: Ngtcp2Connection
frameSorter*: FrameSorter
finSent*: bool
proc setUserData*(state: BaseStreamState, stream: stream.Stream) =
state.connection.setStreamUserData(stream.id, unsafeAddr state[])
@@ -22,3 +23,11 @@ proc allowMoreIncomingBytes*(state: BaseStreamState, amount: uint64) =
return
state.connection.extendStreamOffset(stream.id, amount)
state.connection.send()
proc sendFin*(state: BaseStreamState, stream: stream.Stream) =
if not state.finSent:
state.finSent = true
discard state.connection.send(stream.id, @[], true)
proc reset*(state: BaseStreamState, stream: stream.Stream) =
state.connection.shutdownStream(stream.id)

View File

@@ -15,6 +15,7 @@ proc newClosedStreamState*(
connection: base.connection,
incoming: base.incoming,
frameSorter: base.frameSorter,
finSent: base.finSent,
wasReset: wasReset,
)
@@ -22,7 +23,14 @@ method enter*(state: ClosedStreamState, stream: Stream) =
procCall enter(StreamState(state), stream)
state.stream = Opt.some(stream)
state.setUserData(stream)
if state.wasReset:
state.frameSorter.reset()
state.frameSorter.close()
if state.wasReset:
state.reset(stream)
else:
state.sendFin(stream)
stream.closed.fire()
method leave*(state: ClosedStreamState) =
doAssert false, "ClosedStreamState state should never leave"

View File

@@ -45,19 +45,18 @@ method read*(state: OpenStreamState): Future[seq[byte]] {.async.} =
return await state.read()
method write*(state: OpenStreamState, bytes: seq[byte]): Future[void] =
state.connection.send(state.stream.get.id, bytes)
method close*(state: OpenStreamState) {.async.} =
# Bidirectional streams, close() only closes the send side of the stream.
let stream = state.stream.valueOr:
return
discard state.connection.send(state.stream.get.id, @[], true) # Send FIN
state.connection.send(stream.id, bytes)
method close*(state: OpenStreamState) {.async.} =
let stream = state.stream.valueOr:
return
stream.switch(newReceiveStreamState(state))
method closeWrite*(state: OpenStreamState) {.async.} =
let stream = state.stream.valueOr:
return
discard state.connection.send(state.stream.get.id, @[], true) # Send FIN
stream.switch(newReceiveStreamState(state))
method closeRead*(state: OpenStreamState) {.async.} =
@@ -79,14 +78,9 @@ method receive*(state: OpenStreamState, offset: uint64, bytes: seq[byte], isFin:
if state.frameSorter.isComplete():
let stream = state.stream.valueOr:
return
stream.closed.fire()
stream.switch(newClosedStreamState(state))
method reset*(state: OpenStreamState) =
let stream = state.stream.valueOr:
return
state.connection.shutdownStream(stream.id)
stream.closed.fire()
state.frameSorter.reset()
stream.switch(newClosedStreamState(state, wasReset = true))

View File

@@ -10,13 +10,17 @@ type ReceiveStreamState* = ref object of BaseStreamState
proc newReceiveStreamState*(base: BaseStreamState): ReceiveStreamState =
ReceiveStreamState(
connection: base.connection, incoming: base.incoming, frameSorter: base.frameSorter
connection: base.connection,
incoming: base.incoming,
frameSorter: base.frameSorter,
finSent: base.finSent,
)
method enter*(state: ReceiveStreamState, stream: Stream) =
procCall enter(StreamState(state), stream)
state.stream = Opt.some(stream)
state.setUserData(stream)
state.sendFin(stream)
method leave*(state: ReceiveStreamState) =
procCall leave(StreamState(state))
@@ -82,14 +86,9 @@ method receive*(
if state.frameSorter.isComplete():
let stream = state.stream.valueOr:
return
stream.closed.fire()
stream.switch(newClosedStreamState(state))
method reset*(state: ReceiveStreamState) =
let stream = state.stream.valueOr:
return
state.connection.shutdownStream(stream.id)
stream.closed.fire()
state.frameSorter.reset()
stream.switch(newClosedStreamState(state, wasReset = true))

View File

@@ -10,7 +10,10 @@ type SendStreamState* = ref object of BaseStreamState
proc newSendStreamState*(base: BaseStreamState): SendStreamState =
SendStreamState(
connection: base.connection, incoming: base.incoming, frameSorter: base.frameSorter
connection: base.connection,
incoming: base.incoming,
frameSorter: base.frameSorter,
finSent: base.finSent,
)
method enter*(state: SendStreamState, stream: Stream) =
@@ -27,18 +30,18 @@ method read*(state: SendStreamState): Future[seq[byte]] {.async.} =
raise newException(ClosedStreamError, "read side is closed")
method write*(state: SendStreamState, bytes: seq[byte]) {.async.} =
await state.connection.send(state.stream.get.id, bytes)
let stream = state.stream.valueOr:
return
await state.connection.send(stream.id, bytes)
method close*(state: SendStreamState) {.async.} =
let stream = state.stream.valueOr:
return
discard state.connection.send(state.stream.get.id, @[], true) # Send FIN
stream.switch(newClosedStreamState(state))
method closeWrite*(state: SendStreamState) {.async.} =
let stream = state.stream.valueOr:
return
discard state.connection.send(state.stream.get.id, @[], true) # Send FIN
stream.switch(newClosedStreamState(state))
method closeRead*(stream: SendStreamState) {.async.} =
@@ -58,8 +61,4 @@ method receive*(state: SendStreamState, offset: uint64, bytes: seq[byte], isFin:
method reset*(state: SendStreamState) =
let stream = state.stream.valueOr:
return
state.connection.shutdownStream(stream.id)
stream.closed.fire()
state.frameSorter.reset()
stream.switch(newClosedStreamState(state, wasReset = true))