diff --git a/jwt.nim b/jwt.nim index bbca7d8..461cece 100644 --- a/jwt.nim +++ b/jwt.nim @@ -1,9 +1,6 @@ import json, strutils, tables, times import bearssl - -from jwt/private/crypto import nil - -import jwt/private/[claims, jose, utils] +import jwt/[claims, jose, utils, crypto] type InvalidToken* = object of ValueError diff --git a/jwt/claims.nim b/jwt/claims.nim new file mode 100644 index 0000000..3e3e329 --- /dev/null +++ b/jwt/claims.nim @@ -0,0 +1,144 @@ +import json, strutils, times, tables +import utils + +type + ClaimKind* = enum + ISS + SUB + NBF + EXP + AUD + IAT + JTI + GENERAL + + Claim* = ref ClaimObj + ClaimObj* {.acyclic.} = object + node*: JsonNode + kind*: ClaimKind + +proc newClaims*( + claims: varargs[tuple[key: string, val: Claim]] +): TableRef[string, Claim] = + result = newTable[string, Claim](claims) + +proc newClaim*(k: ClaimKind, node: JsonNode): Claim = + new result + result.kind = k + result.node = node + +# ISS +proc newISS*(node: JsonNode): Claim = + checkJsonNodeKind(node, JString) + return newClaim(ISS, node) + +# SUB +proc newSUB*(node: JsonNode): Claim = + checkJsonNodeKind(node, JString) + return newClaim(SUB, node) + +# AUD +proc newAUD*(node: JsonNode): Claim = + if node.kind != JArray and node.kind != JString: + raise newException(ValueError, "Invalid kind") + return newClaim(AUD, node) + +proc newAUD*(recipients: seq[string]): Claim = + var node = newJArray() + for r in recipients: + node.add(%r) + result = newAUD(node) + +proc newAUD*(recipient: string): Claim = + return newAUD(@[recipient]) + +proc newAUD*(recipients: varargs[string]): Claim = + return newAUD(@recipients) + +# Claims that have any kind of time +proc newTimeClaim*(k: ClaimKind, j: JsonNode): Claim = + # Check that the json kind is int.. + checkJsonNodeKind(j, JInt) + return newClaim(k, j) + +proc newTimeClaim*(k: ClaimKind, s: string): Claim = + return newTimeClaim(k, %parseInt(s)) + +proc newTimeClaim*(k: ClaimKind, i: int64): Claim = + return newTimeClaim(k, %i) + +# Returns the claimKeyms value as a time +proc getClaimTime*(c: Claim): Time = + result = fromUnix(c.node.num) + +# NBF +proc newNBF*(s: string): Claim = + return newTimeClaim(NBF, s) + +proc newNBF*(j: JsonNode): Claim = + return newTimeClaim(NBF, j) + +proc newNBF*(i: int64): Claim = + return newTimeClaim(NBF, i) + +# EXP +proc newEXP*(s: string): Claim = + return newTimeClaim(EXP, s) + +proc newEXP*(j: JsonNode): Claim = + return newTimeClaim(EXP, j) + +proc newEXP*(i: int64): Claim = + return newTimeClaim(EXP, i) + +# IAT +proc newIAT*(s: string): Claim = + return newTimeClaim(IAT, s) + +proc newIAT*(j: JsonNode): Claim = + return newTimeClaim(IAT, j) + +proc newIAT*(i: int64): Claim = + return newTimeClaim(IAT, i) + +# JTI +proc newJTI*(j: JsonNode): Claim = + assert j.kind == JString + return newClaim(JTI, j) + +proc newJTI*(s: string): Claim = + return newJTI(%s) + +proc toClaims*(j: JsonNode): TableRef[string, Claim] = + result = newClaims() + + for claimKey, claimNode in j: + case claimKey + of "iss": + result[claimKey] = newISS(claimNode) + of "sub": + result[claimKey] = newSUB(claimNode) + of "aud": + result[claimKey] = newAUD(claimNode) + of "nbf": + result[claimKey] = newNBF(claimNode) + of "exp": + result[claimKey] = newEXP(claimNode) + of "iat": + result[claimKey] = newIAT(claimNode) + of "jti": + result[claimKey] = newJTI(claimNode) + else: + result[claimKey] = newClaim(GENERAL, claimNode) + +proc `%`*(c: Claim): JsonNode = + result = c.node + +proc `%`*(claims: TableRef[string, Claim]): JsonNode = + result = newJObject() + for k, v in claims: + result[k] = %v + +proc toBase64*(claims: TableRef[string, Claim]): string = + let asJson = %claims + result = encodeUrlSafe($asJson) diff --git a/jwt/private/crypto.nim b/jwt/crypto.nim similarity index 70% rename from jwt/private/crypto.nim rename to jwt/crypto.nim index fd2392f..76f336b 100644 --- a/jwt/private/crypto.nim +++ b/jwt/crypto.nim @@ -3,7 +3,7 @@ import bearssl, bearssl_pkey_decoder # This pragma should be the same as in nim-bearssl/decls.nim {.pragma: bearSslFunc, cdecl, gcsafe, noSideEffect, raises: [].} -proc bearHMAC*(digestVtable: ptr HashClass; key, d: string): seq[byte] = +proc bearHMAC*(digestVtable: ptr HashClass, key, d: string): seq[byte] = var hKey: HmacKeyContext var hCtx: HmacContext hmacKeyInit(hKey, digestVtable, key.cstring, key.len.uint) @@ -16,15 +16,18 @@ proc bearHMAC*(digestVtable: ptr HashClass; key, d: string): seq[byte] = proc invalidPemKey() = raise newException(ValueError, "Invalid PEM encoding") -proc pemDecoderLoop(pem: string, prc: proc(ctx: pointer, pbytes: pointer, nbytes: uint) {.bearSslFunc.}, ctx: pointer) = +proc pemDecoderLoop( + pem: string, + prc: proc(ctx: pointer, pbytes: pointer, nbytes: uint) {.bearSslFunc.}, + ctx: pointer, +) = var pemCtx: PemDecoderContext pemDecoderInit(pemCtx) var length = len(pem) var offset = 0 var inobj = false while length > 0: - var tlen = pemDecoderPush(pemCtx, - unsafeAddr pem[offset], length.uint).int + var tlen = pemDecoderPush(pemCtx, unsafeAddr pem[offset], length.uint).int offset = offset + tlen length = length - tlen @@ -44,13 +47,23 @@ proc pemDecoderLoop(pem: string, prc: proc(ctx: pointer, pbytes: pointer, nbytes proc decodeFromPem(skCtx: var SkeyDecoderContext, pem: string) = skeyDecoderInit(skCtx) - pemDecoderLoop(pem, cast[proc(ctx: pointer, pbytes: pointer, nbytes: uint) {.bearSslFunc.}](skeyDecoderPush), addr skCtx) - if skeyDecoderLastError(skCtx) != 0: invalidPemKey() + pemDecoderLoop( + pem, + cast[proc(ctx: pointer, pbytes: pointer, nbytes: uint) {.bearSslFunc.}](skeyDecoderPush), + addr skCtx, + ) + if skeyDecoderLastError(skCtx) != 0: + invalidPemKey() proc decodeFromPem(pkCtx: var PkeyDecoderContext, pem: string) = pkeyDecoderInit(addr pkCtx) - pemDecoderLoop(pem, cast[proc(ctx: pointer, pbytes: pointer, nbytes: uint) {.bearSslFunc.}](pkeyDecoderPush), addr pkCtx) - if pkeyDecoderLastError(addr pkCtx) != 0: invalidPemKey() + pemDecoderLoop( + pem, + cast[proc(ctx: pointer, pbytes: pointer, nbytes: uint) {.bearSslFunc.}](pkeyDecoderPush), + addr pkCtx, + ) + if pkeyDecoderLastError(addr pkCtx) != 0: + invalidPemKey() proc calcHash(alg: ptr HashClass, data: string, output: var array[64, byte]) = var ctx: array[512, byte] @@ -61,14 +74,17 @@ proc calcHash(alg: ptr HashClass, data: string, output: var array[64, byte]) = alg.update(pCtx, unsafeAddr data[0], data.len.uint) alg.`out`(pCtx, addr output[0]) -proc bearSignRSPem*(data, key: string, alg: ptr HashClass, hashOid: cstring, hashLen: int): seq[byte] = +proc bearSignRSPem*( + data, key: string, alg: ptr HashClass, hashOid: cstring, hashLen: int +): seq[byte] = # Step 1. Extract RSA key from `key` in PEM format var skCtx: SkeyDecoderContext decodeFromPem(skCtx, key) if skeyDecoderKeyType(skCtx) != KEYTYPE_RSA: invalidPemKey() - template pk(): RsaPrivateKey = skCtx.key.rsa + template pk(): RsaPrivateKey = + skCtx.key.rsa # Step 2. Hash! var digest: array[64, byte] @@ -78,16 +94,24 @@ proc bearSignRSPem*(data, key: string, alg: ptr HashClass, hashOid: cstring, has result = newSeqUninitialized[byte](sigLen) let s = rsaPkcs1SignGetDefault() assert(not s.isNil) - if s(cast[ptr byte](hashOid), addr digest[0], hashLen.uint, addr pk, addr result[0]) != 1: + if s(cast[ptr byte](hashOid), addr digest[0], hashLen.uint, addr pk, addr result[0]) != + 1: raise newException(ValueError, "Could not sign") -proc bearVerifyRSPem*(data, key: string, sig: openarray[byte], alg: ptr HashClass, hashOid: cstring, hashLen: int): bool = +proc bearVerifyRSPem*( + data, key: string, + sig: openarray[byte], + alg: ptr HashClass, + hashOid: cstring, + hashLen: int, +): bool = # Step 1. Extract RSA key from `key` in PEM format var pkCtx: PkeyDecoderContext decodeFromPem(pkCtx, key) if pkeyDecoderKeyType(addr pkCtx) != KEYTYPE_RSA: invalidPemKey() - template pk(): RsaPublicKey = pkCtx.key.rsa + template pk(): RsaPublicKey = + pkCtx.key.rsa var digest: array[64, byte] calcHash(alg, data, digest) @@ -95,7 +119,14 @@ proc bearVerifyRSPem*(data, key: string, sig: openarray[byte], alg: ptr HashClas let s = rsaPkcs1VrfyGetDefault() var digest2: array[64, byte] - if s(unsafeAddr sig[0], sig.len.uint, cast[ptr byte](hashOid), hashLen.uint, addr pk, addr digest2[0]) != 1: + if s( + unsafeAddr sig[0], + sig.len.uint, + cast[ptr byte](hashOid), + hashLen.uint, + addr pk, + addr digest2[0], + ) != 1: return false digest == digest2 @@ -107,7 +138,8 @@ proc bearSignECPem*(data, key: string, alg: ptr HashClass): seq[byte] = if skeyDecoderKeyType(skCtx) != KEYTYPE_EC: invalidPemKey() - template pk(): EcPrivateKey = skCtx.key.ec + template pk(): EcPrivateKey = + skCtx.key.ec # Step 2. Hash! var digest: array[64, byte] @@ -123,13 +155,16 @@ proc bearSignECPem*(data, key: string, alg: ptr HashClass): seq[byte] = assert(sz <= maxSigLen) result.setLen(sz) -proc bearVerifyECPem*(data, key: string, sig: openarray[byte], alg: ptr HashClass, hashLen: int): bool = +proc bearVerifyECPem*( + data, key: string, sig: openarray[byte], alg: ptr HashClass, hashLen: int +): bool = # Step 1. Extract EC Pub key from `key` in PEM format var pkCtx: PkeyDecoderContext decodeFromPem(pkCtx, key) if pkeyDecoderKeyType(addr pkCtx) != KEYTYPE_EC: invalidPemKey() - template pk(): EcPublicKey = pkCtx.key.ec + template pk(): EcPublicKey = + pkCtx.key.ec # bearssl ecdsaVrfy requires pubkey to be prepended with 0x04 byte, do it here assert((pk.q == addr pkCtx.key_data) and pk.qlen < sizeof(pkCtx.key_data).uint) @@ -142,4 +177,5 @@ proc bearVerifyECPem*(data, key: string, sig: openarray[byte], alg: ptr HashClas let impl = ecGetDefault() let s = ecdsaVrfyRawGetDefault() - result = s(impl, addr digest[0], hashLen.uint, addr pk, unsafeAddr sig[0], sig.len.uint) == 1 + result = + s(impl, addr digest[0], hashLen.uint, addr pk, unsafeAddr sig[0], sig.len.uint) == 1 diff --git a/jwt/private/jose.nim b/jwt/jose.nim similarity index 99% rename from jwt/private/jose.nim rename to jwt/jose.nim index bb63405..fe6e0f9 100644 --- a/jwt/private/jose.nim +++ b/jwt/jose.nim @@ -1,5 +1,4 @@ import json, strutils - import utils type @@ -23,7 +22,6 @@ proc strToSignatureAlgorithm(s: string): SignatureAlgorithm = except ValueError: raise newException(UnsupportedAlgorithm, "$# isn't supported" % s) - proc toHeader*(j: JsonNode): JsonNode = # Check that the keys are present so we dont blow up. result = newJObject() @@ -43,6 +41,5 @@ proc `%`*(alg: SignatureAlgorithm): JsonNode = let s = $alg return %s - proc toBase64*(h: JsonNode): string = result = encodeUrlSafe($h) diff --git a/jwt/private/claims.nim b/jwt/private/claims.nim deleted file mode 100644 index 3c400e4..0000000 --- a/jwt/private/claims.nim +++ /dev/null @@ -1,141 +0,0 @@ -import json, strutils, times, tables - -import utils - - -type - ClaimKind* = enum - ISS, - SUB, - NBF, - EXP, - AUD, - IAT, - JTI, - GENERAL - - Claim* = ref ClaimObj - ClaimObj* {.acyclic.} = object - node*: JsonNode - kind*: ClaimKind - - - -proc newClaims*(claims: varargs[tuple[key: string, val: Claim]]): TableRef[string, Claim] = - result = newTable[string, Claim](claims) - - -proc newClaim*(k: ClaimKind, node: JsonNode): Claim = - new result - result.kind = k - result.node = node - -# ISS -proc newISS*(node: JsonNode): Claim = - checkJsonNodeKind(node, JString) - return newClaim(ISS, node) - -# SUB -proc newSUB*(node: JsonNode): Claim = - checkJsonNodeKind(node, JString) - return newClaim(SUB, node) - -# AUD -proc newAUD*(node: JsonNode): Claim = - if node.kind != JArray and node.kind != JString: - raise newException(ValueError, "Invalid kind") - return newClaim(AUD, node) - -proc newAUD*(recipients: seq[string]): Claim = - var node = newJArray() - for r in recipients: - node.add(%r) - result = newAUD(node) - -proc newAUD*(recipient: string): Claim = return newAUD(@[recipient]) - -proc newAUD*(recipients: varargs[string]): Claim = return newAUD(@recipients) - - -# Claims that have any kind of time -proc newTimeClaim*(k: ClaimKind, j: JsonNode): Claim = - # Check that the json kind is int.. - checkJsonNodeKind(j, JInt) - return newClaim(k, j) - -proc newTimeClaim*(k: ClaimKind, s: string): Claim = - return newTimeClaim(k, %parseInt(s)) - -proc newTimeClaim*(k: ClaimKind, i: int64): Claim = - return newTimeClaim(k, %i) - -# Returns the claimKeyms value as a time -proc getClaimTime*(c: Claim): Time = - result = fromUnix(c.node.num) - -# NBF -proc newNBF*(s: string): Claim = return newTimeClaim(NBF, s) - -proc newNBF*(j: JsonNode): Claim = return newTimeClaim(NBF, j) - -proc newNBF*(i: int64): Claim = return newTimeClaim(NBF, i) - -# EXP -proc newEXP*(s: string): Claim = return newTimeClaim(EXP, s) - -proc newEXP*(j: JsonNode): Claim = return newTimeClaim(EXP, j) - -proc newEXP*(i: int64): Claim = return newTimeClaim(EXP, i) - -# IAT -proc newIAT*(s: string): Claim = return newTimeClaim(IAT, s) - -proc newIAT*(j: JsonNode): Claim = return newTimeClaim(IAT, j) - -proc newIAT*(i: int64): Claim = return newTimeClaim(IAT, i) - -# JTI -proc newJTI*(j: JsonNode): Claim = - assert j.kind == JString - return newClaim(JTI, j) - -proc newJTI*(s: string): Claim = - return newJTI(%s) - - -proc toClaims*(j: JsonNode): TableRef[string, Claim] = - result = newClaims() - - for claimKey, claimNode in j: - case claimKey: - of "iss": - result[claimKey] = newISS(claimNode) - of "sub": - result[claimKey] = newSUB(claimNode) - of "aud": - result[claimKey] = newAUD(claimNode) - of "nbf": - result[claimKey] = newNBF(claimNode) - of "exp": - result[claimKey] = newEXP(claimNode) - of "iat": - result[claimKey] = newIAT(claimNode) - of "jti": - result[claimKey] = newJTI(claimNode) - else: - result[claimKey] = newClaim(GENERAL, claimNode) - - -proc `%`*(c: Claim): JsonNode = - result = c.node - - -proc `%`*(claims: TableRef[string, Claim]): JsonNode = - result = newJObject() - for k, v in claims: - result[k] = %v - - -proc toBase64*(claims: TableRef[string, Claim]): string = - let asJson = %claims - result = encodeUrlSafe($asJson) diff --git a/jwt/private/utils.nim b/jwt/utils.nim similarity index 94% rename from jwt/private/utils.nim rename to jwt/utils.nim index c76ef5d..9254b95 100644 --- a/jwt/private/utils.nim +++ b/jwt/utils.nim @@ -1,14 +1,11 @@ import json, strutils - -from base64 import nil - +import base64 proc checkJsonNodeKind*(node: JsonNode, kind: JsonNodeKind) = # Check that a given JsonNode has a given kind, raise ValueError if not if node.kind != kind: raise newException(ValueError, "Invalid kind") - proc checkKeysExists*(node: JsonNode, keys: varargs[string]) = for key in keys: if not node.hasKey(key): @@ -18,7 +15,7 @@ proc encodeUrlSafe*(s: openarray[byte]): string = when (NimMajor >= 1 and (NimMinor >= 1 or NimPatch >= 2)) or NimMajor >= 2: result = base64.encode(s) else: - result = base64.encode(s, newLine="") + result = base64.encode(s, newLine = "") while result.endsWith("="): result.setLen(result.len - 1) result = result.replace('+', '-').replace('/', '_') diff --git a/tests/t_claims.nim b/tests/t_claims.nim index 88eb629..503cc40 100644 --- a/tests/t_claims.nim +++ b/tests/t_claims.nim @@ -1,18 +1,18 @@ import json, unittest - -import ../jwt +import jwt suite "Claim ops": test "Create claims from JSON": - let asJson = %{ - "iss": %"jane", - "sub": %"john", - "nbf": %1234, - "iat": %1234, - "exp": %1234, - "jti": %"token-id", - "foo": %{"bar": %1} - } + let asJson = + %{ + "iss": %"jane", + "sub": %"john", + "nbf": %1234, + "iat": %1234, + "exp": %1234, + "jti": %"token-id", + "foo": %{"bar": %1}, + } let claims = asJson.toClaims let toJson = %claims diff --git a/tests/t_jwt.nim b/tests/t_jwt.nim index a332e08..f8dfcfc 100644 --- a/tests/t_jwt.nim +++ b/tests/t_jwt.nim @@ -1,6 +1,5 @@ import json, times, unittest - -import ../jwt +import jwt proc getToken(claims: JsonNode = newJObject(), header: JsonNode = newJObject()): JWT = for k, v in %*{"alg": "HS512", "typ": "JWT"}: