mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
When new users sign up, check if they previously filled out the Tally beta application form and, if so, pre-populate their CoPilotUnderstanding with business data extracted from that form. This gives the CoPilot (Otto) immediate context about the user on their very first chat interaction. ### Changes 🏗️ - **`backend/util/settings.py`**: Added `tally_api_key` to `Secrets` class - **`backend/.env.default`**: Added `TALLY_API_KEY=` env var entry - **`backend/data/tally.py`** (new): Core Tally integration module - Redis-cached email index of form submissions (1h TTL) with incremental refresh via `startDate` - Paginated Tally API fetching with Bearer token auth - Email matching (case-insensitive) against submission data - LLM extraction (gpt-4o-mini via OpenRouter) of `BusinessUnderstandingInput` fields - Fire-and-forget orchestrator that is idempotent and never raises - **`backend/api/features/v1.py`**: Added background task in `get_or_create_user_route` to trigger Tally lookup on login (skips if understanding already exists) - **`backend/data/tally_test.py`** (new): 15 unit tests covering index building, email case-insensitivity, cache hit/miss, format helpers, idempotency, graceful degradation, and error resilience ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] All 15 unit tests pass (`poetry run pytest backend/data/tally_test.py --noconftest -xvs`) - [x] Lint clean (`poetry run ruff check` on changed files) - [x] Type check clean (`poetry run pyright` on new files) - [ ] Manual: Set `TALLY_API_KEY` in `.env`, create a new user, verify CoPilotUnderstanding is populated - [ ] Manual: Verify user creation succeeds when Tally API key is missing or API is down #### For configuration changes: - [x] `.env.default` is updated or already compatible with my changes - [x] `docker-compose.yml` is updated or already compatible with my changes - [x] I have included a list of my configuration changes in the PR description (under **Changes**) - Added `TALLY_API_KEY=` to `.env.default` (optional, empty by default — feature is a no-op without it) <!-- greptile_comment --> <details><summary><h3>Greptile Summary</h3></summary> This PR adds a Tally form integration that pre-populates `CoPilotUnderstanding` for new users by matching their signup email against cached Tally beta application form submissions, then using an LLM (gpt-4o-mini via OpenRouter) to extract structured business data. - **New module `tally.py`** implements Redis-cached email indexing of Tally form submissions with incremental refresh, email matching, LLM extraction, and an idempotent fire-and-forget orchestrator - **`v1.py`** adds a background task on the `get_or_create_user_route` to trigger Tally lookup on every login (idempotency check is inside the called function) - **`settings.py` / `.env.default`** adds `tally_api_key` as an optional secret — feature is a no-op without it - **`tally_test.py`** adds 15 unit tests with thorough mocking coverage - **Bug: TTL mismatch** — `_LAST_FETCH_TTL` (2h) > `_INDEX_TTL` (1h) creates a window where incremental refresh loses all previously indexed emails because the base index has expired but `last_fetch` persists. This will cause silent data loss for users whose form submissions were indexed before the cache expiry - **Bug: `str.format()` on LLM prompt** — form data containing `{` or `}` will crash the prompt formatting, silently preventing understanding population for those users </details> <details><summary><h3>Confidence Score: 2/5</h3></summary> - This PR has two logic bugs that will cause silent data loss in production — recommend fixing before merge. - The TTL mismatch between `_LAST_FETCH_TTL` and `_INDEX_TTL` will intermittently cause incomplete caches, silently dropping users from the email index. The `str.format()` issue will cause failures for any form submission containing curly braces. Both bugs are caught by the top-level exception handler, so they won't crash the service, but they will silently prevent the feature from working correctly for affected users. The overall architecture is sound and well-tested for normal paths. - `autogpt_platform/backend/backend/data/tally.py` — contains both the TTL mismatch bug in `_refresh_cache` and the `str.format()` issue in `extract_business_understanding` </details> <details><summary><h3>Sequence Diagram</h3></summary> ```mermaid sequenceDiagram participant User participant API as v1.py (get_or_create_user_route) participant Tally as tally.py (populate_understanding_from_tally) participant DB as Database (understanding) participant Redis participant TallyAPI as Tally API participant LLM as OpenRouter (gpt-4o-mini) User->>API: POST /auth/user (JWT) API->>API: get_or_create_user(user_data) API-->>User: Return user (immediate) API->>Tally: asyncio.create_task(populate_understanding_from_tally) Tally->>DB: get_business_understanding(user_id) alt Understanding exists DB-->>Tally: existing understanding Note over Tally: Skip (idempotent) else No understanding DB-->>Tally: None Tally->>Tally: Check tally_api_key configured Tally->>Redis: Check cached email index alt Cache hit Redis-->>Tally: email_index + questions else Cache miss Redis-->>Tally: None Tally->>TallyAPI: GET /forms/{id}/submissions (paginated) TallyAPI-->>Tally: submissions + questions Tally->>Tally: Build email index Tally->>Redis: Cache index (1h TTL) end Tally->>Tally: Lookup email in index alt Email found Tally->>Tally: format_submission_for_llm() Tally->>LLM: Extract BusinessUnderstandingInput LLM-->>Tally: JSON structured data Tally->>DB: upsert_business_understanding(user_id, input) end end ``` </details> <sub>Last reviewed commit: 92d2da4</sub> <!-- greptile_other_comments_section --> <!-- /greptile_comment --> --------- Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Otto (AGPT) <otto@agpt.co> Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
711 lines
21 KiB
Python
711 lines
21 KiB
Python
import json
|
|
from datetime import datetime, timezone
|
|
from io import BytesIO
|
|
from unittest.mock import AsyncMock, Mock, patch
|
|
|
|
import fastapi
|
|
import fastapi.testclient
|
|
import pytest
|
|
import pytest_mock
|
|
import starlette.datastructures
|
|
from fastapi import HTTPException, UploadFile
|
|
from pytest_snapshot.plugin import Snapshot
|
|
|
|
from backend.data.credit import AutoTopUpConfig
|
|
from backend.data.graph import GraphModel
|
|
|
|
from .v1 import upload_file, v1_router
|
|
|
|
app = fastapi.FastAPI()
|
|
app.include_router(v1_router)
|
|
|
|
client = fastapi.testclient.TestClient(app)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def setup_app_auth(mock_jwt_user, setup_test_user):
|
|
"""Setup auth overrides for all tests in this module"""
|
|
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
|
|
|
# setup_test_user fixture already executed and user is created in database
|
|
# It returns the user_id which we don't need to await
|
|
|
|
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
|
yield
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
# Auth endpoints tests
|
|
def test_get_or_create_user_route(
|
|
mocker: pytest_mock.MockFixture,
|
|
configured_snapshot: Snapshot,
|
|
test_user_id: str,
|
|
) -> None:
|
|
"""Test get or create user endpoint"""
|
|
mock_user = Mock()
|
|
mock_user.created_at = datetime.now(timezone.utc)
|
|
mock_user.model_dump.return_value = {
|
|
"id": test_user_id,
|
|
"email": "test@example.com",
|
|
"name": "Test User",
|
|
}
|
|
|
|
mocker.patch(
|
|
"backend.api.features.v1.get_or_create_user",
|
|
return_value=mock_user,
|
|
)
|
|
|
|
response = client.post("/auth/user")
|
|
|
|
assert response.status_code == 200
|
|
response_data = response.json()
|
|
|
|
configured_snapshot.assert_match(
|
|
json.dumps(response_data, indent=2, sort_keys=True),
|
|
"auth_user",
|
|
)
|
|
|
|
|
|
def test_update_user_email_route(
|
|
mocker: pytest_mock.MockFixture,
|
|
snapshot: Snapshot,
|
|
) -> None:
|
|
"""Test update user email endpoint"""
|
|
mocker.patch(
|
|
"backend.api.features.v1.update_user_email",
|
|
return_value=None,
|
|
)
|
|
|
|
response = client.post("/auth/user/email", json="newemail@example.com")
|
|
|
|
assert response.status_code == 200
|
|
response_data = response.json()
|
|
assert response_data["email"] == "newemail@example.com"
|
|
|
|
snapshot.snapshot_dir = "snapshots"
|
|
snapshot.assert_match(
|
|
json.dumps(response_data, indent=2, sort_keys=True),
|
|
"auth_email",
|
|
)
|
|
|
|
|
|
# Blocks endpoints tests
|
|
def test_get_graph_blocks(
|
|
mocker: pytest_mock.MockFixture,
|
|
snapshot: Snapshot,
|
|
) -> None:
|
|
"""Test get blocks endpoint"""
|
|
# Mock block
|
|
mock_block = Mock()
|
|
mock_block.to_dict.return_value = {
|
|
"id": "test-block",
|
|
"name": "Test Block",
|
|
"description": "A test block",
|
|
"disabled": False,
|
|
}
|
|
mock_block.id = "test-block"
|
|
mock_block.disabled = False
|
|
|
|
# Mock get_blocks
|
|
mocker.patch(
|
|
"backend.api.features.v1.get_blocks",
|
|
return_value={"test-block": lambda: mock_block},
|
|
)
|
|
|
|
# Mock block costs
|
|
mocker.patch(
|
|
"backend.data.credit.get_block_cost",
|
|
return_value=[{"cost": 10, "type": "credit"}],
|
|
)
|
|
|
|
response = client.get("/blocks")
|
|
|
|
assert response.status_code == 200
|
|
response_data = response.json()
|
|
assert len(response_data) == 1
|
|
assert response_data[0]["id"] == "test-block"
|
|
|
|
snapshot.snapshot_dir = "snapshots"
|
|
snapshot.assert_match(
|
|
json.dumps(response_data, indent=2, sort_keys=True),
|
|
"blks_all",
|
|
)
|
|
|
|
|
|
def test_execute_graph_block(
|
|
mocker: pytest_mock.MockFixture,
|
|
snapshot: Snapshot,
|
|
) -> None:
|
|
"""Test execute block endpoint"""
|
|
# Mock block
|
|
mock_block = Mock()
|
|
mock_block.disabled = False
|
|
|
|
async def mock_execute(*args, **kwargs):
|
|
yield "output1", {"data": "result1"}
|
|
yield "output2", {"data": "result2"}
|
|
|
|
mock_block.execute = mock_execute
|
|
|
|
mocker.patch(
|
|
"backend.api.features.v1.get_block",
|
|
return_value=mock_block,
|
|
)
|
|
|
|
# Mock user for user_context
|
|
mock_user = Mock()
|
|
mock_user.timezone = "UTC"
|
|
|
|
mocker.patch(
|
|
"backend.api.features.v1.get_user_by_id",
|
|
return_value=mock_user,
|
|
)
|
|
|
|
request_data = {
|
|
"input_name": "test_input",
|
|
"input_value": "test_value",
|
|
}
|
|
|
|
response = client.post("/blocks/test-block/execute", json=request_data)
|
|
|
|
assert response.status_code == 200
|
|
response_data = response.json()
|
|
|
|
snapshot.snapshot_dir = "snapshots"
|
|
snapshot.assert_match(
|
|
json.dumps(response_data, indent=2, sort_keys=True),
|
|
"blks_exec",
|
|
)
|
|
|
|
|
|
def test_execute_graph_block_not_found(
|
|
mocker: pytest_mock.MockFixture,
|
|
) -> None:
|
|
"""Test execute block with non-existent block"""
|
|
mocker.patch(
|
|
"backend.api.features.v1.get_block",
|
|
return_value=None,
|
|
)
|
|
|
|
response = client.post("/blocks/nonexistent-block/execute", json={})
|
|
|
|
assert response.status_code == 404
|
|
assert "not found" in response.json()["detail"]
|
|
|
|
|
|
# Credits endpoints tests
|
|
def test_get_user_credits(
|
|
mocker: pytest_mock.MockFixture,
|
|
snapshot: Snapshot,
|
|
) -> None:
|
|
"""Test get user credits endpoint"""
|
|
mock_credit_model = Mock()
|
|
mock_credit_model.get_credits = AsyncMock(return_value=1000)
|
|
mocker.patch(
|
|
"backend.api.features.v1.get_user_credit_model",
|
|
return_value=mock_credit_model,
|
|
)
|
|
|
|
response = client.get("/credits")
|
|
|
|
assert response.status_code == 200
|
|
response_data = response.json()
|
|
assert response_data["credits"] == 1000
|
|
|
|
snapshot.snapshot_dir = "snapshots"
|
|
snapshot.assert_match(
|
|
json.dumps(response_data, indent=2, sort_keys=True),
|
|
"cred_bal",
|
|
)
|
|
|
|
|
|
def test_request_top_up(
|
|
mocker: pytest_mock.MockFixture,
|
|
snapshot: Snapshot,
|
|
) -> None:
|
|
"""Test request top up endpoint"""
|
|
mock_credit_model = Mock()
|
|
mock_credit_model.top_up_intent = AsyncMock(
|
|
return_value="https://checkout.example.com/session123"
|
|
)
|
|
mocker.patch(
|
|
"backend.api.features.v1.get_user_credit_model",
|
|
return_value=mock_credit_model,
|
|
)
|
|
|
|
request_data = {"credit_amount": 500}
|
|
|
|
response = client.post("/credits", json=request_data)
|
|
|
|
assert response.status_code == 200
|
|
response_data = response.json()
|
|
assert "checkout_url" in response_data
|
|
|
|
snapshot.snapshot_dir = "snapshots"
|
|
snapshot.assert_match(
|
|
json.dumps(response_data, indent=2, sort_keys=True),
|
|
"cred_topup_req",
|
|
)
|
|
|
|
|
|
def test_get_auto_top_up(
|
|
mocker: pytest_mock.MockFixture,
|
|
snapshot: Snapshot,
|
|
) -> None:
|
|
"""Test get auto top-up configuration endpoint"""
|
|
mock_config = AutoTopUpConfig(threshold=100, amount=500)
|
|
|
|
mocker.patch(
|
|
"backend.api.features.v1.get_auto_top_up",
|
|
return_value=mock_config,
|
|
)
|
|
|
|
response = client.get("/credits/auto-top-up")
|
|
|
|
assert response.status_code == 200
|
|
response_data = response.json()
|
|
assert response_data["threshold"] == 100
|
|
assert response_data["amount"] == 500
|
|
|
|
snapshot.snapshot_dir = "snapshots"
|
|
snapshot.assert_match(
|
|
json.dumps(response_data, indent=2, sort_keys=True),
|
|
"cred_topup_cfg",
|
|
)
|
|
|
|
|
|
def test_configure_auto_top_up(
|
|
mocker: pytest_mock.MockFixture,
|
|
snapshot: Snapshot,
|
|
) -> None:
|
|
"""Test configure auto top-up endpoint - this test would have caught the enum casting bug"""
|
|
# Mock the set_auto_top_up function to avoid database operations
|
|
mocker.patch(
|
|
"backend.api.features.v1.set_auto_top_up",
|
|
return_value=None,
|
|
)
|
|
|
|
# Mock credit model to avoid Stripe API calls
|
|
mock_credit_model = mocker.AsyncMock()
|
|
mock_credit_model.get_credits.return_value = 50 # Current balance below threshold
|
|
mock_credit_model.top_up_credits.return_value = None
|
|
|
|
mocker.patch(
|
|
"backend.api.features.v1.get_user_credit_model",
|
|
return_value=mock_credit_model,
|
|
)
|
|
|
|
# Test data
|
|
request_data = {
|
|
"threshold": 100,
|
|
"amount": 500,
|
|
}
|
|
|
|
response = client.post("/credits/auto-top-up", json=request_data)
|
|
|
|
# This should succeed with our fix, but would have failed before with the enum casting error
|
|
assert response.status_code == 200
|
|
assert response.json() == "Auto top-up settings updated"
|
|
|
|
|
|
def test_configure_auto_top_up_validation_errors(
|
|
mocker: pytest_mock.MockFixture,
|
|
) -> None:
|
|
"""Test configure auto top-up endpoint validation"""
|
|
# Mock set_auto_top_up to avoid database operations for successful case
|
|
mocker.patch("backend.api.features.v1.set_auto_top_up")
|
|
|
|
# Mock credit model to avoid Stripe API calls for the successful case
|
|
mock_credit_model = mocker.AsyncMock()
|
|
mock_credit_model.get_credits.return_value = 50
|
|
mock_credit_model.top_up_credits.return_value = None
|
|
|
|
mocker.patch(
|
|
"backend.api.features.v1.get_user_credit_model",
|
|
return_value=mock_credit_model,
|
|
)
|
|
|
|
# Test negative threshold
|
|
response = client.post(
|
|
"/credits/auto-top-up", json={"threshold": -1, "amount": 500}
|
|
)
|
|
assert response.status_code == 422 # Validation error
|
|
|
|
# Test amount too small (but not 0)
|
|
response = client.post(
|
|
"/credits/auto-top-up", json={"threshold": 100, "amount": 100}
|
|
)
|
|
assert response.status_code == 422 # Validation error
|
|
|
|
# Test amount = 0 (should be allowed)
|
|
response = client.post("/credits/auto-top-up", json={"threshold": 100, "amount": 0})
|
|
assert response.status_code == 200 # Should succeed
|
|
|
|
|
|
# Graphs endpoints tests
|
|
def test_get_graphs(
|
|
mocker: pytest_mock.MockFixture,
|
|
snapshot: Snapshot,
|
|
test_user_id: str,
|
|
) -> None:
|
|
"""Test get graphs endpoint"""
|
|
mock_graph = GraphModel(
|
|
id="graph-123",
|
|
version=1,
|
|
is_active=True,
|
|
name="Test Graph",
|
|
description="A test graph",
|
|
user_id=test_user_id,
|
|
created_at=datetime(2025, 9, 4, 13, 37),
|
|
)
|
|
|
|
mocker.patch(
|
|
"backend.data.graph.list_graphs_paginated",
|
|
return_value=Mock(graphs=[mock_graph]),
|
|
)
|
|
|
|
response = client.get("/graphs")
|
|
|
|
assert response.status_code == 200
|
|
response_data = response.json()
|
|
assert len(response_data) == 1
|
|
assert response_data[0]["id"] == "graph-123"
|
|
|
|
snapshot.snapshot_dir = "snapshots"
|
|
snapshot.assert_match(
|
|
json.dumps(response_data, indent=2, sort_keys=True),
|
|
"grphs_all",
|
|
)
|
|
|
|
|
|
def test_get_graph(
|
|
mocker: pytest_mock.MockFixture,
|
|
snapshot: Snapshot,
|
|
test_user_id: str,
|
|
) -> None:
|
|
"""Test get single graph endpoint"""
|
|
mock_graph = GraphModel(
|
|
id="graph-123",
|
|
version=1,
|
|
is_active=True,
|
|
name="Test Graph",
|
|
description="A test graph",
|
|
user_id=test_user_id,
|
|
created_at=datetime(2025, 9, 4, 13, 37),
|
|
)
|
|
|
|
mocker.patch(
|
|
"backend.api.features.v1.graph_db.get_graph",
|
|
return_value=mock_graph,
|
|
)
|
|
|
|
response = client.get("/graphs/graph-123")
|
|
|
|
assert response.status_code == 200
|
|
response_data = response.json()
|
|
assert response_data["id"] == "graph-123"
|
|
|
|
snapshot.snapshot_dir = "snapshots"
|
|
snapshot.assert_match(
|
|
json.dumps(response_data, indent=2, sort_keys=True),
|
|
"grph_single",
|
|
)
|
|
|
|
|
|
def test_get_graph_not_found(
|
|
mocker: pytest_mock.MockFixture,
|
|
) -> None:
|
|
"""Test get graph with non-existent ID"""
|
|
mocker.patch(
|
|
"backend.api.features.v1.graph_db.get_graph",
|
|
return_value=None,
|
|
)
|
|
|
|
response = client.get("/graphs/nonexistent-graph")
|
|
|
|
assert response.status_code == 404
|
|
assert "not found" in response.json()["detail"]
|
|
|
|
|
|
def test_delete_graph(
|
|
mocker: pytest_mock.MockFixture,
|
|
snapshot: Snapshot,
|
|
test_user_id: str,
|
|
) -> None:
|
|
"""Test delete graph endpoint"""
|
|
# Mock active graph for deactivation
|
|
mock_graph = GraphModel(
|
|
id="graph-123",
|
|
version=1,
|
|
is_active=True,
|
|
name="Test Graph",
|
|
description="A test graph",
|
|
user_id=test_user_id,
|
|
created_at=datetime(2025, 9, 4, 13, 37),
|
|
)
|
|
|
|
mocker.patch(
|
|
"backend.api.features.v1.graph_db.get_graph",
|
|
return_value=mock_graph,
|
|
)
|
|
mocker.patch(
|
|
"backend.api.features.v1.on_graph_deactivate",
|
|
return_value=None,
|
|
)
|
|
mocker.patch(
|
|
"backend.api.features.v1.graph_db.delete_graph",
|
|
return_value=3, # Number of versions deleted
|
|
)
|
|
|
|
response = client.delete("/graphs/graph-123")
|
|
|
|
assert response.status_code == 200
|
|
response_data = response.json()
|
|
assert response_data["version_counts"] == 3
|
|
|
|
snapshot.snapshot_dir = "snapshots"
|
|
snapshot.assert_match(
|
|
json.dumps(response_data, indent=2, sort_keys=True),
|
|
"grphs_del",
|
|
)
|
|
|
|
|
|
# Invalid request tests
|
|
def test_invalid_json_request() -> None:
|
|
"""Test endpoint with invalid JSON"""
|
|
response = client.post(
|
|
"/auth/user/email",
|
|
content="invalid json",
|
|
headers={"Content-Type": "application/json"},
|
|
)
|
|
assert response.status_code == 422
|
|
|
|
|
|
def test_missing_required_field() -> None:
|
|
"""Test endpoint with missing required field"""
|
|
response = client.post("/credits", json={}) # Missing credit_amount
|
|
assert response.status_code == 422
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upload_file_success(test_user_id: str):
|
|
"""Test successful file upload."""
|
|
# Create mock upload file
|
|
file_content = b"test file content"
|
|
file_obj = BytesIO(file_content)
|
|
upload_file_mock = UploadFile(
|
|
filename="test.txt",
|
|
file=file_obj,
|
|
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
|
|
)
|
|
|
|
# Mock dependencies
|
|
with patch("backend.api.features.v1.scan_content_safe") as mock_scan, patch(
|
|
"backend.api.features.v1.get_cloud_storage_handler"
|
|
) as mock_handler_getter:
|
|
|
|
mock_scan.return_value = None
|
|
mock_handler = AsyncMock()
|
|
mock_handler.store_file.return_value = "gcs://test-bucket/uploads/123/test.txt"
|
|
mock_handler_getter.return_value = mock_handler
|
|
|
|
# Mock file.read()
|
|
upload_file_mock.read = AsyncMock(return_value=file_content)
|
|
|
|
result = await upload_file(
|
|
file=upload_file_mock,
|
|
user_id=test_user_id,
|
|
provider="gcs",
|
|
expiration_hours=24,
|
|
)
|
|
|
|
# Verify result
|
|
assert result.file_uri == "gcs://test-bucket/uploads/123/test.txt"
|
|
assert result.file_name == "test.txt"
|
|
assert result.size == len(file_content)
|
|
assert result.content_type == "text/plain"
|
|
assert result.expires_in_hours == 24
|
|
|
|
# Verify virus scan was called
|
|
mock_scan.assert_called_once_with(file_content, filename="test.txt")
|
|
|
|
# Verify cloud storage operations
|
|
mock_handler.store_file.assert_called_once_with(
|
|
content=file_content,
|
|
filename="test.txt",
|
|
provider="gcs",
|
|
expiration_hours=24,
|
|
user_id=test_user_id,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upload_file_no_filename(test_user_id: str):
|
|
"""Test file upload without filename."""
|
|
file_content = b"test content"
|
|
file_obj = BytesIO(file_content)
|
|
upload_file_mock = UploadFile(
|
|
filename=None,
|
|
file=file_obj,
|
|
headers=starlette.datastructures.Headers(
|
|
{"content-type": "application/octet-stream"}
|
|
),
|
|
)
|
|
|
|
with patch("backend.api.features.v1.scan_content_safe") as mock_scan, patch(
|
|
"backend.api.features.v1.get_cloud_storage_handler"
|
|
) as mock_handler_getter:
|
|
|
|
mock_scan.return_value = None
|
|
mock_handler = AsyncMock()
|
|
mock_handler.store_file.return_value = (
|
|
"gcs://test-bucket/uploads/123/uploaded_file"
|
|
)
|
|
mock_handler_getter.return_value = mock_handler
|
|
|
|
upload_file_mock.read = AsyncMock(return_value=file_content)
|
|
|
|
result = await upload_file(file=upload_file_mock, user_id=test_user_id)
|
|
|
|
assert result.file_name == "uploaded_file"
|
|
assert result.content_type == "application/octet-stream"
|
|
|
|
# Verify virus scan was called with default filename
|
|
mock_scan.assert_called_once_with(file_content, filename="uploaded_file")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upload_file_invalid_expiration(test_user_id: str):
|
|
"""Test file upload with invalid expiration hours."""
|
|
file_obj = BytesIO(b"content")
|
|
upload_file_mock = UploadFile(
|
|
filename="test.txt",
|
|
file=file_obj,
|
|
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
|
|
)
|
|
|
|
# Test expiration too short
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await upload_file(
|
|
file=upload_file_mock, user_id=test_user_id, expiration_hours=0
|
|
)
|
|
assert exc_info.value.status_code == 400
|
|
assert "between 1 and 48" in exc_info.value.detail
|
|
|
|
# Test expiration too long
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await upload_file(
|
|
file=upload_file_mock, user_id=test_user_id, expiration_hours=49
|
|
)
|
|
assert exc_info.value.status_code == 400
|
|
assert "between 1 and 48" in exc_info.value.detail
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upload_file_virus_scan_failure(test_user_id: str):
|
|
"""Test file upload when virus scan fails."""
|
|
file_content = b"malicious content"
|
|
file_obj = BytesIO(file_content)
|
|
upload_file_mock = UploadFile(
|
|
filename="virus.txt",
|
|
file=file_obj,
|
|
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
|
|
)
|
|
|
|
with patch("backend.api.features.v1.scan_content_safe") as mock_scan:
|
|
# Mock virus scan to raise exception
|
|
mock_scan.side_effect = RuntimeError("Virus detected!")
|
|
|
|
upload_file_mock.read = AsyncMock(return_value=file_content)
|
|
|
|
with pytest.raises(RuntimeError, match="Virus detected!"):
|
|
await upload_file(file=upload_file_mock, user_id=test_user_id)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upload_file_cloud_storage_failure(test_user_id: str):
|
|
"""Test file upload when cloud storage fails."""
|
|
file_content = b"test content"
|
|
file_obj = BytesIO(file_content)
|
|
upload_file_mock = UploadFile(
|
|
filename="test.txt",
|
|
file=file_obj,
|
|
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
|
|
)
|
|
|
|
with patch("backend.api.features.v1.scan_content_safe") as mock_scan, patch(
|
|
"backend.api.features.v1.get_cloud_storage_handler"
|
|
) as mock_handler_getter:
|
|
|
|
mock_scan.return_value = None
|
|
mock_handler = AsyncMock()
|
|
mock_handler.store_file.side_effect = RuntimeError("Storage error!")
|
|
mock_handler_getter.return_value = mock_handler
|
|
|
|
upload_file_mock.read = AsyncMock(return_value=file_content)
|
|
|
|
with pytest.raises(RuntimeError, match="Storage error!"):
|
|
await upload_file(file=upload_file_mock, user_id=test_user_id)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upload_file_size_limit_exceeded(test_user_id: str):
|
|
"""Test file upload when file size exceeds the limit."""
|
|
# Create a file that exceeds the default 256MB limit
|
|
large_file_content = b"x" * (257 * 1024 * 1024) # 257MB
|
|
file_obj = BytesIO(large_file_content)
|
|
upload_file_mock = UploadFile(
|
|
filename="large_file.txt",
|
|
file=file_obj,
|
|
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
|
|
)
|
|
|
|
upload_file_mock.read = AsyncMock(return_value=large_file_content)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await upload_file(file=upload_file_mock, user_id=test_user_id)
|
|
|
|
assert exc_info.value.status_code == 400
|
|
assert "exceeds the maximum allowed size of 256MB" in exc_info.value.detail
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upload_file_gcs_not_configured_fallback(test_user_id: str):
|
|
"""Test file upload fallback to base64 when GCS is not configured."""
|
|
file_content = b"test file content"
|
|
file_obj = BytesIO(file_content)
|
|
upload_file_mock = UploadFile(
|
|
filename="test.txt",
|
|
file=file_obj,
|
|
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
|
|
)
|
|
|
|
with patch("backend.api.features.v1.scan_content_safe") as mock_scan, patch(
|
|
"backend.api.features.v1.get_cloud_storage_handler"
|
|
) as mock_handler_getter:
|
|
|
|
mock_scan.return_value = None
|
|
mock_handler = AsyncMock()
|
|
mock_handler.config.gcs_bucket_name = "" # Simulate no GCS bucket configured
|
|
mock_handler_getter.return_value = mock_handler
|
|
|
|
upload_file_mock.read = AsyncMock(return_value=file_content)
|
|
|
|
result = await upload_file(file=upload_file_mock, user_id=test_user_id)
|
|
|
|
# Verify fallback behavior
|
|
assert result.file_name == "test.txt"
|
|
assert result.size == len(file_content)
|
|
assert result.content_type == "text/plain"
|
|
assert result.expires_in_hours == 24
|
|
|
|
# Verify file_uri is base64 data URI
|
|
expected_data_uri = "data:text/plain;base64,dGVzdCBmaWxlIGNvbnRlbnQ="
|
|
assert result.file_uri == expected_data_uri
|
|
|
|
# Verify virus scan was called
|
|
mock_scan.assert_called_once_with(file_content, filename="test.txt")
|
|
|
|
# Verify cloud storage methods were NOT called
|
|
mock_handler.store_file.assert_not_called()
|