fix(backend): Convert pyclamd to aioclamd for anti-virus scan concurrency improvement (#10258)

Currently, we are using PyClamd to run a file anti-virus scan for all
the files uploaded into the platform. We split the file into small
chunks and serially check the chunks for the virus scan. The socket is
not thread-safe, and we need to create multiple sockets across many
threads to leverage concurrency. To make this step concurrent and keep
it fully async, we need to migrate PyClamd to aioclamd.

### Changes 🏗️

Convert pyclamd to aioclamd, leverage chunk parallelism scan with a
semaphore limiting the concurrency limit.

#### Side Note
Shout-out to @tedyu for raising this improvement idea.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Execute file upload into the platform
This commit is contained in:
Zamil Majdy
2025-06-30 14:09:30 -07:00
committed by GitHub
parent 89a5ba69e5
commit 9a6ae90d12
5 changed files with 179 additions and 114 deletions

View File

@@ -254,6 +254,14 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default=True,
description="Whether virus scanning is enabled or not",
)
clamav_max_concurrency: int = Field(
default=10,
description="The maximum number of concurrent scans to perform",
)
clamav_mark_failed_scans_as_clean: bool = Field(
default=False,
description="Whether to mark failed scans as clean or not",
)
@field_validator("platform_base_url", "frontend_base_url")
@classmethod

View File

@@ -1,9 +1,10 @@
import asyncio
import io
import logging
import time
from typing import Optional, Tuple
import pyclamd
import aioclamd
from pydantic import BaseModel
from pydantic_settings import BaseSettings
@@ -21,37 +22,47 @@ class VirusScanResult(BaseModel):
class VirusScannerSettings(BaseSettings):
# Tunables for the scanner layer (NOT the ClamAV daemon).
clamav_service_host: str = "localhost"
clamav_service_port: int = 3310
clamav_service_timeout: int = 60
clamav_service_enabled: bool = True
max_scan_size: int = 100 * 1024 * 1024 # 100 MB
chunk_size: int = 25 * 1024 * 1024 # 25 MB (safe for 50MB stream limit)
min_chunk_size: int = 128 * 1024 # 128 KB minimum
max_retries: int = 8 # halve chunk ≤ max_retries times
# If the service is disabled, all files are considered clean.
mark_failed_scans_as_clean: bool = False
# Client-side protective limits
max_scan_size: int = 2 * 1024 * 1024 * 1024 # 2 GB guard-rail in memory
min_chunk_size: int = 128 * 1024 # 128 KB hard floor
max_retries: int = 8 # halve ≤ max_retries times
# Concurrency throttle toward the ClamAV daemon. Do *NOT* simply turn this
# up to the number of CPU cores; keep it ≤ (MaxThreads / pods) 1.
max_concurrency: int = 5
class VirusScannerService:
"""
Thin async wrapper around ClamAV. Creates a fresh `ClamdNetworkSocket`
per chunk (the class is *not* thread-safe) and falls back to smaller
chunks when the daemon rejects the stream size.
"""Fully-async ClamAV wrapper using **aioclamd**.
• Reuses a single `ClamdAsyncClient` connection (aioclamd keeps the socket open).
• Throttles concurrent `INSTREAM` calls with an `asyncio.Semaphore` so we don't exhaust daemon worker threads or file descriptors.
• Falls back to progressively smaller chunk sizes when the daemon rejects a stream as too large.
"""
def __init__(self, settings: VirusScannerSettings) -> None:
self.settings = settings
def _new_client(self) -> pyclamd.ClamdNetworkSocket:
return pyclamd.ClamdNetworkSocket(
host=self.settings.clamav_service_host,
port=self.settings.clamav_service_port,
timeout=self.settings.clamav_service_timeout,
self._client = aioclamd.ClamdAsyncClient(
host=settings.clamav_service_host,
port=settings.clamav_service_port,
timeout=settings.clamav_service_timeout,
)
self._sem = asyncio.Semaphore(settings.max_concurrency)
# ------------------------------------------------------------------ #
# Helpers
# ------------------------------------------------------------------ #
@staticmethod
def _parse_raw(raw: Optional[dict]) -> Tuple[bool, Optional[str]]:
"""
Convert pyclamd output to (infected?, threat_name).
Convert aioclamd output to (infected?, threat_name).
Returns (False, None) for clean.
"""
if not raw:
@@ -59,24 +70,22 @@ class VirusScannerService:
status, threat = next(iter(raw.values()))
return status == "FOUND", threat
async def _scan_chunk(self, chunk: bytes) -> Tuple[bool, Optional[str]]:
loop = asyncio.get_running_loop()
client = self._new_client()
try:
raw = await loop.run_in_executor(None, client.scan_stream, chunk)
return self._parse_raw(raw)
# ClamAV aborts the socket when >StreamMaxLength → BrokenPipe/Reset.
except (BrokenPipeError, ConnectionResetError) as exc:
raise RuntimeError("size-limit") from exc
except Exception as exc:
if "INSTREAM size limit exceeded" in str(exc):
async def _instream(self, chunk: bytes) -> Tuple[bool, Optional[str]]:
"""Scan **one** chunk with concurrency control."""
async with self._sem:
try:
raw = await self._client.instream(io.BytesIO(chunk))
return self._parse_raw(raw)
except (BrokenPipeError, ConnectionResetError) as exc:
raise RuntimeError("size-limit") from exc
raise
except Exception as exc:
if "INSTREAM size limit exceeded" in str(exc):
raise RuntimeError("size-limit") from exc
raise
# --------------------------------------------------------------------- #
# ------------------------------------------------------------------ #
# Public API
# --------------------------------------------------------------------- #
# ------------------------------------------------------------------ #
async def scan_file(
self, content: bytes, *, filename: str = "unknown"
@@ -84,81 +93,74 @@ class VirusScannerService:
"""
Scan `content`. Returns a result object or raises on infrastructure
failure (unreachable daemon, etc.).
The algorithm always tries whole-file first. If the daemon refuses
on size grounds, it falls back to chunked parallel scanning.
"""
if not self.settings.clamav_service_enabled:
logger.warning("Virus scanning disabled accepting %s", filename)
logger.warning(f"Virus scanning disabled accepting {filename}")
return VirusScanResult(
is_clean=True, scan_time_ms=0, file_size=len(content)
)
if len(content) > self.settings.max_scan_size:
logger.warning(
f"File {filename} ({len(content)} bytes) exceeds max scan size ({self.settings.max_scan_size}), skipping virus scan"
f"File {filename} ({len(content)} bytes) exceeds client max scan size ({self.settings.max_scan_size}); Stopping virus scan"
)
return VirusScanResult(
is_clean=True, # Assume clean for oversized files
is_clean=self.settings.mark_failed_scans_as_clean,
file_size=len(content),
scan_time_ms=0,
threat_name=None,
)
loop = asyncio.get_running_loop()
if not await loop.run_in_executor(None, self._new_client().ping):
# Ensure daemon is reachable (small RTT check)
if not await self._client.ping():
raise RuntimeError("ClamAV service is unreachable")
start = time.monotonic()
chunk_size = self.settings.chunk_size
for retry in range(self.settings.max_retries + 1):
chunk_size = len(content) # Start with full content length
for retry in range(self.settings.max_retries):
# For small files, don't check min_chunk_size limit
if chunk_size < self.settings.min_chunk_size and chunk_size < len(content):
break
logger.debug(
f"Scanning {filename} with chunk size: {chunk_size // 1_048_576} MB (retry {retry + 1}/{self.settings.max_retries})"
)
try:
logger.debug(
f"Scanning {filename} with chunk size: {chunk_size // 1_048_576}MB"
)
# Scan all chunks with current chunk size
for offset in range(0, len(content), chunk_size):
chunk_data = content[offset : offset + chunk_size]
infected, threat = await self._scan_chunk(chunk_data)
tasks = [
asyncio.create_task(self._instream(content[o : o + chunk_size]))
for o in range(0, len(content), chunk_size)
]
for coro in asyncio.as_completed(tasks):
infected, threat = await coro
if infected:
for t in tasks:
if not t.done():
t.cancel()
return VirusScanResult(
is_clean=False,
threat_name=threat,
file_size=len(content),
scan_time_ms=int((time.monotonic() - start) * 1000),
)
# All chunks clean
return VirusScanResult(
is_clean=True,
file_size=len(content),
scan_time_ms=int((time.monotonic() - start) * 1000),
)
except RuntimeError as exc:
if (
str(exc) == "size-limit"
and chunk_size > self.settings.min_chunk_size
):
if str(exc) == "size-limit":
chunk_size //= 2
logger.info(
f"Chunk size too large for {filename}, reducing to {chunk_size // 1_048_576}MB (retry {retry + 1}/{self.settings.max_retries + 1})"
)
continue
else:
# Either not a size-limit error, or we've hit minimum chunk size
logger.error(f"Cannot scan {filename}: {exc}")
raise
# If we can't scan even with minimum chunk size, log warning and allow file
logger.error(f"Cannot scan {filename}: {exc}")
raise
# Phase 3 give up but warn
logger.warning(
f"Unable to virus scan {filename} ({len(content)} bytes) - chunk size limits exceeded. "
f"Allowing file but recommend manual review."
f"Unable to virus scan {filename} ({len(content)} bytes) even with minimum chunk size ({self.settings.min_chunk_size} bytes). Recommend manual review."
)
return VirusScanResult(
is_clean=True, # Allow file when scanning impossible
is_clean=self.settings.mark_failed_scans_as_clean,
file_size=len(content),
scan_time_ms=int((time.monotonic() - start) * 1000),
threat_name=None,
)
@@ -172,6 +174,8 @@ def get_virus_scanner() -> VirusScannerService:
clamav_service_host=settings.config.clamav_service_host,
clamav_service_port=settings.config.clamav_service_port,
clamav_service_enabled=settings.config.clamav_service_enabled,
max_concurrency=settings.config.clamav_max_concurrency,
mark_failed_scans_as_clean=settings.config.clamav_mark_failed_scans_as_clean,
)
_scanner = VirusScannerService(_settings)
return _scanner

View File

@@ -1,5 +1,4 @@
import asyncio
import time
from unittest.mock import AsyncMock, Mock, patch
import pytest
@@ -22,6 +21,7 @@ class TestVirusScannerService:
clamav_service_port=3310,
clamav_service_enabled=True,
max_scan_size=10 * 1024 * 1024, # 10MB for testing
mark_failed_scans_as_clean=False, # For testing, failed scans should be clean
)
@pytest.fixture
@@ -54,25 +54,51 @@ class TestVirusScannerService:
# Create content larger than max_scan_size
large_content = b"x" * (scanner.settings.max_scan_size + 1)
# Large files are allowed but marked as clean with a warning
# Large files behavior depends on mark_failed_scans_as_clean setting
result = await scanner.scan_file(large_content, filename="large_file.txt")
assert result.is_clean is True
assert result.is_clean == scanner.settings.mark_failed_scans_as_clean
assert result.file_size == len(large_content)
assert result.scan_time_ms == 0
@pytest.mark.asyncio
async def test_scan_file_too_large_both_configurations(self):
"""Test large file handling with both mark_failed_scans_as_clean configurations"""
large_content = b"x" * (10 * 1024 * 1024 + 1) # Larger than 10MB
# Test with mark_failed_scans_as_clean=True
settings_clean = VirusScannerSettings(
max_scan_size=10 * 1024 * 1024, mark_failed_scans_as_clean=True
)
scanner_clean = VirusScannerService(settings_clean)
result_clean = await scanner_clean.scan_file(
large_content, filename="large_file.txt"
)
assert result_clean.is_clean is True
# Test with mark_failed_scans_as_clean=False
settings_dirty = VirusScannerSettings(
max_scan_size=10 * 1024 * 1024, mark_failed_scans_as_clean=False
)
scanner_dirty = VirusScannerService(settings_dirty)
result_dirty = await scanner_dirty.scan_file(
large_content, filename="large_file.txt"
)
assert result_dirty.is_clean is False
# Note: ping method was removed from current implementation
@pytest.mark.asyncio
@patch("pyclamd.ClamdNetworkSocket")
async def test_scan_clean_file(self, mock_clamav_class, scanner):
def mock_scan_stream(_):
time.sleep(0.001) # Small delay to ensure timing > 0
async def test_scan_clean_file(self, scanner):
async def mock_instream(_):
await asyncio.sleep(0.001) # Small delay to ensure timing > 0
return None # No virus detected
mock_client = Mock()
mock_client.ping.return_value = True
mock_client.scan_stream = mock_scan_stream
mock_clamav_class.return_value = mock_client
mock_client.ping = AsyncMock(return_value=True)
mock_client.instream = AsyncMock(side_effect=mock_instream)
# Replace the client instance that was created in the constructor
scanner._client = mock_client
content = b"clean file content"
result = await scanner.scan_file(content, filename="clean.txt")
@@ -83,16 +109,17 @@ class TestVirusScannerService:
assert result.scan_time_ms > 0
@pytest.mark.asyncio
@patch("pyclamd.ClamdNetworkSocket")
async def test_scan_infected_file(self, mock_clamav_class, scanner):
def mock_scan_stream(_):
time.sleep(0.001) # Small delay to ensure timing > 0
async def test_scan_infected_file(self, scanner):
async def mock_instream(_):
await asyncio.sleep(0.001) # Small delay to ensure timing > 0
return {"stream": ("FOUND", "Win.Test.EICAR_HDB-1")}
mock_client = Mock()
mock_client.ping.return_value = True
mock_client.scan_stream = mock_scan_stream
mock_clamav_class.return_value = mock_client
mock_client.ping = AsyncMock(return_value=True)
mock_client.instream = AsyncMock(side_effect=mock_instream)
# Replace the client instance that was created in the constructor
scanner._client = mock_client
content = b"infected file content"
result = await scanner.scan_file(content, filename="infected.txt")
@@ -103,11 +130,12 @@ class TestVirusScannerService:
assert result.scan_time_ms > 0
@pytest.mark.asyncio
@patch("pyclamd.ClamdNetworkSocket")
async def test_scan_clamav_unavailable_fail_safe(self, mock_clamav_class, scanner):
async def test_scan_clamav_unavailable_fail_safe(self, scanner):
mock_client = Mock()
mock_client.ping.return_value = False
mock_clamav_class.return_value = mock_client
mock_client.ping = AsyncMock(return_value=False)
# Replace the client instance that was created in the constructor
scanner._client = mock_client
content = b"test content"
@@ -115,12 +143,13 @@ class TestVirusScannerService:
await scanner.scan_file(content, filename="test.txt")
@pytest.mark.asyncio
@patch("pyclamd.ClamdNetworkSocket")
async def test_scan_error_fail_safe(self, mock_clamav_class, scanner):
async def test_scan_error_fail_safe(self, scanner):
mock_client = Mock()
mock_client.ping.return_value = True
mock_client.scan_stream.side_effect = Exception("Scanning error")
mock_clamav_class.return_value = mock_client
mock_client.ping = AsyncMock(return_value=True)
mock_client.instream = AsyncMock(side_effect=Exception("Scanning error"))
# Replace the client instance that was created in the constructor
scanner._client = mock_client
content = b"test content"
@@ -150,16 +179,17 @@ class TestVirusScannerService:
assert result.file_size == 1024
@pytest.mark.asyncio
@patch("pyclamd.ClamdNetworkSocket")
async def test_concurrent_scans(self, mock_clamav_class, scanner):
def mock_scan_stream(_):
time.sleep(0.001) # Small delay to ensure timing > 0
async def test_concurrent_scans(self, scanner):
async def mock_instream(_):
await asyncio.sleep(0.001) # Small delay to ensure timing > 0
return None
mock_client = Mock()
mock_client.ping.return_value = True
mock_client.scan_stream = mock_scan_stream
mock_clamav_class.return_value = mock_client
mock_client.ping = AsyncMock(return_value=True)
mock_client.instream = AsyncMock(side_effect=mock_instream)
# Replace the client instance that was created in the constructor
scanner._client = mock_client
content1 = b"file1 content"
content2 = b"file2 content"

View File

@@ -17,6 +17,18 @@ aiormq = ">=6.8,<6.9"
exceptiongroup = ">=1,<2"
yarl = "*"
[[package]]
name = "aioclamd"
version = "1.0.0"
description = "Asynchronous client for virus scanning with ClamAV"
optional = false
python-versions = ">=3.7,<4.0"
groups = ["main"]
files = [
{file = "aioclamd-1.0.0-py3-none-any.whl", hash = "sha256:4727da3953a4b38be4c2de1acb6b3bb3c94c1c171dcac780b80234ee6253f3d9"},
{file = "aioclamd-1.0.0.tar.gz", hash = "sha256:7b14e94e3a2285cc89e2f4d434e2a01f322d3cb95476ce2dda015a7980876047"},
]
[[package]]
name = "aiodns"
version = "3.4.0"
@@ -3856,17 +3868,6 @@ cffi = ">=1.5.0"
[package.extras]
idna = ["idna (>=2.1)"]
[[package]]
name = "pyclamd"
version = "0.4.0"
description = "pyClamd is a python interface to Clamd (Clamav daemon)."
optional = false
python-versions = "*"
groups = ["main"]
files = [
{file = "pyClamd-0.4.0.tar.gz", hash = "sha256:ddd588577e5db123760b6ddaac46b5c4b1d9044a00b5d9422de59f83a55c20fe"},
]
[[package]]
name = "pycodestyle"
version = "2.13.0"
@@ -5017,6 +5018,27 @@ statsig = ["statsig (>=0.55.3)"]
tornado = ["tornado (>=6)"]
unleash = ["UnleashClient (>=6.0.1)"]
[[package]]
name = "setuptools"
version = "80.9.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"},
{file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"},
]
[package.extras]
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""]
core = ["importlib_metadata (>=6) ; python_version < \"3.10\"", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1) ; python_version < \"3.11\"", "wheel (>=0.43.0)"]
cover = ["pytest-cov"]
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
enabler = ["pytest-enabler (>=2.2)"]
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
type = ["importlib_metadata (>=7.0.2) ; python_version < \"3.10\"", "jaraco.develop (>=7.21) ; sys_platform != \"cygwin\"", "mypy (==1.14.*)", "pytest-mypy"]
[[package]]
name = "sgmllib3k"
version = "1.0.0"
@@ -6380,4 +6402,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.13"
content-hash = "bd117a21d817a2a735ed923c383713dd08469938ef5f7d07c4222da1acca2b5c"
content-hash = "b5c1201f27ee8d05d5d8c89702123df4293f124301d1aef7451591a351872260"

View File

@@ -68,8 +68,9 @@ zerobouncesdk = "^1.1.1"
# NOTE: please insert new dependencies in their alphabetical location
pytest-snapshot = "^0.9.0"
aiofiles = "^24.1.0"
pyclamd = "^0.4.0"
tiktoken = "^0.9.0"
aioclamd = "^1.0.0"
setuptools = "^80.9.0"
[tool.poetry.group.dev.dependencies]
aiohappyeyeballs = "^2.6.1"