mirror of
https://github.com/vacp2p/nim-webrtc.git
synced 2026-01-09 13:57:59 -05:00
feat: dtls connection using mbedtls (#10)
* feat: dtls connection using mbedtls * refactor: change according to the stun protocol rework * chore: rename init proc into new * docs: adds object field comments * chore: split dtls.nim into two files & renaming * chore: remove useless code * chore: remove TODOs as they were addressed with a Stun refactorization * fix: oversight on dtls.new * feat: add dtls test * chore: added license & used pragma on testdtls * fix: remove usage of deprecated TrackerCounter * fix: trackers counter * fix: - add windows linking library - make stun stop asynchronous (causing issue on macos) - store private key and certificate * chore: renaming test * docs: update DtlsConn comment * fix: remove code duplicate * chore: update comment * chore: remove duplication mbedtls initialization code in accept/connect and un-expose mbedtls context * feat: add exception management to dtls_transport * fix: check address family before handshake * fix: exhaustive case * fix: do not create dtlsConn if the address family is not IP * chore: remove entropy from MbedTLSCtx * chore: remove asyncspawn of cleanupdtlsconn * chore: ctx is no longer public * test: add a test with more than 2 nodes * chore: started is now useful * chore: update Dtls.stop * chore: removed unecessary todos * docs: add comments on DtlsConn.read and getters * feat: add tracker for dtls connection and transport * chore: privatize local and remote certificate * style: use nph * fix: remove laddr from dtls_conn (not used) * style: sort imports * chore: clean Dtls.stop * fix: remote address is no longer exposed * fix: raddr change oversight * chore: change `verify` name * chore: changed `sendFuture: Future[void]` into `dataToSend: seq[byte]` * chore: avoid sequence copy * chore: change assert message --------- Co-authored-by: diegomrsantos <diegomrsantos@gmail.com>
This commit is contained in:
@@ -10,3 +10,4 @@
|
||||
{.used.}
|
||||
|
||||
import teststun
|
||||
import testdtls
|
||||
|
||||
83
tests/testdtls.nim
Normal file
83
tests/testdtls.nim
Normal file
@@ -0,0 +1,83 @@
|
||||
# Nim-WebRTC
|
||||
# Copyright (c) 2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
{.used.}
|
||||
|
||||
import chronos
|
||||
import ../webrtc/udp_transport
|
||||
import ../webrtc/stun/stun_transport
|
||||
import ../webrtc/dtls/dtls_transport
|
||||
import ../webrtc/dtls/dtls_connection
|
||||
import ./asyncunit
|
||||
|
||||
suite "DTLS":
|
||||
teardown:
|
||||
checkLeaks()
|
||||
|
||||
asyncTest "Two DTLS nodes connecting to each other, then sending/receiving data":
|
||||
let
|
||||
localAddr1 = initTAddress("127.0.0.1:4444")
|
||||
localAddr2 = initTAddress("127.0.0.1:5555")
|
||||
udp1 = UdpTransport.new(localAddr1)
|
||||
udp2 = UdpTransport.new(localAddr2)
|
||||
stun1 = Stun.new(udp1)
|
||||
stun2 = Stun.new(udp2)
|
||||
dtls1 = Dtls.new(stun1)
|
||||
dtls2 = Dtls.new(stun2)
|
||||
conn1Fut = dtls1.accept()
|
||||
conn2 = await dtls2.connect(localAddr1)
|
||||
conn1 = await conn1Fut
|
||||
|
||||
await conn1.write(@[1'u8, 2, 3, 4])
|
||||
let seq1 = await conn2.read()
|
||||
check seq1 == @[1'u8, 2, 3, 4]
|
||||
|
||||
await conn2.write(@[5'u8, 6, 7, 8])
|
||||
let seq2 = await conn1.read()
|
||||
check seq2 == @[5'u8, 6, 7, 8]
|
||||
await allFutures(conn1.close(), conn2.close())
|
||||
await allFutures(dtls1.stop(), dtls2.stop())
|
||||
await allFutures(stun1.stop(), stun2.stop())
|
||||
await allFutures(udp1.close(), udp2.close())
|
||||
|
||||
asyncTest "Two DTLS nodes connecting to the same DTLS server, sending/receiving data":
|
||||
let
|
||||
localAddr1 = initTAddress("127.0.0.1:4444")
|
||||
localAddr2 = initTAddress("127.0.0.1:5555")
|
||||
localAddr3 = initTAddress("127.0.0.1:6666")
|
||||
udp1 = UdpTransport.new(localAddr1)
|
||||
udp2 = UdpTransport.new(localAddr2)
|
||||
udp3 = UdpTransport.new(localAddr3)
|
||||
stun1 = Stun.new(udp1)
|
||||
stun2 = Stun.new(udp2)
|
||||
stun3 = Stun.new(udp3)
|
||||
dtls1 = Dtls.new(stun1)
|
||||
dtls2 = Dtls.new(stun2)
|
||||
dtls3 = Dtls.new(stun3)
|
||||
servConn1Fut = dtls1.accept()
|
||||
servConn2Fut = dtls1.accept()
|
||||
clientConn1 = await dtls2.connect(localAddr1)
|
||||
clientConn2 = await dtls3.connect(localAddr1)
|
||||
servConn1 = await servConn1Fut
|
||||
servConn2 = await servConn2Fut
|
||||
|
||||
await servConn1.write(@[1'u8, 2, 3, 4])
|
||||
await servConn2.write(@[5'u8, 6, 7, 8])
|
||||
await clientConn1.write(@[9'u8, 10, 11, 12])
|
||||
await clientConn2.write(@[13'u8, 14, 15, 16])
|
||||
check:
|
||||
(await clientConn1.read()) == @[1'u8, 2, 3, 4]
|
||||
(await clientConn2.read()) == @[5'u8, 6, 7, 8]
|
||||
(await servConn1.read()) == @[9'u8, 10, 11, 12]
|
||||
(await servConn2.read()) == @[13'u8, 14, 15, 16]
|
||||
await allFutures(servConn1.close(), servConn2.close())
|
||||
await allFutures(clientConn1.close(), clientConn2.close())
|
||||
await allFutures(dtls1.stop(), dtls2.stop(), dtls3.stop())
|
||||
await allFutures(stun1.stop(), stun2.stop(), stun3.stop())
|
||||
await allFutures(udp1.close(), udp2.close(), udp3.close())
|
||||
@@ -55,7 +55,7 @@ suite "Stun message encoding/decoding":
|
||||
decoded == msg
|
||||
messageIntegrity.attributeType == AttrMessageIntegrity.uint16
|
||||
fingerprint.attributeType == AttrFingerprint.uint16
|
||||
conn.close()
|
||||
await conn.close()
|
||||
await udp.close()
|
||||
|
||||
asyncTest "Get BindingResponse from BindingRequest + encode & decode":
|
||||
@@ -82,7 +82,7 @@ suite "Stun message encoding/decoding":
|
||||
bindingResponse == decoded
|
||||
messageIntegrity.attributeType == AttrMessageIntegrity.uint16
|
||||
fingerprint.attributeType == AttrFingerprint.uint16
|
||||
conn.close()
|
||||
await conn.close()
|
||||
await udp.close()
|
||||
|
||||
suite "Stun checkForError":
|
||||
@@ -114,7 +114,7 @@ suite "Stun checkForError":
|
||||
|
||||
check:
|
||||
errorMissUsername.getAttribute(ErrorCode).get().getErrorCode() == ECBadRequest
|
||||
conn.close()
|
||||
await conn.close()
|
||||
await udp.close()
|
||||
|
||||
asyncTest "checkForError: UsernameChecker returns false":
|
||||
@@ -136,5 +136,5 @@ suite "Stun checkForError":
|
||||
|
||||
check:
|
||||
error.getAttribute(ErrorCode).get().getErrorCode() == ECUnauthorized
|
||||
conn.close()
|
||||
await conn.close()
|
||||
await udp.close()
|
||||
|
||||
@@ -17,12 +17,15 @@ let lang = getEnv("NIMLANG", "c") # Which backend (c/cpp/js)
|
||||
let flags = getEnv("NIMFLAGS", "") # Extra flags for the compiler
|
||||
let verbose = getEnv("V", "") notin ["", "0"]
|
||||
|
||||
let cfg =
|
||||
var cfg =
|
||||
" --styleCheck:usages --styleCheck:error" &
|
||||
(if verbose: "" else: " --verbosity:0 --hints:off") &
|
||||
" --skipParentCfg --skipUserCfg -f" &
|
||||
" --threads:on --opt:speed"
|
||||
|
||||
when defined(windows):
|
||||
cfg = cfg & " --clib:ws2_32"
|
||||
|
||||
import hashes
|
||||
|
||||
proc runTest(filename: string) =
|
||||
|
||||
280
webrtc/dtls/dtls_connection.nim
Normal file
280
webrtc/dtls/dtls_connection.nim
Normal file
@@ -0,0 +1,280 @@
|
||||
# Nim-WebRTC
|
||||
# Copyright (c) 2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import chronos, chronicles
|
||||
import
|
||||
mbedtls/[
|
||||
ssl, ssl_cookie, ssl_cache, pk, md, ctr_drbg, rsa, x509, x509_crt, bignum, error,
|
||||
net_sockets, timing,
|
||||
]
|
||||
import ../errors, ../stun/[stun_connection], ./dtls_utils
|
||||
|
||||
logScope:
|
||||
topics = "webrtc dtls_conn"
|
||||
|
||||
const DtlsConnTracker* = "webrtc.dtls.conn"
|
||||
|
||||
type
|
||||
MbedTLSCtx = object
|
||||
ssl: mbedtls_ssl_context
|
||||
config: mbedtls_ssl_config
|
||||
cookie: mbedtls_ssl_cookie_ctx
|
||||
cache: mbedtls_ssl_cache_context
|
||||
timer: mbedtls_timing_delay_context
|
||||
pkey: mbedtls_pk_context
|
||||
srvcert: mbedtls_x509_crt
|
||||
ctr_drbg: mbedtls_ctr_drbg_context
|
||||
|
||||
DtlsConn* = ref object
|
||||
# DtlsConn is a Dtls connection receiving and sending data using
|
||||
# the underlying Stun Connection
|
||||
conn*: StunConn # The wrapper protocol Stun Connection
|
||||
raddr: TransportAddress # Remote address
|
||||
dataRecv: seq[byte] # data received which will be read by SCTP
|
||||
dataToSend: seq[byte]
|
||||
# This sequence is set by synchronous Mbed-TLS `dtlsSend` callbacks
|
||||
# and sent, if set, once the synchronous functions ends
|
||||
|
||||
# Close connection management
|
||||
closed: bool
|
||||
closeEvent: AsyncEvent
|
||||
|
||||
# Local and Remote certificate, needed by wrapped protocol DataChannel
|
||||
# and by libp2p
|
||||
localCert: seq[byte]
|
||||
remoteCert: seq[byte]
|
||||
|
||||
# Mbed-TLS contexts
|
||||
ctx: MbedTLSCtx
|
||||
|
||||
proc getRemoteCertificateCallback(
|
||||
ctx: pointer, pcert: ptr mbedtls_x509_crt, state: cint, pflags: ptr uint32
|
||||
): cint {.cdecl.} =
|
||||
# getRemoteCertificateCallback is the procedure called by mbedtls when
|
||||
# receiving the remote certificate. It's usually used to verify the validity
|
||||
# of the certificate, we don't do it. We use this procedure to store the remot
|
||||
# certificate as it's mandatory to have it for the Prologue of the Noise
|
||||
# protocol, aswell as the localCertificate.
|
||||
var self = cast[DtlsConn](ctx)
|
||||
let cert = pcert[]
|
||||
|
||||
self.remoteCert = newSeq[byte](cert.raw.len)
|
||||
copyMem(addr self.remoteCert[0], cert.raw.p, cert.raw.len)
|
||||
return 0
|
||||
|
||||
proc dtlsSend(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
|
||||
# dtlsSend is the procedure called by mbedtls when data needs to be sent.
|
||||
# As the StunConn's write proc is asynchronous and dtlsSend cannot be async,
|
||||
# we store the message to be sent and it after the end of the function
|
||||
# (see write or dtlsHanshake for example).
|
||||
var self = cast[DtlsConn](ctx)
|
||||
self.dataToSend = newSeq[byte](len)
|
||||
if len > 0:
|
||||
copyMem(addr self.dataToSend[0], buf, len)
|
||||
trace "dtls send", len
|
||||
result = len.cint
|
||||
|
||||
proc dtlsRecv(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
|
||||
# dtlsRecv is the procedure called by mbedtls when data needs to be received.
|
||||
# As we cannot asynchronously await for data to be received, we use a data received
|
||||
# queue. If this queue is empty, we return `MBEDTLS_ERR_SSL_WANT_READ` for us to await
|
||||
# when the mbedtls proc resumed (see read or dtlsHandshake for example)
|
||||
let self = cast[DtlsConn](ctx)
|
||||
if self.dataRecv.len() == 0:
|
||||
return MBEDTLS_ERR_SSL_WANT_READ
|
||||
|
||||
copyMem(buf, addr self.dataRecv[0], self.dataRecv.len())
|
||||
result = self.dataRecv.len().cint
|
||||
self.dataRecv = @[]
|
||||
trace "dtls receive", len, result
|
||||
|
||||
proc new*(T: type DtlsConn, conn: StunConn): T =
|
||||
## Initialize a Dtls Connection
|
||||
##
|
||||
var self = T(conn: conn)
|
||||
self.raddr = conn.raddr
|
||||
self.closed = false
|
||||
self.closeEvent = newAsyncEvent()
|
||||
return self
|
||||
|
||||
proc dtlsConnInit(self: DtlsConn) =
|
||||
mb_ssl_init(self.ctx.ssl)
|
||||
mb_ssl_config_init(self.ctx.config)
|
||||
mb_ssl_conf_rng(self.ctx.config, mbedtls_ctr_drbg_random, self.ctx.ctr_drbg)
|
||||
mb_ssl_conf_read_timeout(self.ctx.config, 10000) # in milliseconds
|
||||
mb_ssl_conf_ca_chain(self.ctx.config, self.ctx.srvcert.next, nil)
|
||||
mb_ssl_set_timer_cb(self.ctx.ssl, self.ctx.timer)
|
||||
mb_ssl_set_verify(self.ctx.ssl, getRemoteCertificateCallback, self)
|
||||
mb_ssl_set_bio(self.ctx.ssl, cast[pointer](self), dtlsSend, dtlsRecv, nil)
|
||||
|
||||
proc acceptInit*(
|
||||
self: DtlsConn,
|
||||
ctr_drbg: mbedtls_ctr_drbg_context,
|
||||
pkey: mbedtls_pk_context,
|
||||
srvcert: mbedtls_x509_crt,
|
||||
localCert: seq[byte],
|
||||
) =
|
||||
try:
|
||||
self.ctx.ctr_drbg = ctr_drbg
|
||||
self.ctx.pkey = pkey
|
||||
self.ctx.srvcert = srvcert
|
||||
self.localCert = localCert
|
||||
|
||||
self.dtlsConnInit()
|
||||
mb_ssl_cookie_init(self.ctx.cookie)
|
||||
mb_ssl_cache_init(self.ctx.cache)
|
||||
mb_ssl_config_defaults(
|
||||
self.ctx.config, MBEDTLS_SSL_IS_SERVER, MBEDTLS_SSL_TRANSPORT_DATAGRAM,
|
||||
MBEDTLS_SSL_PRESET_DEFAULT,
|
||||
)
|
||||
mb_ssl_conf_own_cert(self.ctx.config, self.ctx.srvcert, self.ctx.pkey)
|
||||
mb_ssl_cookie_setup(self.ctx.cookie, mbedtls_ctr_drbg_random, self.ctx.ctr_drbg)
|
||||
mb_ssl_conf_dtls_cookies(self.ctx.config, addr self.ctx.cookie)
|
||||
mb_ssl_setup(self.ctx.ssl, self.ctx.config)
|
||||
mb_ssl_session_reset(self.ctx.ssl)
|
||||
mb_ssl_conf_authmode(self.ctx.config, MBEDTLS_SSL_VERIFY_OPTIONAL)
|
||||
except MbedTLSError as exc:
|
||||
raise newException(WebRtcError, "DTLS - Accept initialization: " & exc.msg, exc)
|
||||
|
||||
proc connectInit*(self: DtlsConn, ctr_drbg: mbedtls_ctr_drbg_context) =
|
||||
try:
|
||||
self.ctx.ctr_drbg = ctr_drbg
|
||||
self.ctx.pkey = self.ctx.ctr_drbg.generateKey()
|
||||
self.ctx.srvcert = self.ctx.ctr_drbg.generateCertificate(self.ctx.pkey)
|
||||
self.localCert = newSeq[byte](self.ctx.srvcert.raw.len)
|
||||
copyMem(addr self.localCert[0], self.ctx.srvcert.raw.p, self.ctx.srvcert.raw.len)
|
||||
|
||||
self.dtlsConnInit()
|
||||
mb_ssl_config_defaults(
|
||||
self.ctx.config, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_DATAGRAM,
|
||||
MBEDTLS_SSL_PRESET_DEFAULT,
|
||||
)
|
||||
mb_ssl_setup(self.ctx.ssl, self.ctx.config)
|
||||
mb_ssl_conf_authmode(self.ctx.config, MBEDTLS_SSL_VERIFY_OPTIONAL)
|
||||
except MbedTLSError as exc:
|
||||
raise newException(WebRtcError, "DTLS - Connect initialization: " & exc.msg, exc)
|
||||
|
||||
proc join*(self: DtlsConn) {.async: (raises: [CancelledError]).} =
|
||||
## Wait for the Dtls Connection to be closed
|
||||
##
|
||||
await self.closeEvent.wait()
|
||||
|
||||
proc dtlsHandshake*(
|
||||
self: DtlsConn, isServer: bool
|
||||
) {.async: (raises: [CancelledError, WebRtcError]).} =
|
||||
var shouldRead = isServer
|
||||
try:
|
||||
while self.ctx.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER:
|
||||
if shouldRead:
|
||||
if isServer:
|
||||
case self.raddr.family
|
||||
of AddressFamily.IPv4:
|
||||
mb_ssl_set_client_transport_id(self.ctx.ssl, self.raddr.address_v4)
|
||||
of AddressFamily.IPv6:
|
||||
mb_ssl_set_client_transport_id(self.ctx.ssl, self.raddr.address_v6)
|
||||
else:
|
||||
raiseAssert("Remote address must be IPv4 or IPv6")
|
||||
let (data, _) = await self.conn.read()
|
||||
self.dataRecv = data
|
||||
self.dataToSend = @[]
|
||||
let res = mb_ssl_handshake_step(self.ctx.ssl)
|
||||
if self.dataToSend.len() > 0:
|
||||
await self.conn.write(self.dataToSend)
|
||||
self.dataToSend = @[]
|
||||
shouldRead = false
|
||||
if res == MBEDTLS_ERR_SSL_WANT_WRITE:
|
||||
continue
|
||||
elif res == MBEDTLS_ERR_SSL_WANT_READ:
|
||||
shouldRead = true
|
||||
continue
|
||||
elif res == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED:
|
||||
mb_ssl_session_reset(self.ctx.ssl)
|
||||
shouldRead = isServer
|
||||
continue
|
||||
elif res != 0:
|
||||
raise newException(WebRtcError, "DTLS - " & $(res.mbedtls_high_level_strerr()))
|
||||
except MbedTLSError as exc:
|
||||
trace "Dtls handshake error", errorMsg = exc.msg
|
||||
raise newException(WebRtcError, "DTLS - Handshake error", exc)
|
||||
trackCounter(DtlsConnTracker)
|
||||
|
||||
proc close*(self: DtlsConn) {.async: (raises: [CancelledError, WebRtcError]).} =
|
||||
## Close a Dtls Connection
|
||||
##
|
||||
if self.closed:
|
||||
debug "Try to close an already closed DtlsConn"
|
||||
return
|
||||
self.closed = true
|
||||
self.dataToSend = @[]
|
||||
let x = mbedtls_ssl_close_notify(addr self.ctx.ssl)
|
||||
if self.dataToSend.len() > 0:
|
||||
await self.conn.write(self.dataToSend)
|
||||
self.dataToSend = @[]
|
||||
untrackCounter(DtlsConnTracker)
|
||||
self.closeEvent.fire()
|
||||
|
||||
proc write*(self: DtlsConn, msg: seq[byte]) {.async.} =
|
||||
## Write a message using mbedtls_ssl_write
|
||||
##
|
||||
# Mbed-TLS will wrap the message properly and call `dtlsSend` callback.
|
||||
# `dtlsSend` will store the message to be sent on the higher Stun connection.
|
||||
if self.closed:
|
||||
debug "Try to write on an already closed DtlsConn"
|
||||
return
|
||||
var buf = msg
|
||||
try:
|
||||
self.dataToSend = @[]
|
||||
let write = mb_ssl_write(self.ctx.ssl, buf)
|
||||
if self.dataToSend.len() > 0:
|
||||
await self.conn.write(self.dataToSend)
|
||||
self.dataToSend = @[]
|
||||
trace "Dtls write", msgLen = msg.len(), actuallyWrote = write
|
||||
except MbedTLSError as exc:
|
||||
trace "Dtls write error", errorMsg = exc.msg
|
||||
raise exc
|
||||
|
||||
proc read*(self: DtlsConn): Future[seq[byte]] {.async.} =
|
||||
## Read the next received message by StunConn.
|
||||
## Uncypher it using mbedtls_ssl_read.
|
||||
##
|
||||
# First we read the StunConn using the asynchronous `StunConn.read` procedure.
|
||||
# When we received data, we stored it in `DtlsConn.dataRecv` and call `dtlsRecv`
|
||||
# callback using mbedtls in order to decypher it.
|
||||
if self.closed:
|
||||
debug "Try to read on an already closed DtlsConn"
|
||||
return
|
||||
var res = newSeq[byte](8192)
|
||||
while true:
|
||||
let (data, _) = await self.conn.read()
|
||||
self.dataRecv = data
|
||||
let length =
|
||||
mbedtls_ssl_read(addr self.ctx.ssl, cast[ptr byte](addr res[0]), res.len().uint)
|
||||
if length == MBEDTLS_ERR_SSL_WANT_READ:
|
||||
continue
|
||||
if length < 0:
|
||||
raise newException(
|
||||
WebRtcError, "DTLS - " & $(length.cint.mbedtls_high_level_strerr())
|
||||
)
|
||||
res.setLen(length)
|
||||
return res
|
||||
|
||||
proc remoteCertificate*(conn: DtlsConn): seq[byte] =
|
||||
## Get the remote certificate
|
||||
##
|
||||
conn.remoteCert
|
||||
|
||||
proc localCertificate*(conn: DtlsConn): seq[byte] =
|
||||
## Get the local certificate
|
||||
##
|
||||
conn.localCert
|
||||
|
||||
proc remoteAddress*(conn: DtlsConn): TransportAddress =
|
||||
## Get the remote address
|
||||
##
|
||||
conn.raddr
|
||||
146
webrtc/dtls/dtls_transport.nim
Normal file
146
webrtc/dtls/dtls_transport.nim
Normal file
@@ -0,0 +1,146 @@
|
||||
# Nim-WebRTC
|
||||
# Copyright (c) 2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import deques, tables, sequtils
|
||||
import
|
||||
chronos,
|
||||
chronicles,
|
||||
mbedtls/[
|
||||
ssl, ssl_cookie, ssl_cache, pk, md, entropy, ctr_drbg, rsa, x509, x509_crt, bignum,
|
||||
error, net_sockets, timing,
|
||||
]
|
||||
import
|
||||
./[dtls_utils, dtls_connection], ../errors, ../stun/[stun_connection, stun_transport]
|
||||
|
||||
logScope:
|
||||
topics = "webrtc dtls"
|
||||
|
||||
# Implementation of a DTLS client and a DTLS Server by using the Mbed-TLS library.
|
||||
# Multiple things here are unintuitive partly because of the callbacks
|
||||
# used by Mbed-TLS which cannot be async.
|
||||
|
||||
const DtlsTransportTracker* = "webrtc.dtls.transport"
|
||||
|
||||
type
|
||||
DtlsConnAndCleanup = object
|
||||
connection: DtlsConn
|
||||
cleanup: Future[void].Raising([])
|
||||
|
||||
Dtls* = ref object of RootObj
|
||||
connections: Table[TransportAddress, DtlsConnAndCleanup]
|
||||
transport: Stun
|
||||
laddr: TransportAddress
|
||||
started: bool
|
||||
ctr_drbg: mbedtls_ctr_drbg_context
|
||||
entropy: mbedtls_entropy_context
|
||||
|
||||
serverPrivKey: mbedtls_pk_context
|
||||
serverCert: mbedtls_x509_crt
|
||||
localCert: seq[byte]
|
||||
|
||||
proc new*(T: type Dtls, transport: Stun): T =
|
||||
var self = T(
|
||||
connections: initTable[TransportAddress, DtlsConnAndCleanup](),
|
||||
transport: transport,
|
||||
laddr: transport.laddr,
|
||||
started: true,
|
||||
)
|
||||
|
||||
mb_ctr_drbg_init(self.ctr_drbg)
|
||||
mb_entropy_init(self.entropy)
|
||||
mb_ctr_drbg_seed(self.ctr_drbg, mbedtls_entropy_func, self.entropy, nil, 0)
|
||||
|
||||
self.serverPrivKey = self.ctr_drbg.generateKey()
|
||||
self.serverCert = self.ctr_drbg.generateCertificate(self.serverPrivKey)
|
||||
self.localCert = newSeq[byte](self.serverCert.raw.len)
|
||||
copyMem(addr self.localCert[0], self.serverCert.raw.p, self.serverCert.raw.len)
|
||||
trackCounter(DtlsTransportTracker)
|
||||
return self
|
||||
|
||||
proc stop*(self: Dtls) {.async: (raises: [CancelledError]).} =
|
||||
## Stop the Dtls transport. Stop every opened connections.
|
||||
##
|
||||
if not self.started:
|
||||
warn "Already stopped"
|
||||
return
|
||||
|
||||
self.started = false
|
||||
let
|
||||
allCloses = toSeq(self.connections.values()).mapIt(it.connection.close())
|
||||
allCleanup = toSeq(self.connections.values()).mapIt(it.cleanup)
|
||||
await noCancel allFutures(allCloses)
|
||||
await noCancel allFutures(allCleanup)
|
||||
untrackCounter(DtlsTransportTracker)
|
||||
|
||||
proc localCertificate*(self: Dtls): seq[byte] =
|
||||
## Local certificate getter
|
||||
self.localCert
|
||||
|
||||
proc localAddress*(self: Dtls): TransportAddress =
|
||||
self.laddr
|
||||
|
||||
proc cleanupDtlsConn(self: Dtls, conn: DtlsConn) {.async: (raises: []).} =
|
||||
# Waiting for a connection to be closed to remove it from the table
|
||||
try:
|
||||
await conn.join()
|
||||
except CancelledError as exc:
|
||||
discard
|
||||
|
||||
self.connections.del(conn.remoteAddress())
|
||||
|
||||
proc accept*(
|
||||
self: Dtls
|
||||
): Future[DtlsConn] {.async: (raises: [CancelledError, WebRtcError]).} =
|
||||
## Accept a Dtls Connection
|
||||
##
|
||||
if not self.started:
|
||||
raise newException(WebRtcError, "DTLS - Dtls transport not started")
|
||||
var res: DtlsConn
|
||||
|
||||
while true:
|
||||
let
|
||||
stunConn = await self.transport.accept()
|
||||
raddr = stunConn.raddr
|
||||
if raddr.family == AddressFamily.IPv4 or raddr.family == AddressFamily.IPv6:
|
||||
try:
|
||||
res = DtlsConn.new(stunConn)
|
||||
res.acceptInit(
|
||||
self.ctr_drbg, self.serverPrivKey, self.serverCert, self.localCert
|
||||
)
|
||||
await res.dtlsHandshake(true)
|
||||
self.connections[raddr] =
|
||||
DtlsConnAndCleanup(connection: res, cleanup: self.cleanupDtlsConn(res))
|
||||
break
|
||||
except WebRtcError as exc:
|
||||
trace "Handshake fails, try accept another connection", raddr, error = exc.msg
|
||||
self.connections.del(raddr)
|
||||
return res
|
||||
|
||||
proc connect*(
|
||||
self: Dtls, raddr: TransportAddress
|
||||
): Future[DtlsConn] {.async: (raises: [CancelledError, WebRtcError]).} =
|
||||
## Connect to a remote address, creating a Dtls Connection
|
||||
##
|
||||
if not self.started:
|
||||
raise newException(WebRtcError, "DTLS - Dtls transport not started")
|
||||
if raddr.family != AddressFamily.IPv4 and raddr.family != AddressFamily.IPv6:
|
||||
raise newException(WebRtcError, "DTLS - Can only connect to IP address")
|
||||
var res = DtlsConn.new(await self.transport.connect(raddr))
|
||||
res.connectInit(self.ctr_drbg)
|
||||
|
||||
try:
|
||||
await res.dtlsHandshake(false)
|
||||
self.connections[raddr] =
|
||||
DtlsConnAndCleanup(connection: res, cleanup: self.cleanupDtlsConn(res))
|
||||
except WebRtcError as exc:
|
||||
trace "Handshake fails", raddr, error = exc.msg
|
||||
self.connections.del(raddr)
|
||||
raise exc
|
||||
|
||||
return res
|
||||
81
webrtc/dtls/dtls_utils.nim
Normal file
81
webrtc/dtls/dtls_utils.nim
Normal file
@@ -0,0 +1,81 @@
|
||||
# Nim-WebRTC
|
||||
# Copyright (c) 2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import std/times
|
||||
import ../errors
|
||||
|
||||
import mbedtls/[pk, rsa, ctr_drbg, x509_crt, bignum, md, error]
|
||||
|
||||
# This sequence is used for debugging.
|
||||
const mb_ssl_states* =
|
||||
@[
|
||||
"MBEDTLS_SSL_HELLO_REQUEST", "MBEDTLS_SSL_CLIENT_HELLO", "MBEDTLS_SSL_SERVER_HELLO",
|
||||
"MBEDTLS_SSL_SERVER_CERTIFICATE", "MBEDTLS_SSL_SERVER_KEY_EXCHANGE",
|
||||
"MBEDTLS_SSL_CERTIFICATE_REQUEST", "MBEDTLS_SSL_SERVER_HELLO_DONE",
|
||||
"MBEDTLS_SSL_CLIENT_CERTIFICATE", "MBEDTLS_SSL_CLIENT_KEY_EXCHANGE",
|
||||
"MBEDTLS_SSL_CERTIFICATE_VERIFY", "MBEDTLS_SSL_CLIENT_CHANGE_CIPHER_SPEC",
|
||||
"MBEDTLS_SSL_CLIENT_FINISHED", "MBEDTLS_SSL_SERVER_CHANGE_CIPHER_SPEC",
|
||||
"MBEDTLS_SSL_SERVER_FINISHED", "MBEDTLS_SSL_FLUSH_BUFFERS",
|
||||
"MBEDTLS_SSL_HANDSHAKE_WRAPUP", "MBEDTLS_SSL_NEW_SESSION_TICKET",
|
||||
"MBEDTLS_SSL_SERVER_HELLO_VERIFY_REQUEST_SENT", "MBEDTLS_SSL_HELLO_RETRY_REQUEST",
|
||||
"MBEDTLS_SSL_ENCRYPTED_EXTENSIONS", "MBEDTLS_SSL_END_OF_EARLY_DATA",
|
||||
"MBEDTLS_SSL_CLIENT_CERTIFICATE_VERIFY",
|
||||
"MBEDTLS_SSL_CLIENT_CCS_AFTER_SERVER_FINISHED",
|
||||
"MBEDTLS_SSL_CLIENT_CCS_BEFORE_2ND_CLIENT_HELLO",
|
||||
"MBEDTLS_SSL_SERVER_CCS_AFTER_SERVER_HELLO",
|
||||
"MBEDTLS_SSL_CLIENT_CCS_AFTER_CLIENT_HELLO",
|
||||
"MBEDTLS_SSL_SERVER_CCS_AFTER_HELLO_RETRY_REQUEST", "MBEDTLS_SSL_HANDSHAKE_OVER",
|
||||
"MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET",
|
||||
"MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET_FLUSH",
|
||||
]
|
||||
|
||||
template generateKey*(random: mbedtls_ctr_drbg_context): mbedtls_pk_context =
|
||||
var res: mbedtls_pk_context
|
||||
mb_pk_init(res)
|
||||
discard mbedtls_pk_setup(addr res, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))
|
||||
mb_rsa_gen_key(mb_pk_rsa(res), mbedtls_ctr_drbg_random, random, 2048, 65537)
|
||||
let x = mb_pk_rsa(res)
|
||||
res
|
||||
|
||||
template generateCertificate*(
|
||||
random: mbedtls_ctr_drbg_context, issuer_key: mbedtls_pk_context
|
||||
): mbedtls_x509_crt =
|
||||
let
|
||||
name = "C=FR,O=Status,CN=webrtc"
|
||||
time_format =
|
||||
try:
|
||||
initTimeFormat("YYYYMMddHHmmss")
|
||||
except TimeFormatParseError as exc:
|
||||
raise newException(WebRtcError, "DTLS - " & exc.msg, exc)
|
||||
time_from = times.now().format(time_format)
|
||||
time_to = (times.now() + times.years(1)).format(time_format)
|
||||
|
||||
var write_cert: mbedtls_x509write_cert
|
||||
var serial_mpi: mbedtls_mpi
|
||||
mb_x509write_crt_init(write_cert)
|
||||
mb_x509write_crt_set_md_alg(write_cert, MBEDTLS_MD_SHA256)
|
||||
mb_x509write_crt_set_subject_key(write_cert, issuer_key)
|
||||
mb_x509write_crt_set_issuer_key(write_cert, issuer_key)
|
||||
mb_x509write_crt_set_subject_name(write_cert, name)
|
||||
mb_x509write_crt_set_issuer_name(write_cert, name)
|
||||
mb_x509write_crt_set_validity(write_cert, time_from, time_to)
|
||||
mb_x509write_crt_set_basic_constraints(write_cert, 0, -1)
|
||||
mb_x509write_crt_set_subject_key_identifier(write_cert)
|
||||
mb_x509write_crt_set_authority_key_identifier(write_cert)
|
||||
mb_mpi_init(serial_mpi)
|
||||
let serial_hex = mb_mpi_read_string(serial_mpi, 16)
|
||||
mb_x509write_crt_set_serial(write_cert, serial_mpi)
|
||||
let buf =
|
||||
try:
|
||||
mb_x509write_crt_pem(write_cert, 2048, mbedtls_ctr_drbg_random, random)
|
||||
except MbedTLSError as exc:
|
||||
raise newException(WebRtcError, "DTLS - " & exc.msg, exc)
|
||||
var res: mbedtls_x509_crt
|
||||
mb_x509_crt_parse(res, buf)
|
||||
res
|
||||
@@ -220,14 +220,14 @@ proc join*(self: StunConn) {.async: (raises: [CancelledError]).} =
|
||||
##
|
||||
await self.closeEvent.wait()
|
||||
|
||||
proc close*(self: StunConn) =
|
||||
proc close*(self: StunConn) {.async: (raises: []).} =
|
||||
## Close a Stun Connection
|
||||
##
|
||||
if self.closed:
|
||||
debug "Try to close an already closed StunConn"
|
||||
return
|
||||
await self.handlesFut.cancelAndWait()
|
||||
self.closeEvent.fire()
|
||||
self.handlesFut.cancelSoon()
|
||||
self.closed = true
|
||||
untrackCounter(StunConnectionTracker)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import tables
|
||||
import tables, sequtils
|
||||
import chronos, chronicles, bearssl
|
||||
import stun_connection, stun_message, ../udp_transport
|
||||
|
||||
@@ -22,8 +22,9 @@ type
|
||||
Stun* = ref object
|
||||
connections: Table[TransportAddress, StunConn]
|
||||
pendingConn: AsyncQueue[StunConn]
|
||||
readingLoop: Future[void]
|
||||
readingLoop: Future[void].Raising([CancelledError])
|
||||
udp: UdpTransport
|
||||
laddr*: TransportAddress
|
||||
|
||||
usernameProvider: StunUsernameProvider
|
||||
usernameChecker: StunUsernameChecker
|
||||
@@ -84,12 +85,14 @@ proc stunReadLoop(self: Stun) {.async: (raises: [CancelledError]).} =
|
||||
else:
|
||||
await stunConn.dataRecv.addLast(buf)
|
||||
|
||||
proc stop(self: Stun) =
|
||||
proc stop*(self: Stun) {.async: (raises: []).} =
|
||||
## Stop the Stun transport and close all the connections
|
||||
##
|
||||
for conn in self.connections.values():
|
||||
conn.close()
|
||||
self.readingLoop.cancelSoon()
|
||||
try:
|
||||
await allFutures(toSeq(self.connections.values()).mapIt(it.close()))
|
||||
except CancelledError as exc:
|
||||
discard
|
||||
await self.readingLoop.cancelAndWait()
|
||||
untrackCounter(StunTransportTracker)
|
||||
|
||||
proc defaultUsernameProvider(): string = ""
|
||||
@@ -108,12 +111,13 @@ proc new*(
|
||||
##
|
||||
var self = T(
|
||||
udp: udp,
|
||||
laddr: udp.laddr,
|
||||
usernameProvider: usernameProvider,
|
||||
usernameChecker: usernameChecker,
|
||||
passwordProvider: passwordProvider,
|
||||
rng: rng
|
||||
)
|
||||
self.readingLoop = stunReadLoop()
|
||||
self.readingLoop = self.stunReadLoop()
|
||||
self.pendingConn = newAsyncQueue[StunConn](StunMaxPendingConnections)
|
||||
trackCounter(StunTransportTracker)
|
||||
return self
|
||||
|
||||
Reference in New Issue
Block a user