tests: Fix up tests for r2.lib.tracking.

This commit is contained in:
Neil Williams
2013-01-02 11:14:18 -08:00
parent 49f41715da
commit f1f7f36b32
2 changed files with 48 additions and 19 deletions

View File

@@ -66,9 +66,9 @@ def _unpad_message(text):
return unpadded
def _make_cipher(initialization_vector):
def _make_cipher(initialization_vector, secret):
"""Return a block cipher object for use in `encrypt` and `decrypt`."""
return AES.new(g.tracking_secret[:KEY_SIZE], AES.MODE_CBC,
return AES.new(secret[:KEY_SIZE], AES.MODE_CBC,
initialization_vector[:AES.block_size])
@@ -83,15 +83,22 @@ def encrypt(plaintext):
"""
salt = _make_salt()
return _encrypt(salt, plaintext, g.tracking_secret)
def _make_salt():
# we want SALT_SIZE letters of salt text, but we're generating random bytes
# so we'll calculate how many bytes we need to get SALT_SIZE characters of
# base64 output. because of padding, this only works for SALT_SIZE % 4 == 0
assert SALT_SIZE % 4 == 0
salt_byte_count = (SALT_SIZE / 4) * 3
salt_bytes = get_random_bytes(salt_byte_count)
salt = base64.b64encode(salt_bytes)
cipher = _make_cipher(salt)
return base64.b64encode(salt_bytes)
def _encrypt(salt, plaintext, secret):
cipher = _make_cipher(salt, secret)
padded = _pad_message(plaintext)
ciphertext = cipher.encrypt(padded)
@@ -107,10 +114,14 @@ def decrypt(encrypted):
"""
return _decrypt(encrypted, g.tracking_secret)
def _decrypt(encrypted, secret):
encrypted = urllib.unquote_plus(encrypted)
salt, encoded = encrypted[:SALT_SIZE], encrypted[SALT_SIZE:]
ciphertext = base64.b64decode(encoded)
cipher = _make_cipher(salt)
cipher = _make_cipher(salt, secret)
padded = cipher.decrypt(ciphertext)
return _unpad_message(padded)

View File

@@ -22,41 +22,59 @@
import unittest
from r2.lib import tracking
from r2.tests import RedditTestCase
KEY_SIZE = tracking.KEY_SIZE
MESSAGE = "the quick brown fox jumped over..."
BLOCK_O_PADDING = ("\x10\x10\x10\x10\x10\x10\x10\x10"
"\x10\x10\x10\x10\x10\x10\x10\x10")
SECRET = "abcdefghijklmnopqrstuvwxyz"
ENCRYPTED = ("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaIbzth1QTzJxzHbHGnJywG5V1uR3tWtSB"
"8hTyIcfg6rUZC4Wo0pT8jkEt9o1c%2FkTn")
class TestPadding(unittest.TestCase):
def test_pad_empty_string(self):
padded = tracking._pad_message("")
from r2.lib.tracking import _pad_message
padded = _pad_message("")
self.assertEquals(padded, BLOCK_O_PADDING)
def test_pad_round_string(self):
padded = tracking._pad_message("x" * KEY_SIZE)
from r2.lib.tracking import _pad_message, KEY_SIZE
padded = _pad_message("x" * KEY_SIZE)
self.assertEquals(len(padded), KEY_SIZE * 2)
self.assertEquals(padded[KEY_SIZE:], BLOCK_O_PADDING)
def test_unpad_empty_message(self):
unpadded = tracking._unpad_message("")
from r2.lib.tracking import _unpad_message
unpadded = _unpad_message("")
self.assertEquals(unpadded, "")
def test_unpad_evil_message(self):
from r2.lib.tracking import _unpad_message
evil = ("a" * 88) + chr(57)
result = tracking._unpad_message(evil)
result = _unpad_message(evil)
self.assertEquals(result, "")
def test_padding_roundtrip(self):
tested = tracking._unpad_message(tracking._pad_message(MESSAGE))
from r2.lib.tracking import _unpad_message, _pad_message
tested = _unpad_message(_pad_message(MESSAGE))
self.assertEquals(MESSAGE, tested)
class TestEncryption(RedditTestCase):
def test_encryption_roundtrip(self):
tested = tracking.decrypt(tracking.encrypt(MESSAGE))
self.assertEquals(MESSAGE, tested)
class TestEncryption(unittest.TestCase):
def test_salt(self):
from r2.lib.tracking import _make_salt, SALT_SIZE
self.assertEquals(len(_make_salt()), SALT_SIZE)
def test_encrypt(self):
from r2.lib.tracking import _encrypt, SALT_SIZE
encrypted = _encrypt(
"a" * SALT_SIZE,
MESSAGE,
SECRET,
)
self.assertEquals(encrypted, ENCRYPTED)
def test_decrypt(self):
from r2.lib.tracking import _decrypt
decrypted = _decrypt(ENCRYPTED, SECRET)
self.assertEquals(MESSAGE, decrypted)