feat: dtls connection using mbedtls (#10)

* feat: dtls connection using mbedtls

* refactor: change according to the stun protocol rework

* chore: rename init proc into new

* docs: adds object field comments

* chore: split dtls.nim into two files & renaming

* chore: remove useless code

* chore: remove TODOs as they were addressed with a Stun refactorization

* fix: oversight on dtls.new

* feat: add dtls test

* chore: added license & used pragma on testdtls

* fix: remove usage of deprecated TrackerCounter

* fix: trackers counter

* fix:
- add windows linking library
- make stun stop asynchronous (causing issue on macos)
- store private key and certificate

* chore: renaming test

* docs: update DtlsConn comment

* fix: remove code duplicate

* chore: update comment

* chore: remove duplication mbedtls initialization code in accept/connect and un-expose mbedtls context

* feat: add exception management to dtls_transport

* fix: check address family before handshake

* fix: exhaustive case

* fix: do not create dtlsConn if the address family is not IP

* chore: remove entropy from MbedTLSCtx

* chore: remove asyncspawn of cleanupdtlsconn

* chore: ctx is no longer public

* test: add a test with more than 2 nodes

* chore: started is now useful

* chore: update Dtls.stop

* chore: removed unecessary todos

* docs: add comments on DtlsConn.read and getters

* feat: add tracker for dtls connection and transport

* chore: privatize local and remote certificate

* style: use nph

* fix: remove laddr from dtls_conn (not used)

* style: sort imports

* chore: clean Dtls.stop

* fix: remote address is no longer exposed

* fix: raddr change oversight

* chore: change `verify` name

* chore: changed `sendFuture: Future[void]` into `dataToSend: seq[byte]`

* chore: avoid sequence copy

* chore: change assert message

---------

Co-authored-by: diegomrsantos <diegomrsantos@gmail.com>
This commit is contained in:
Ludovic Chenut
2024-08-13 15:54:03 +02:00
committed by GitHub
parent 81b91e32a9
commit d75e328e77
9 changed files with 612 additions and 14 deletions

View File

@@ -10,3 +10,4 @@
{.used.}
import teststun
import testdtls

83
tests/testdtls.nim Normal file
View File

@@ -0,0 +1,83 @@
# Nim-WebRTC
# Copyright (c) 2024 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
{.used.}
import chronos
import ../webrtc/udp_transport
import ../webrtc/stun/stun_transport
import ../webrtc/dtls/dtls_transport
import ../webrtc/dtls/dtls_connection
import ./asyncunit
suite "DTLS":
teardown:
checkLeaks()
asyncTest "Two DTLS nodes connecting to each other, then sending/receiving data":
let
localAddr1 = initTAddress("127.0.0.1:4444")
localAddr2 = initTAddress("127.0.0.1:5555")
udp1 = UdpTransport.new(localAddr1)
udp2 = UdpTransport.new(localAddr2)
stun1 = Stun.new(udp1)
stun2 = Stun.new(udp2)
dtls1 = Dtls.new(stun1)
dtls2 = Dtls.new(stun2)
conn1Fut = dtls1.accept()
conn2 = await dtls2.connect(localAddr1)
conn1 = await conn1Fut
await conn1.write(@[1'u8, 2, 3, 4])
let seq1 = await conn2.read()
check seq1 == @[1'u8, 2, 3, 4]
await conn2.write(@[5'u8, 6, 7, 8])
let seq2 = await conn1.read()
check seq2 == @[5'u8, 6, 7, 8]
await allFutures(conn1.close(), conn2.close())
await allFutures(dtls1.stop(), dtls2.stop())
await allFutures(stun1.stop(), stun2.stop())
await allFutures(udp1.close(), udp2.close())
asyncTest "Two DTLS nodes connecting to the same DTLS server, sending/receiving data":
let
localAddr1 = initTAddress("127.0.0.1:4444")
localAddr2 = initTAddress("127.0.0.1:5555")
localAddr3 = initTAddress("127.0.0.1:6666")
udp1 = UdpTransport.new(localAddr1)
udp2 = UdpTransport.new(localAddr2)
udp3 = UdpTransport.new(localAddr3)
stun1 = Stun.new(udp1)
stun2 = Stun.new(udp2)
stun3 = Stun.new(udp3)
dtls1 = Dtls.new(stun1)
dtls2 = Dtls.new(stun2)
dtls3 = Dtls.new(stun3)
servConn1Fut = dtls1.accept()
servConn2Fut = dtls1.accept()
clientConn1 = await dtls2.connect(localAddr1)
clientConn2 = await dtls3.connect(localAddr1)
servConn1 = await servConn1Fut
servConn2 = await servConn2Fut
await servConn1.write(@[1'u8, 2, 3, 4])
await servConn2.write(@[5'u8, 6, 7, 8])
await clientConn1.write(@[9'u8, 10, 11, 12])
await clientConn2.write(@[13'u8, 14, 15, 16])
check:
(await clientConn1.read()) == @[1'u8, 2, 3, 4]
(await clientConn2.read()) == @[5'u8, 6, 7, 8]
(await servConn1.read()) == @[9'u8, 10, 11, 12]
(await servConn2.read()) == @[13'u8, 14, 15, 16]
await allFutures(servConn1.close(), servConn2.close())
await allFutures(clientConn1.close(), clientConn2.close())
await allFutures(dtls1.stop(), dtls2.stop(), dtls3.stop())
await allFutures(stun1.stop(), stun2.stop(), stun3.stop())
await allFutures(udp1.close(), udp2.close(), udp3.close())

View File

@@ -55,7 +55,7 @@ suite "Stun message encoding/decoding":
decoded == msg
messageIntegrity.attributeType == AttrMessageIntegrity.uint16
fingerprint.attributeType == AttrFingerprint.uint16
conn.close()
await conn.close()
await udp.close()
asyncTest "Get BindingResponse from BindingRequest + encode & decode":
@@ -82,7 +82,7 @@ suite "Stun message encoding/decoding":
bindingResponse == decoded
messageIntegrity.attributeType == AttrMessageIntegrity.uint16
fingerprint.attributeType == AttrFingerprint.uint16
conn.close()
await conn.close()
await udp.close()
suite "Stun checkForError":
@@ -114,7 +114,7 @@ suite "Stun checkForError":
check:
errorMissUsername.getAttribute(ErrorCode).get().getErrorCode() == ECBadRequest
conn.close()
await conn.close()
await udp.close()
asyncTest "checkForError: UsernameChecker returns false":
@@ -136,5 +136,5 @@ suite "Stun checkForError":
check:
error.getAttribute(ErrorCode).get().getErrorCode() == ECUnauthorized
conn.close()
await conn.close()
await udp.close()

View File

@@ -17,12 +17,15 @@ let lang = getEnv("NIMLANG", "c") # Which backend (c/cpp/js)
let flags = getEnv("NIMFLAGS", "") # Extra flags for the compiler
let verbose = getEnv("V", "") notin ["", "0"]
let cfg =
var cfg =
" --styleCheck:usages --styleCheck:error" &
(if verbose: "" else: " --verbosity:0 --hints:off") &
" --skipParentCfg --skipUserCfg -f" &
" --threads:on --opt:speed"
when defined(windows):
cfg = cfg & " --clib:ws2_32"
import hashes
proc runTest(filename: string) =

View File

@@ -0,0 +1,280 @@
# Nim-WebRTC
# Copyright (c) 2024 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import chronos, chronicles
import
mbedtls/[
ssl, ssl_cookie, ssl_cache, pk, md, ctr_drbg, rsa, x509, x509_crt, bignum, error,
net_sockets, timing,
]
import ../errors, ../stun/[stun_connection], ./dtls_utils
logScope:
topics = "webrtc dtls_conn"
const DtlsConnTracker* = "webrtc.dtls.conn"
type
MbedTLSCtx = object
ssl: mbedtls_ssl_context
config: mbedtls_ssl_config
cookie: mbedtls_ssl_cookie_ctx
cache: mbedtls_ssl_cache_context
timer: mbedtls_timing_delay_context
pkey: mbedtls_pk_context
srvcert: mbedtls_x509_crt
ctr_drbg: mbedtls_ctr_drbg_context
DtlsConn* = ref object
# DtlsConn is a Dtls connection receiving and sending data using
# the underlying Stun Connection
conn*: StunConn # The wrapper protocol Stun Connection
raddr: TransportAddress # Remote address
dataRecv: seq[byte] # data received which will be read by SCTP
dataToSend: seq[byte]
# This sequence is set by synchronous Mbed-TLS `dtlsSend` callbacks
# and sent, if set, once the synchronous functions ends
# Close connection management
closed: bool
closeEvent: AsyncEvent
# Local and Remote certificate, needed by wrapped protocol DataChannel
# and by libp2p
localCert: seq[byte]
remoteCert: seq[byte]
# Mbed-TLS contexts
ctx: MbedTLSCtx
proc getRemoteCertificateCallback(
ctx: pointer, pcert: ptr mbedtls_x509_crt, state: cint, pflags: ptr uint32
): cint {.cdecl.} =
# getRemoteCertificateCallback is the procedure called by mbedtls when
# receiving the remote certificate. It's usually used to verify the validity
# of the certificate, we don't do it. We use this procedure to store the remot
# certificate as it's mandatory to have it for the Prologue of the Noise
# protocol, aswell as the localCertificate.
var self = cast[DtlsConn](ctx)
let cert = pcert[]
self.remoteCert = newSeq[byte](cert.raw.len)
copyMem(addr self.remoteCert[0], cert.raw.p, cert.raw.len)
return 0
proc dtlsSend(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
# dtlsSend is the procedure called by mbedtls when data needs to be sent.
# As the StunConn's write proc is asynchronous and dtlsSend cannot be async,
# we store the message to be sent and it after the end of the function
# (see write or dtlsHanshake for example).
var self = cast[DtlsConn](ctx)
self.dataToSend = newSeq[byte](len)
if len > 0:
copyMem(addr self.dataToSend[0], buf, len)
trace "dtls send", len
result = len.cint
proc dtlsRecv(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
# dtlsRecv is the procedure called by mbedtls when data needs to be received.
# As we cannot asynchronously await for data to be received, we use a data received
# queue. If this queue is empty, we return `MBEDTLS_ERR_SSL_WANT_READ` for us to await
# when the mbedtls proc resumed (see read or dtlsHandshake for example)
let self = cast[DtlsConn](ctx)
if self.dataRecv.len() == 0:
return MBEDTLS_ERR_SSL_WANT_READ
copyMem(buf, addr self.dataRecv[0], self.dataRecv.len())
result = self.dataRecv.len().cint
self.dataRecv = @[]
trace "dtls receive", len, result
proc new*(T: type DtlsConn, conn: StunConn): T =
## Initialize a Dtls Connection
##
var self = T(conn: conn)
self.raddr = conn.raddr
self.closed = false
self.closeEvent = newAsyncEvent()
return self
proc dtlsConnInit(self: DtlsConn) =
mb_ssl_init(self.ctx.ssl)
mb_ssl_config_init(self.ctx.config)
mb_ssl_conf_rng(self.ctx.config, mbedtls_ctr_drbg_random, self.ctx.ctr_drbg)
mb_ssl_conf_read_timeout(self.ctx.config, 10000) # in milliseconds
mb_ssl_conf_ca_chain(self.ctx.config, self.ctx.srvcert.next, nil)
mb_ssl_set_timer_cb(self.ctx.ssl, self.ctx.timer)
mb_ssl_set_verify(self.ctx.ssl, getRemoteCertificateCallback, self)
mb_ssl_set_bio(self.ctx.ssl, cast[pointer](self), dtlsSend, dtlsRecv, nil)
proc acceptInit*(
self: DtlsConn,
ctr_drbg: mbedtls_ctr_drbg_context,
pkey: mbedtls_pk_context,
srvcert: mbedtls_x509_crt,
localCert: seq[byte],
) =
try:
self.ctx.ctr_drbg = ctr_drbg
self.ctx.pkey = pkey
self.ctx.srvcert = srvcert
self.localCert = localCert
self.dtlsConnInit()
mb_ssl_cookie_init(self.ctx.cookie)
mb_ssl_cache_init(self.ctx.cache)
mb_ssl_config_defaults(
self.ctx.config, MBEDTLS_SSL_IS_SERVER, MBEDTLS_SSL_TRANSPORT_DATAGRAM,
MBEDTLS_SSL_PRESET_DEFAULT,
)
mb_ssl_conf_own_cert(self.ctx.config, self.ctx.srvcert, self.ctx.pkey)
mb_ssl_cookie_setup(self.ctx.cookie, mbedtls_ctr_drbg_random, self.ctx.ctr_drbg)
mb_ssl_conf_dtls_cookies(self.ctx.config, addr self.ctx.cookie)
mb_ssl_setup(self.ctx.ssl, self.ctx.config)
mb_ssl_session_reset(self.ctx.ssl)
mb_ssl_conf_authmode(self.ctx.config, MBEDTLS_SSL_VERIFY_OPTIONAL)
except MbedTLSError as exc:
raise newException(WebRtcError, "DTLS - Accept initialization: " & exc.msg, exc)
proc connectInit*(self: DtlsConn, ctr_drbg: mbedtls_ctr_drbg_context) =
try:
self.ctx.ctr_drbg = ctr_drbg
self.ctx.pkey = self.ctx.ctr_drbg.generateKey()
self.ctx.srvcert = self.ctx.ctr_drbg.generateCertificate(self.ctx.pkey)
self.localCert = newSeq[byte](self.ctx.srvcert.raw.len)
copyMem(addr self.localCert[0], self.ctx.srvcert.raw.p, self.ctx.srvcert.raw.len)
self.dtlsConnInit()
mb_ssl_config_defaults(
self.ctx.config, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_DATAGRAM,
MBEDTLS_SSL_PRESET_DEFAULT,
)
mb_ssl_setup(self.ctx.ssl, self.ctx.config)
mb_ssl_conf_authmode(self.ctx.config, MBEDTLS_SSL_VERIFY_OPTIONAL)
except MbedTLSError as exc:
raise newException(WebRtcError, "DTLS - Connect initialization: " & exc.msg, exc)
proc join*(self: DtlsConn) {.async: (raises: [CancelledError]).} =
## Wait for the Dtls Connection to be closed
##
await self.closeEvent.wait()
proc dtlsHandshake*(
self: DtlsConn, isServer: bool
) {.async: (raises: [CancelledError, WebRtcError]).} =
var shouldRead = isServer
try:
while self.ctx.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER:
if shouldRead:
if isServer:
case self.raddr.family
of AddressFamily.IPv4:
mb_ssl_set_client_transport_id(self.ctx.ssl, self.raddr.address_v4)
of AddressFamily.IPv6:
mb_ssl_set_client_transport_id(self.ctx.ssl, self.raddr.address_v6)
else:
raiseAssert("Remote address must be IPv4 or IPv6")
let (data, _) = await self.conn.read()
self.dataRecv = data
self.dataToSend = @[]
let res = mb_ssl_handshake_step(self.ctx.ssl)
if self.dataToSend.len() > 0:
await self.conn.write(self.dataToSend)
self.dataToSend = @[]
shouldRead = false
if res == MBEDTLS_ERR_SSL_WANT_WRITE:
continue
elif res == MBEDTLS_ERR_SSL_WANT_READ:
shouldRead = true
continue
elif res == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED:
mb_ssl_session_reset(self.ctx.ssl)
shouldRead = isServer
continue
elif res != 0:
raise newException(WebRtcError, "DTLS - " & $(res.mbedtls_high_level_strerr()))
except MbedTLSError as exc:
trace "Dtls handshake error", errorMsg = exc.msg
raise newException(WebRtcError, "DTLS - Handshake error", exc)
trackCounter(DtlsConnTracker)
proc close*(self: DtlsConn) {.async: (raises: [CancelledError, WebRtcError]).} =
## Close a Dtls Connection
##
if self.closed:
debug "Try to close an already closed DtlsConn"
return
self.closed = true
self.dataToSend = @[]
let x = mbedtls_ssl_close_notify(addr self.ctx.ssl)
if self.dataToSend.len() > 0:
await self.conn.write(self.dataToSend)
self.dataToSend = @[]
untrackCounter(DtlsConnTracker)
self.closeEvent.fire()
proc write*(self: DtlsConn, msg: seq[byte]) {.async.} =
## Write a message using mbedtls_ssl_write
##
# Mbed-TLS will wrap the message properly and call `dtlsSend` callback.
# `dtlsSend` will store the message to be sent on the higher Stun connection.
if self.closed:
debug "Try to write on an already closed DtlsConn"
return
var buf = msg
try:
self.dataToSend = @[]
let write = mb_ssl_write(self.ctx.ssl, buf)
if self.dataToSend.len() > 0:
await self.conn.write(self.dataToSend)
self.dataToSend = @[]
trace "Dtls write", msgLen = msg.len(), actuallyWrote = write
except MbedTLSError as exc:
trace "Dtls write error", errorMsg = exc.msg
raise exc
proc read*(self: DtlsConn): Future[seq[byte]] {.async.} =
## Read the next received message by StunConn.
## Uncypher it using mbedtls_ssl_read.
##
# First we read the StunConn using the asynchronous `StunConn.read` procedure.
# When we received data, we stored it in `DtlsConn.dataRecv` and call `dtlsRecv`
# callback using mbedtls in order to decypher it.
if self.closed:
debug "Try to read on an already closed DtlsConn"
return
var res = newSeq[byte](8192)
while true:
let (data, _) = await self.conn.read()
self.dataRecv = data
let length =
mbedtls_ssl_read(addr self.ctx.ssl, cast[ptr byte](addr res[0]), res.len().uint)
if length == MBEDTLS_ERR_SSL_WANT_READ:
continue
if length < 0:
raise newException(
WebRtcError, "DTLS - " & $(length.cint.mbedtls_high_level_strerr())
)
res.setLen(length)
return res
proc remoteCertificate*(conn: DtlsConn): seq[byte] =
## Get the remote certificate
##
conn.remoteCert
proc localCertificate*(conn: DtlsConn): seq[byte] =
## Get the local certificate
##
conn.localCert
proc remoteAddress*(conn: DtlsConn): TransportAddress =
## Get the remote address
##
conn.raddr

View File

@@ -0,0 +1,146 @@
# Nim-WebRTC
# Copyright (c) 2024 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import deques, tables, sequtils
import
chronos,
chronicles,
mbedtls/[
ssl, ssl_cookie, ssl_cache, pk, md, entropy, ctr_drbg, rsa, x509, x509_crt, bignum,
error, net_sockets, timing,
]
import
./[dtls_utils, dtls_connection], ../errors, ../stun/[stun_connection, stun_transport]
logScope:
topics = "webrtc dtls"
# Implementation of a DTLS client and a DTLS Server by using the Mbed-TLS library.
# Multiple things here are unintuitive partly because of the callbacks
# used by Mbed-TLS which cannot be async.
const DtlsTransportTracker* = "webrtc.dtls.transport"
type
DtlsConnAndCleanup = object
connection: DtlsConn
cleanup: Future[void].Raising([])
Dtls* = ref object of RootObj
connections: Table[TransportAddress, DtlsConnAndCleanup]
transport: Stun
laddr: TransportAddress
started: bool
ctr_drbg: mbedtls_ctr_drbg_context
entropy: mbedtls_entropy_context
serverPrivKey: mbedtls_pk_context
serverCert: mbedtls_x509_crt
localCert: seq[byte]
proc new*(T: type Dtls, transport: Stun): T =
var self = T(
connections: initTable[TransportAddress, DtlsConnAndCleanup](),
transport: transport,
laddr: transport.laddr,
started: true,
)
mb_ctr_drbg_init(self.ctr_drbg)
mb_entropy_init(self.entropy)
mb_ctr_drbg_seed(self.ctr_drbg, mbedtls_entropy_func, self.entropy, nil, 0)
self.serverPrivKey = self.ctr_drbg.generateKey()
self.serverCert = self.ctr_drbg.generateCertificate(self.serverPrivKey)
self.localCert = newSeq[byte](self.serverCert.raw.len)
copyMem(addr self.localCert[0], self.serverCert.raw.p, self.serverCert.raw.len)
trackCounter(DtlsTransportTracker)
return self
proc stop*(self: Dtls) {.async: (raises: [CancelledError]).} =
## Stop the Dtls transport. Stop every opened connections.
##
if not self.started:
warn "Already stopped"
return
self.started = false
let
allCloses = toSeq(self.connections.values()).mapIt(it.connection.close())
allCleanup = toSeq(self.connections.values()).mapIt(it.cleanup)
await noCancel allFutures(allCloses)
await noCancel allFutures(allCleanup)
untrackCounter(DtlsTransportTracker)
proc localCertificate*(self: Dtls): seq[byte] =
## Local certificate getter
self.localCert
proc localAddress*(self: Dtls): TransportAddress =
self.laddr
proc cleanupDtlsConn(self: Dtls, conn: DtlsConn) {.async: (raises: []).} =
# Waiting for a connection to be closed to remove it from the table
try:
await conn.join()
except CancelledError as exc:
discard
self.connections.del(conn.remoteAddress())
proc accept*(
self: Dtls
): Future[DtlsConn] {.async: (raises: [CancelledError, WebRtcError]).} =
## Accept a Dtls Connection
##
if not self.started:
raise newException(WebRtcError, "DTLS - Dtls transport not started")
var res: DtlsConn
while true:
let
stunConn = await self.transport.accept()
raddr = stunConn.raddr
if raddr.family == AddressFamily.IPv4 or raddr.family == AddressFamily.IPv6:
try:
res = DtlsConn.new(stunConn)
res.acceptInit(
self.ctr_drbg, self.serverPrivKey, self.serverCert, self.localCert
)
await res.dtlsHandshake(true)
self.connections[raddr] =
DtlsConnAndCleanup(connection: res, cleanup: self.cleanupDtlsConn(res))
break
except WebRtcError as exc:
trace "Handshake fails, try accept another connection", raddr, error = exc.msg
self.connections.del(raddr)
return res
proc connect*(
self: Dtls, raddr: TransportAddress
): Future[DtlsConn] {.async: (raises: [CancelledError, WebRtcError]).} =
## Connect to a remote address, creating a Dtls Connection
##
if not self.started:
raise newException(WebRtcError, "DTLS - Dtls transport not started")
if raddr.family != AddressFamily.IPv4 and raddr.family != AddressFamily.IPv6:
raise newException(WebRtcError, "DTLS - Can only connect to IP address")
var res = DtlsConn.new(await self.transport.connect(raddr))
res.connectInit(self.ctr_drbg)
try:
await res.dtlsHandshake(false)
self.connections[raddr] =
DtlsConnAndCleanup(connection: res, cleanup: self.cleanupDtlsConn(res))
except WebRtcError as exc:
trace "Handshake fails", raddr, error = exc.msg
self.connections.del(raddr)
raise exc
return res

View File

@@ -0,0 +1,81 @@
# Nim-WebRTC
# Copyright (c) 2024 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import std/times
import ../errors
import mbedtls/[pk, rsa, ctr_drbg, x509_crt, bignum, md, error]
# This sequence is used for debugging.
const mb_ssl_states* =
@[
"MBEDTLS_SSL_HELLO_REQUEST", "MBEDTLS_SSL_CLIENT_HELLO", "MBEDTLS_SSL_SERVER_HELLO",
"MBEDTLS_SSL_SERVER_CERTIFICATE", "MBEDTLS_SSL_SERVER_KEY_EXCHANGE",
"MBEDTLS_SSL_CERTIFICATE_REQUEST", "MBEDTLS_SSL_SERVER_HELLO_DONE",
"MBEDTLS_SSL_CLIENT_CERTIFICATE", "MBEDTLS_SSL_CLIENT_KEY_EXCHANGE",
"MBEDTLS_SSL_CERTIFICATE_VERIFY", "MBEDTLS_SSL_CLIENT_CHANGE_CIPHER_SPEC",
"MBEDTLS_SSL_CLIENT_FINISHED", "MBEDTLS_SSL_SERVER_CHANGE_CIPHER_SPEC",
"MBEDTLS_SSL_SERVER_FINISHED", "MBEDTLS_SSL_FLUSH_BUFFERS",
"MBEDTLS_SSL_HANDSHAKE_WRAPUP", "MBEDTLS_SSL_NEW_SESSION_TICKET",
"MBEDTLS_SSL_SERVER_HELLO_VERIFY_REQUEST_SENT", "MBEDTLS_SSL_HELLO_RETRY_REQUEST",
"MBEDTLS_SSL_ENCRYPTED_EXTENSIONS", "MBEDTLS_SSL_END_OF_EARLY_DATA",
"MBEDTLS_SSL_CLIENT_CERTIFICATE_VERIFY",
"MBEDTLS_SSL_CLIENT_CCS_AFTER_SERVER_FINISHED",
"MBEDTLS_SSL_CLIENT_CCS_BEFORE_2ND_CLIENT_HELLO",
"MBEDTLS_SSL_SERVER_CCS_AFTER_SERVER_HELLO",
"MBEDTLS_SSL_CLIENT_CCS_AFTER_CLIENT_HELLO",
"MBEDTLS_SSL_SERVER_CCS_AFTER_HELLO_RETRY_REQUEST", "MBEDTLS_SSL_HANDSHAKE_OVER",
"MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET",
"MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET_FLUSH",
]
template generateKey*(random: mbedtls_ctr_drbg_context): mbedtls_pk_context =
var res: mbedtls_pk_context
mb_pk_init(res)
discard mbedtls_pk_setup(addr res, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))
mb_rsa_gen_key(mb_pk_rsa(res), mbedtls_ctr_drbg_random, random, 2048, 65537)
let x = mb_pk_rsa(res)
res
template generateCertificate*(
random: mbedtls_ctr_drbg_context, issuer_key: mbedtls_pk_context
): mbedtls_x509_crt =
let
name = "C=FR,O=Status,CN=webrtc"
time_format =
try:
initTimeFormat("YYYYMMddHHmmss")
except TimeFormatParseError as exc:
raise newException(WebRtcError, "DTLS - " & exc.msg, exc)
time_from = times.now().format(time_format)
time_to = (times.now() + times.years(1)).format(time_format)
var write_cert: mbedtls_x509write_cert
var serial_mpi: mbedtls_mpi
mb_x509write_crt_init(write_cert)
mb_x509write_crt_set_md_alg(write_cert, MBEDTLS_MD_SHA256)
mb_x509write_crt_set_subject_key(write_cert, issuer_key)
mb_x509write_crt_set_issuer_key(write_cert, issuer_key)
mb_x509write_crt_set_subject_name(write_cert, name)
mb_x509write_crt_set_issuer_name(write_cert, name)
mb_x509write_crt_set_validity(write_cert, time_from, time_to)
mb_x509write_crt_set_basic_constraints(write_cert, 0, -1)
mb_x509write_crt_set_subject_key_identifier(write_cert)
mb_x509write_crt_set_authority_key_identifier(write_cert)
mb_mpi_init(serial_mpi)
let serial_hex = mb_mpi_read_string(serial_mpi, 16)
mb_x509write_crt_set_serial(write_cert, serial_mpi)
let buf =
try:
mb_x509write_crt_pem(write_cert, 2048, mbedtls_ctr_drbg_random, random)
except MbedTLSError as exc:
raise newException(WebRtcError, "DTLS - " & exc.msg, exc)
var res: mbedtls_x509_crt
mb_x509_crt_parse(res, buf)
res

View File

@@ -220,14 +220,14 @@ proc join*(self: StunConn) {.async: (raises: [CancelledError]).} =
##
await self.closeEvent.wait()
proc close*(self: StunConn) =
proc close*(self: StunConn) {.async: (raises: []).} =
## Close a Stun Connection
##
if self.closed:
debug "Try to close an already closed StunConn"
return
await self.handlesFut.cancelAndWait()
self.closeEvent.fire()
self.handlesFut.cancelSoon()
self.closed = true
untrackCounter(StunConnectionTracker)

View File

@@ -7,7 +7,7 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
import tables
import tables, sequtils
import chronos, chronicles, bearssl
import stun_connection, stun_message, ../udp_transport
@@ -22,8 +22,9 @@ type
Stun* = ref object
connections: Table[TransportAddress, StunConn]
pendingConn: AsyncQueue[StunConn]
readingLoop: Future[void]
readingLoop: Future[void].Raising([CancelledError])
udp: UdpTransport
laddr*: TransportAddress
usernameProvider: StunUsernameProvider
usernameChecker: StunUsernameChecker
@@ -84,12 +85,14 @@ proc stunReadLoop(self: Stun) {.async: (raises: [CancelledError]).} =
else:
await stunConn.dataRecv.addLast(buf)
proc stop(self: Stun) =
proc stop*(self: Stun) {.async: (raises: []).} =
## Stop the Stun transport and close all the connections
##
for conn in self.connections.values():
conn.close()
self.readingLoop.cancelSoon()
try:
await allFutures(toSeq(self.connections.values()).mapIt(it.close()))
except CancelledError as exc:
discard
await self.readingLoop.cancelAndWait()
untrackCounter(StunTransportTracker)
proc defaultUsernameProvider(): string = ""
@@ -108,12 +111,13 @@ proc new*(
##
var self = T(
udp: udp,
laddr: udp.laddr,
usernameProvider: usernameProvider,
usernameChecker: usernameChecker,
passwordProvider: passwordProvider,
rng: rng
)
self.readingLoop = stunReadLoop()
self.readingLoop = self.stunReadLoop()
self.pendingConn = newAsyncQueue[StunConn](StunMaxPendingConnections)
trackCounter(StunTransportTracker)
return self