From c1d2801be212caaa397dd9b26a52f11e8a8398cb Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Tue, 3 Feb 2026 09:16:55 -0500 Subject: [PATCH] fix: reject reserved script names for crew folders --- lib/crewai/src/crewai/cli/constants.py | 9 ++- lib/crewai/src/crewai/cli/create_crew.py | 50 +++++++++++- lib/crewai/src/crewai/cli/provider.py | 97 ++++++++++++------------ lib/crewai/tests/cli/test_create_crew.py | 17 +++++ 4 files changed, 118 insertions(+), 55 deletions(-) diff --git a/lib/crewai/src/crewai/cli/constants.py b/lib/crewai/src/crewai/cli/constants.py index a3755b1a6..4de0d0082 100644 --- a/lib/crewai/src/crewai/cli/constants.py +++ b/lib/crewai/src/crewai/cli/constants.py @@ -1,10 +1,13 @@ +from typing import Any + + DEFAULT_CREWAI_ENTERPRISE_URL = "https://app.crewai.com" CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER = "workos" CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE = "client_01JNJQWBJ4SPFN3SWJM5T7BDG8" CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID = "client_01JYT06R59SP0NXYGD994NFXXX" CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN = "login.crewai.com" -ENV_VARS = { +ENV_VARS: dict[str, list[dict[str, Any]]] = { "openai": [ { "prompt": "Enter your OPENAI API key (press Enter to skip)", @@ -112,7 +115,7 @@ ENV_VARS = { } -PROVIDERS = [ +PROVIDERS: list[str] = [ "openai", "anthropic", "gemini", @@ -127,7 +130,7 @@ PROVIDERS = [ "sambanova", ] -MODELS = { +MODELS: dict[str, list[str]] = { "openai": [ "gpt-4", "gpt-4.1", diff --git a/lib/crewai/src/crewai/cli/create_crew.py b/lib/crewai/src/crewai/cli/create_crew.py index e4d84e8bc..51e2f00ac 100644 --- a/lib/crewai/src/crewai/cli/create_crew.py +++ b/lib/crewai/src/crewai/cli/create_crew.py @@ -3,6 +3,7 @@ import shutil import sys import click +import tomli from crewai.cli.constants import ENV_VARS, MODELS from crewai.cli.provider import ( @@ -13,7 +14,31 @@ from crewai.cli.provider import ( from crewai.cli.utils import copy_template, load_env_vars, write_env_file -def create_folder_structure(name, parent_folder=None): +def get_reserved_script_names() -> set[str]: + """Get reserved script names from pyproject.toml template. + + Returns: + Set of reserved script names that would conflict with crew folder names. + """ + package_dir = Path(__file__).parent + template_path = package_dir / "templates" / "crew" / "pyproject.toml" + + with open(template_path, "r") as f: + template_content = f.read() + + template_content = template_content.replace("{{folder_name}}", "_placeholder_") + template_content = template_content.replace("{{name}}", "placeholder") + template_content = template_content.replace("{{crew_name}}", "Placeholder") + + template_data = tomli.loads(template_content) + script_names = set(template_data.get("project", {}).get("scripts", {}).keys()) + script_names.discard("_placeholder_") + return script_names + + +def create_folder_structure( + name: str, parent_folder: str | None = None +) -> tuple[Path, str, str]: import keyword import re @@ -51,6 +76,14 @@ def create_folder_structure(name, parent_folder=None): f"Project name '{name}' would generate invalid Python module name '{folder_name}'" ) + reserved_names = get_reserved_script_names() + if folder_name in reserved_names: + raise ValueError( + f"Project name '{name}' would generate folder name '{folder_name}' which is reserved. " + f"Reserved names are: {', '.join(sorted(reserved_names))}. " + "Please choose a different name." + ) + class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "") class_name = re.sub(r"[^a-zA-Z0-9_]", "", class_name) @@ -114,7 +147,9 @@ def create_folder_structure(name, parent_folder=None): return folder_path, folder_name, class_name -def copy_template_files(folder_path, name, class_name, parent_folder): +def copy_template_files( + folder_path: Path, name: str, class_name: str, parent_folder: str | None +) -> None: package_dir = Path(__file__).parent templates_dir = package_dir / "templates" / "crew" @@ -155,7 +190,12 @@ def copy_template_files(folder_path, name, class_name, parent_folder): copy_template(src_file, dst_file, name, class_name, folder_path.name) -def create_crew(name, provider=None, skip_provider=False, parent_folder=None): +def create_crew( + name: str, + provider: str | None = None, + skip_provider: bool = False, + parent_folder: str | None = None, +) -> None: folder_path, folder_name, class_name = create_folder_structure(name, parent_folder) env_vars = load_env_vars(folder_path) if not skip_provider: @@ -189,7 +229,9 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None): if selected_provider is None: # User typed 'q' click.secho("Exiting...", fg="yellow") sys.exit(0) - if selected_provider: # Valid selection + if selected_provider and isinstance( + selected_provider, str + ): # Valid selection break click.secho( "No provider selected. Please try again or press 'q' to exit.", fg="red" diff --git a/lib/crewai/src/crewai/cli/provider.py b/lib/crewai/src/crewai/cli/provider.py index ec6edc0cb..6de337b85 100644 --- a/lib/crewai/src/crewai/cli/provider.py +++ b/lib/crewai/src/crewai/cli/provider.py @@ -1,8 +1,10 @@ from collections import defaultdict +from collections.abc import Sequence import json import os from pathlib import Path import time +from typing import Any import certifi import click @@ -11,16 +13,15 @@ import requests from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS -def select_choice(prompt_message, choices): - """ - Presents a list of choices to the user and prompts them to select one. +def select_choice(prompt_message: str, choices: Sequence[str]) -> str | None: + """Presents a list of choices to the user and prompts them to select one. Args: - - prompt_message (str): The message to display to the user before presenting the choices. - - choices (list): A list of options to present to the user. + prompt_message: The message to display to the user before presenting the choices. + choices: A list of options to present to the user. Returns: - - str: The selected choice from the list, or None if the user chooses to quit. + The selected choice from the list, or None if the user chooses to quit. """ provider_models = get_provider_data() @@ -52,16 +53,14 @@ def select_choice(prompt_message, choices): ) -def select_provider(provider_models): - """ - Presents a list of providers to the user and prompts them to select one. +def select_provider(provider_models: dict[str, list[str]]) -> str | None | bool: + """Presents a list of providers to the user and prompts them to select one. Args: - - provider_models (dict): A dictionary of provider models. + provider_models: A dictionary of provider models. Returns: - - str: The selected provider - - None: If user explicitly quits + The selected provider, None if user explicitly quits, or False if no selection. """ predefined_providers = [p.lower() for p in PROVIDERS] all_providers = sorted(set(predefined_providers + list(provider_models.keys()))) @@ -80,16 +79,15 @@ def select_provider(provider_models): return provider.lower() if provider else False -def select_model(provider, provider_models): - """ - Presents a list of models for a given provider to the user and prompts them to select one. +def select_model(provider: str, provider_models: dict[str, list[str]]) -> str | None: + """Presents a list of models for a given provider to the user and prompts them to select one. Args: - - provider (str): The provider for which to select a model. - - provider_models (dict): A dictionary of provider models. + provider: The provider for which to select a model. + provider_models: A dictionary of provider models. Returns: - - str: The selected model, or None if the operation is aborted or an invalid selection is made. + The selected model, or None if the operation is aborted or an invalid selection is made. """ predefined_providers = [p.lower() for p in PROVIDERS] @@ -107,16 +105,17 @@ def select_model(provider, provider_models): ) -def load_provider_data(cache_file, cache_expiry): - """ - Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web. +def load_provider_data(cache_file: Path, cache_expiry: int) -> dict[str, Any] | None: + """Loads provider data from a cache file if it exists and is not expired. + + If the cache is expired or corrupted, it fetches the data from the web. Args: - - cache_file (Path): The path to the cache file. - - cache_expiry (int): The cache expiry time in seconds. + cache_file: The path to the cache file. + cache_expiry: The cache expiry time in seconds. Returns: - - dict or None: The loaded provider data or None if the operation fails. + The loaded provider data or None if the operation fails. """ current_time = time.time() if ( @@ -137,32 +136,31 @@ def load_provider_data(cache_file, cache_expiry): return fetch_provider_data(cache_file) -def read_cache_file(cache_file): - """ - Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON. +def read_cache_file(cache_file: Path) -> dict[str, Any] | None: + """Reads and returns the JSON content from a cache file. Args: - - cache_file (Path): The path to the cache file. + cache_file: The path to the cache file. Returns: - - dict or None: The JSON content of the cache file or None if the JSON is invalid. + The JSON content of the cache file or None if the JSON is invalid. """ try: with open(cache_file, "r") as f: - return json.load(f) + data: dict[str, Any] = json.load(f) + return data except json.JSONDecodeError: return None -def fetch_provider_data(cache_file): - """ - Fetches provider data from a specified URL and caches it to a file. +def fetch_provider_data(cache_file: Path) -> dict[str, Any] | None: + """Fetches provider data from a specified URL and caches it to a file. Args: - - cache_file (Path): The path to the cache file. + cache_file: The path to the cache file. Returns: - - dict or None: The fetched provider data or None if the operation fails. + The fetched provider data or None if the operation fails. """ ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where() @@ -180,36 +178,39 @@ def fetch_provider_data(cache_file): return None -def download_data(response): - """ - Downloads data from a given HTTP response and returns the JSON content. +def download_data(response: requests.Response) -> dict[str, Any]: + """Downloads data from a given HTTP response and returns the JSON content. Args: - - response (requests.Response): The HTTP response object. + response: The HTTP response object. Returns: - - dict: The JSON content of the response. + The JSON content of the response. """ total_size = int(response.headers.get("content-length", 0)) block_size = 8192 - data_chunks = [] + data_chunks: list[bytes] = [] + bar: Any with click.progressbar( length=total_size, label="Downloading", show_pos=True - ) as progress_bar: + ) as bar: for chunk in response.iter_content(block_size): if chunk: data_chunks.append(chunk) - progress_bar.update(len(chunk)) + bar.update(len(chunk)) data_content = b"".join(data_chunks) - return json.loads(data_content.decode("utf-8")) + result: dict[str, Any] = json.loads(data_content.decode("utf-8")) + return result -def get_provider_data(): - """ - Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models. +def get_provider_data() -> dict[str, list[str]] | None: + """Retrieves provider data from a cache file. + + Filters out models based on provider criteria, and returns a dictionary of providers + mapped to their models. Returns: - - dict or None: A dictionary of providers mapped to their models or None if the operation fails. + A dictionary of providers mapped to their models or None if the operation fails. """ cache_dir = Path.home() / ".crewai" cache_dir.mkdir(exist_ok=True) diff --git a/lib/crewai/tests/cli/test_create_crew.py b/lib/crewai/tests/cli/test_create_crew.py index 638be9b5d..478372f7f 100644 --- a/lib/crewai/tests/cli/test_create_crew.py +++ b/lib/crewai/tests/cli/test_create_crew.py @@ -296,6 +296,23 @@ def test_create_folder_structure_folder_name_validation(): shutil.rmtree(folder_path) +def test_create_folder_structure_rejects_reserved_names(): + """Test that reserved script names are rejected to prevent pyproject.toml conflicts.""" + with tempfile.TemporaryDirectory() as temp_dir: + reserved_names = ["test", "train", "replay", "run_crew", "run_with_trigger"] + + for reserved_name in reserved_names: + with pytest.raises(ValueError, match="which is reserved"): + create_folder_structure(reserved_name, parent_folder=temp_dir) + + with pytest.raises(ValueError, match="which is reserved"): + create_folder_structure(f"{reserved_name}/", parent_folder=temp_dir) + + capitalized = reserved_name.capitalize() + with pytest.raises(ValueError, match="which is reserved"): + create_folder_structure(capitalized, parent_folder=temp_dir) + + @mock.patch("crewai.cli.create_crew.create_folder_structure") @mock.patch("crewai.cli.create_crew.copy_template") @mock.patch("crewai.cli.create_crew.load_env_vars")