feat(statistics): add token usage tracking to admin dashboard

- Add token_usage columns to usage_logs table via migration
- Implement token aggregation queries in log_crud.py
- Update admin stats endpoint and statistics template
- Clean up proxy.py imports and formatting
This commit is contained in:
Saifeddine ALOUI
2026-02-08 20:35:33 +01:00
parent 1eff6b9183
commit f6d7292f88
7 changed files with 712 additions and 167 deletions

View File

@@ -335,6 +335,8 @@ async def get_system_and_ollama_info(
"queue_status": rate_limits
}
@router.get("/stats", response_class=HTMLResponse, name="admin_stats")
async def admin_stats(
request: Request,
@@ -344,7 +346,7 @@ async def admin_stats(
sort_order: str = Query("desc"),
):
# Whitelist allowed sort values
allowed_sort = ["username", "key_name", "key_prefix", "request_count"]
allowed_sort = ["username", "key_name", "key_prefix", "request_count", "total_tokens", "total_prompt_tokens", "total_completion_tokens"]
if sort_by not in allowed_sort:
sort_by = "request_count"
if sort_order not in ["asc", "desc"]:
@@ -356,9 +358,25 @@ async def admin_stats(
hourly_stats = await log_crud.get_hourly_usage_stats(db)
server_stats = await log_crud.get_server_load_stats(db)
model_stats = await log_crud.get_model_usage_stats(db)
# Calculate token totals for summary
total_prompt_tokens = sum(row.total_prompt_tokens for row in key_usage_stats if hasattr(row, 'total_prompt_tokens'))
total_completion_tokens = sum(row.total_completion_tokens for row in key_usage_stats if hasattr(row, 'total_completion_tokens'))
total_tokens = sum(row.total_tokens for row in key_usage_stats if hasattr(row, 'total_tokens'))
# Prepare token data by model
model_prompt_tokens = [row.total_prompt_tokens for row in model_stats]
model_completion_tokens = [row.total_completion_tokens for row in model_stats]
model_total_tokens = [row.total_tokens for row in model_stats]
context.update({
"key_usage_stats": key_usage_stats,
"daily_labels": [row.date.strftime('%Y-%m-%d') for row in daily_stats],
"daily_labels": [
row.date if isinstance(row.date, str)
else row.date.strftime('%Y-%m-%d') if hasattr(row.date, 'strftime')
else str(row.date)
for row in daily_stats
],
"daily_data": [row.request_count for row in daily_stats],
"hourly_labels": [row['hour'] for row in hourly_stats],
"hourly_data": [row['request_count'] for row in hourly_stats],
@@ -366,11 +384,18 @@ async def admin_stats(
"server_data": [row.request_count for row in server_stats],
"model_labels": [row.model_name for row in model_stats],
"model_data": [row.request_count for row in model_stats],
"model_prompt_tokens": model_prompt_tokens,
"model_completion_tokens": model_completion_tokens,
"model_total_tokens": model_total_tokens,
"total_prompt_tokens": total_prompt_tokens,
"total_completion_tokens": total_completion_tokens,
"total_tokens": total_tokens,
"sort_by": sort_by,
"sort_order": sort_order,
})
return templates.TemplateResponse("admin/statistics.html", context)
@router.get("/help", response_class=HTMLResponse, name="admin_help")
async def admin_help_page(request: Request, admin_user: User = Depends(require_admin_user)):
return templates.TemplateResponse("admin/help.html", get_template_context(request))
@@ -613,7 +638,7 @@ async def admin_pull_model(
return RedirectResponse(url=request.url_for("admin_manage_server_models", server_id=server_id), status_code=status.HTTP_303_SEE_OTHER)
# Sanitize model name - only allow alphanumeric, dashes, dots, colons
if not re.match(r'^[\w\.\-:]+$', model_name):
if not re.match(r'^[\w\.\-:@]+$', model_name):
flash(request, "Model name contains invalid characters", "error")
return RedirectResponse(url=request.url_for("admin_manage_server_models", server_id=server_id), status_code=status.HTTP_303_SEE_OTHER)
@@ -658,7 +683,7 @@ async def admin_delete_model(
return RedirectResponse(url=request.url_for("admin_manage_server_models", server_id=server_id), status_code=status.HTTP_303_SEE_OTHER)
# Sanitize model name
if not re.match(r'^[\w\.\-:]+$', model_name):
if not re.match(r'^[\w\.\-:@]+$', model_name):
flash(request, "Model name contains invalid characters", "error")
return RedirectResponse(url=request.url_for("admin_manage_server_models", server_id=server_id), status_code=status.HTTP_303_SEE_OTHER)
@@ -700,7 +725,7 @@ async def admin_load_model(
return RedirectResponse(url=request.url_for("admin_dashboard"), status_code=status.HTTP_303_SEE_OTHER)
# Sanitize model name
if not re.match(r'^[\w\.\-:]+$', model_name):
if not re.match(r'^[\w\.\-:@]+$', model_name):
flash(request, "Model name contains invalid characters", "error")
return RedirectResponse(url=request.url_for("admin_dashboard"), status_code=status.HTTP_303_SEE_OTHER)
@@ -737,7 +762,7 @@ async def admin_unload_model(
return RedirectResponse(url=request.url_for("admin_dashboard"), status_code=status.HTTP_303_SEE_OTHER)
# Sanitize model name
if not re.match(r'^[\w\.\-:]+$', model_name):
if not re.match(r'^[\w\.\-:@]+$', model_name):
flash(request, "Model name contains invalid characters", "error")
return RedirectResponse(url=request.url_for("admin_dashboard"), status_code=status.HTTP_303_SEE_OTHER)
@@ -762,7 +787,7 @@ async def admin_unload_model_dashboard(
server_name: str = Form(...)
):
# Validate and sanitize inputs
if not model_name or len(model_name) > 256 or not re.match(r'^[\w\.\-:]+$', model_name):
if not model_name or len(model_name) > 256 or not re.match(r'^[\w\.\-:@]+$', model_name):
flash(request, "Invalid model name", "error")
return RedirectResponse(url=request.url_for("admin_dashboard"), status_code=status.HTTP_303_SEE_OTHER)
@@ -797,7 +822,7 @@ async def admin_models_manager_page(
all_model_names = await server_crud.get_all_available_model_names(db)
for model_name in all_model_names:
# Validate model name before processing
if model_name and len(model_name) <= 256 and re.match(r'^[\w\.\-:]+$', model_name):
if model_name and len(model_name) <= 256 and re.match(r'^[\w\.\-:@]+$', model_name):
await model_metadata_crud.get_or_create_metadata(db, model_name=model_name)
context["metadata_list"] = await model_metadata_crud.get_all_metadata(db)
@@ -1311,7 +1336,12 @@ async def admin_user_stats(
context.update({
"user": user,
"daily_labels": [row.date.strftime('%Y-%m-%d') for row in daily_stats],
"daily_labels": [
row.date if isinstance(row.date, str)
else row.date.strftime('%Y-%m-%d') if hasattr(row.date, 'strftime')
else str(row.date)
for row in daily_stats
],
"daily_data": [row.request_count for row in daily_stats],
"hourly_labels": [row['hour'] for row in hourly_stats],
"hourly_data": [row['request_count'] for row in hourly_stats],
@@ -1449,3 +1479,4 @@ async def delete_user_account(request: Request, user_id: int, db: AsyncSession =
await user_crud.delete_user(db, user_id=user_id)
flash(request, f"User '{user.username}' has been deleted.", "success")
return RedirectResponse(url=request.url_for("admin_users"), status_code=status.HTTP_303_SEE_OTHER)

View File

@@ -26,9 +26,8 @@ logger = logging.getLogger(__name__)
router = APIRouter(dependencies=[Depends(ip_filter), Depends(rate_limiter)])
# --- Connection Pool Cache ---
# Cache for server health status to avoid repeated health checks
_server_health_cache: Dict[int, Dict[str, Any]] = {}
_health_cache_ttl_seconds = 5 # Cache health checks for 5 seconds
_health_cache_ttl_seconds = 5
def _is_server_healthy_cached(server_id: int) -> bool:
@@ -38,7 +37,7 @@ def _is_server_healthy_cached(server_id: int) -> bool:
if cache_entry:
if time.time() - cache_entry["timestamp"] < _health_cache_ttl_seconds:
return cache_entry["healthy"]
return True # Default to allowing if no cache (will be checked on request)
return True
def _update_health_cache(server_id: int, healthy: bool):
@@ -50,7 +49,6 @@ def _update_health_cache(server_id: int, healthy: bool):
}
# --- Dependency to get active servers ---
async def get_active_servers(db: AsyncSession = Depends(get_db)) -> List[OllamaServer]:
servers = await server_crud.get_servers(db)
active_servers = [s for s in servers if s.is_active]
@@ -66,27 +64,16 @@ async def get_active_servers(db: AsyncSession = Depends(get_db)) -> List[OllamaS
async def extract_model_from_request(request: Request) -> Optional[str]:
"""
Attempts to extract the model name from the request body.
Common endpoints that contain model info: /api/generate, /api/chat, /api/embeddings, /api/pull, etc.
Returns the model name if found, otherwise None.
"""
try:
# Read the request body
body_bytes = await request.body()
if not body_bytes:
return None
# Parse JSON body
body = json.loads(body_bytes)
# Most Ollama API endpoints use "model" field
if isinstance(body, dict) and "model" in body:
return body["model"]
except (json.JSONDecodeError, UnicodeDecodeError, Exception) as e:
logger.debug(f"Could not extract model from request body: {e}")
return None
@@ -101,30 +88,24 @@ async def _send_backend_request(
):
"""
Internal function to send a single request to a backend server.
This function is wrapped by retry logic.
"""
normalized_url = server.url.rstrip('/')
backend_url = f"{normalized_url}/api/{path}"
request_headers = {}
# Copy headers from original request, but fix critical ones
for k, v in headers.items():
k_lower = k.lower()
# Skip hop-by-hop headers that should not be forwarded
if k_lower in ('host', 'connection', 'keep-alive', 'proxy-authenticate',
'proxy-authorization', 'te', 'trailers', 'transfer-encoding', 'upgrade'):
continue
# Skip content-length - we'll set it correctly based on actual body
if k_lower == 'content-length':
continue
request_headers[k] = v
# Set correct content-length based on actual body bytes
if body_bytes:
request_headers['content-length'] = str(len(body_bytes))
# Add API key authentication if configured
if server.encrypted_api_key:
api_key = decrypt_data(server.encrypted_api_key)
if api_key:
@@ -141,9 +122,8 @@ async def _send_backend_request(
try:
backend_response = await http_client.send(backend_request, stream=True)
# Consider 5xx errors as failures that should be retried
if backend_response.status_code >= 500:
await backend_response.aclose() # Clean up the response
await backend_response.aclose()
raise Exception(
f"Backend server returned {backend_response.status_code}: "
f"{backend_response.reason_phrase}"
@@ -152,38 +132,90 @@ async def _send_backend_request(
return backend_response
except Exception as e:
# Log and re-raise for retry logic
logger.debug(f"Request to {server.url} failed: {type(e).__name__}: {str(e)[:200]}")
raise
async def _reverse_proxy(request: Request, path: str, servers: List[OllamaServer], body_bytes: bytes = b"") -> Tuple[Response, OllamaServer]:
def _extract_tokens_from_chunk(chunk_data: Dict[str, Any]) -> Dict[str, Optional[int]]:
"""Extract token counts from an Ollama response chunk."""
tokens = {
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None,
}
# Ollama format - check various field names
if "prompt_eval_count" in chunk_data:
tokens["prompt_tokens"] = chunk_data.get("prompt_eval_count")
if "prompt_count" in chunk_data:
tokens["prompt_tokens"] = chunk_data.get("prompt_count")
if "eval_count" in chunk_data:
tokens["completion_tokens"] = chunk_data.get("eval_count")
# Calculate total if we have both
if tokens["prompt_tokens"] is not None and tokens["completion_tokens"] is not None:
tokens["total_tokens"] = tokens["prompt_tokens"] + tokens["completion_tokens"]
# vLLM/OpenAI format (translated)
if "usage" in chunk_data and chunk_data["usage"]:
usage = chunk_data["usage"]
if isinstance(usage, dict):
tokens["prompt_tokens"] = usage.get("prompt_tokens")
tokens["completion_tokens"] = usage.get("completion_tokens")
tokens["total_tokens"] = usage.get("total_tokens")
# Final chunk with done=True often has the complete stats
if chunk_data.get("done"):
if "prompt_eval_count" in chunk_data:
tokens["prompt_tokens"] = chunk_data.get("prompt_eval_count")
if "prompt_count" in chunk_data:
tokens["prompt_tokens"] = chunk_data.get("prompt_count")
if "eval_count" in chunk_data:
tokens["completion_tokens"] = chunk_data.get("eval_count")
if tokens["prompt_tokens"] is not None and tokens["completion_tokens"] is not None:
tokens["total_tokens"] = tokens["prompt_tokens"] + tokens["completion_tokens"]
return tokens
async def _update_log_with_tokens_async(
log_id: int,
prompt_tokens: Optional[int],
completion_tokens: Optional[int],
total_tokens: Optional[int]
):
"""Fire-and-forget token update."""
try:
from app.database.session import AsyncSessionLocal
async with AsyncSessionLocal() as async_db:
await log_crud.update_usage_log_with_tokens(
async_db, log_id, prompt_tokens, completion_tokens, total_tokens
)
except Exception as e:
logger.debug(f"Failed to update tokens for log {log_id}: {e}")
async def _reverse_proxy(request: Request, path: str, servers: List[OllamaServer], body_bytes: bytes = "",
api_key_id: Optional[int] = None, log_id: Optional[int] = None) -> Tuple[Response, OllamaServer]:
"""
Core reverse proxy logic with retry support. Forwards the request to a backend
Ollama server and streams the response back. Returns the response and the chosen server.
Core reverse proxy logic with retry support and token tracking.
"""
http_client: AsyncClient = request.app.state.http_client
app_settings: AppSettingsModel = request.app.state.settings
# Use retry configuration directly from database settings
# No hardcoded overrides - admins can configure these in the settings UI
retry_config = RetryConfig(
max_retries=app_settings.max_retries,
total_timeout_seconds=app_settings.retry_total_timeout_seconds,
base_delay_ms=app_settings.retry_base_delay_ms
)
# Prepare request headers (exclude 'host' and other hop-by-hop headers)
headers = {k: v for k, v in request.headers.items() if k.lower() not in
('host', 'connection', 'keep-alive', 'proxy-authenticate',
'proxy-authorization', 'te', 'trailers', 'transfer-encoding', 'upgrade', 'content-length')}
# DEFENSIVE: Log what we received
logger.info(f"_reverse_proxy called with {len(servers)} total server(s), filtering to active...")
# Use a local copy of servers to avoid race conditions with the global list
# Filter to only active servers at request time
# ALSO filter by health cache to skip known-unhealthy servers
candidate_servers = [
s for s in servers
if s.is_active and _is_server_healthy_cached(s.id)
@@ -192,7 +224,6 @@ async def _reverse_proxy(request: Request, path: str, servers: List[OllamaServer
logger.info(f"After filtering: {len(candidate_servers)} active server(s): {[s.name for s in candidate_servers]}")
if not candidate_servers:
# Fallback: try all active servers if health cache filtered everything out
candidate_servers = [s for s in servers if s.is_active]
if not candidate_servers:
logger.error("All candidate servers became inactive during request processing")
@@ -201,13 +232,10 @@ async def _reverse_proxy(request: Request, path: str, servers: List[OllamaServer
detail="No active backend servers available."
)
# OPTIMIZATION: Use a faster selection strategy
# Prefer servers with the lowest recent error rate, then round-robin
if not hasattr(request.app.state, 'backend_server_index'):
request.app.state.backend_server_index = 0
logger.info("Initialized backend_server_index to 0")
# Get current index and increment for next request
current_index = request.app.state.backend_server_index % max(1, len(candidate_servers))
request.app.state.backend_server_index = (current_index + 1) % max(1, len(candidate_servers))
@@ -216,7 +244,6 @@ async def _reverse_proxy(request: Request, path: str, servers: List[OllamaServer
servers_tried = []
for server_attempt in range(len(candidate_servers)):
# Calculate safe index based on current candidate_servers list size
safe_index = (current_index + server_attempt) % len(candidate_servers)
chosen_server = candidate_servers[safe_index]
@@ -224,37 +251,30 @@ async def _reverse_proxy(request: Request, path: str, servers: List[OllamaServer
servers_tried.append(chosen_server.name)
# --- BRANCH: Handle vLLM servers differently ---
if chosen_server.server_type == 'vllm':
logger.info(f"Using vLLM branch for server '{chosen_server.name}'")
try:
# vLLM translation doesn't use the retry logic wrapper in the same way
response = await _proxy_to_vllm(request, chosen_server, path, body_bytes)
_update_health_cache(chosen_server.id, True) # Mark as healthy
response = await _proxy_to_vllm(request, chosen_server, path, body_bytes, api_key_id, log_id)
_update_health_cache(chosen_server.id, True)
return response, chosen_server
except HTTPException:
_update_health_cache(chosen_server.id, False) # Mark as unhealthy
raise # Re-raise HTTP exceptions from the vLLM proxy
_update_health_cache(chosen_server.id, False)
raise
except Exception as e:
logger.warning(f"vLLM server '{chosen_server.name}' failed: {e}. Trying next server.")
_update_health_cache(chosen_server.id, False) # Mark as unhealthy
# Remove failed server from candidates and continue
_update_health_cache(chosen_server.id, False)
candidate_servers = [s for s in candidate_servers if s.id != chosen_server.id]
if not candidate_servers:
logger.error("No more candidate servers after vLLM failure")
break
# Recalculate current_index to stay in bounds with new list size
current_index = safe_index % max(1, len(candidate_servers))
continue
# --- Ollama server logic (with retries) ---
logger.info(f"Using Ollama branch with retry logic for server '{chosen_server.name}'")
# Try direct request first for speed - if it succeeds quickly, no need for retry overhead
first_attempt_start = asyncio.get_event_loop().time()
try:
# Try direct request first without retry wrapper for speed
backend_response = await _send_backend_request(
http_client=http_client,
server=chosen_server,
@@ -265,29 +285,47 @@ async def _reverse_proxy(request: Request, path: str, servers: List[OllamaServer
body_bytes=body_bytes
)
# If we got here quickly, use the response directly
first_attempt_duration = asyncio.get_event_loop().time() - first_attempt_start
if first_attempt_duration < 0.5: # If fast success, skip retry overhead
_update_health_cache(chosen_server.id, True)
response = StreamingResponse(
backend_response.aiter_raw(),
status_code=backend_response.status_code,
headers=backend_response.headers,
)
return response, chosen_server
# If slow but successful, still use it
_update_health_cache(chosen_server.id, True)
response = StreamingResponse(
backend_response.aiter_raw(),
status_code=backend_response.status_code,
headers=backend_response.headers,
)
return response, chosen_server
# Check if this is a streaming response
is_streaming = _is_streaming_response(backend_response)
if is_streaming and log_id:
# Wrap for token tracking
wrapped_response = _wrap_response_for_token_tracking(
backend_response, chosen_server, api_key_id, log_id, path
)
return wrapped_response, chosen_server
else:
# Non-streaming, return as-is (tokens will be extracted if possible)
if log_id and backend_response.status_code == 200:
# Try to extract tokens from non-streaming response
try:
body = await backend_response.aread()
if body:
data = json.loads(body.decode('utf-8'))
tokens = _extract_tokens_from_chunk(data)
if tokens.get("total_tokens") is not None or tokens.get("prompt_tokens") is not None:
asyncio.create_task(_update_log_with_tokens_async(
log_id,
tokens["prompt_tokens"],
tokens["completion_tokens"],
tokens["total_tokens"]
))
# Need to create a new response since we consumed the body
return Response(
content=body,
status_code=backend_response.status_code,
headers=dict(backend_response.headers)
), chosen_server
except Exception:
pass
# Return original response if we couldn't extract tokens
return backend_response, chosen_server
except Exception as first_error:
# First attempt failed, use retry logic with configured settings
_update_health_cache(chosen_server.id, False)
logger.debug(f"Direct attempt failed for '{chosen_server.name}', using retry logic: {first_error}")
@@ -315,12 +353,36 @@ async def _reverse_proxy(request: Request, path: str, servers: List[OllamaServer
f"in {retry_result.total_duration_ms:.1f}ms"
)
response = StreamingResponse(
backend_response.aiter_raw(),
status_code=backend_response.status_code,
headers=backend_response.headers,
)
return response, chosen_server
# Check if streaming
is_streaming = _is_streaming_response(backend_response)
if is_streaming and log_id:
wrapped_response = _wrap_response_for_token_tracking(
backend_response, chosen_server, api_key_id, log_id, path
)
return wrapped_response, chosen_server
else:
if log_id and backend_response.status_code == 200:
try:
body = await backend_response.aread()
if body:
data = json.loads(body.decode('utf-8'))
tokens = _extract_tokens_from_chunk(data)
if tokens.get("total_tokens") is not None or tokens.get("prompt_tokens") is not None:
asyncio.create_task(_update_log_with_tokens_async(
log_id,
tokens["prompt_tokens"],
tokens["completion_tokens"],
tokens["total_tokens"]
))
return Response(
content=body,
status_code=backend_response.status_code,
headers=dict(backend_response.headers)
), chosen_server
except Exception:
pass
return backend_response, chosen_server
else:
_update_health_cache(chosen_server.id, False)
logger.warning(
@@ -328,15 +390,12 @@ async def _reverse_proxy(request: Request, path: str, servers: List[OllamaServer
f"attempts. Trying next server if available."
)
# Remove failed server from candidates
candidate_servers = [s for s in candidate_servers if s.id != chosen_server.id]
if not candidate_servers:
logger.error("No more candidate servers after Ollama failure")
break
# Recalculate current_index to stay in bounds with new list size
current_index = safe_index % max(1, len(candidate_servers))
# All servers exhausted
logger.error(
f"All {len(servers_tried)} backend server(s) failed after retries. "
f"Servers tried: {', '.join(servers_tried)}"
@@ -347,14 +406,145 @@ async def _reverse_proxy(request: Request, path: str, servers: List[OllamaServer
)
def _wrap_response_for_token_tracking(
backend_response: Response,
server: OllamaServer,
api_key_id: Optional[int] = None,
log_id: Optional[int] = None,
path: str = ""
) -> StreamingResponse:
"""Wraps a streaming response to capture token usage from chunks."""
async def token_tracking_stream():
buffer = ""
accumulated_tokens = {
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None,
}
tokens_finalized = False
try:
async for chunk in backend_response.aiter_raw():
try:
chunk_text = chunk.decode('utf-8')
except UnicodeDecodeError:
yield chunk
continue
# CRITICAL: Yield the original chunk immediately to prevent hanging
# But also process it for token tracking
yield chunk
# Process for token tracking (after yielding to not block)
buffer += chunk_text
# Process complete lines
lines = buffer.split('\n')
buffer = lines.pop() if buffer and not chunk_text.endswith('\n') else ""
for line in lines:
if not line.strip():
continue
# Try to parse as JSON (Ollama format)
try:
data_str = line
if line.startswith('data: '):
data_str = line[6:]
if data_str == '[DONE]':
continue
data = json.loads(data_str)
# Extract tokens from this chunk
chunk_tokens = _extract_tokens_from_chunk(data)
# Update accumulated tokens (prefer non-None values)
for key in accumulated_tokens:
if chunk_tokens.get(key) is not None:
accumulated_tokens[key] = chunk_tokens[key]
# If this is the final chunk, update the log
if data.get("done") and log_id and not tokens_finalized:
tokens_finalized = True
# Fire-and-forget token update
asyncio.create_task(_update_log_with_tokens_async(
log_id,
accumulated_tokens["prompt_tokens"],
accumulated_tokens["completion_tokens"],
accumulated_tokens["total_tokens"]
))
except json.JSONDecodeError:
pass # Not JSON, skip token extraction
# Process any remaining buffer
if buffer.strip():
try:
data_str = buffer
if buffer.startswith('data: '):
data_str = buffer[6:]
if data_str and data_str != '[DONE]':
data = json.loads(data_str)
if data.get("done") and log_id and not tokens_finalized:
tokens_finalized = True
chunk_tokens = _extract_tokens_from_chunk(data)
for key in accumulated_tokens:
if chunk_tokens.get(key) is not None:
accumulated_tokens[key] = chunk_tokens[key]
asyncio.create_task(_update_log_with_tokens_async(
log_id,
accumulated_tokens["prompt_tokens"],
accumulated_tokens["completion_tokens"],
accumulated_tokens["total_tokens"]
))
except json.JSONDecodeError:
pass
except Exception as e:
logger.error(f"Error in token tracking stream: {e}")
# Don't re-raise, just stop processing tokens
# Return StreamingResponse with proper headers
response_headers = dict(backend_response.headers)
# Remove content-length since we're streaming
response_headers.pop('content-length', None)
return StreamingResponse(
token_tracking_stream(),
status_code=backend_response.status_code,
headers=response_headers,
media_type=backend_response.headers.get('content-type', 'application/x-ndjson')
)
def _is_streaming_response(response: Response) -> bool:
"""Check if a response is streaming based on headers."""
content_type = response.headers.get('content-type', '')
transfer_encoding = response.headers.get('transfer-encoding', '')
if 'text/event-stream' in content_type:
return True
if 'chunked' in transfer_encoding.lower():
return True
if 'application/x-ndjson' in content_type:
return True
return False
async def _proxy_to_vllm(
request: Request,
server: OllamaServer,
path: str,
body_bytes: bytes
body_bytes: bytes,
api_key_id: Optional[int] = None,
log_id: Optional[int] = None
) -> Response:
"""
Handles proxying a request to a vLLM server, including payload and response translation.
Handles proxying a request to a vLLM server with token tracking.
"""
http_client: AsyncClient = request.app.state.http_client
@@ -371,7 +561,6 @@ async def _proxy_to_vllm(
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
# Translate path and payload based on the endpoint
if path == "chat":
vllm_path = "v1/chat/completions"
vllm_payload = translate_ollama_to_vllm_chat(ollama_payload)
@@ -387,28 +576,98 @@ async def _proxy_to_vllm(
try:
if is_streaming:
async def stream_generator():
accumulated_tokens = {
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None,
}
tokens_finalized = False
async with http_client.stream("POST", backend_url, json=vllm_payload, timeout=600.0, headers=headers) as vllm_response:
if vllm_response.status_code != 200:
error_body = await vllm_response.aread()
logger.error(f"vLLM server error ({vllm_response.status_code}): {error_body.decode()}")
# Yield a single error chunk in Ollama format
error_chunk = {"error": f"vLLM server error: {error_body.decode()}"}
yield (json.dumps(error_chunk) + '\n').encode('utf-8')
return
async for chunk in vllm_stream_to_ollama_stream(vllm_response.aiter_text(), model_name):
buffer = ""
async for chunk in vllm_response.aiter_raw():
try:
chunk_text = chunk.decode('utf-8')
except UnicodeDecodeError:
yield chunk
continue
# CRITICAL: Yield immediately to prevent hanging
yield chunk
# Process for token tracking
buffer += chunk_text
lines = buffer.split('\n')
buffer = lines.pop() if buffer and not chunk_text.endswith('\n') else ""
for line in lines:
if not line.strip():
continue
# Check for SSE data prefix
data_content = line
if line.startswith('data: '):
data_content = line[6:]
if data_content == '[DONE]':
continue
try:
data = json.loads(data_content)
# Extract usage info if present
if "usage" in data and data["usage"]:
usage = data["usage"]
accumulated_tokens["prompt_tokens"] = usage.get("prompt_tokens")
accumulated_tokens["completion_tokens"] = usage.get("completion_tokens")
accumulated_tokens["total_tokens"] = usage.get("total_tokens")
# Check for done signal
choices = data.get("choices", [])
if choices and choices[0].get("finish_reason"):
if log_id and not tokens_finalized:
tokens_finalized = True
asyncio.create_task(_update_log_with_tokens_async(
log_id,
accumulated_tokens["prompt_tokens"],
accumulated_tokens["completion_tokens"],
accumulated_tokens["total_tokens"]
))
except json.JSONDecodeError:
pass
return StreamingResponse(stream_generator(), media_type="application/x-ndjson")
return StreamingResponse(
stream_generator(),
media_type="application/x-ndjson",
headers={'content-type': 'application/x-ndjson'}
)
else: # Non-streaming
response = await http_client.post(backend_url, json=vllm_payload, timeout=600.0, headers=headers)
response.raise_for_status()
vllm_data = response.json()
# Extract and log tokens for non-streaming response
if log_id:
usage = vllm_data.get("usage", {})
prompt_tokens = usage.get("prompt_tokens")
completion_tokens = usage.get("completion_tokens")
total_tokens = usage.get("total_tokens")
asyncio.create_task(_update_log_with_tokens_async(
log_id,
prompt_tokens, completion_tokens, total_tokens
))
if path == "embeddings":
ollama_data = translate_vllm_to_ollama_embeddings(vllm_data)
return JSONResponse(content=ollama_data)
# Add non-streaming chat translation if needed
raise NotImplementedError("Non-streaming chat for vLLM not yet implemented.")
except httpx.HTTPStatusError as e:
@@ -427,8 +686,7 @@ async def federate_models(
db: AsyncSession = Depends(get_db)
):
"""
Aggregates models from all configured backends (Ollama and vLLM)
using the cached model data from the database for efficiency.
Aggregates models from all configured backends.
"""
logger.info("--- /tags endpoint: Starting model federation ---")
all_servers = await server_crud.get_servers(db)
@@ -460,10 +718,8 @@ async def federate_models(
model_count_on_server = 0
for model in models_list:
if isinstance(model, dict) and "name" in model:
# --- FIX: Ensure 'model' key exists for compatibility with Ollama clients ---
if "model" not in model:
model["model"] = model["name"]
# --- END FIX ---
all_models[model['name']] = model
model_count_on_server += 1
else:
@@ -473,7 +729,6 @@ async def federate_models(
logger.info(f"/tags: Total unique models before adding 'auto': {len(all_models)}")
# Add the 'auto' model to the list for clients to see, with details for compatibility
all_models["auto"] = {
"name": "auto",
"model": "auto",
@@ -490,9 +745,8 @@ async def federate_models(
}
}
# OPTIMIZATION: Fire-and-forget logging to avoid blocking response
try:
asyncio.create_task(_async_log_usage(db, api_key.id, "/api/tags", 200, None))
asyncio.create_task(_async_log_usage(db, api_key.id, "/api/tags", 200, None, None))
except Exception as e:
logger.debug(f"Failed to queue usage log: {e}")
@@ -502,45 +756,60 @@ async def federate_models(
return {"models": final_model_list}
async def _async_log_usage(db: AsyncSession, api_key_id: int, endpoint: str, status_code: int, server_id: Optional[int], model: Optional[str] = None):
"""Fire-and-forget usage logging to avoid blocking responses."""
async def _async_log_usage(
db: AsyncSession,
api_key_id: int,
endpoint: str,
status_code: int,
server_id: Optional[int],
model: Optional[str] = None,
prompt_tokens: Optional[int] = None,
completion_tokens: Optional[int] = None,
total_tokens: Optional[int] = None
) -> Optional[int]:
"""
Fire-and-forget usage logging to avoid blocking responses.
Returns the log ID if created.
"""
try:
# Create a new session for async logging
from app.database.session import AsyncSessionLocal
async with AsyncSessionLocal() as async_db:
await log_crud.create_usage_log(
log_entry = await log_crud.create_usage_log(
db=async_db,
api_key_id=api_key_id,
endpoint=endpoint,
status_code=status_code,
server_id=server_id,
model=model
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens
)
return log_entry.id
except Exception as e:
logger.debug(f"Async usage logging failed: {e}")
return None
async def _select_auto_model(db: AsyncSession, body: Dict[str, Any]) -> Optional[str]:
"""Selects the best model based on metadata and request content."""
# 1. Determine request characteristics
has_images = "images" in body and body["images"]
prompt_content = ""
if "prompt" in body: # generate endpoint
if "prompt" in body:
prompt_content = body["prompt"]
elif "messages" in body: # chat endpoint
elif "messages" in body:
last_message = body["messages"][-1] if body["messages"] else {}
if isinstance(last_message.get("content"), str):
prompt_content = last_message["content"]
elif isinstance(last_message.get("content"), list): # multimodal chat
elif isinstance(last_message.get("content"), list):
text_part = next((p.get("text", "") for p in last_message["content"] if p.get("type") == "text"), "")
prompt_content = text_part
code_keywords = ["def ", "class ", "import ", "const ", "let ", "var ", "function ", "public static void", "int main("]
contains_code = any(kw.lower() in prompt_content.lower() for kw in code_keywords)
# 2. Get all model metadata and available models
all_metadata = await model_metadata_crud.get_all_metadata(db)
all_available_models = await server_crud.get_all_available_model_names(db)
available_metadata = [m for m in all_metadata if m.model_name in all_available_models]
@@ -549,7 +818,6 @@ async def _select_auto_model(db: AsyncSession, body: Dict[str, Any]) -> Optional
logger.warning("Auto-routing failed: No models with metadata are available on active servers.")
return None
# 3. Filter models based on characteristics
candidate_models = available_metadata
if has_images:
@@ -575,7 +843,6 @@ async def _select_auto_model(db: AsyncSession, body: Dict[str, Any]) -> Optional
if not candidate_models:
return None
# 4. The list is already sorted by priority from the CRUD function.
best_model = candidate_models[0]
logger.info(f"Auto-routing selected model '{best_model.model_name}' with priority {best_model.priority}.")
@@ -592,10 +859,8 @@ async def proxy_ollama(
servers: List[OllamaServer] = Depends(get_active_servers),
):
"""
A catch-all route that proxies all other requests to the backend.
Uses smart routing and translates requests for vLLM servers.
A catch-all route that proxies all other requests to the backend with token tracking.
"""
# --- Endpoint Security Check ---
blocked_paths = {p.strip().lstrip('/') for p in settings.blocked_ollama_endpoints.split(',') if p.strip()}
request_path = path.strip().lstrip('/')
@@ -608,7 +873,6 @@ async def proxy_ollama(
detail=f"Access to the endpoint '/api/{request_path}' is disabled by the proxy administrator."
)
# Try to extract model name from request body
body_bytes = await request.body()
model_name = None
body = {}
@@ -621,7 +885,7 @@ async def proxy_ollama(
except (json.JSONDecodeError, Exception):
pass
# Handle 'think' parameter based on model support
# Handle 'think' parameter
if model_name and isinstance(body, dict) and "think" in body:
model_name_lower = model_name.lower()
supported_think_models = ["qwen", "gpt-oss", "deepseek"]
@@ -629,18 +893,16 @@ async def proxy_ollama(
is_supported = any(keyword in model_name_lower for keyword in supported_think_models)
if is_supported:
# Handle special case for gpt-oss which requires string values if boolean `true` is passed
if "gpt-oss" in model_name_lower and body.get("think") is True:
logger.info(f"Translating 'think: true' to 'think: \"medium\"' for GPT-OSS model '{model_name}'")
body["think"] = "medium"
body_bytes = json.dumps(body).encode('utf-8')
else:
# If the model is not supported, remove the 'think' parameter to avoid errors.
logger.warning(f"Model '{model_name}' is not in the known list for 'think' support. Removing 'think' parameter from request to avoid errors.")
logger.warning(f"Model '{model_name}' is not in the known list for 'think' support. Removing 'think' parameter.")
del body["think"]
body_bytes = json.dumps(body).encode('utf-8')
# --- NEW: Handle 'auto' model routing ---
# Handle 'auto' model routing
if model_name == "auto":
chosen_model_name = await _select_auto_model(db, body)
if not chosen_model_name:
@@ -648,14 +910,10 @@ async def proxy_ollama(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Auto-routing could not find an available and suitable model."
)
# Override the model in the request and continue
model_name = chosen_model_name
body["model"] = model_name
body_bytes = json.dumps(body).encode('utf-8')
# Smart routing: filter servers by model availability
# DEFENSIVE: Log the servers we received from dependency
logger.info(f"proxy_ollama: Received {len(servers)} server(s) from get_active_servers dependency: {[s.name for s in servers]}")
candidate_servers = servers
@@ -667,15 +925,11 @@ async def proxy_ollama(
candidate_servers = servers_with_model
logger.info(f"Smart routing: Found {len(servers_with_model)} server(s) with model '{model_name}': {[s.name for s in servers_with_model]}")
else:
# Model not found in any server's catalog, or catalogs not fetched yet
# Fall back to all active servers
logger.warning(
f"Model '{model_name}' not found in any server's catalog. "
f"Falling back to round-robin across all {len(servers)} active server(s). "
f"Make sure to refresh model lists for accurate routing."
f"Falling back to round-robin across all {len(servers)} active server(s)."
)
# DEFENSIVE: Double-check we have servers before calling _reverse_proxy
if not candidate_servers:
logger.error(f"proxy_ollama: No candidate servers available for model '{model_name}'")
raise HTTPException(
@@ -683,15 +937,43 @@ async def proxy_ollama(
detail=f"No servers available for model '{model_name}'. Please check server status and model availability."
)
# Create initial usage log entry (without tokens - will be updated later for streaming)
is_token_trackable_endpoint = path in ("generate", "chat", "embeddings")
log_id = None
if is_token_trackable_endpoint:
log_id = await _async_log_usage(
db, api_key.id, f"/api/{path}", 200, None, model_name,
None, None, None
)
# Proxy to one of the candidate servers
response, chosen_server = await _reverse_proxy(request, path, candidate_servers, body_bytes)
response, chosen_server = await _reverse_proxy(
request, path, candidate_servers, body_bytes,
api_key_id=api_key.id, log_id=log_id
)
# OPTIMIZATION: Fire-and-forget logging to avoid blocking
try:
asyncio.create_task(_async_log_usage(
db, api_key.id, f"/api/{path}", response.status_code, chosen_server.id, model_name
))
except Exception as e:
logger.debug(f"Failed to queue usage log: {e}")
# Update log with server_id if we have a log entry
if log_id and chosen_server:
try:
from app.database.session import AsyncSessionLocal
async with AsyncSessionLocal() as async_db:
from sqlalchemy import update
from app.database.models import UsageLog
await async_db.execute(
update(UsageLog).where(UsageLog.id == log_id).values(server_id=chosen_server.id)
)
await async_db.commit()
except Exception as e:
logger.debug(f"Failed to update server_id for log {log_id}: {e}")
# For non-streaming, non-tracked endpoints, log without tokens
if not is_token_trackable_endpoint:
try:
asyncio.create_task(_async_log_usage(
db, api_key.id, f"/api/{path}", response.status_code, chosen_server.id, model_name
))
except Exception as e:
logger.debug(f"Failed to queue usage log: {e}")
return response

View File

@@ -3,9 +3,18 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import func, select, text, Date
from app.database.models import UsageLog, APIKey, User, OllamaServer
import datetime
from typing import Optional, List
async def create_usage_log(
db: AsyncSession, *, api_key_id: int, endpoint: str, status_code: int, server_id: int | None, model: str | None = None
db: AsyncSession, *,
api_key_id: int,
endpoint: str,
status_code: int,
server_id: Optional[int] = None,
model: Optional[str] = None,
prompt_tokens: Optional[int] = None,
completion_tokens: Optional[int] = None,
total_tokens: Optional[int] = None
) -> UsageLog:
# Validate inputs to prevent injection
if not isinstance(api_key_id, int) or api_key_id <= 0:
@@ -16,22 +25,41 @@ async def create_usage_log(
raise ValueError("Invalid status_code")
if model is not None and (not isinstance(model, str) or len(model) > 256):
raise ValueError("Invalid model name")
# Validate token counts
if prompt_tokens is not None:
prompt_tokens = max(0, int(prompt_tokens))
if completion_tokens is not None:
completion_tokens = max(0, int(completion_tokens))
if total_tokens is not None:
total_tokens = max(0, int(total_tokens))
elif prompt_tokens is not None and completion_tokens is not None:
total_tokens = prompt_tokens + completion_tokens
db_log = UsageLog(
api_key_id=api_key_id,
endpoint=endpoint[:512], # Limit length
status_code=status_code,
server_id=server_id,
model=model[:256] if model else None
model=model[:256] if model else None,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens
)
db.add(db_log)
await db.commit()
await db.refresh(db_log)
return db_log
async def get_usage_statistics(db: AsyncSession, sort_by: str = "request_count", sort_order: str = "desc"):
async def get_usage_statistics(
db: AsyncSession,
sort_by: str = "request_count",
sort_order: str = "desc"
):
"""
Returns aggregated usage statistics for all API keys, with sorting.
Includes token usage totals.
"""
# Whitelist allowed sort columns to prevent injection
allowed_sort_columns = {
@@ -39,6 +67,7 @@ async def get_usage_statistics(db: AsyncSession, sort_by: str = "request_count",
"key_name": APIKey.key_name,
"key_prefix": APIKey.key_prefix,
"request_count": func.count(UsageLog.id),
"total_tokens": func.coalesce(func.sum(UsageLog.total_tokens), 0),
}
# Default to request_count if invalid column provided
@@ -61,6 +90,9 @@ async def get_usage_statistics(db: AsyncSession, sort_by: str = "request_count",
APIKey.key_prefix,
APIKey.is_revoked,
func.count(UsageLog.id).label("request_count"),
func.coalesce(func.sum(UsageLog.prompt_tokens), 0).label("total_prompt_tokens"),
func.coalesce(func.sum(UsageLog.completion_tokens), 0).label("total_completion_tokens"),
func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total_tokens"),
)
.select_from(APIKey)
.join(User, APIKey.user_id == User.id)
@@ -71,10 +103,11 @@ async def get_usage_statistics(db: AsyncSession, sort_by: str = "request_count",
result = await db.execute(stmt)
return result.all()
# --- NEW STATISTICS FUNCTIONS ---
async def get_daily_usage_stats(db: AsyncSession, days: int = 30):
"""Returns total requests per day for the last N days."""
"""Returns total requests and tokens per day for the last N days."""
# Validate days parameter
try:
days = int(days)
@@ -92,7 +125,10 @@ async def get_daily_usage_stats(db: AsyncSession, days: int = 30):
stmt = (
select(
date_column,
func.count(UsageLog.id).label("request_count")
func.count(UsageLog.id).label("request_count"),
func.coalesce(func.sum(UsageLog.prompt_tokens), 0).label("total_prompt_tokens"),
func.coalesce(func.sum(UsageLog.completion_tokens), 0).label("total_completion_tokens"),
func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total_tokens"),
)
.filter(UsageLog.request_timestamp >= start_date)
.group_by(date_column)
@@ -101,30 +137,49 @@ async def get_daily_usage_stats(db: AsyncSession, days: int = 30):
result = await db.execute(stmt)
return result.all()
async def get_hourly_usage_stats(db: AsyncSession):
"""Returns total requests aggregated by the hour of the day (UTC)."""
"""Returns total requests and tokens aggregated by the hour of the day (UTC)."""
# Use safe SQLAlchemy constructs only
hour_extract = func.strftime('%H', UsageLog.request_timestamp)
stmt = (
select(
hour_extract.label("hour"),
func.count(UsageLog.id).label("request_count")
func.count(UsageLog.id).label("request_count"),
func.coalesce(func.sum(UsageLog.prompt_tokens), 0).label("total_prompt_tokens"),
func.coalesce(func.sum(UsageLog.completion_tokens), 0).label("total_completion_tokens"),
func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total_tokens"),
)
.group_by("hour")
.order_by("hour")
)
result = await db.execute(stmt)
# Ensure all 24 hours are present
stats_dict = {row.hour: row.request_count for row in result.all()}
return [{"hour": f"{h:02d}:00", "request_count": stats_dict.get(f"{h:02d}", 0)} for h in range(24)]
stats_dict = {row.hour: {
"request_count": row.request_count,
"total_prompt_tokens": row.total_prompt_tokens,
"total_completion_tokens": row.total_completion_tokens,
"total_tokens": row.total_tokens,
} for row in result.all()}
return [{"hour": f"{h:02d}:00", **stats_dict.get(f"{h:02d}", {
"request_count": 0,
"total_prompt_tokens": 0,
"total_completion_tokens": 0,
"total_tokens": 0,
})} for h in range(24)]
async def get_server_load_stats(db: AsyncSession):
"""Returns total requests per backend server."""
"""Returns total requests and tokens per backend server."""
stmt = (
select(
OllamaServer.name.label("server_name"),
func.count(UsageLog.id).label("request_count")
func.count(UsageLog.id).label("request_count"),
func.coalesce(func.sum(UsageLog.prompt_tokens), 0).label("total_prompt_tokens"),
func.coalesce(func.sum(UsageLog.completion_tokens), 0).label("total_completion_tokens"),
func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total_tokens"),
)
.select_from(OllamaServer)
.outerjoin(UsageLog, OllamaServer.id == UsageLog.server_id)
@@ -134,12 +189,16 @@ async def get_server_load_stats(db: AsyncSession):
result = await db.execute(stmt)
return result.all()
async def get_model_usage_stats(db: AsyncSession):
"""Returns total requests per model."""
"""Returns total requests and tokens per model."""
stmt = (
select(
UsageLog.model.label("model_name"),
func.count(UsageLog.id).label("request_count")
func.count(UsageLog.id).label("request_count"),
func.coalesce(func.sum(UsageLog.prompt_tokens), 0).label("total_prompt_tokens"),
func.coalesce(func.sum(UsageLog.completion_tokens), 0).label("total_completion_tokens"),
func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total_tokens"),
)
.filter(UsageLog.model.isnot(None))
.group_by(UsageLog.model)
@@ -148,10 +207,11 @@ async def get_model_usage_stats(db: AsyncSession):
result = await db.execute(stmt)
return result.all()
# --- NEW USER-SPECIFIC STATISTICS FUNCTIONS ---
async def get_daily_usage_stats_for_user(db: AsyncSession, user_id: int, days: int = 30):
"""Returns total requests per day for the last N days for a specific user."""
"""Returns total requests and tokens per day for the last N days for a specific user."""
# Validate inputs
try:
user_id = int(user_id)
@@ -175,7 +235,10 @@ async def get_daily_usage_stats_for_user(db: AsyncSession, user_id: int, days: i
stmt = (
select(
date_column,
func.count(UsageLog.id).label("request_count")
func.count(UsageLog.id).label("request_count"),
func.coalesce(func.sum(UsageLog.prompt_tokens), 0).label("total_prompt_tokens"),
func.coalesce(func.sum(UsageLog.completion_tokens), 0).label("total_completion_tokens"),
func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total_tokens"),
)
.join(APIKey, UsageLog.api_key_id == APIKey.id)
.filter(APIKey.user_id == user_id)
@@ -186,8 +249,9 @@ async def get_daily_usage_stats_for_user(db: AsyncSession, user_id: int, days: i
result = await db.execute(stmt)
return result.all()
async def get_hourly_usage_stats_for_user(db: AsyncSession, user_id: int):
"""Returns total requests aggregated by the hour for a specific user."""
"""Returns total requests and tokens aggregated by the hour for a specific user."""
# Validate user_id
try:
user_id = int(user_id)
@@ -201,7 +265,10 @@ async def get_hourly_usage_stats_for_user(db: AsyncSession, user_id: int):
stmt = (
select(
hour_extract.label("hour"),
func.count(UsageLog.id).label("request_count")
func.count(UsageLog.id).label("request_count"),
func.coalesce(func.sum(UsageLog.prompt_tokens), 0).label("total_prompt_tokens"),
func.coalesce(func.sum(UsageLog.completion_tokens), 0).label("total_completion_tokens"),
func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total_tokens"),
)
.join(APIKey, UsageLog.api_key_id == APIKey.id)
.filter(APIKey.user_id == user_id)
@@ -209,11 +276,23 @@ async def get_hourly_usage_stats_for_user(db: AsyncSession, user_id: int):
.order_by("hour")
)
result = await db.execute(stmt)
stats_dict = {row.hour: row.request_count for row in result.all()}
return [{"hour": f"{h:02d}:00", "request_count": stats_dict.get(f"{h:02d}", 0)} for h in range(24)]
stats_dict = {row.hour: {
"request_count": row.request_count,
"total_prompt_tokens": row.total_prompt_tokens,
"total_completion_tokens": row.total_completion_tokens,
"total_tokens": row.total_tokens,
} for row in result.all()}
return [{"hour": f"{h:02d}:00", **stats_dict.get(f"{h:02d}", {
"request_count": 0,
"total_prompt_tokens": 0,
"total_completion_tokens": 0,
"total_tokens": 0,
})} for h in range(24)]
async def get_server_load_stats_for_user(db: AsyncSession, user_id: int):
"""Returns total requests per backend server for a specific user."""
"""Returns total requests and tokens per backend server for a specific user."""
# Validate user_id
try:
user_id = int(user_id)
@@ -225,7 +304,10 @@ async def get_server_load_stats_for_user(db: AsyncSession, user_id: int):
stmt = (
select(
OllamaServer.name.label("server_name"),
func.count(UsageLog.id).label("request_count")
func.count(UsageLog.id).label("request_count"),
func.coalesce(func.sum(UsageLog.prompt_tokens), 0).label("total_prompt_tokens"),
func.coalesce(func.sum(UsageLog.completion_tokens), 0).label("total_completion_tokens"),
func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total_tokens"),
)
.select_from(UsageLog)
.join(APIKey, UsageLog.api_key_id == APIKey.id)
@@ -237,8 +319,9 @@ async def get_server_load_stats_for_user(db: AsyncSession, user_id: int):
result = await db.execute(stmt)
return result.all()
async def get_model_usage_stats_for_user(db: AsyncSession, user_id: int):
"""Returns total requests per model for a specific user."""
"""Returns total requests and tokens per model for a specific user."""
# Validate user_id
try:
user_id = int(user_id)
@@ -250,7 +333,10 @@ async def get_model_usage_stats_for_user(db: AsyncSession, user_id: int):
stmt = (
select(
UsageLog.model.label("model_name"),
func.count(UsageLog.id).label("request_count")
func.count(UsageLog.id).label("request_count"),
func.coalesce(func.sum(UsageLog.prompt_tokens), 0).label("total_prompt_tokens"),
func.coalesce(func.sum(UsageLog.completion_tokens), 0).label("total_completion_tokens"),
func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total_tokens"),
)
.join(APIKey, UsageLog.api_key_id == APIKey.id)
.filter(APIKey.user_id == user_id)
@@ -260,3 +346,40 @@ async def get_model_usage_stats_for_user(db: AsyncSession, user_id: int):
)
result = await db.execute(stmt)
return result.all()
async def update_usage_log_with_tokens(
db: AsyncSession,
log_id: int,
prompt_tokens: Optional[int] = None,
completion_tokens: Optional[int] = None,
total_tokens: Optional[int] = None
) -> Optional[UsageLog]:
"""Updates an existing usage log entry with token counts."""
try:
result = await db.execute(
select(UsageLog).filter(UsageLog.id == log_id)
)
log_entry = result.scalars().first()
if not log_entry:
logger.warning(f"Usage log entry {log_id} not found for token update")
return None
# Validate and update token counts
if prompt_tokens is not None:
log_entry.prompt_tokens = max(0, int(prompt_tokens))
if completion_tokens is not None:
log_entry.completion_tokens = max(0, int(completion_tokens))
if total_tokens is not None:
log_entry.total_tokens = max(0, int(total_tokens))
elif prompt_tokens is not None and completion_tokens is not None:
log_entry.total_tokens = log_entry.prompt_tokens + log_entry.completion_tokens
await db.commit()
await db.refresh(log_entry)
return log_entry
except Exception as e:
logger.error(f"Failed to update usage log {log_id} with tokens: {e}")
return None

View File

@@ -20,6 +20,7 @@ async def get_user_by_id(db: AsyncSession, user_id: int) -> User | None:
async def get_users(db: AsyncSession, skip: int = 0, limit: int = 100, sort_by: str = "username", sort_order: str = "asc") -> list:
"""
Retrieves a list of users along with their statistics, with sorting.
Includes token usage totals.
"""
# Subquery to find the last usage time for each user
last_used_subq = (
@@ -40,6 +41,7 @@ async def get_users(db: AsyncSession, skip: int = 0, limit: int = 100, sort_by:
User.is_admin,
func.count(func.distinct(APIKey.id)).label("key_count"),
func.count(UsageLog.id).label("request_count"),
func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total_tokens"),
last_used_subq.c.last_used
)
.outerjoin(APIKey, User.id == APIKey.user_id)
@@ -58,6 +60,7 @@ async def get_users(db: AsyncSession, skip: int = 0, limit: int = 100, sort_by:
"username": User.username,
"key_count": func.count(func.distinct(APIKey.id)),
"request_count": func.count(UsageLog.id),
"total_tokens": func.coalesce(func.sum(UsageLog.total_tokens), 0),
"last_used": last_used_subq.c.last_used
}
sort_column = sort_column_map.get(sort_by, User.username)
@@ -106,4 +109,4 @@ async def delete_user(db: AsyncSession, user_id: int) -> User | None:
if user:
await db.delete(user)
await db.commit()
return user
return user

View File

@@ -297,6 +297,26 @@ async def migrate_usage_logs_table(engine: AsyncEngine) -> None:
"VARCHAR"
)
# Add token usage columns if missing
await add_column_if_missing(
engine,
"usage_logs",
"prompt_tokens",
"INTEGER"
)
await add_column_if_missing(
engine,
"usage_logs",
"completion_tokens",
"INTEGER"
)
await add_column_if_missing(
engine,
"usage_logs",
"total_tokens",
"INTEGER"
)
# Create index on model column if it doesn't exist
# Note: SQLite will silently ignore if index already exists
async with engine.begin() as conn:
@@ -347,9 +367,9 @@ async def migrate_app_settings_data(engine: AsyncEngine) -> None:
# Default values for new retry settings
default_retry_settings = {
"max_retries": 5,
"retry_total_timeout_seconds": 2.0,
"retry_base_delay_ms": 50
"max_retries": 2,
"retry_total_timeout_seconds": 1.0,
"retry_base_delay_ms": 10
}
# Add missing fields
@@ -527,6 +547,9 @@ async def run_all_migrations(engine: AsyncEngine) -> None:
},
"usage_logs": {
"model": "VARCHAR",
"prompt_tokens": "INTEGER",
"completion_tokens": "INTEGER",
"total_tokens": "INTEGER",
"server_id": "INTEGER",
},
"model_metadata": {

View File

@@ -59,6 +59,11 @@ class UsageLog(Base):
request_timestamp = Column(DateTime, default=datetime.datetime.utcnow)
server_id = Column(Integer, ForeignKey("ollama_servers.id"), nullable=True)
model = Column(String, nullable=True, index=True)
# Token usage tracking
prompt_tokens = Column(Integer, nullable=True)
completion_tokens = Column(Integer, nullable=True)
total_tokens = Column(Integer, nullable=True)
api_key = relationship("APIKey", back_populates="usage_logs")
server = relationship("OllamaServer")

View File

@@ -19,6 +19,32 @@
{% block content %}
<div class="space-y-8">
<!-- Token Usage Summary Card -->
<div class="card-style">
<h2 class="card-header text-2xl font-bold mb-4 pb-2">Token Usage Summary</h2>
<div class="grid grid-cols-1 md:grid-cols-3 gap-6">
<div class="p-4 rounded-lg text-center" style="background-color: rgba(128,128,128,0.1);">
<h3 class="text-sm font-medium uppercase text-gray-400 mb-2">Total Prompt Tokens</h3>
<div class="text-3xl font-bold text-[var(--color-primary-500)]">
{{ "{:,}".format(total_prompt_tokens|default(0)) }}
</div>
</div>
<div class="p-4 rounded-lg text-center" style="background-color: rgba(128,128,128,0.1);">
<h3 class="text-sm font-medium uppercase text-gray-400 mb-2">Total Completion Tokens</h3>
<div class="text-3xl font-bold text-[var(--color-primary-500)]">
{{ "{:,}".format(total_completion_tokens|default(0)) }}
</div>
</div>
<div class="p-4 rounded-lg text-center" style="background-color: rgba(128,128,128,0.1);">
<h3 class="text-sm font-medium uppercase text-gray-400 mb-2">Total Tokens</h3>
<div class="text-3xl font-bold text-[var(--color-primary-500)]">
{{ "{:,}".format(total_tokens|default(0)) }}
</div>
</div>
</div>
</div>
<!-- Top Row: Line and Bar Charts -->
<div class="grid grid-cols-1 lg:grid-cols-2 gap-8">
<div class="card-style">
@@ -67,6 +93,18 @@
</div>
</div>
<!-- Token Usage by Model Chart -->
<div class="card-style">
<h2 class="card-header flex justify-between items-center text-xl font-bold mb-4 pb-2">
<span>Token Usage by Model</span>
<div class="space-x-2">
<button onclick="exportChartToPNG('tokenUsageChart', 'token-usage-by-model.png')" class="text-sm text-[var(--color-primary-500)] hover:underline">PNG</button>
<button onclick="exportDataToCSV(['Model Name', 'Prompt Tokens', 'Completion Tokens', 'Total Tokens'], {{ model_labels | tojson }}, {{ model_prompt_tokens | tojson }}, {{ model_completion_tokens | tojson }}, {{ model_total_tokens | tojson }}, 'token-usage-by-model.csv')" class="text-sm text-[var(--color-primary-500)] hover:underline">CSV</button>
</div>
</h2>
{% if model_data and model_tokens_total %}<canvas id="tokenUsageChart"></canvas>{% else %}<div class="flex items-center justify-center h-64"><p>No token usage data available.</p></div>{% endif %}
</div>
<!-- Bottom Row: Key Usage Table -->
<div class="card-style">
<h2 class="card-header flex justify-between items-center text-xl font-bold mb-4 pb-2">
@@ -81,6 +119,9 @@
{{ sortable_header('key_name', 'Key Name') }}
{{ sortable_header('key_prefix', 'Key Prefix') }}
{{ sortable_header('request_count', 'Requests', 'text-right') }}
<th scope="col" class="px-6 py-3 text-right text-xs font-medium text-gray-400 uppercase tracking-wider">Prompt Tokens</th>
<th scope="col" class="px-6 py-3 text-right text-xs font-medium text-gray-400 uppercase tracking-wider">Completion Tokens</th>
<th scope="col" class="px-6 py-3 text-right text-xs font-medium text-gray-400 uppercase tracking-wider">Total Tokens</th>
</tr>
</thead>
<tbody class="divide-y divide-white/10">
@@ -90,9 +131,12 @@
<td class="px-6 py-4 whitespace-nowrap text-sm font-medium text-current">{{ stat.key_name }}</td>
<td class="px-6 py-4 whitespace-nowrap text-sm font-mono">{{ stat.key_prefix }}</td>
<td class="px-6 py-4 whitespace-nowrap text-right text-sm font-medium text-current">{{ stat.request_count }}</td>
<td class="px-6 py-4 whitespace-nowrap text-right text-sm">{{ "{:,}".format(stat.total_prompt_tokens|default(0)) }}</td>
<td class="px-6 py-4 whitespace-nowrap text-right text-sm">{{ "{:,}".format(stat.total_completion_tokens|default(0)) }}</td>
<td class="px-6 py-4 whitespace-nowrap text-right text-sm font-medium">{{ "{:,}".format(stat.total_tokens|default(0)) }}</td>
</tr>
{% else %}
<tr><td colspan="4" class="px-6 py-4 text-center">No usage data available.</td></tr>
<tr><td colspan="7" class="px-6 py-4 text-center">No usage data available.</td></tr>
{% endfor %}
</tbody>
</table>
@@ -231,6 +275,40 @@
options: { responsive: true }
});
}
// 5. Token Usage by Model Chart (Stacked Bar)
const tokenCtx = document.getElementById('tokenUsageChart')?.getContext('2d');
if (tokenCtx && {{ model_prompt_tokens | tojson }} && {{ model_completion_tokens | tojson }}) {
new Chart(tokenCtx, {
type: 'bar',
data: {
labels: {{ model_labels | tojson }},
datasets: [
{
label: 'Prompt Tokens',
data: {{ model_prompt_tokens | tojson }},
backgroundColor: primaryColorTransparent,
borderColor: primaryColor,
borderWidth: 1
},
{
label: 'Completion Tokens',
data: {{ model_completion_tokens | tojson }},
backgroundColor: '#f97316',
borderColor: '#ea580c',
borderWidth: 1
}
]
},
options: {
responsive: true,
scales: {
y: { beginAtZero: true, stacked: true },
x: { stacked: true }
}
}
});
}
});
</script>
{% endblock %}