diff --git a/lib/crewai/src/crewai/cli/shared/token_manager.py b/lib/crewai/src/crewai/cli/shared/token_manager.py index 4546efd55..02c176924 100644 --- a/lib/crewai/src/crewai/cli/shared/token_manager.py +++ b/lib/crewai/src/crewai/cli/shared/token_manager.py @@ -3,103 +3,56 @@ import json import os from pathlib import Path import sys -from typing import BinaryIO, cast +import tempfile +from typing import Final, Literal, cast from cryptography.fernet import Fernet -if sys.platform == "win32": - import msvcrt -else: - import fcntl +_FERNET_KEY_LENGTH: Final[Literal[44]] = 44 class TokenManager: - def __init__(self, file_path: str = "tokens.enc") -> None: - """ - Initialize the TokenManager class. + """Manages encrypted token storage.""" - :param file_path: The file path to store the encrypted tokens. Default is "tokens.enc". + def __init__(self, file_path: str = "tokens.enc") -> None: + """Initialize the TokenManager. + + Args: + file_path: The file path to store encrypted tokens. """ self.file_path = file_path self.key = self._get_or_create_key() self.fernet = Fernet(self.key) - @staticmethod - def _acquire_lock(file_handle: BinaryIO) -> None: - """ - Acquire an exclusive lock on a file handle. - - Args: - file_handle: Open file handle to lock. - """ - if sys.platform == "win32": - msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1) - else: - fcntl.flock(file_handle.fileno(), fcntl.LOCK_EX) - - @staticmethod - def _release_lock(file_handle: BinaryIO) -> None: - """ - Release the lock on a file handle. - - Args: - file_handle: Open file handle to unlock. - """ - if sys.platform == "win32": - msvcrt.locking(file_handle.fileno(), msvcrt.LK_UNLCK, 1) - else: - fcntl.flock(file_handle.fileno(), fcntl.LOCK_UN) - def _get_or_create_key(self) -> bytes: - """ - Get or create the encryption key with file locking to prevent race conditions. + """Get or create the encryption key. Returns: - The encryption key. + The encryption key as bytes. """ - key_filename = "secret.key" - storage_path = self.get_secure_storage_path() + key_filename: str = "secret.key" - key = self.read_secure_file(key_filename) - if key is not None and len(key) == 44: + key = self._read_secure_file(key_filename) + if key is not None and len(key) == _FERNET_KEY_LENGTH: return key - lock_file_path = storage_path / f"{key_filename}.lock" - - try: - lock_file_path.touch() - - with open(lock_file_path, "r+b") as lock_file: - self._acquire_lock(lock_file) - try: - key = self.read_secure_file(key_filename) - if key is not None and len(key) == 44: - return key - - new_key = Fernet.generate_key() - self.save_secure_file(key_filename, new_key) - return new_key - finally: - try: - self._release_lock(lock_file) - except OSError: - pass - except OSError: - key = self.read_secure_file(key_filename) - if key is not None and len(key) == 44: - return key - - new_key = Fernet.generate_key() - self.save_secure_file(key_filename, new_key) + new_key = Fernet.generate_key() + if self._atomic_create_secure_file(key_filename, new_key): return new_key - def save_tokens(self, access_token: str, expires_at: int) -> None: - """ - Save the access token and its expiration time. + key = self._read_secure_file(key_filename) + if key is not None and len(key) == _FERNET_KEY_LENGTH: + return key - :param access_token: The access token to save. - :param expires_at: The UNIX timestamp of the expiration time. + raise RuntimeError("Failed to create or read encryption key") + + def save_tokens(self, access_token: str, expires_at: int) -> None: + """Save the access token and its expiration time. + + Args: + access_token: The access token to save. + expires_at: The UNIX timestamp of the expiration time. """ expiration_time = datetime.fromtimestamp(expires_at) data = { @@ -107,15 +60,15 @@ class TokenManager: "expiration": expiration_time.isoformat(), } encrypted_data = self.fernet.encrypt(json.dumps(data).encode()) - self.save_secure_file(self.file_path, encrypted_data) + self._atomic_write_secure_file(self.file_path, encrypted_data) def get_token(self) -> str | None: - """ - Get the access token if it is valid and not expired. + """Get the access token if it is valid and not expired. - :return: The access token if valid and not expired, otherwise None. + Returns: + The access token if valid and not expired, otherwise None. """ - encrypted_data = self.read_secure_file(self.file_path) + encrypted_data = self._read_secure_file(self.file_path) if encrypted_data is None: return None @@ -126,20 +79,18 @@ class TokenManager: if expiration <= datetime.now(): return None - return cast(str | None, data["access_token"]) + return cast(str | None, data.get("access_token")) def clear_tokens(self) -> None: - """ - Clear the tokens. - """ - self.delete_secure_file(self.file_path) + """Clear the stored tokens.""" + self._delete_secure_file(self.file_path) @staticmethod - def get_secure_storage_path() -> Path: - """ - Get the secure storage path based on the operating system. + def _get_secure_storage_path() -> Path: + """Get the secure storage path based on the operating system. - :return: The secure storage path. + Returns: + The secure storage path. """ if sys.platform == "win32": base_path = os.environ.get("LOCALAPPDATA") @@ -155,44 +106,81 @@ class TokenManager: return storage_path - def save_secure_file(self, filename: str, content: bytes) -> None: - """ - Save the content to a secure file. + def _atomic_create_secure_file(self, filename: str, content: bytes) -> bool: + """Create a file only if it doesn't exist. - :param filename: The name of the file. - :param content: The content to save. + Args: + filename: The name of the file. + content: The content to write. + + Returns: + True if file was created, False if it already exists. """ - storage_path = self.get_secure_storage_path() + storage_path = self._get_secure_storage_path() file_path = storage_path / filename - with open(file_path, "wb") as f: - f.write(content) + try: + fd = os.open(file_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o600) + try: + os.write(fd, content) + finally: + os.close(fd) + return True + except FileExistsError: + return False - os.chmod(file_path, 0o600) + def _atomic_write_secure_file(self, filename: str, content: bytes) -> None: + """Write content to a secure file. - def read_secure_file(self, filename: str) -> bytes | None: + Args: + filename: The name of the file. + content: The content to write. """ - Read the content of a secure file. - - :param filename: The name of the file. - :return: The content of the file if it exists, otherwise None. - """ - storage_path = self.get_secure_storage_path() + storage_path = self._get_secure_storage_path() file_path = storage_path / filename - if not file_path.exists(): + fd, temp_path = tempfile.mkstemp(dir=storage_path, prefix=f".{filename}.") + fd_closed = False + try: + os.write(fd, content) + os.close(fd) + fd_closed = True + os.chmod(temp_path, 0o600) + os.replace(temp_path, file_path) + except Exception: + if not fd_closed: + os.close(fd) + if os.path.exists(temp_path): + os.unlink(temp_path) + raise + + def _read_secure_file(self, filename: str) -> bytes | None: + """Read the content of a secure file. + + Args: + filename: The name of the file. + + Returns: + The content of the file if it exists, otherwise None. + """ + storage_path = self._get_secure_storage_path() + file_path = storage_path / filename + + try: + with open(file_path, "rb") as f: + return f.read() + except FileNotFoundError: return None - with open(file_path, "rb") as f: - return f.read() + def _delete_secure_file(self, filename: str) -> None: + """Delete a secure file. - def delete_secure_file(self, filename: str) -> None: + Args: + filename: The name of the file. """ - Delete the secure file. - - :param filename: The name of the file. - """ - storage_path = self.get_secure_storage_path() + storage_path = self._get_secure_storage_path() file_path = storage_path / filename - if file_path.exists(): - file_path.unlink(missing_ok=True) + try: + file_path.unlink() + except FileNotFoundError: + pass diff --git a/lib/crewai/tests/cli/test_token_manager.py b/lib/crewai/tests/cli/test_token_manager.py index 6ca859278..5d7fc5790 100644 --- a/lib/crewai/tests/cli/test_token_manager.py +++ b/lib/crewai/tests/cli/test_token_manager.py @@ -1,7 +1,12 @@ +"""Tests for TokenManager with atomic file operations.""" + import json +import os +import tempfile import unittest from datetime import datetime, timedelta -from unittest.mock import MagicMock, patch +from pathlib import Path +from unittest.mock import patch from cryptography.fernet import Fernet @@ -9,15 +14,22 @@ from crewai.cli.shared.token_manager import TokenManager class TestTokenManager(unittest.TestCase): + """Test cases for TokenManager.""" + @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") - def setUp(self, mock_get_key): + def setUp(self, mock_get_key: unittest.mock.MagicMock) -> None: + """Set up test fixtures.""" mock_get_key.return_value = Fernet.generate_key() self.token_manager = TokenManager() - @patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file") - @patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file") + @patch("crewai.cli.shared.token_manager.TokenManager._read_secure_file") @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") - def test_get_or_create_key_existing(self, mock_get_or_create, mock_save, mock_read): + def test_get_or_create_key_existing( + self, + mock_get_or_create: unittest.mock.MagicMock, + mock_read: unittest.mock.MagicMock, + ) -> None: + """Test that existing key is returned when present.""" mock_key = Fernet.generate_key() mock_get_or_create.return_value = mock_key @@ -26,40 +38,49 @@ class TestTokenManager(unittest.TestCase): self.assertEqual(result, mock_key) - @patch("crewai.cli.shared.token_manager.Fernet.generate_key") - @patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file") - @patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file") - @patch("crewai.cli.shared.token_manager.TokenManager._acquire_lock") - @patch("crewai.cli.shared.token_manager.TokenManager._release_lock") - @patch("builtins.open", new_callable=unittest.mock.mock_open) - def test_get_or_create_key_new( - self, mock_open, mock_release_lock, mock_acquire_lock, mock_save, mock_read, mock_generate - ): - mock_key = b"new_key" - mock_read.return_value = None - mock_generate.return_value = mock_key + def test_get_or_create_key_new(self) -> None: + """Test that new key is created when none exists.""" + mock_key = Fernet.generate_key() - result = self.token_manager._get_or_create_key() + with ( + patch.object(self.token_manager, "_read_secure_file", return_value=None) as mock_read, + patch.object(self.token_manager, "_atomic_create_secure_file", return_value=True) as mock_atomic_create, + patch("crewai.cli.shared.token_manager.Fernet.generate_key", return_value=mock_key) as mock_generate, + ): + result = self.token_manager._get_or_create_key() - self.assertEqual(result, mock_key) - # read_secure_file is called twice: once for fast path, once inside lock - self.assertEqual(mock_read.call_count, 2) - mock_read.assert_called_with("secret.key") - mock_generate.assert_called_once() - mock_save.assert_called_once_with("secret.key", mock_key) - # Verify lock was acquired and released - mock_acquire_lock.assert_called_once() - mock_release_lock.assert_called_once() + self.assertEqual(result, mock_key) + mock_read.assert_called_with("secret.key") + mock_generate.assert_called_once() + mock_atomic_create.assert_called_once_with("secret.key", mock_key) - @patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file") - def test_save_tokens(self, mock_save): + def test_get_or_create_key_race_condition(self) -> None: + """Test that another process's key is used when atomic create fails.""" + our_key = Fernet.generate_key() + their_key = Fernet.generate_key() + + with ( + patch.object(self.token_manager, "_read_secure_file", side_effect=[None, their_key]) as mock_read, + patch.object(self.token_manager, "_atomic_create_secure_file", return_value=False) as mock_atomic_create, + patch("crewai.cli.shared.token_manager.Fernet.generate_key", return_value=our_key), + ): + result = self.token_manager._get_or_create_key() + + self.assertEqual(result, their_key) + self.assertEqual(mock_read.call_count, 2) + + @patch("crewai.cli.shared.token_manager.TokenManager._atomic_write_secure_file") + def test_save_tokens( + self, mock_write: unittest.mock.MagicMock + ) -> None: + """Test saving tokens encrypts and writes atomically.""" access_token = "test_token" expires_at = int((datetime.now() + timedelta(seconds=3600)).timestamp()) self.token_manager.save_tokens(access_token, expires_at) - mock_save.assert_called_once() - args = mock_save.call_args[0] + mock_write.assert_called_once() + args = mock_write.call_args[0] self.assertEqual(args[0], "tokens.enc") decrypted_data = self.token_manager.fernet.decrypt(args[1]) data = json.loads(decrypted_data) @@ -67,8 +88,11 @@ class TestTokenManager(unittest.TestCase): expiration = datetime.fromisoformat(data["expiration"]) self.assertEqual(expiration, datetime.fromtimestamp(expires_at)) - @patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file") - def test_get_token_valid(self, mock_read): + @patch("crewai.cli.shared.token_manager.TokenManager._read_secure_file") + def test_get_token_valid( + self, mock_read: unittest.mock.MagicMock + ) -> None: + """Test getting a valid non-expired token.""" access_token = "test_token" expiration = (datetime.now() + timedelta(hours=1)).isoformat() data = {"access_token": access_token, "expiration": expiration} @@ -79,8 +103,11 @@ class TestTokenManager(unittest.TestCase): self.assertEqual(result, access_token) - @patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file") - def test_get_token_expired(self, mock_read): + @patch("crewai.cli.shared.token_manager.TokenManager._read_secure_file") + def test_get_token_expired( + self, mock_read: unittest.mock.MagicMock + ) -> None: + """Test that expired token returns None.""" access_token = "test_token" expiration = (datetime.now() - timedelta(hours=1)).isoformat() data = {"access_token": access_token, "expiration": expiration} @@ -91,76 +118,177 @@ class TestTokenManager(unittest.TestCase): self.assertIsNone(result) - @patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path") - @patch("builtins.open", new_callable=unittest.mock.mock_open) - @patch("crewai.cli.shared.token_manager.os.chmod") - def test_save_secure_file(self, mock_chmod, mock_open, mock_get_path): - mock_path = MagicMock() - mock_get_path.return_value = mock_path - filename = "test_file.txt" - content = b"test_content" + @patch("crewai.cli.shared.token_manager.TokenManager._read_secure_file") + def test_get_token_not_found( + self, mock_read: unittest.mock.MagicMock + ) -> None: + """Test that missing token file returns None.""" + mock_read.return_value = None - self.token_manager.save_secure_file(filename, content) - - mock_path.__truediv__.assert_called_once_with(filename) - mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "wb") - mock_open().write.assert_called_once_with(content) - mock_chmod.assert_called_once_with(mock_path.__truediv__.return_value, 0o600) - - @patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path") - @patch( - "builtins.open", new_callable=unittest.mock.mock_open, read_data=b"test_content" - ) - def test_read_secure_file_exists(self, mock_open, mock_get_path): - mock_path = MagicMock() - mock_get_path.return_value = mock_path - mock_path.__truediv__.return_value.exists.return_value = True - filename = "test_file.txt" - - result = self.token_manager.read_secure_file(filename) - - self.assertEqual(result, b"test_content") - mock_path.__truediv__.assert_called_once_with(filename) - mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "rb") - - @patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path") - def test_read_secure_file_not_exists(self, mock_get_path): - mock_path = MagicMock() - mock_get_path.return_value = mock_path - mock_path.__truediv__.return_value.exists.return_value = False - filename = "test_file.txt" - - result = self.token_manager.read_secure_file(filename) + result = self.token_manager.get_token() self.assertIsNone(result) - mock_path.__truediv__.assert_called_once_with(filename) - - @patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path") - def test_clear_tokens(self, mock_get_path): - mock_path = MagicMock() - mock_get_path.return_value = mock_path + @patch("crewai.cli.shared.token_manager.TokenManager._delete_secure_file") + def test_clear_tokens( + self, mock_delete: unittest.mock.MagicMock + ) -> None: + """Test clearing tokens deletes the token file.""" self.token_manager.clear_tokens() - mock_path.__truediv__.assert_called_once_with("tokens.enc") - mock_path.__truediv__.return_value.unlink.assert_called_once_with( - missing_ok=True - ) + mock_delete.assert_called_once_with("tokens.enc") - @patch("crewai.cli.shared.token_manager.Fernet.generate_key") - @patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file") - @patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file") - @patch("builtins.open", side_effect=OSError(9, "Bad file descriptor")) - def test_get_or_create_key_oserror_fallback( - self, mock_open, mock_save, mock_read, mock_generate - ): - """Test that OSError during file locking falls back to lock-free creation.""" - mock_key = Fernet.generate_key() - mock_read.return_value = None - mock_generate.return_value = mock_key - result = self.token_manager._get_or_create_key() +class TestAtomicFileOperations(unittest.TestCase): + """Test atomic file operations directly.""" - self.assertEqual(result, mock_key) - self.assertGreaterEqual(mock_generate.call_count, 1) - self.assertGreaterEqual(mock_save.call_count, 1) + def setUp(self) -> None: + """Set up test fixtures with temp directory.""" + self.temp_dir = tempfile.mkdtemp() + self.original_get_path = TokenManager._get_secure_storage_path + + # Patch to use temp directory + def mock_get_path() -> Path: + return Path(self.temp_dir) + + TokenManager._get_secure_storage_path = staticmethod(mock_get_path) + + def tearDown(self) -> None: + """Clean up temp directory.""" + TokenManager._get_secure_storage_path = staticmethod(self.original_get_path) + import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) + + @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + def test_atomic_create_new_file( + self, mock_get_key: unittest.mock.MagicMock + ) -> None: + """Test atomic create succeeds for new file.""" + mock_get_key.return_value = Fernet.generate_key() + tm = TokenManager() + + result = tm._atomic_create_secure_file("test.txt", b"content") + + self.assertTrue(result) + file_path = Path(self.temp_dir) / "test.txt" + self.assertTrue(file_path.exists()) + self.assertEqual(file_path.read_bytes(), b"content") + self.assertEqual(file_path.stat().st_mode & 0o777, 0o600) + + @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + def test_atomic_create_existing_file( + self, mock_get_key: unittest.mock.MagicMock + ) -> None: + """Test atomic create fails for existing file.""" + mock_get_key.return_value = Fernet.generate_key() + tm = TokenManager() + + # Create file first + file_path = Path(self.temp_dir) / "test.txt" + file_path.write_bytes(b"original") + + result = tm._atomic_create_secure_file("test.txt", b"new content") + + self.assertFalse(result) + self.assertEqual(file_path.read_bytes(), b"original") + + @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + def test_atomic_write_new_file( + self, mock_get_key: unittest.mock.MagicMock + ) -> None: + """Test atomic write creates new file.""" + mock_get_key.return_value = Fernet.generate_key() + tm = TokenManager() + + tm._atomic_write_secure_file("test.txt", b"content") + + file_path = Path(self.temp_dir) / "test.txt" + self.assertTrue(file_path.exists()) + self.assertEqual(file_path.read_bytes(), b"content") + self.assertEqual(file_path.stat().st_mode & 0o777, 0o600) + + @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + def test_atomic_write_overwrites( + self, mock_get_key: unittest.mock.MagicMock + ) -> None: + """Test atomic write overwrites existing file.""" + mock_get_key.return_value = Fernet.generate_key() + tm = TokenManager() + + file_path = Path(self.temp_dir) / "test.txt" + file_path.write_bytes(b"original") + + tm._atomic_write_secure_file("test.txt", b"new content") + + self.assertEqual(file_path.read_bytes(), b"new content") + + @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + def test_atomic_write_no_temp_file_on_success( + self, mock_get_key: unittest.mock.MagicMock + ) -> None: + """Test that temp file is cleaned up after successful write.""" + mock_get_key.return_value = Fernet.generate_key() + tm = TokenManager() + + tm._atomic_write_secure_file("test.txt", b"content") + + # Check no temp files remain + temp_files = list(Path(self.temp_dir).glob(".test.txt.*")) + self.assertEqual(len(temp_files), 0) + + @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + def test_read_secure_file_exists( + self, mock_get_key: unittest.mock.MagicMock + ) -> None: + """Test reading existing file.""" + mock_get_key.return_value = Fernet.generate_key() + tm = TokenManager() + + file_path = Path(self.temp_dir) / "test.txt" + file_path.write_bytes(b"content") + + result = tm._read_secure_file("test.txt") + + self.assertEqual(result, b"content") + + @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + def test_read_secure_file_not_exists( + self, mock_get_key: unittest.mock.MagicMock + ) -> None: + """Test reading non-existent file returns None.""" + mock_get_key.return_value = Fernet.generate_key() + tm = TokenManager() + + result = tm._read_secure_file("nonexistent.txt") + + self.assertIsNone(result) + + @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + def test_delete_secure_file_exists( + self, mock_get_key: unittest.mock.MagicMock + ) -> None: + """Test deleting existing file.""" + mock_get_key.return_value = Fernet.generate_key() + tm = TokenManager() + + file_path = Path(self.temp_dir) / "test.txt" + file_path.write_bytes(b"content") + + tm._delete_secure_file("test.txt") + + self.assertFalse(file_path.exists()) + + @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + def test_delete_secure_file_not_exists( + self, mock_get_key: unittest.mock.MagicMock + ) -> None: + """Test deleting non-existent file doesn't raise.""" + mock_get_key.return_value = Fernet.generate_key() + tm = TokenManager() + + # Should not raise + tm._delete_secure_file("nonexistent.txt") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/lib/crewai/tests/cli/tools/test_main.py b/lib/crewai/tests/cli/tools/test_main.py index fa1c5fa44..71acea76d 100644 --- a/lib/crewai/tests/cli/tools/test_main.py +++ b/lib/crewai/tests/cli/tools/test_main.py @@ -31,7 +31,7 @@ def tool_command(): with tempfile.TemporaryDirectory() as temp_dir: # Mock the secure storage path to use the temp directory with patch.object( - TokenManager, "get_secure_storage_path", return_value=Path(temp_dir) + TokenManager, "_get_secure_storage_path", return_value=Path(temp_dir) ): TokenManager().save_tokens( "test-token", (datetime.now() + timedelta(seconds=36000)).timestamp()