chore: fix directory structure

This commit is contained in:
Gabriel Cruz
2025-05-15 09:10:52 -03:00
parent 0a59984730
commit baff57251e
8 changed files with 213 additions and 184 deletions

View File

@@ -1,9 +1,6 @@
import json, strutils, tables, times import json, strutils, tables, times
import bearssl import bearssl
import jwt/[claims, jose, utils, crypto]
from jwt/private/crypto import nil
import jwt/private/[claims, jose, utils]
type type
InvalidToken* = object of ValueError InvalidToken* = object of ValueError

144
jwt/claims.nim Normal file
View File

@@ -0,0 +1,144 @@
import json, strutils, times, tables
import utils
type
ClaimKind* = enum
ISS
SUB
NBF
EXP
AUD
IAT
JTI
GENERAL
Claim* = ref ClaimObj
ClaimObj* {.acyclic.} = object
node*: JsonNode
kind*: ClaimKind
proc newClaims*(
claims: varargs[tuple[key: string, val: Claim]]
): TableRef[string, Claim] =
result = newTable[string, Claim](claims)
proc newClaim*(k: ClaimKind, node: JsonNode): Claim =
new result
result.kind = k
result.node = node
# ISS
proc newISS*(node: JsonNode): Claim =
checkJsonNodeKind(node, JString)
return newClaim(ISS, node)
# SUB
proc newSUB*(node: JsonNode): Claim =
checkJsonNodeKind(node, JString)
return newClaim(SUB, node)
# AUD
proc newAUD*(node: JsonNode): Claim =
if node.kind != JArray and node.kind != JString:
raise newException(ValueError, "Invalid kind")
return newClaim(AUD, node)
proc newAUD*(recipients: seq[string]): Claim =
var node = newJArray()
for r in recipients:
node.add(%r)
result = newAUD(node)
proc newAUD*(recipient: string): Claim =
return newAUD(@[recipient])
proc newAUD*(recipients: varargs[string]): Claim =
return newAUD(@recipients)
# Claims that have any kind of time
proc newTimeClaim*(k: ClaimKind, j: JsonNode): Claim =
# Check that the json kind is int..
checkJsonNodeKind(j, JInt)
return newClaim(k, j)
proc newTimeClaim*(k: ClaimKind, s: string): Claim =
return newTimeClaim(k, %parseInt(s))
proc newTimeClaim*(k: ClaimKind, i: int64): Claim =
return newTimeClaim(k, %i)
# Returns the claimKeyms value as a time
proc getClaimTime*(c: Claim): Time =
result = fromUnix(c.node.num)
# NBF
proc newNBF*(s: string): Claim =
return newTimeClaim(NBF, s)
proc newNBF*(j: JsonNode): Claim =
return newTimeClaim(NBF, j)
proc newNBF*(i: int64): Claim =
return newTimeClaim(NBF, i)
# EXP
proc newEXP*(s: string): Claim =
return newTimeClaim(EXP, s)
proc newEXP*(j: JsonNode): Claim =
return newTimeClaim(EXP, j)
proc newEXP*(i: int64): Claim =
return newTimeClaim(EXP, i)
# IAT
proc newIAT*(s: string): Claim =
return newTimeClaim(IAT, s)
proc newIAT*(j: JsonNode): Claim =
return newTimeClaim(IAT, j)
proc newIAT*(i: int64): Claim =
return newTimeClaim(IAT, i)
# JTI
proc newJTI*(j: JsonNode): Claim =
assert j.kind == JString
return newClaim(JTI, j)
proc newJTI*(s: string): Claim =
return newJTI(%s)
proc toClaims*(j: JsonNode): TableRef[string, Claim] =
result = newClaims()
for claimKey, claimNode in j:
case claimKey
of "iss":
result[claimKey] = newISS(claimNode)
of "sub":
result[claimKey] = newSUB(claimNode)
of "aud":
result[claimKey] = newAUD(claimNode)
of "nbf":
result[claimKey] = newNBF(claimNode)
of "exp":
result[claimKey] = newEXP(claimNode)
of "iat":
result[claimKey] = newIAT(claimNode)
of "jti":
result[claimKey] = newJTI(claimNode)
else:
result[claimKey] = newClaim(GENERAL, claimNode)
proc `%`*(c: Claim): JsonNode =
result = c.node
proc `%`*(claims: TableRef[string, Claim]): JsonNode =
result = newJObject()
for k, v in claims:
result[k] = %v
proc toBase64*(claims: TableRef[string, Claim]): string =
let asJson = %claims
result = encodeUrlSafe($asJson)

View File

@@ -3,7 +3,7 @@ import bearssl, bearssl_pkey_decoder
# This pragma should be the same as in nim-bearssl/decls.nim # This pragma should be the same as in nim-bearssl/decls.nim
{.pragma: bearSslFunc, cdecl, gcsafe, noSideEffect, raises: [].} {.pragma: bearSslFunc, cdecl, gcsafe, noSideEffect, raises: [].}
proc bearHMAC*(digestVtable: ptr HashClass; key, d: string): seq[byte] = proc bearHMAC*(digestVtable: ptr HashClass, key, d: string): seq[byte] =
var hKey: HmacKeyContext var hKey: HmacKeyContext
var hCtx: HmacContext var hCtx: HmacContext
hmacKeyInit(hKey, digestVtable, key.cstring, key.len.uint) hmacKeyInit(hKey, digestVtable, key.cstring, key.len.uint)
@@ -16,15 +16,18 @@ proc bearHMAC*(digestVtable: ptr HashClass; key, d: string): seq[byte] =
proc invalidPemKey() = proc invalidPemKey() =
raise newException(ValueError, "Invalid PEM encoding") raise newException(ValueError, "Invalid PEM encoding")
proc pemDecoderLoop(pem: string, prc: proc(ctx: pointer, pbytes: pointer, nbytes: uint) {.bearSslFunc.}, ctx: pointer) = proc pemDecoderLoop(
pem: string,
prc: proc(ctx: pointer, pbytes: pointer, nbytes: uint) {.bearSslFunc.},
ctx: pointer,
) =
var pemCtx: PemDecoderContext var pemCtx: PemDecoderContext
pemDecoderInit(pemCtx) pemDecoderInit(pemCtx)
var length = len(pem) var length = len(pem)
var offset = 0 var offset = 0
var inobj = false var inobj = false
while length > 0: while length > 0:
var tlen = pemDecoderPush(pemCtx, var tlen = pemDecoderPush(pemCtx, unsafeAddr pem[offset], length.uint).int
unsafeAddr pem[offset], length.uint).int
offset = offset + tlen offset = offset + tlen
length = length - tlen length = length - tlen
@@ -44,13 +47,23 @@ proc pemDecoderLoop(pem: string, prc: proc(ctx: pointer, pbytes: pointer, nbytes
proc decodeFromPem(skCtx: var SkeyDecoderContext, pem: string) = proc decodeFromPem(skCtx: var SkeyDecoderContext, pem: string) =
skeyDecoderInit(skCtx) skeyDecoderInit(skCtx)
pemDecoderLoop(pem, cast[proc(ctx: pointer, pbytes: pointer, nbytes: uint) {.bearSslFunc.}](skeyDecoderPush), addr skCtx) pemDecoderLoop(
if skeyDecoderLastError(skCtx) != 0: invalidPemKey() pem,
cast[proc(ctx: pointer, pbytes: pointer, nbytes: uint) {.bearSslFunc.}](skeyDecoderPush),
addr skCtx,
)
if skeyDecoderLastError(skCtx) != 0:
invalidPemKey()
proc decodeFromPem(pkCtx: var PkeyDecoderContext, pem: string) = proc decodeFromPem(pkCtx: var PkeyDecoderContext, pem: string) =
pkeyDecoderInit(addr pkCtx) pkeyDecoderInit(addr pkCtx)
pemDecoderLoop(pem, cast[proc(ctx: pointer, pbytes: pointer, nbytes: uint) {.bearSslFunc.}](pkeyDecoderPush), addr pkCtx) pemDecoderLoop(
if pkeyDecoderLastError(addr pkCtx) != 0: invalidPemKey() pem,
cast[proc(ctx: pointer, pbytes: pointer, nbytes: uint) {.bearSslFunc.}](pkeyDecoderPush),
addr pkCtx,
)
if pkeyDecoderLastError(addr pkCtx) != 0:
invalidPemKey()
proc calcHash(alg: ptr HashClass, data: string, output: var array[64, byte]) = proc calcHash(alg: ptr HashClass, data: string, output: var array[64, byte]) =
var ctx: array[512, byte] var ctx: array[512, byte]
@@ -61,14 +74,17 @@ proc calcHash(alg: ptr HashClass, data: string, output: var array[64, byte]) =
alg.update(pCtx, unsafeAddr data[0], data.len.uint) alg.update(pCtx, unsafeAddr data[0], data.len.uint)
alg.`out`(pCtx, addr output[0]) alg.`out`(pCtx, addr output[0])
proc bearSignRSPem*(data, key: string, alg: ptr HashClass, hashOid: cstring, hashLen: int): seq[byte] = proc bearSignRSPem*(
data, key: string, alg: ptr HashClass, hashOid: cstring, hashLen: int
): seq[byte] =
# Step 1. Extract RSA key from `key` in PEM format # Step 1. Extract RSA key from `key` in PEM format
var skCtx: SkeyDecoderContext var skCtx: SkeyDecoderContext
decodeFromPem(skCtx, key) decodeFromPem(skCtx, key)
if skeyDecoderKeyType(skCtx) != KEYTYPE_RSA: if skeyDecoderKeyType(skCtx) != KEYTYPE_RSA:
invalidPemKey() invalidPemKey()
template pk(): RsaPrivateKey = skCtx.key.rsa template pk(): RsaPrivateKey =
skCtx.key.rsa
# Step 2. Hash! # Step 2. Hash!
var digest: array[64, byte] var digest: array[64, byte]
@@ -78,16 +94,24 @@ proc bearSignRSPem*(data, key: string, alg: ptr HashClass, hashOid: cstring, has
result = newSeqUninitialized[byte](sigLen) result = newSeqUninitialized[byte](sigLen)
let s = rsaPkcs1SignGetDefault() let s = rsaPkcs1SignGetDefault()
assert(not s.isNil) assert(not s.isNil)
if s(cast[ptr byte](hashOid), addr digest[0], hashLen.uint, addr pk, addr result[0]) != 1: if s(cast[ptr byte](hashOid), addr digest[0], hashLen.uint, addr pk, addr result[0]) !=
1:
raise newException(ValueError, "Could not sign") raise newException(ValueError, "Could not sign")
proc bearVerifyRSPem*(data, key: string, sig: openarray[byte], alg: ptr HashClass, hashOid: cstring, hashLen: int): bool = proc bearVerifyRSPem*(
data, key: string,
sig: openarray[byte],
alg: ptr HashClass,
hashOid: cstring,
hashLen: int,
): bool =
# Step 1. Extract RSA key from `key` in PEM format # Step 1. Extract RSA key from `key` in PEM format
var pkCtx: PkeyDecoderContext var pkCtx: PkeyDecoderContext
decodeFromPem(pkCtx, key) decodeFromPem(pkCtx, key)
if pkeyDecoderKeyType(addr pkCtx) != KEYTYPE_RSA: if pkeyDecoderKeyType(addr pkCtx) != KEYTYPE_RSA:
invalidPemKey() invalidPemKey()
template pk(): RsaPublicKey = pkCtx.key.rsa template pk(): RsaPublicKey =
pkCtx.key.rsa
var digest: array[64, byte] var digest: array[64, byte]
calcHash(alg, data, digest) calcHash(alg, data, digest)
@@ -95,7 +119,14 @@ proc bearVerifyRSPem*(data, key: string, sig: openarray[byte], alg: ptr HashClas
let s = rsaPkcs1VrfyGetDefault() let s = rsaPkcs1VrfyGetDefault()
var digest2: array[64, byte] var digest2: array[64, byte]
if s(unsafeAddr sig[0], sig.len.uint, cast[ptr byte](hashOid), hashLen.uint, addr pk, addr digest2[0]) != 1: if s(
unsafeAddr sig[0],
sig.len.uint,
cast[ptr byte](hashOid),
hashLen.uint,
addr pk,
addr digest2[0],
) != 1:
return false return false
digest == digest2 digest == digest2
@@ -107,7 +138,8 @@ proc bearSignECPem*(data, key: string, alg: ptr HashClass): seq[byte] =
if skeyDecoderKeyType(skCtx) != KEYTYPE_EC: if skeyDecoderKeyType(skCtx) != KEYTYPE_EC:
invalidPemKey() invalidPemKey()
template pk(): EcPrivateKey = skCtx.key.ec template pk(): EcPrivateKey =
skCtx.key.ec
# Step 2. Hash! # Step 2. Hash!
var digest: array[64, byte] var digest: array[64, byte]
@@ -123,13 +155,16 @@ proc bearSignECPem*(data, key: string, alg: ptr HashClass): seq[byte] =
assert(sz <= maxSigLen) assert(sz <= maxSigLen)
result.setLen(sz) result.setLen(sz)
proc bearVerifyECPem*(data, key: string, sig: openarray[byte], alg: ptr HashClass, hashLen: int): bool = proc bearVerifyECPem*(
data, key: string, sig: openarray[byte], alg: ptr HashClass, hashLen: int
): bool =
# Step 1. Extract EC Pub key from `key` in PEM format # Step 1. Extract EC Pub key from `key` in PEM format
var pkCtx: PkeyDecoderContext var pkCtx: PkeyDecoderContext
decodeFromPem(pkCtx, key) decodeFromPem(pkCtx, key)
if pkeyDecoderKeyType(addr pkCtx) != KEYTYPE_EC: if pkeyDecoderKeyType(addr pkCtx) != KEYTYPE_EC:
invalidPemKey() invalidPemKey()
template pk(): EcPublicKey = pkCtx.key.ec template pk(): EcPublicKey =
pkCtx.key.ec
# bearssl ecdsaVrfy requires pubkey to be prepended with 0x04 byte, do it here # bearssl ecdsaVrfy requires pubkey to be prepended with 0x04 byte, do it here
assert((pk.q == addr pkCtx.key_data) and pk.qlen < sizeof(pkCtx.key_data).uint) assert((pk.q == addr pkCtx.key_data) and pk.qlen < sizeof(pkCtx.key_data).uint)
@@ -142,4 +177,5 @@ proc bearVerifyECPem*(data, key: string, sig: openarray[byte], alg: ptr HashClas
let impl = ecGetDefault() let impl = ecGetDefault()
let s = ecdsaVrfyRawGetDefault() let s = ecdsaVrfyRawGetDefault()
result = s(impl, addr digest[0], hashLen.uint, addr pk, unsafeAddr sig[0], sig.len.uint) == 1 result =
s(impl, addr digest[0], hashLen.uint, addr pk, unsafeAddr sig[0], sig.len.uint) == 1

View File

@@ -1,5 +1,4 @@
import json, strutils import json, strutils
import utils import utils
type type
@@ -23,7 +22,6 @@ proc strToSignatureAlgorithm(s: string): SignatureAlgorithm =
except ValueError: except ValueError:
raise newException(UnsupportedAlgorithm, "$# isn't supported" % s) raise newException(UnsupportedAlgorithm, "$# isn't supported" % s)
proc toHeader*(j: JsonNode): JsonNode = proc toHeader*(j: JsonNode): JsonNode =
# Check that the keys are present so we dont blow up. # Check that the keys are present so we dont blow up.
result = newJObject() result = newJObject()
@@ -43,6 +41,5 @@ proc `%`*(alg: SignatureAlgorithm): JsonNode =
let s = $alg let s = $alg
return %s return %s
proc toBase64*(h: JsonNode): string = proc toBase64*(h: JsonNode): string =
result = encodeUrlSafe($h) result = encodeUrlSafe($h)

View File

@@ -1,141 +0,0 @@
import json, strutils, times, tables
import utils
type
ClaimKind* = enum
ISS,
SUB,
NBF,
EXP,
AUD,
IAT,
JTI,
GENERAL
Claim* = ref ClaimObj
ClaimObj* {.acyclic.} = object
node*: JsonNode
kind*: ClaimKind
proc newClaims*(claims: varargs[tuple[key: string, val: Claim]]): TableRef[string, Claim] =
result = newTable[string, Claim](claims)
proc newClaim*(k: ClaimKind, node: JsonNode): Claim =
new result
result.kind = k
result.node = node
# ISS
proc newISS*(node: JsonNode): Claim =
checkJsonNodeKind(node, JString)
return newClaim(ISS, node)
# SUB
proc newSUB*(node: JsonNode): Claim =
checkJsonNodeKind(node, JString)
return newClaim(SUB, node)
# AUD
proc newAUD*(node: JsonNode): Claim =
if node.kind != JArray and node.kind != JString:
raise newException(ValueError, "Invalid kind")
return newClaim(AUD, node)
proc newAUD*(recipients: seq[string]): Claim =
var node = newJArray()
for r in recipients:
node.add(%r)
result = newAUD(node)
proc newAUD*(recipient: string): Claim = return newAUD(@[recipient])
proc newAUD*(recipients: varargs[string]): Claim = return newAUD(@recipients)
# Claims that have any kind of time
proc newTimeClaim*(k: ClaimKind, j: JsonNode): Claim =
# Check that the json kind is int..
checkJsonNodeKind(j, JInt)
return newClaim(k, j)
proc newTimeClaim*(k: ClaimKind, s: string): Claim =
return newTimeClaim(k, %parseInt(s))
proc newTimeClaim*(k: ClaimKind, i: int64): Claim =
return newTimeClaim(k, %i)
# Returns the claimKeyms value as a time
proc getClaimTime*(c: Claim): Time =
result = fromUnix(c.node.num)
# NBF
proc newNBF*(s: string): Claim = return newTimeClaim(NBF, s)
proc newNBF*(j: JsonNode): Claim = return newTimeClaim(NBF, j)
proc newNBF*(i: int64): Claim = return newTimeClaim(NBF, i)
# EXP
proc newEXP*(s: string): Claim = return newTimeClaim(EXP, s)
proc newEXP*(j: JsonNode): Claim = return newTimeClaim(EXP, j)
proc newEXP*(i: int64): Claim = return newTimeClaim(EXP, i)
# IAT
proc newIAT*(s: string): Claim = return newTimeClaim(IAT, s)
proc newIAT*(j: JsonNode): Claim = return newTimeClaim(IAT, j)
proc newIAT*(i: int64): Claim = return newTimeClaim(IAT, i)
# JTI
proc newJTI*(j: JsonNode): Claim =
assert j.kind == JString
return newClaim(JTI, j)
proc newJTI*(s: string): Claim =
return newJTI(%s)
proc toClaims*(j: JsonNode): TableRef[string, Claim] =
result = newClaims()
for claimKey, claimNode in j:
case claimKey:
of "iss":
result[claimKey] = newISS(claimNode)
of "sub":
result[claimKey] = newSUB(claimNode)
of "aud":
result[claimKey] = newAUD(claimNode)
of "nbf":
result[claimKey] = newNBF(claimNode)
of "exp":
result[claimKey] = newEXP(claimNode)
of "iat":
result[claimKey] = newIAT(claimNode)
of "jti":
result[claimKey] = newJTI(claimNode)
else:
result[claimKey] = newClaim(GENERAL, claimNode)
proc `%`*(c: Claim): JsonNode =
result = c.node
proc `%`*(claims: TableRef[string, Claim]): JsonNode =
result = newJObject()
for k, v in claims:
result[k] = %v
proc toBase64*(claims: TableRef[string, Claim]): string =
let asJson = %claims
result = encodeUrlSafe($asJson)

View File

@@ -1,14 +1,11 @@
import json, strutils import json, strutils
import base64
from base64 import nil
proc checkJsonNodeKind*(node: JsonNode, kind: JsonNodeKind) = proc checkJsonNodeKind*(node: JsonNode, kind: JsonNodeKind) =
# Check that a given JsonNode has a given kind, raise ValueError if not # Check that a given JsonNode has a given kind, raise ValueError if not
if node.kind != kind: if node.kind != kind:
raise newException(ValueError, "Invalid kind") raise newException(ValueError, "Invalid kind")
proc checkKeysExists*(node: JsonNode, keys: varargs[string]) = proc checkKeysExists*(node: JsonNode, keys: varargs[string]) =
for key in keys: for key in keys:
if not node.hasKey(key): if not node.hasKey(key):
@@ -18,7 +15,7 @@ proc encodeUrlSafe*(s: openarray[byte]): string =
when (NimMajor >= 1 and (NimMinor >= 1 or NimPatch >= 2)) or NimMajor >= 2: when (NimMajor >= 1 and (NimMinor >= 1 or NimPatch >= 2)) or NimMajor >= 2:
result = base64.encode(s) result = base64.encode(s)
else: else:
result = base64.encode(s, newLine="") result = base64.encode(s, newLine = "")
while result.endsWith("="): while result.endsWith("="):
result.setLen(result.len - 1) result.setLen(result.len - 1)
result = result.replace('+', '-').replace('/', '_') result = result.replace('+', '-').replace('/', '_')

View File

@@ -1,18 +1,18 @@
import json, unittest import json, unittest
import jwt
import ../jwt
suite "Claim ops": suite "Claim ops":
test "Create claims from JSON": test "Create claims from JSON":
let asJson = %{ let asJson =
"iss": %"jane", %{
"sub": %"john", "iss": %"jane",
"nbf": %1234, "sub": %"john",
"iat": %1234, "nbf": %1234,
"exp": %1234, "iat": %1234,
"jti": %"token-id", "exp": %1234,
"foo": %{"bar": %1} "jti": %"token-id",
} "foo": %{"bar": %1},
}
let claims = asJson.toClaims let claims = asJson.toClaims
let toJson = %claims let toJson = %claims

View File

@@ -1,6 +1,5 @@
import json, times, unittest import json, times, unittest
import jwt
import ../jwt
proc getToken(claims: JsonNode = newJObject(), header: JsonNode = newJObject()): JWT = proc getToken(claims: JsonNode = newJObject(), header: JsonNode = newJObject()): JWT =
for k, v in %*{"alg": "HS512", "typ": "JWT"}: for k, v in %*{"alg": "HS512", "typ": "JWT"}: