diff --git a/r2/r2/lib/tracking.py b/r2/r2/lib/tracking.py index 0e1ee5b35..38c72d259 100644 --- a/r2/r2/lib/tracking.py +++ b/r2/r2/lib/tracking.py @@ -66,9 +66,9 @@ def _unpad_message(text): return unpadded -def _make_cipher(initialization_vector): +def _make_cipher(initialization_vector, secret): """Return a block cipher object for use in `encrypt` and `decrypt`.""" - return AES.new(g.tracking_secret[:KEY_SIZE], AES.MODE_CBC, + return AES.new(secret[:KEY_SIZE], AES.MODE_CBC, initialization_vector[:AES.block_size]) @@ -83,15 +83,22 @@ def encrypt(plaintext): """ + salt = _make_salt() + return _encrypt(salt, plaintext, g.tracking_secret) + + +def _make_salt(): # we want SALT_SIZE letters of salt text, but we're generating random bytes # so we'll calculate how many bytes we need to get SALT_SIZE characters of # base64 output. because of padding, this only works for SALT_SIZE % 4 == 0 assert SALT_SIZE % 4 == 0 salt_byte_count = (SALT_SIZE / 4) * 3 - salt_bytes = get_random_bytes(salt_byte_count) - salt = base64.b64encode(salt_bytes) - cipher = _make_cipher(salt) + return base64.b64encode(salt_bytes) + + +def _encrypt(salt, plaintext, secret): + cipher = _make_cipher(salt, secret) padded = _pad_message(plaintext) ciphertext = cipher.encrypt(padded) @@ -107,10 +114,14 @@ def decrypt(encrypted): """ + return _decrypt(encrypted, g.tracking_secret) + + +def _decrypt(encrypted, secret): encrypted = urllib.unquote_plus(encrypted) salt, encoded = encrypted[:SALT_SIZE], encrypted[SALT_SIZE:] ciphertext = base64.b64decode(encoded) - cipher = _make_cipher(salt) + cipher = _make_cipher(salt, secret) padded = cipher.decrypt(ciphertext) return _unpad_message(padded) diff --git a/r2/r2/tests/unit/lib/tracking_test.py b/r2/r2/tests/unit/lib/tracking_test.py index a525a3017..98866e85b 100644 --- a/r2/r2/tests/unit/lib/tracking_test.py +++ b/r2/r2/tests/unit/lib/tracking_test.py @@ -22,41 +22,59 @@ import unittest -from r2.lib import tracking -from r2.tests import RedditTestCase - -KEY_SIZE = tracking.KEY_SIZE MESSAGE = "the quick brown fox jumped over..." BLOCK_O_PADDING = ("\x10\x10\x10\x10\x10\x10\x10\x10" "\x10\x10\x10\x10\x10\x10\x10\x10") +SECRET = "abcdefghijklmnopqrstuvwxyz" +ENCRYPTED = ("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaIbzth1QTzJxzHbHGnJywG5V1uR3tWtSB" + "8hTyIcfg6rUZC4Wo0pT8jkEt9o1c%2FkTn") class TestPadding(unittest.TestCase): def test_pad_empty_string(self): - padded = tracking._pad_message("") + from r2.lib.tracking import _pad_message + padded = _pad_message("") self.assertEquals(padded, BLOCK_O_PADDING) def test_pad_round_string(self): - padded = tracking._pad_message("x" * KEY_SIZE) + from r2.lib.tracking import _pad_message, KEY_SIZE + padded = _pad_message("x" * KEY_SIZE) self.assertEquals(len(padded), KEY_SIZE * 2) self.assertEquals(padded[KEY_SIZE:], BLOCK_O_PADDING) def test_unpad_empty_message(self): - unpadded = tracking._unpad_message("") + from r2.lib.tracking import _unpad_message + unpadded = _unpad_message("") self.assertEquals(unpadded, "") def test_unpad_evil_message(self): + from r2.lib.tracking import _unpad_message evil = ("a" * 88) + chr(57) - result = tracking._unpad_message(evil) + result = _unpad_message(evil) self.assertEquals(result, "") def test_padding_roundtrip(self): - tested = tracking._unpad_message(tracking._pad_message(MESSAGE)) + from r2.lib.tracking import _unpad_message, _pad_message + tested = _unpad_message(_pad_message(MESSAGE)) self.assertEquals(MESSAGE, tested) -class TestEncryption(RedditTestCase): - def test_encryption_roundtrip(self): - tested = tracking.decrypt(tracking.encrypt(MESSAGE)) - self.assertEquals(MESSAGE, tested) +class TestEncryption(unittest.TestCase): + def test_salt(self): + from r2.lib.tracking import _make_salt, SALT_SIZE + self.assertEquals(len(_make_salt()), SALT_SIZE) + + def test_encrypt(self): + from r2.lib.tracking import _encrypt, SALT_SIZE + encrypted = _encrypt( + "a" * SALT_SIZE, + MESSAGE, + SECRET, + ) + self.assertEquals(encrypted, ENCRYPTED) + + def test_decrypt(self): + from r2.lib.tracking import _decrypt + decrypted = _decrypt(ENCRYPTED, SECRET) + self.assertEquals(MESSAGE, decrypted)