mirror of
https://github.com/vacp2p/nim-quic.git
synced 2026-01-09 22:08:09 -05:00
refactor(streamstate): removed optional Stream (#152)
This commit is contained in:
@@ -12,7 +12,8 @@ logScope:
|
||||
proc openStream*(
|
||||
connection: Ngtcp2Connection, unidirectional: bool
|
||||
): Stream {.raises: [QuicError].} =
|
||||
let stream = newStream(newOpenStreamState(connection))
|
||||
let stream = newStream()
|
||||
stream.switch(newOpenStreamState(connection, stream))
|
||||
let id =
|
||||
if unidirectional:
|
||||
connection.openUniStream(addr stream[])
|
||||
@@ -25,7 +26,8 @@ proc onStreamOpen(
|
||||
conn: ptr ngtcp2_conn, stream_id: int64, user_data: pointer
|
||||
): cint {.cdecl.} =
|
||||
let connection = cast[Ngtcp2Connection](user_data)
|
||||
let stream = newStream(newOpenStreamState(connection))
|
||||
let stream = newStream()
|
||||
stream.switch(newOpenStreamState(connection, stream))
|
||||
stream.id = stream_id
|
||||
connection.setStreamUserData(stream_id, addr stream[])
|
||||
connection.onIncomingStream(stream)
|
||||
|
||||
@@ -4,38 +4,36 @@ import ../native/connection
|
||||
import ./queue
|
||||
|
||||
type BaseStreamState* = ref object of StreamState
|
||||
stream*: Opt[Stream]
|
||||
queue*: StreamQueue
|
||||
connection*: Ngtcp2Connection
|
||||
stream*: Stream
|
||||
queue*: StreamQueue
|
||||
finSent*: bool
|
||||
|
||||
method onEnter*(state: BaseStreamState) {.raises: [QuicError].} =
|
||||
procCall onEnter(StreamState(state))
|
||||
|
||||
method onLeave*(state: BaseStreamState) =
|
||||
procCall onLeave(StreamState(state))
|
||||
|
||||
method expire*(state: BaseStreamState) {.raises: [].} =
|
||||
let stream = state.stream.valueOr:
|
||||
return
|
||||
stream.closed.fire()
|
||||
state.stream.closed.fire()
|
||||
|
||||
method write*(
|
||||
state: BaseStreamState, bytes: seq[byte]
|
||||
) {.async: (raises: [CancelledError, QuicError]).} =
|
||||
let stream = state.stream.valueOr:
|
||||
return
|
||||
await state.connection.send(stream.id, bytes)
|
||||
await state.connection.send(state.stream.id, bytes)
|
||||
|
||||
proc allowMoreIncomingBytes*(state: BaseStreamState, amount: uint64) =
|
||||
let stream = state.stream.valueOr:
|
||||
return
|
||||
state.connection.extendStreamOffset(stream.id, amount)
|
||||
state.connection.extendStreamOffset(state.stream.id, amount)
|
||||
state.connection.send()
|
||||
|
||||
proc sendFin*(state: BaseStreamState, stream: stream.Stream) =
|
||||
proc sendFin*(state: BaseStreamState) =
|
||||
if not state.finSent:
|
||||
state.finSent = true
|
||||
discard state.connection.send(stream.id, @[], true)
|
||||
discard state.connection.send(state.stream.id, @[], true)
|
||||
|
||||
proc reset*(state: BaseStreamState, stream: stream.Stream) {.raises: [QuicError].} =
|
||||
state.connection.shutdownStream(stream.id)
|
||||
proc reset*(state: BaseStreamState) {.raises: [QuicError].} =
|
||||
state.connection.shutdownStream(state.stream.id)
|
||||
|
||||
proc switch*(state: BaseStreamState, newStream: StreamState) {.raises: [QuicError].} =
|
||||
let stream = state.stream.valueOr:
|
||||
return
|
||||
stream.switch(newStream)
|
||||
proc switch*(state: BaseStreamState, nextState: StreamState) {.raises: [QuicError].} =
|
||||
state.stream.switch(nextState)
|
||||
|
||||
@@ -12,24 +12,24 @@ proc newClosedStreamState*(
|
||||
): ClosedStreamState =
|
||||
ClosedStreamState(
|
||||
connection: base.connection,
|
||||
stream: base.stream,
|
||||
queue: base.queue,
|
||||
finSent: base.finSent,
|
||||
wasReset: wasReset,
|
||||
)
|
||||
|
||||
method enter*(state: ClosedStreamState, stream: Stream) {.raises: [QuicError].} =
|
||||
procCall enter(StreamState(state), stream)
|
||||
state.stream = Opt.some(stream)
|
||||
method onEnter*(state: ClosedStreamState) {.raises: [QuicError].} =
|
||||
procCall onEnter(BaseStreamState(state))
|
||||
if state.wasReset:
|
||||
state.queue.reset()
|
||||
state.queue.close()
|
||||
if state.wasReset:
|
||||
state.reset(stream)
|
||||
state.reset()
|
||||
else:
|
||||
state.sendFin(stream)
|
||||
stream.closed.fire()
|
||||
state.sendFin()
|
||||
state.stream.closed.fire()
|
||||
|
||||
method leave*(state: ClosedStreamState) =
|
||||
method onLeave*(state: ClosedStreamState) =
|
||||
raiseAssert "ClosedStreamState state should never leave"
|
||||
|
||||
method read*(
|
||||
|
||||
@@ -9,16 +9,10 @@ import ./sendstate
|
||||
|
||||
type OpenStreamState* = ref object of BaseStreamState
|
||||
|
||||
proc newOpenStreamState*(connection: Ngtcp2Connection): OpenStreamState =
|
||||
OpenStreamState(connection: connection, queue: initStreamQueue())
|
||||
|
||||
method enter*(state: OpenStreamState, stream: Stream) =
|
||||
procCall enter(StreamState(state), stream)
|
||||
state.stream = Opt.some(stream)
|
||||
|
||||
method leave*(state: OpenStreamState) =
|
||||
procCall leave(StreamState(state))
|
||||
state.stream = Opt.none(Stream)
|
||||
proc newOpenStreamState*(
|
||||
connection: Ngtcp2Connection, stream: Stream
|
||||
): OpenStreamState =
|
||||
OpenStreamState(connection: connection, stream: stream, queue: initStreamQueue())
|
||||
|
||||
method read*(
|
||||
state: OpenStreamState
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import ../../../errors
|
||||
import ../../../basics
|
||||
import ../../stream
|
||||
import ./queue
|
||||
import ./basestate
|
||||
import ./closestate
|
||||
@@ -9,24 +8,21 @@ type ReceiveStreamState* = ref object of BaseStreamState
|
||||
|
||||
proc newReceiveStreamState*(base: BaseStreamState): ReceiveStreamState =
|
||||
ReceiveStreamState(
|
||||
connection: base.connection, queue: base.queue, finSent: base.finSent
|
||||
connection: base.connection,
|
||||
stream: base.stream,
|
||||
queue: base.queue,
|
||||
finSent: base.finSent,
|
||||
)
|
||||
|
||||
method enter*(state: ReceiveStreamState, stream: Stream) =
|
||||
procCall enter(StreamState(state), stream)
|
||||
state.stream = Opt.some(stream)
|
||||
state.sendFin(stream)
|
||||
|
||||
method leave*(state: ReceiveStreamState) =
|
||||
procCall leave(StreamState(state))
|
||||
state.stream = Opt.none(Stream)
|
||||
method onEnter*(state: ReceiveStreamState) {.raises: [QuicError].} =
|
||||
procCall onEnter(BaseStreamState(state))
|
||||
state.sendFin()
|
||||
|
||||
method read*(
|
||||
state: ReceiveStreamState
|
||||
): Future[seq[byte]] {.async: (raises: [CancelledError, QuicError]).} =
|
||||
# Check for immediate EOF conditions
|
||||
if state.queue.isEOF() and state.queue.incoming.len == 0:
|
||||
state.switch(newClosedStreamState(state))
|
||||
return @[] # Return EOF immediately per RFC 9000 "Data Read" state
|
||||
|
||||
let data = await state.queue.incoming.get()
|
||||
@@ -38,7 +34,6 @@ method read*(
|
||||
|
||||
# Empty data (len == 0) and this is EOF
|
||||
if state.queue.isEOF():
|
||||
state.switch(newClosedStreamState(state))
|
||||
return @[] # Return EOF per RFC 9000
|
||||
|
||||
# Empty data but no EOF; continue reading for more data
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import ../../../errors
|
||||
import ../../../basics
|
||||
import ../../stream
|
||||
import ./queue
|
||||
import ./basestate
|
||||
import ./closestate
|
||||
@@ -8,17 +7,17 @@ import ./closestate
|
||||
type SendStreamState* = ref object of BaseStreamState
|
||||
|
||||
proc newSendStreamState*(base: BaseStreamState): SendStreamState =
|
||||
SendStreamState(connection: base.connection, queue: base.queue, finSent: base.finSent)
|
||||
SendStreamState(
|
||||
connection: base.connection,
|
||||
stream: base.stream,
|
||||
queue: base.queue,
|
||||
finSent: base.finSent,
|
||||
)
|
||||
|
||||
method enter*(state: SendStreamState, stream: Stream) {.raises: [QuicError].} =
|
||||
procCall enter(StreamState(state), stream)
|
||||
state.stream = Opt.some(stream)
|
||||
method onEnter*(state: SendStreamState) {.raises: [QuicError].} =
|
||||
procCall onEnter(BaseStreamState(state))
|
||||
state.queue.close()
|
||||
|
||||
method leave*(state: SendStreamState) =
|
||||
procCall leave(StreamState(state))
|
||||
state.stream = Opt.none(Stream)
|
||||
|
||||
method read*(
|
||||
state: SendStreamState
|
||||
): Future[seq[byte]] {.async: (raises: [CancelledError, QuicError]).} =
|
||||
|
||||
@@ -13,11 +13,11 @@ type
|
||||
|
||||
StreamError* = object of QuicError
|
||||
|
||||
method enter*(state: StreamState, stream: Stream) {.base, raises: [QuicError].} =
|
||||
method onEnter*(state: StreamState) {.base, raises: [QuicError].} =
|
||||
doAssert not state.entered, "states are not reentrant"
|
||||
state.entered = true
|
||||
|
||||
method leave*(state: StreamState) {.base, raises: [QuicError].} =
|
||||
method onLeave*(state: StreamState) {.base, raises: [QuicError].} =
|
||||
discard
|
||||
|
||||
method read*(
|
||||
@@ -62,15 +62,16 @@ method receive*(
|
||||
method expire*(state: StreamState) {.base, raises: [].} =
|
||||
raiseAssert "override method: expire"
|
||||
|
||||
proc newStream*(state: StreamState): Stream {.raises: [QuicError].} =
|
||||
let stream = Stream(state: state, closed: newAsyncEvent(), lock: newAsyncLock())
|
||||
state.enter(stream)
|
||||
stream
|
||||
proc newStream*(): Stream =
|
||||
return Stream(closed: newAsyncEvent(), lock: newAsyncLock())
|
||||
|
||||
proc switch*(stream: Stream, newState: StreamState) {.raises: [QuicError].} =
|
||||
stream.state.leave()
|
||||
stream.state = newState
|
||||
stream.state.enter(stream)
|
||||
proc switch*(stream: Stream, nextState: StreamState) {.raises: [QuicError].} =
|
||||
let currentState = stream.state
|
||||
if not isNil(currentState):
|
||||
currentState.onLeave()
|
||||
|
||||
stream.state = nextState
|
||||
nextState.onEnter()
|
||||
|
||||
proc id*(stream: Stream): int64 =
|
||||
stream.id
|
||||
|
||||
Reference in New Issue
Block a user