fix: reject reserved script names for crew folders

This commit is contained in:
Greyson LaLonde
2026-02-03 09:16:55 -05:00
committed by GitHub
parent 6a8483fcb6
commit c1d2801be2
4 changed files with 118 additions and 55 deletions

View File

@@ -1,10 +1,13 @@
from typing import Any
DEFAULT_CREWAI_ENTERPRISE_URL = "https://app.crewai.com"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER = "workos"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE = "client_01JNJQWBJ4SPFN3SWJM5T7BDG8"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID = "client_01JYT06R59SP0NXYGD994NFXXX"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN = "login.crewai.com"
ENV_VARS = {
ENV_VARS: dict[str, list[dict[str, Any]]] = {
"openai": [
{
"prompt": "Enter your OPENAI API key (press Enter to skip)",
@@ -112,7 +115,7 @@ ENV_VARS = {
}
PROVIDERS = [
PROVIDERS: list[str] = [
"openai",
"anthropic",
"gemini",
@@ -127,7 +130,7 @@ PROVIDERS = [
"sambanova",
]
MODELS = {
MODELS: dict[str, list[str]] = {
"openai": [
"gpt-4",
"gpt-4.1",

View File

@@ -3,6 +3,7 @@ import shutil
import sys
import click
import tomli
from crewai.cli.constants import ENV_VARS, MODELS
from crewai.cli.provider import (
@@ -13,7 +14,31 @@ from crewai.cli.provider import (
from crewai.cli.utils import copy_template, load_env_vars, write_env_file
def create_folder_structure(name, parent_folder=None):
def get_reserved_script_names() -> set[str]:
"""Get reserved script names from pyproject.toml template.
Returns:
Set of reserved script names that would conflict with crew folder names.
"""
package_dir = Path(__file__).parent
template_path = package_dir / "templates" / "crew" / "pyproject.toml"
with open(template_path, "r") as f:
template_content = f.read()
template_content = template_content.replace("{{folder_name}}", "_placeholder_")
template_content = template_content.replace("{{name}}", "placeholder")
template_content = template_content.replace("{{crew_name}}", "Placeholder")
template_data = tomli.loads(template_content)
script_names = set(template_data.get("project", {}).get("scripts", {}).keys())
script_names.discard("_placeholder_")
return script_names
def create_folder_structure(
name: str, parent_folder: str | None = None
) -> tuple[Path, str, str]:
import keyword
import re
@@ -51,6 +76,14 @@ def create_folder_structure(name, parent_folder=None):
f"Project name '{name}' would generate invalid Python module name '{folder_name}'"
)
reserved_names = get_reserved_script_names()
if folder_name in reserved_names:
raise ValueError(
f"Project name '{name}' would generate folder name '{folder_name}' which is reserved. "
f"Reserved names are: {', '.join(sorted(reserved_names))}. "
"Please choose a different name."
)
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
class_name = re.sub(r"[^a-zA-Z0-9_]", "", class_name)
@@ -114,7 +147,9 @@ def create_folder_structure(name, parent_folder=None):
return folder_path, folder_name, class_name
def copy_template_files(folder_path, name, class_name, parent_folder):
def copy_template_files(
folder_path: Path, name: str, class_name: str, parent_folder: str | None
) -> None:
package_dir = Path(__file__).parent
templates_dir = package_dir / "templates" / "crew"
@@ -155,7 +190,12 @@ def copy_template_files(folder_path, name, class_name, parent_folder):
copy_template(src_file, dst_file, name, class_name, folder_path.name)
def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
def create_crew(
name: str,
provider: str | None = None,
skip_provider: bool = False,
parent_folder: str | None = None,
) -> None:
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
env_vars = load_env_vars(folder_path)
if not skip_provider:
@@ -189,7 +229,9 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
if selected_provider is None: # User typed 'q'
click.secho("Exiting...", fg="yellow")
sys.exit(0)
if selected_provider: # Valid selection
if selected_provider and isinstance(
selected_provider, str
): # Valid selection
break
click.secho(
"No provider selected. Please try again or press 'q' to exit.", fg="red"

View File

@@ -1,8 +1,10 @@
from collections import defaultdict
from collections.abc import Sequence
import json
import os
from pathlib import Path
import time
from typing import Any
import certifi
import click
@@ -11,16 +13,15 @@ import requests
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
def select_choice(prompt_message, choices):
"""
Presents a list of choices to the user and prompts them to select one.
def select_choice(prompt_message: str, choices: Sequence[str]) -> str | None:
"""Presents a list of choices to the user and prompts them to select one.
Args:
- prompt_message (str): The message to display to the user before presenting the choices.
- choices (list): A list of options to present to the user.
prompt_message: The message to display to the user before presenting the choices.
choices: A list of options to present to the user.
Returns:
- str: The selected choice from the list, or None if the user chooses to quit.
The selected choice from the list, or None if the user chooses to quit.
"""
provider_models = get_provider_data()
@@ -52,16 +53,14 @@ def select_choice(prompt_message, choices):
)
def select_provider(provider_models):
"""
Presents a list of providers to the user and prompts them to select one.
def select_provider(provider_models: dict[str, list[str]]) -> str | None | bool:
"""Presents a list of providers to the user and prompts them to select one.
Args:
- provider_models (dict): A dictionary of provider models.
provider_models: A dictionary of provider models.
Returns:
- str: The selected provider
- None: If user explicitly quits
The selected provider, None if user explicitly quits, or False if no selection.
"""
predefined_providers = [p.lower() for p in PROVIDERS]
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
@@ -80,16 +79,15 @@ def select_provider(provider_models):
return provider.lower() if provider else False
def select_model(provider, provider_models):
"""
Presents a list of models for a given provider to the user and prompts them to select one.
def select_model(provider: str, provider_models: dict[str, list[str]]) -> str | None:
"""Presents a list of models for a given provider to the user and prompts them to select one.
Args:
- provider (str): The provider for which to select a model.
- provider_models (dict): A dictionary of provider models.
provider: The provider for which to select a model.
provider_models: A dictionary of provider models.
Returns:
- str: The selected model, or None if the operation is aborted or an invalid selection is made.
The selected model, or None if the operation is aborted or an invalid selection is made.
"""
predefined_providers = [p.lower() for p in PROVIDERS]
@@ -107,16 +105,17 @@ def select_model(provider, provider_models):
)
def load_provider_data(cache_file, cache_expiry):
"""
Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web.
def load_provider_data(cache_file: Path, cache_expiry: int) -> dict[str, Any] | None:
"""Loads provider data from a cache file if it exists and is not expired.
If the cache is expired or corrupted, it fetches the data from the web.
Args:
- cache_file (Path): The path to the cache file.
- cache_expiry (int): The cache expiry time in seconds.
cache_file: The path to the cache file.
cache_expiry: The cache expiry time in seconds.
Returns:
- dict or None: The loaded provider data or None if the operation fails.
The loaded provider data or None if the operation fails.
"""
current_time = time.time()
if (
@@ -137,32 +136,31 @@ def load_provider_data(cache_file, cache_expiry):
return fetch_provider_data(cache_file)
def read_cache_file(cache_file):
"""
Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
def read_cache_file(cache_file: Path) -> dict[str, Any] | None:
"""Reads and returns the JSON content from a cache file.
Args:
- cache_file (Path): The path to the cache file.
cache_file: The path to the cache file.
Returns:
- dict or None: The JSON content of the cache file or None if the JSON is invalid.
The JSON content of the cache file or None if the JSON is invalid.
"""
try:
with open(cache_file, "r") as f:
return json.load(f)
data: dict[str, Any] = json.load(f)
return data
except json.JSONDecodeError:
return None
def fetch_provider_data(cache_file):
"""
Fetches provider data from a specified URL and caches it to a file.
def fetch_provider_data(cache_file: Path) -> dict[str, Any] | None:
"""Fetches provider data from a specified URL and caches it to a file.
Args:
- cache_file (Path): The path to the cache file.
cache_file: The path to the cache file.
Returns:
- dict or None: The fetched provider data or None if the operation fails.
The fetched provider data or None if the operation fails.
"""
ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where()
@@ -180,36 +178,39 @@ def fetch_provider_data(cache_file):
return None
def download_data(response):
"""
Downloads data from a given HTTP response and returns the JSON content.
def download_data(response: requests.Response) -> dict[str, Any]:
"""Downloads data from a given HTTP response and returns the JSON content.
Args:
- response (requests.Response): The HTTP response object.
response: The HTTP response object.
Returns:
- dict: The JSON content of the response.
The JSON content of the response.
"""
total_size = int(response.headers.get("content-length", 0))
block_size = 8192
data_chunks = []
data_chunks: list[bytes] = []
bar: Any
with click.progressbar(
length=total_size, label="Downloading", show_pos=True
) as progress_bar:
) as bar:
for chunk in response.iter_content(block_size):
if chunk:
data_chunks.append(chunk)
progress_bar.update(len(chunk))
bar.update(len(chunk))
data_content = b"".join(data_chunks)
return json.loads(data_content.decode("utf-8"))
result: dict[str, Any] = json.loads(data_content.decode("utf-8"))
return result
def get_provider_data():
"""
Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
def get_provider_data() -> dict[str, list[str]] | None:
"""Retrieves provider data from a cache file.
Filters out models based on provider criteria, and returns a dictionary of providers
mapped to their models.
Returns:
- dict or None: A dictionary of providers mapped to their models or None if the operation fails.
A dictionary of providers mapped to their models or None if the operation fails.
"""
cache_dir = Path.home() / ".crewai"
cache_dir.mkdir(exist_ok=True)

View File

@@ -296,6 +296,23 @@ def test_create_folder_structure_folder_name_validation():
shutil.rmtree(folder_path)
def test_create_folder_structure_rejects_reserved_names():
"""Test that reserved script names are rejected to prevent pyproject.toml conflicts."""
with tempfile.TemporaryDirectory() as temp_dir:
reserved_names = ["test", "train", "replay", "run_crew", "run_with_trigger"]
for reserved_name in reserved_names:
with pytest.raises(ValueError, match="which is reserved"):
create_folder_structure(reserved_name, parent_folder=temp_dir)
with pytest.raises(ValueError, match="which is reserved"):
create_folder_structure(f"{reserved_name}/", parent_folder=temp_dir)
capitalized = reserved_name.capitalize()
with pytest.raises(ValueError, match="which is reserved"):
create_folder_structure(capitalized, parent_folder=temp_dir)
@mock.patch("crewai.cli.create_crew.create_folder_structure")
@mock.patch("crewai.cli.create_crew.copy_template")
@mock.patch("crewai.cli.create_crew.load_env_vars")