mirror of
https://github.com/reddit-archive/reddit.git
synced 2026-01-28 16:28:01 -05:00
tests: Fix up tests for r2.lib.tracking.
This commit is contained in:
@@ -66,9 +66,9 @@ def _unpad_message(text):
|
||||
return unpadded
|
||||
|
||||
|
||||
def _make_cipher(initialization_vector):
|
||||
def _make_cipher(initialization_vector, secret):
|
||||
"""Return a block cipher object for use in `encrypt` and `decrypt`."""
|
||||
return AES.new(g.tracking_secret[:KEY_SIZE], AES.MODE_CBC,
|
||||
return AES.new(secret[:KEY_SIZE], AES.MODE_CBC,
|
||||
initialization_vector[:AES.block_size])
|
||||
|
||||
|
||||
@@ -83,15 +83,22 @@ def encrypt(plaintext):
|
||||
|
||||
"""
|
||||
|
||||
salt = _make_salt()
|
||||
return _encrypt(salt, plaintext, g.tracking_secret)
|
||||
|
||||
|
||||
def _make_salt():
|
||||
# we want SALT_SIZE letters of salt text, but we're generating random bytes
|
||||
# so we'll calculate how many bytes we need to get SALT_SIZE characters of
|
||||
# base64 output. because of padding, this only works for SALT_SIZE % 4 == 0
|
||||
assert SALT_SIZE % 4 == 0
|
||||
salt_byte_count = (SALT_SIZE / 4) * 3
|
||||
|
||||
salt_bytes = get_random_bytes(salt_byte_count)
|
||||
salt = base64.b64encode(salt_bytes)
|
||||
cipher = _make_cipher(salt)
|
||||
return base64.b64encode(salt_bytes)
|
||||
|
||||
|
||||
def _encrypt(salt, plaintext, secret):
|
||||
cipher = _make_cipher(salt, secret)
|
||||
|
||||
padded = _pad_message(plaintext)
|
||||
ciphertext = cipher.encrypt(padded)
|
||||
@@ -107,10 +114,14 @@ def decrypt(encrypted):
|
||||
|
||||
"""
|
||||
|
||||
return _decrypt(encrypted, g.tracking_secret)
|
||||
|
||||
|
||||
def _decrypt(encrypted, secret):
|
||||
encrypted = urllib.unquote_plus(encrypted)
|
||||
salt, encoded = encrypted[:SALT_SIZE], encrypted[SALT_SIZE:]
|
||||
ciphertext = base64.b64decode(encoded)
|
||||
cipher = _make_cipher(salt)
|
||||
cipher = _make_cipher(salt, secret)
|
||||
padded = cipher.decrypt(ciphertext)
|
||||
return _unpad_message(padded)
|
||||
|
||||
|
||||
@@ -22,41 +22,59 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from r2.lib import tracking
|
||||
from r2.tests import RedditTestCase
|
||||
|
||||
|
||||
KEY_SIZE = tracking.KEY_SIZE
|
||||
MESSAGE = "the quick brown fox jumped over..."
|
||||
BLOCK_O_PADDING = ("\x10\x10\x10\x10\x10\x10\x10\x10"
|
||||
"\x10\x10\x10\x10\x10\x10\x10\x10")
|
||||
SECRET = "abcdefghijklmnopqrstuvwxyz"
|
||||
ENCRYPTED = ("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaIbzth1QTzJxzHbHGnJywG5V1uR3tWtSB"
|
||||
"8hTyIcfg6rUZC4Wo0pT8jkEt9o1c%2FkTn")
|
||||
|
||||
|
||||
class TestPadding(unittest.TestCase):
|
||||
def test_pad_empty_string(self):
|
||||
padded = tracking._pad_message("")
|
||||
from r2.lib.tracking import _pad_message
|
||||
padded = _pad_message("")
|
||||
self.assertEquals(padded, BLOCK_O_PADDING)
|
||||
|
||||
def test_pad_round_string(self):
|
||||
padded = tracking._pad_message("x" * KEY_SIZE)
|
||||
from r2.lib.tracking import _pad_message, KEY_SIZE
|
||||
padded = _pad_message("x" * KEY_SIZE)
|
||||
self.assertEquals(len(padded), KEY_SIZE * 2)
|
||||
self.assertEquals(padded[KEY_SIZE:], BLOCK_O_PADDING)
|
||||
|
||||
def test_unpad_empty_message(self):
|
||||
unpadded = tracking._unpad_message("")
|
||||
from r2.lib.tracking import _unpad_message
|
||||
unpadded = _unpad_message("")
|
||||
self.assertEquals(unpadded, "")
|
||||
|
||||
def test_unpad_evil_message(self):
|
||||
from r2.lib.tracking import _unpad_message
|
||||
evil = ("a" * 88) + chr(57)
|
||||
result = tracking._unpad_message(evil)
|
||||
result = _unpad_message(evil)
|
||||
self.assertEquals(result, "")
|
||||
|
||||
def test_padding_roundtrip(self):
|
||||
tested = tracking._unpad_message(tracking._pad_message(MESSAGE))
|
||||
from r2.lib.tracking import _unpad_message, _pad_message
|
||||
tested = _unpad_message(_pad_message(MESSAGE))
|
||||
self.assertEquals(MESSAGE, tested)
|
||||
|
||||
|
||||
class TestEncryption(RedditTestCase):
|
||||
def test_encryption_roundtrip(self):
|
||||
tested = tracking.decrypt(tracking.encrypt(MESSAGE))
|
||||
self.assertEquals(MESSAGE, tested)
|
||||
class TestEncryption(unittest.TestCase):
|
||||
def test_salt(self):
|
||||
from r2.lib.tracking import _make_salt, SALT_SIZE
|
||||
self.assertEquals(len(_make_salt()), SALT_SIZE)
|
||||
|
||||
def test_encrypt(self):
|
||||
from r2.lib.tracking import _encrypt, SALT_SIZE
|
||||
encrypted = _encrypt(
|
||||
"a" * SALT_SIZE,
|
||||
MESSAGE,
|
||||
SECRET,
|
||||
)
|
||||
self.assertEquals(encrypted, ENCRYPTED)
|
||||
|
||||
def test_decrypt(self):
|
||||
from r2.lib.tracking import _decrypt
|
||||
decrypted = _decrypt(ENCRYPTED, SECRET)
|
||||
self.assertEquals(MESSAGE, decrypted)
|
||||
|
||||
Reference in New Issue
Block a user