mirror of
https://github.com/ParisNeo/lollms_hub.git
synced 2026-05-04 03:01:01 -04:00
feat(statistics): add token usage tracking to admin dashboard
- Add token_usage columns to usage_logs table via migration - Implement token aggregation queries in log_crud.py - Update admin stats endpoint and statistics template - Clean up proxy.py imports and formatting
This commit is contained in:
@@ -335,6 +335,8 @@ async def get_system_and_ollama_info(
|
||||
"queue_status": rate_limits
|
||||
}
|
||||
|
||||
|
||||
|
||||
@router.get("/stats", response_class=HTMLResponse, name="admin_stats")
|
||||
async def admin_stats(
|
||||
request: Request,
|
||||
@@ -344,7 +346,7 @@ async def admin_stats(
|
||||
sort_order: str = Query("desc"),
|
||||
):
|
||||
# Whitelist allowed sort values
|
||||
allowed_sort = ["username", "key_name", "key_prefix", "request_count"]
|
||||
allowed_sort = ["username", "key_name", "key_prefix", "request_count", "total_tokens", "total_prompt_tokens", "total_completion_tokens"]
|
||||
if sort_by not in allowed_sort:
|
||||
sort_by = "request_count"
|
||||
if sort_order not in ["asc", "desc"]:
|
||||
@@ -356,9 +358,25 @@ async def admin_stats(
|
||||
hourly_stats = await log_crud.get_hourly_usage_stats(db)
|
||||
server_stats = await log_crud.get_server_load_stats(db)
|
||||
model_stats = await log_crud.get_model_usage_stats(db)
|
||||
|
||||
# Calculate token totals for summary
|
||||
total_prompt_tokens = sum(row.total_prompt_tokens for row in key_usage_stats if hasattr(row, 'total_prompt_tokens'))
|
||||
total_completion_tokens = sum(row.total_completion_tokens for row in key_usage_stats if hasattr(row, 'total_completion_tokens'))
|
||||
total_tokens = sum(row.total_tokens for row in key_usage_stats if hasattr(row, 'total_tokens'))
|
||||
|
||||
# Prepare token data by model
|
||||
model_prompt_tokens = [row.total_prompt_tokens for row in model_stats]
|
||||
model_completion_tokens = [row.total_completion_tokens for row in model_stats]
|
||||
model_total_tokens = [row.total_tokens for row in model_stats]
|
||||
|
||||
context.update({
|
||||
"key_usage_stats": key_usage_stats,
|
||||
"daily_labels": [row.date.strftime('%Y-%m-%d') for row in daily_stats],
|
||||
"daily_labels": [
|
||||
row.date if isinstance(row.date, str)
|
||||
else row.date.strftime('%Y-%m-%d') if hasattr(row.date, 'strftime')
|
||||
else str(row.date)
|
||||
for row in daily_stats
|
||||
],
|
||||
"daily_data": [row.request_count for row in daily_stats],
|
||||
"hourly_labels": [row['hour'] for row in hourly_stats],
|
||||
"hourly_data": [row['request_count'] for row in hourly_stats],
|
||||
@@ -366,11 +384,18 @@ async def admin_stats(
|
||||
"server_data": [row.request_count for row in server_stats],
|
||||
"model_labels": [row.model_name for row in model_stats],
|
||||
"model_data": [row.request_count for row in model_stats],
|
||||
"model_prompt_tokens": model_prompt_tokens,
|
||||
"model_completion_tokens": model_completion_tokens,
|
||||
"model_total_tokens": model_total_tokens,
|
||||
"total_prompt_tokens": total_prompt_tokens,
|
||||
"total_completion_tokens": total_completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"sort_by": sort_by,
|
||||
"sort_order": sort_order,
|
||||
})
|
||||
return templates.TemplateResponse("admin/statistics.html", context)
|
||||
|
||||
|
||||
@router.get("/help", response_class=HTMLResponse, name="admin_help")
|
||||
async def admin_help_page(request: Request, admin_user: User = Depends(require_admin_user)):
|
||||
return templates.TemplateResponse("admin/help.html", get_template_context(request))
|
||||
@@ -613,7 +638,7 @@ async def admin_pull_model(
|
||||
return RedirectResponse(url=request.url_for("admin_manage_server_models", server_id=server_id), status_code=status.HTTP_303_SEE_OTHER)
|
||||
|
||||
# Sanitize model name - only allow alphanumeric, dashes, dots, colons
|
||||
if not re.match(r'^[\w\.\-:]+$', model_name):
|
||||
if not re.match(r'^[\w\.\-:@]+$', model_name):
|
||||
flash(request, "Model name contains invalid characters", "error")
|
||||
return RedirectResponse(url=request.url_for("admin_manage_server_models", server_id=server_id), status_code=status.HTTP_303_SEE_OTHER)
|
||||
|
||||
@@ -658,7 +683,7 @@ async def admin_delete_model(
|
||||
return RedirectResponse(url=request.url_for("admin_manage_server_models", server_id=server_id), status_code=status.HTTP_303_SEE_OTHER)
|
||||
|
||||
# Sanitize model name
|
||||
if not re.match(r'^[\w\.\-:]+$', model_name):
|
||||
if not re.match(r'^[\w\.\-:@]+$', model_name):
|
||||
flash(request, "Model name contains invalid characters", "error")
|
||||
return RedirectResponse(url=request.url_for("admin_manage_server_models", server_id=server_id), status_code=status.HTTP_303_SEE_OTHER)
|
||||
|
||||
@@ -700,7 +725,7 @@ async def admin_load_model(
|
||||
return RedirectResponse(url=request.url_for("admin_dashboard"), status_code=status.HTTP_303_SEE_OTHER)
|
||||
|
||||
# Sanitize model name
|
||||
if not re.match(r'^[\w\.\-:]+$', model_name):
|
||||
if not re.match(r'^[\w\.\-:@]+$', model_name):
|
||||
flash(request, "Model name contains invalid characters", "error")
|
||||
return RedirectResponse(url=request.url_for("admin_dashboard"), status_code=status.HTTP_303_SEE_OTHER)
|
||||
|
||||
@@ -737,7 +762,7 @@ async def admin_unload_model(
|
||||
return RedirectResponse(url=request.url_for("admin_dashboard"), status_code=status.HTTP_303_SEE_OTHER)
|
||||
|
||||
# Sanitize model name
|
||||
if not re.match(r'^[\w\.\-:]+$', model_name):
|
||||
if not re.match(r'^[\w\.\-:@]+$', model_name):
|
||||
flash(request, "Model name contains invalid characters", "error")
|
||||
return RedirectResponse(url=request.url_for("admin_dashboard"), status_code=status.HTTP_303_SEE_OTHER)
|
||||
|
||||
@@ -762,7 +787,7 @@ async def admin_unload_model_dashboard(
|
||||
server_name: str = Form(...)
|
||||
):
|
||||
# Validate and sanitize inputs
|
||||
if not model_name or len(model_name) > 256 or not re.match(r'^[\w\.\-:]+$', model_name):
|
||||
if not model_name or len(model_name) > 256 or not re.match(r'^[\w\.\-:@]+$', model_name):
|
||||
flash(request, "Invalid model name", "error")
|
||||
return RedirectResponse(url=request.url_for("admin_dashboard"), status_code=status.HTTP_303_SEE_OTHER)
|
||||
|
||||
@@ -797,7 +822,7 @@ async def admin_models_manager_page(
|
||||
all_model_names = await server_crud.get_all_available_model_names(db)
|
||||
for model_name in all_model_names:
|
||||
# Validate model name before processing
|
||||
if model_name and len(model_name) <= 256 and re.match(r'^[\w\.\-:]+$', model_name):
|
||||
if model_name and len(model_name) <= 256 and re.match(r'^[\w\.\-:@]+$', model_name):
|
||||
await model_metadata_crud.get_or_create_metadata(db, model_name=model_name)
|
||||
|
||||
context["metadata_list"] = await model_metadata_crud.get_all_metadata(db)
|
||||
@@ -1311,7 +1336,12 @@ async def admin_user_stats(
|
||||
|
||||
context.update({
|
||||
"user": user,
|
||||
"daily_labels": [row.date.strftime('%Y-%m-%d') for row in daily_stats],
|
||||
"daily_labels": [
|
||||
row.date if isinstance(row.date, str)
|
||||
else row.date.strftime('%Y-%m-%d') if hasattr(row.date, 'strftime')
|
||||
else str(row.date)
|
||||
for row in daily_stats
|
||||
],
|
||||
"daily_data": [row.request_count for row in daily_stats],
|
||||
"hourly_labels": [row['hour'] for row in hourly_stats],
|
||||
"hourly_data": [row['request_count'] for row in hourly_stats],
|
||||
@@ -1449,3 +1479,4 @@ async def delete_user_account(request: Request, user_id: int, db: AsyncSession =
|
||||
await user_crud.delete_user(db, user_id=user_id)
|
||||
flash(request, f"User '{user.username}' has been deleted.", "success")
|
||||
return RedirectResponse(url=request.url_for("admin_users"), status_code=status.HTTP_303_SEE_OTHER)
|
||||
|
||||
|
||||
@@ -26,9 +26,8 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(dependencies=[Depends(ip_filter), Depends(rate_limiter)])
|
||||
|
||||
# --- Connection Pool Cache ---
|
||||
# Cache for server health status to avoid repeated health checks
|
||||
_server_health_cache: Dict[int, Dict[str, Any]] = {}
|
||||
_health_cache_ttl_seconds = 5 # Cache health checks for 5 seconds
|
||||
_health_cache_ttl_seconds = 5
|
||||
|
||||
|
||||
def _is_server_healthy_cached(server_id: int) -> bool:
|
||||
@@ -38,7 +37,7 @@ def _is_server_healthy_cached(server_id: int) -> bool:
|
||||
if cache_entry:
|
||||
if time.time() - cache_entry["timestamp"] < _health_cache_ttl_seconds:
|
||||
return cache_entry["healthy"]
|
||||
return True # Default to allowing if no cache (will be checked on request)
|
||||
return True
|
||||
|
||||
|
||||
def _update_health_cache(server_id: int, healthy: bool):
|
||||
@@ -50,7 +49,6 @@ def _update_health_cache(server_id: int, healthy: bool):
|
||||
}
|
||||
|
||||
|
||||
# --- Dependency to get active servers ---
|
||||
async def get_active_servers(db: AsyncSession = Depends(get_db)) -> List[OllamaServer]:
|
||||
servers = await server_crud.get_servers(db)
|
||||
active_servers = [s for s in servers if s.is_active]
|
||||
@@ -66,27 +64,16 @@ async def get_active_servers(db: AsyncSession = Depends(get_db)) -> List[OllamaS
|
||||
async def extract_model_from_request(request: Request) -> Optional[str]:
|
||||
"""
|
||||
Attempts to extract the model name from the request body.
|
||||
Common endpoints that contain model info: /api/generate, /api/chat, /api/embeddings, /api/pull, etc.
|
||||
|
||||
Returns the model name if found, otherwise None.
|
||||
"""
|
||||
try:
|
||||
# Read the request body
|
||||
body_bytes = await request.body()
|
||||
|
||||
if not body_bytes:
|
||||
return None
|
||||
|
||||
# Parse JSON body
|
||||
body = json.loads(body_bytes)
|
||||
|
||||
# Most Ollama API endpoints use "model" field
|
||||
if isinstance(body, dict) and "model" in body:
|
||||
return body["model"]
|
||||
|
||||
except (json.JSONDecodeError, UnicodeDecodeError, Exception) as e:
|
||||
logger.debug(f"Could not extract model from request body: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -101,30 +88,24 @@ async def _send_backend_request(
|
||||
):
|
||||
"""
|
||||
Internal function to send a single request to a backend server.
|
||||
This function is wrapped by retry logic.
|
||||
"""
|
||||
normalized_url = server.url.rstrip('/')
|
||||
backend_url = f"{normalized_url}/api/{path}"
|
||||
|
||||
request_headers = {}
|
||||
|
||||
# Copy headers from original request, but fix critical ones
|
||||
for k, v in headers.items():
|
||||
k_lower = k.lower()
|
||||
# Skip hop-by-hop headers that should not be forwarded
|
||||
if k_lower in ('host', 'connection', 'keep-alive', 'proxy-authenticate',
|
||||
'proxy-authorization', 'te', 'trailers', 'transfer-encoding', 'upgrade'):
|
||||
continue
|
||||
# Skip content-length - we'll set it correctly based on actual body
|
||||
if k_lower == 'content-length':
|
||||
continue
|
||||
request_headers[k] = v
|
||||
|
||||
# Set correct content-length based on actual body bytes
|
||||
if body_bytes:
|
||||
request_headers['content-length'] = str(len(body_bytes))
|
||||
|
||||
# Add API key authentication if configured
|
||||
if server.encrypted_api_key:
|
||||
api_key = decrypt_data(server.encrypted_api_key)
|
||||
if api_key:
|
||||
@@ -141,9 +122,8 @@ async def _send_backend_request(
|
||||
try:
|
||||
backend_response = await http_client.send(backend_request, stream=True)
|
||||
|
||||
# Consider 5xx errors as failures that should be retried
|
||||
if backend_response.status_code >= 500:
|
||||
await backend_response.aclose() # Clean up the response
|
||||
await backend_response.aclose()
|
||||
raise Exception(
|
||||
f"Backend server returned {backend_response.status_code}: "
|
||||
f"{backend_response.reason_phrase}"
|
||||
@@ -152,38 +132,90 @@ async def _send_backend_request(
|
||||
return backend_response
|
||||
|
||||
except Exception as e:
|
||||
# Log and re-raise for retry logic
|
||||
logger.debug(f"Request to {server.url} failed: {type(e).__name__}: {str(e)[:200]}")
|
||||
raise
|
||||
|
||||
|
||||
async def _reverse_proxy(request: Request, path: str, servers: List[OllamaServer], body_bytes: bytes = b"") -> Tuple[Response, OllamaServer]:
|
||||
def _extract_tokens_from_chunk(chunk_data: Dict[str, Any]) -> Dict[str, Optional[int]]:
|
||||
"""Extract token counts from an Ollama response chunk."""
|
||||
tokens = {
|
||||
"prompt_tokens": None,
|
||||
"completion_tokens": None,
|
||||
"total_tokens": None,
|
||||
}
|
||||
|
||||
# Ollama format - check various field names
|
||||
if "prompt_eval_count" in chunk_data:
|
||||
tokens["prompt_tokens"] = chunk_data.get("prompt_eval_count")
|
||||
if "prompt_count" in chunk_data:
|
||||
tokens["prompt_tokens"] = chunk_data.get("prompt_count")
|
||||
if "eval_count" in chunk_data:
|
||||
tokens["completion_tokens"] = chunk_data.get("eval_count")
|
||||
|
||||
# Calculate total if we have both
|
||||
if tokens["prompt_tokens"] is not None and tokens["completion_tokens"] is not None:
|
||||
tokens["total_tokens"] = tokens["prompt_tokens"] + tokens["completion_tokens"]
|
||||
|
||||
# vLLM/OpenAI format (translated)
|
||||
if "usage" in chunk_data and chunk_data["usage"]:
|
||||
usage = chunk_data["usage"]
|
||||
if isinstance(usage, dict):
|
||||
tokens["prompt_tokens"] = usage.get("prompt_tokens")
|
||||
tokens["completion_tokens"] = usage.get("completion_tokens")
|
||||
tokens["total_tokens"] = usage.get("total_tokens")
|
||||
|
||||
# Final chunk with done=True often has the complete stats
|
||||
if chunk_data.get("done"):
|
||||
if "prompt_eval_count" in chunk_data:
|
||||
tokens["prompt_tokens"] = chunk_data.get("prompt_eval_count")
|
||||
if "prompt_count" in chunk_data:
|
||||
tokens["prompt_tokens"] = chunk_data.get("prompt_count")
|
||||
if "eval_count" in chunk_data:
|
||||
tokens["completion_tokens"] = chunk_data.get("eval_count")
|
||||
|
||||
if tokens["prompt_tokens"] is not None and tokens["completion_tokens"] is not None:
|
||||
tokens["total_tokens"] = tokens["prompt_tokens"] + tokens["completion_tokens"]
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
async def _update_log_with_tokens_async(
|
||||
log_id: int,
|
||||
prompt_tokens: Optional[int],
|
||||
completion_tokens: Optional[int],
|
||||
total_tokens: Optional[int]
|
||||
):
|
||||
"""Fire-and-forget token update."""
|
||||
try:
|
||||
from app.database.session import AsyncSessionLocal
|
||||
async with AsyncSessionLocal() as async_db:
|
||||
await log_crud.update_usage_log_with_tokens(
|
||||
async_db, log_id, prompt_tokens, completion_tokens, total_tokens
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to update tokens for log {log_id}: {e}")
|
||||
|
||||
|
||||
async def _reverse_proxy(request: Request, path: str, servers: List[OllamaServer], body_bytes: bytes = "",
|
||||
api_key_id: Optional[int] = None, log_id: Optional[int] = None) -> Tuple[Response, OllamaServer]:
|
||||
"""
|
||||
Core reverse proxy logic with retry support. Forwards the request to a backend
|
||||
Ollama server and streams the response back. Returns the response and the chosen server.
|
||||
Core reverse proxy logic with retry support and token tracking.
|
||||
"""
|
||||
http_client: AsyncClient = request.app.state.http_client
|
||||
app_settings: AppSettingsModel = request.app.state.settings
|
||||
|
||||
# Use retry configuration directly from database settings
|
||||
# No hardcoded overrides - admins can configure these in the settings UI
|
||||
retry_config = RetryConfig(
|
||||
max_retries=app_settings.max_retries,
|
||||
total_timeout_seconds=app_settings.retry_total_timeout_seconds,
|
||||
base_delay_ms=app_settings.retry_base_delay_ms
|
||||
)
|
||||
|
||||
# Prepare request headers (exclude 'host' and other hop-by-hop headers)
|
||||
headers = {k: v for k, v in request.headers.items() if k.lower() not in
|
||||
('host', 'connection', 'keep-alive', 'proxy-authenticate',
|
||||
'proxy-authorization', 'te', 'trailers', 'transfer-encoding', 'upgrade', 'content-length')}
|
||||
|
||||
# DEFENSIVE: Log what we received
|
||||
logger.info(f"_reverse_proxy called with {len(servers)} total server(s), filtering to active...")
|
||||
|
||||
# Use a local copy of servers to avoid race conditions with the global list
|
||||
# Filter to only active servers at request time
|
||||
# ALSO filter by health cache to skip known-unhealthy servers
|
||||
candidate_servers = [
|
||||
s for s in servers
|
||||
if s.is_active and _is_server_healthy_cached(s.id)
|
||||
@@ -192,7 +224,6 @@ async def _reverse_proxy(request: Request, path: str, servers: List[OllamaServer
|
||||
logger.info(f"After filtering: {len(candidate_servers)} active server(s): {[s.name for s in candidate_servers]}")
|
||||
|
||||
if not candidate_servers:
|
||||
# Fallback: try all active servers if health cache filtered everything out
|
||||
candidate_servers = [s for s in servers if s.is_active]
|
||||
if not candidate_servers:
|
||||
logger.error("All candidate servers became inactive during request processing")
|
||||
@@ -201,13 +232,10 @@ async def _reverse_proxy(request: Request, path: str, servers: List[OllamaServer
|
||||
detail="No active backend servers available."
|
||||
)
|
||||
|
||||
# OPTIMIZATION: Use a faster selection strategy
|
||||
# Prefer servers with the lowest recent error rate, then round-robin
|
||||
if not hasattr(request.app.state, 'backend_server_index'):
|
||||
request.app.state.backend_server_index = 0
|
||||
logger.info("Initialized backend_server_index to 0")
|
||||
|
||||
# Get current index and increment for next request
|
||||
current_index = request.app.state.backend_server_index % max(1, len(candidate_servers))
|
||||
request.app.state.backend_server_index = (current_index + 1) % max(1, len(candidate_servers))
|
||||
|
||||
@@ -216,7 +244,6 @@ async def _reverse_proxy(request: Request, path: str, servers: List[OllamaServer
|
||||
servers_tried = []
|
||||
|
||||
for server_attempt in range(len(candidate_servers)):
|
||||
# Calculate safe index based on current candidate_servers list size
|
||||
safe_index = (current_index + server_attempt) % len(candidate_servers)
|
||||
chosen_server = candidate_servers[safe_index]
|
||||
|
||||
@@ -224,37 +251,30 @@ async def _reverse_proxy(request: Request, path: str, servers: List[OllamaServer
|
||||
|
||||
servers_tried.append(chosen_server.name)
|
||||
|
||||
# --- BRANCH: Handle vLLM servers differently ---
|
||||
if chosen_server.server_type == 'vllm':
|
||||
logger.info(f"Using vLLM branch for server '{chosen_server.name}'")
|
||||
try:
|
||||
# vLLM translation doesn't use the retry logic wrapper in the same way
|
||||
response = await _proxy_to_vllm(request, chosen_server, path, body_bytes)
|
||||
_update_health_cache(chosen_server.id, True) # Mark as healthy
|
||||
response = await _proxy_to_vllm(request, chosen_server, path, body_bytes, api_key_id, log_id)
|
||||
_update_health_cache(chosen_server.id, True)
|
||||
return response, chosen_server
|
||||
except HTTPException:
|
||||
_update_health_cache(chosen_server.id, False) # Mark as unhealthy
|
||||
raise # Re-raise HTTP exceptions from the vLLM proxy
|
||||
_update_health_cache(chosen_server.id, False)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"vLLM server '{chosen_server.name}' failed: {e}. Trying next server.")
|
||||
_update_health_cache(chosen_server.id, False) # Mark as unhealthy
|
||||
# Remove failed server from candidates and continue
|
||||
_update_health_cache(chosen_server.id, False)
|
||||
candidate_servers = [s for s in candidate_servers if s.id != chosen_server.id]
|
||||
if not candidate_servers:
|
||||
logger.error("No more candidate servers after vLLM failure")
|
||||
break
|
||||
# Recalculate current_index to stay in bounds with new list size
|
||||
current_index = safe_index % max(1, len(candidate_servers))
|
||||
continue
|
||||
|
||||
# --- Ollama server logic (with retries) ---
|
||||
logger.info(f"Using Ollama branch with retry logic for server '{chosen_server.name}'")
|
||||
|
||||
# Try direct request first for speed - if it succeeds quickly, no need for retry overhead
|
||||
first_attempt_start = asyncio.get_event_loop().time()
|
||||
|
||||
try:
|
||||
# Try direct request first without retry wrapper for speed
|
||||
backend_response = await _send_backend_request(
|
||||
http_client=http_client,
|
||||
server=chosen_server,
|
||||
@@ -265,29 +285,47 @@ async def _reverse_proxy(request: Request, path: str, servers: List[OllamaServer
|
||||
body_bytes=body_bytes
|
||||
)
|
||||
|
||||
# If we got here quickly, use the response directly
|
||||
first_attempt_duration = asyncio.get_event_loop().time() - first_attempt_start
|
||||
|
||||
if first_attempt_duration < 0.5: # If fast success, skip retry overhead
|
||||
_update_health_cache(chosen_server.id, True)
|
||||
response = StreamingResponse(
|
||||
backend_response.aiter_raw(),
|
||||
status_code=backend_response.status_code,
|
||||
headers=backend_response.headers,
|
||||
)
|
||||
return response, chosen_server
|
||||
|
||||
# If slow but successful, still use it
|
||||
_update_health_cache(chosen_server.id, True)
|
||||
response = StreamingResponse(
|
||||
backend_response.aiter_raw(),
|
||||
status_code=backend_response.status_code,
|
||||
headers=backend_response.headers,
|
||||
)
|
||||
return response, chosen_server
|
||||
|
||||
# Check if this is a streaming response
|
||||
is_streaming = _is_streaming_response(backend_response)
|
||||
|
||||
if is_streaming and log_id:
|
||||
# Wrap for token tracking
|
||||
wrapped_response = _wrap_response_for_token_tracking(
|
||||
backend_response, chosen_server, api_key_id, log_id, path
|
||||
)
|
||||
return wrapped_response, chosen_server
|
||||
else:
|
||||
# Non-streaming, return as-is (tokens will be extracted if possible)
|
||||
if log_id and backend_response.status_code == 200:
|
||||
# Try to extract tokens from non-streaming response
|
||||
try:
|
||||
body = await backend_response.aread()
|
||||
if body:
|
||||
data = json.loads(body.decode('utf-8'))
|
||||
tokens = _extract_tokens_from_chunk(data)
|
||||
if tokens.get("total_tokens") is not None or tokens.get("prompt_tokens") is not None:
|
||||
asyncio.create_task(_update_log_with_tokens_async(
|
||||
log_id,
|
||||
tokens["prompt_tokens"],
|
||||
tokens["completion_tokens"],
|
||||
tokens["total_tokens"]
|
||||
))
|
||||
# Need to create a new response since we consumed the body
|
||||
return Response(
|
||||
content=body,
|
||||
status_code=backend_response.status_code,
|
||||
headers=dict(backend_response.headers)
|
||||
), chosen_server
|
||||
except Exception:
|
||||
pass
|
||||
# Return original response if we couldn't extract tokens
|
||||
return backend_response, chosen_server
|
||||
|
||||
except Exception as first_error:
|
||||
# First attempt failed, use retry logic with configured settings
|
||||
_update_health_cache(chosen_server.id, False)
|
||||
logger.debug(f"Direct attempt failed for '{chosen_server.name}', using retry logic: {first_error}")
|
||||
|
||||
@@ -315,12 +353,36 @@ async def _reverse_proxy(request: Request, path: str, servers: List[OllamaServer
|
||||
f"in {retry_result.total_duration_ms:.1f}ms"
|
||||
)
|
||||
|
||||
response = StreamingResponse(
|
||||
backend_response.aiter_raw(),
|
||||
status_code=backend_response.status_code,
|
||||
headers=backend_response.headers,
|
||||
)
|
||||
return response, chosen_server
|
||||
# Check if streaming
|
||||
is_streaming = _is_streaming_response(backend_response)
|
||||
|
||||
if is_streaming and log_id:
|
||||
wrapped_response = _wrap_response_for_token_tracking(
|
||||
backend_response, chosen_server, api_key_id, log_id, path
|
||||
)
|
||||
return wrapped_response, chosen_server
|
||||
else:
|
||||
if log_id and backend_response.status_code == 200:
|
||||
try:
|
||||
body = await backend_response.aread()
|
||||
if body:
|
||||
data = json.loads(body.decode('utf-8'))
|
||||
tokens = _extract_tokens_from_chunk(data)
|
||||
if tokens.get("total_tokens") is not None or tokens.get("prompt_tokens") is not None:
|
||||
asyncio.create_task(_update_log_with_tokens_async(
|
||||
log_id,
|
||||
tokens["prompt_tokens"],
|
||||
tokens["completion_tokens"],
|
||||
tokens["total_tokens"]
|
||||
))
|
||||
return Response(
|
||||
content=body,
|
||||
status_code=backend_response.status_code,
|
||||
headers=dict(backend_response.headers)
|
||||
), chosen_server
|
||||
except Exception:
|
||||
pass
|
||||
return backend_response, chosen_server
|
||||
else:
|
||||
_update_health_cache(chosen_server.id, False)
|
||||
logger.warning(
|
||||
@@ -328,15 +390,12 @@ async def _reverse_proxy(request: Request, path: str, servers: List[OllamaServer
|
||||
f"attempts. Trying next server if available."
|
||||
)
|
||||
|
||||
# Remove failed server from candidates
|
||||
candidate_servers = [s for s in candidate_servers if s.id != chosen_server.id]
|
||||
if not candidate_servers:
|
||||
logger.error("No more candidate servers after Ollama failure")
|
||||
break
|
||||
# Recalculate current_index to stay in bounds with new list size
|
||||
current_index = safe_index % max(1, len(candidate_servers))
|
||||
|
||||
# All servers exhausted
|
||||
logger.error(
|
||||
f"All {len(servers_tried)} backend server(s) failed after retries. "
|
||||
f"Servers tried: {', '.join(servers_tried)}"
|
||||
@@ -347,14 +406,145 @@ async def _reverse_proxy(request: Request, path: str, servers: List[OllamaServer
|
||||
)
|
||||
|
||||
|
||||
def _wrap_response_for_token_tracking(
|
||||
backend_response: Response,
|
||||
server: OllamaServer,
|
||||
api_key_id: Optional[int] = None,
|
||||
log_id: Optional[int] = None,
|
||||
path: str = ""
|
||||
) -> StreamingResponse:
|
||||
"""Wraps a streaming response to capture token usage from chunks."""
|
||||
|
||||
async def token_tracking_stream():
|
||||
buffer = ""
|
||||
accumulated_tokens = {
|
||||
"prompt_tokens": None,
|
||||
"completion_tokens": None,
|
||||
"total_tokens": None,
|
||||
}
|
||||
tokens_finalized = False
|
||||
|
||||
try:
|
||||
async for chunk in backend_response.aiter_raw():
|
||||
try:
|
||||
chunk_text = chunk.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
# CRITICAL: Yield the original chunk immediately to prevent hanging
|
||||
# But also process it for token tracking
|
||||
yield chunk
|
||||
|
||||
# Process for token tracking (after yielding to not block)
|
||||
buffer += chunk_text
|
||||
|
||||
# Process complete lines
|
||||
lines = buffer.split('\n')
|
||||
buffer = lines.pop() if buffer and not chunk_text.endswith('\n') else ""
|
||||
|
||||
for line in lines:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
# Try to parse as JSON (Ollama format)
|
||||
try:
|
||||
data_str = line
|
||||
if line.startswith('data: '):
|
||||
data_str = line[6:]
|
||||
if data_str == '[DONE]':
|
||||
continue
|
||||
|
||||
data = json.loads(data_str)
|
||||
|
||||
# Extract tokens from this chunk
|
||||
chunk_tokens = _extract_tokens_from_chunk(data)
|
||||
|
||||
# Update accumulated tokens (prefer non-None values)
|
||||
for key in accumulated_tokens:
|
||||
if chunk_tokens.get(key) is not None:
|
||||
accumulated_tokens[key] = chunk_tokens[key]
|
||||
|
||||
# If this is the final chunk, update the log
|
||||
if data.get("done") and log_id and not tokens_finalized:
|
||||
tokens_finalized = True
|
||||
# Fire-and-forget token update
|
||||
asyncio.create_task(_update_log_with_tokens_async(
|
||||
log_id,
|
||||
accumulated_tokens["prompt_tokens"],
|
||||
accumulated_tokens["completion_tokens"],
|
||||
accumulated_tokens["total_tokens"]
|
||||
))
|
||||
|
||||
except json.JSONDecodeError:
|
||||
pass # Not JSON, skip token extraction
|
||||
|
||||
# Process any remaining buffer
|
||||
if buffer.strip():
|
||||
try:
|
||||
data_str = buffer
|
||||
if buffer.startswith('data: '):
|
||||
data_str = buffer[6:]
|
||||
if data_str and data_str != '[DONE]':
|
||||
data = json.loads(data_str)
|
||||
if data.get("done") and log_id and not tokens_finalized:
|
||||
tokens_finalized = True
|
||||
chunk_tokens = _extract_tokens_from_chunk(data)
|
||||
for key in accumulated_tokens:
|
||||
if chunk_tokens.get(key) is not None:
|
||||
accumulated_tokens[key] = chunk_tokens[key]
|
||||
|
||||
asyncio.create_task(_update_log_with_tokens_async(
|
||||
log_id,
|
||||
accumulated_tokens["prompt_tokens"],
|
||||
accumulated_tokens["completion_tokens"],
|
||||
accumulated_tokens["total_tokens"]
|
||||
))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in token tracking stream: {e}")
|
||||
# Don't re-raise, just stop processing tokens
|
||||
|
||||
# Return StreamingResponse with proper headers
|
||||
response_headers = dict(backend_response.headers)
|
||||
# Remove content-length since we're streaming
|
||||
response_headers.pop('content-length', None)
|
||||
|
||||
return StreamingResponse(
|
||||
token_tracking_stream(),
|
||||
status_code=backend_response.status_code,
|
||||
headers=response_headers,
|
||||
media_type=backend_response.headers.get('content-type', 'application/x-ndjson')
|
||||
)
|
||||
|
||||
|
||||
def _is_streaming_response(response: Response) -> bool:
|
||||
"""Check if a response is streaming based on headers."""
|
||||
content_type = response.headers.get('content-type', '')
|
||||
transfer_encoding = response.headers.get('transfer-encoding', '')
|
||||
|
||||
if 'text/event-stream' in content_type:
|
||||
return True
|
||||
if 'chunked' in transfer_encoding.lower():
|
||||
return True
|
||||
if 'application/x-ndjson' in content_type:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def _proxy_to_vllm(
|
||||
request: Request,
|
||||
server: OllamaServer,
|
||||
path: str,
|
||||
body_bytes: bytes
|
||||
body_bytes: bytes,
|
||||
api_key_id: Optional[int] = None,
|
||||
log_id: Optional[int] = None
|
||||
) -> Response:
|
||||
"""
|
||||
Handles proxying a request to a vLLM server, including payload and response translation.
|
||||
Handles proxying a request to a vLLM server with token tracking.
|
||||
"""
|
||||
http_client: AsyncClient = request.app.state.http_client
|
||||
|
||||
@@ -371,7 +561,6 @@ async def _proxy_to_vllm(
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# Translate path and payload based on the endpoint
|
||||
if path == "chat":
|
||||
vllm_path = "v1/chat/completions"
|
||||
vllm_payload = translate_ollama_to_vllm_chat(ollama_payload)
|
||||
@@ -387,28 +576,98 @@ async def _proxy_to_vllm(
|
||||
try:
|
||||
if is_streaming:
|
||||
async def stream_generator():
|
||||
accumulated_tokens = {
|
||||
"prompt_tokens": None,
|
||||
"completion_tokens": None,
|
||||
"total_tokens": None,
|
||||
}
|
||||
tokens_finalized = False
|
||||
|
||||
async with http_client.stream("POST", backend_url, json=vllm_payload, timeout=600.0, headers=headers) as vllm_response:
|
||||
if vllm_response.status_code != 200:
|
||||
error_body = await vllm_response.aread()
|
||||
logger.error(f"vLLM server error ({vllm_response.status_code}): {error_body.decode()}")
|
||||
# Yield a single error chunk in Ollama format
|
||||
error_chunk = {"error": f"vLLM server error: {error_body.decode()}"}
|
||||
yield (json.dumps(error_chunk) + '\n').encode('utf-8')
|
||||
return
|
||||
|
||||
async for chunk in vllm_stream_to_ollama_stream(vllm_response.aiter_text(), model_name):
|
||||
buffer = ""
|
||||
async for chunk in vllm_response.aiter_raw():
|
||||
try:
|
||||
chunk_text = chunk.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
# CRITICAL: Yield immediately to prevent hanging
|
||||
yield chunk
|
||||
|
||||
# Process for token tracking
|
||||
buffer += chunk_text
|
||||
lines = buffer.split('\n')
|
||||
buffer = lines.pop() if buffer and not chunk_text.endswith('\n') else ""
|
||||
|
||||
for line in lines:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
# Check for SSE data prefix
|
||||
data_content = line
|
||||
if line.startswith('data: '):
|
||||
data_content = line[6:]
|
||||
|
||||
if data_content == '[DONE]':
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(data_content)
|
||||
|
||||
# Extract usage info if present
|
||||
if "usage" in data and data["usage"]:
|
||||
usage = data["usage"]
|
||||
accumulated_tokens["prompt_tokens"] = usage.get("prompt_tokens")
|
||||
accumulated_tokens["completion_tokens"] = usage.get("completion_tokens")
|
||||
accumulated_tokens["total_tokens"] = usage.get("total_tokens")
|
||||
|
||||
# Check for done signal
|
||||
choices = data.get("choices", [])
|
||||
if choices and choices[0].get("finish_reason"):
|
||||
if log_id and not tokens_finalized:
|
||||
tokens_finalized = True
|
||||
asyncio.create_task(_update_log_with_tokens_async(
|
||||
log_id,
|
||||
accumulated_tokens["prompt_tokens"],
|
||||
accumulated_tokens["completion_tokens"],
|
||||
accumulated_tokens["total_tokens"]
|
||||
))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="application/x-ndjson")
|
||||
return StreamingResponse(
|
||||
stream_generator(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={'content-type': 'application/x-ndjson'}
|
||||
)
|
||||
else: # Non-streaming
|
||||
response = await http_client.post(backend_url, json=vllm_payload, timeout=600.0, headers=headers)
|
||||
response.raise_for_status()
|
||||
vllm_data = response.json()
|
||||
|
||||
# Extract and log tokens for non-streaming response
|
||||
if log_id:
|
||||
usage = vllm_data.get("usage", {})
|
||||
prompt_tokens = usage.get("prompt_tokens")
|
||||
completion_tokens = usage.get("completion_tokens")
|
||||
total_tokens = usage.get("total_tokens")
|
||||
|
||||
asyncio.create_task(_update_log_with_tokens_async(
|
||||
log_id,
|
||||
prompt_tokens, completion_tokens, total_tokens
|
||||
))
|
||||
|
||||
if path == "embeddings":
|
||||
ollama_data = translate_vllm_to_ollama_embeddings(vllm_data)
|
||||
return JSONResponse(content=ollama_data)
|
||||
# Add non-streaming chat translation if needed
|
||||
raise NotImplementedError("Non-streaming chat for vLLM not yet implemented.")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
@@ -427,8 +686,7 @@ async def federate_models(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Aggregates models from all configured backends (Ollama and vLLM)
|
||||
using the cached model data from the database for efficiency.
|
||||
Aggregates models from all configured backends.
|
||||
"""
|
||||
logger.info("--- /tags endpoint: Starting model federation ---")
|
||||
all_servers = await server_crud.get_servers(db)
|
||||
@@ -460,10 +718,8 @@ async def federate_models(
|
||||
model_count_on_server = 0
|
||||
for model in models_list:
|
||||
if isinstance(model, dict) and "name" in model:
|
||||
# --- FIX: Ensure 'model' key exists for compatibility with Ollama clients ---
|
||||
if "model" not in model:
|
||||
model["model"] = model["name"]
|
||||
# --- END FIX ---
|
||||
all_models[model['name']] = model
|
||||
model_count_on_server += 1
|
||||
else:
|
||||
@@ -473,7 +729,6 @@ async def federate_models(
|
||||
|
||||
logger.info(f"/tags: Total unique models before adding 'auto': {len(all_models)}")
|
||||
|
||||
# Add the 'auto' model to the list for clients to see, with details for compatibility
|
||||
all_models["auto"] = {
|
||||
"name": "auto",
|
||||
"model": "auto",
|
||||
@@ -490,9 +745,8 @@ async def federate_models(
|
||||
}
|
||||
}
|
||||
|
||||
# OPTIMIZATION: Fire-and-forget logging to avoid blocking response
|
||||
try:
|
||||
asyncio.create_task(_async_log_usage(db, api_key.id, "/api/tags", 200, None))
|
||||
asyncio.create_task(_async_log_usage(db, api_key.id, "/api/tags", 200, None, None))
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to queue usage log: {e}")
|
||||
|
||||
@@ -502,45 +756,60 @@ async def federate_models(
|
||||
return {"models": final_model_list}
|
||||
|
||||
|
||||
async def _async_log_usage(db: AsyncSession, api_key_id: int, endpoint: str, status_code: int, server_id: Optional[int], model: Optional[str] = None):
|
||||
"""Fire-and-forget usage logging to avoid blocking responses."""
|
||||
async def _async_log_usage(
|
||||
db: AsyncSession,
|
||||
api_key_id: int,
|
||||
endpoint: str,
|
||||
status_code: int,
|
||||
server_id: Optional[int],
|
||||
model: Optional[str] = None,
|
||||
prompt_tokens: Optional[int] = None,
|
||||
completion_tokens: Optional[int] = None,
|
||||
total_tokens: Optional[int] = None
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
Fire-and-forget usage logging to avoid blocking responses.
|
||||
Returns the log ID if created.
|
||||
"""
|
||||
try:
|
||||
# Create a new session for async logging
|
||||
from app.database.session import AsyncSessionLocal
|
||||
async with AsyncSessionLocal() as async_db:
|
||||
await log_crud.create_usage_log(
|
||||
log_entry = await log_crud.create_usage_log(
|
||||
db=async_db,
|
||||
api_key_id=api_key_id,
|
||||
endpoint=endpoint,
|
||||
status_code=status_code,
|
||||
server_id=server_id,
|
||||
model=model
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens
|
||||
)
|
||||
return log_entry.id
|
||||
except Exception as e:
|
||||
logger.debug(f"Async usage logging failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _select_auto_model(db: AsyncSession, body: Dict[str, Any]) -> Optional[str]:
|
||||
"""Selects the best model based on metadata and request content."""
|
||||
|
||||
# 1. Determine request characteristics
|
||||
has_images = "images" in body and body["images"]
|
||||
|
||||
prompt_content = ""
|
||||
if "prompt" in body: # generate endpoint
|
||||
if "prompt" in body:
|
||||
prompt_content = body["prompt"]
|
||||
elif "messages" in body: # chat endpoint
|
||||
elif "messages" in body:
|
||||
last_message = body["messages"][-1] if body["messages"] else {}
|
||||
if isinstance(last_message.get("content"), str):
|
||||
prompt_content = last_message["content"]
|
||||
elif isinstance(last_message.get("content"), list): # multimodal chat
|
||||
elif isinstance(last_message.get("content"), list):
|
||||
text_part = next((p.get("text", "") for p in last_message["content"] if p.get("type") == "text"), "")
|
||||
prompt_content = text_part
|
||||
|
||||
code_keywords = ["def ", "class ", "import ", "const ", "let ", "var ", "function ", "public static void", "int main("]
|
||||
contains_code = any(kw.lower() in prompt_content.lower() for kw in code_keywords)
|
||||
|
||||
# 2. Get all model metadata and available models
|
||||
all_metadata = await model_metadata_crud.get_all_metadata(db)
|
||||
all_available_models = await server_crud.get_all_available_model_names(db)
|
||||
available_metadata = [m for m in all_metadata if m.model_name in all_available_models]
|
||||
@@ -549,7 +818,6 @@ async def _select_auto_model(db: AsyncSession, body: Dict[str, Any]) -> Optional
|
||||
logger.warning("Auto-routing failed: No models with metadata are available on active servers.")
|
||||
return None
|
||||
|
||||
# 3. Filter models based on characteristics
|
||||
candidate_models = available_metadata
|
||||
|
||||
if has_images:
|
||||
@@ -575,7 +843,6 @@ async def _select_auto_model(db: AsyncSession, body: Dict[str, Any]) -> Optional
|
||||
if not candidate_models:
|
||||
return None
|
||||
|
||||
# 4. The list is already sorted by priority from the CRUD function.
|
||||
best_model = candidate_models[0]
|
||||
logger.info(f"Auto-routing selected model '{best_model.model_name}' with priority {best_model.priority}.")
|
||||
|
||||
@@ -592,10 +859,8 @@ async def proxy_ollama(
|
||||
servers: List[OllamaServer] = Depends(get_active_servers),
|
||||
):
|
||||
"""
|
||||
A catch-all route that proxies all other requests to the backend.
|
||||
Uses smart routing and translates requests for vLLM servers.
|
||||
A catch-all route that proxies all other requests to the backend with token tracking.
|
||||
"""
|
||||
# --- Endpoint Security Check ---
|
||||
blocked_paths = {p.strip().lstrip('/') for p in settings.blocked_ollama_endpoints.split(',') if p.strip()}
|
||||
request_path = path.strip().lstrip('/')
|
||||
|
||||
@@ -608,7 +873,6 @@ async def proxy_ollama(
|
||||
detail=f"Access to the endpoint '/api/{request_path}' is disabled by the proxy administrator."
|
||||
)
|
||||
|
||||
# Try to extract model name from request body
|
||||
body_bytes = await request.body()
|
||||
model_name = None
|
||||
body = {}
|
||||
@@ -621,7 +885,7 @@ async def proxy_ollama(
|
||||
except (json.JSONDecodeError, Exception):
|
||||
pass
|
||||
|
||||
# Handle 'think' parameter based on model support
|
||||
# Handle 'think' parameter
|
||||
if model_name and isinstance(body, dict) and "think" in body:
|
||||
model_name_lower = model_name.lower()
|
||||
supported_think_models = ["qwen", "gpt-oss", "deepseek"]
|
||||
@@ -629,18 +893,16 @@ async def proxy_ollama(
|
||||
is_supported = any(keyword in model_name_lower for keyword in supported_think_models)
|
||||
|
||||
if is_supported:
|
||||
# Handle special case for gpt-oss which requires string values if boolean `true` is passed
|
||||
if "gpt-oss" in model_name_lower and body.get("think") is True:
|
||||
logger.info(f"Translating 'think: true' to 'think: \"medium\"' for GPT-OSS model '{model_name}'")
|
||||
body["think"] = "medium"
|
||||
body_bytes = json.dumps(body).encode('utf-8')
|
||||
else:
|
||||
# If the model is not supported, remove the 'think' parameter to avoid errors.
|
||||
logger.warning(f"Model '{model_name}' is not in the known list for 'think' support. Removing 'think' parameter from request to avoid errors.")
|
||||
logger.warning(f"Model '{model_name}' is not in the known list for 'think' support. Removing 'think' parameter.")
|
||||
del body["think"]
|
||||
body_bytes = json.dumps(body).encode('utf-8')
|
||||
|
||||
# --- NEW: Handle 'auto' model routing ---
|
||||
# Handle 'auto' model routing
|
||||
if model_name == "auto":
|
||||
chosen_model_name = await _select_auto_model(db, body)
|
||||
if not chosen_model_name:
|
||||
@@ -648,14 +910,10 @@ async def proxy_ollama(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Auto-routing could not find an available and suitable model."
|
||||
)
|
||||
|
||||
# Override the model in the request and continue
|
||||
model_name = chosen_model_name
|
||||
body["model"] = model_name
|
||||
body_bytes = json.dumps(body).encode('utf-8')
|
||||
|
||||
# Smart routing: filter servers by model availability
|
||||
# DEFENSIVE: Log the servers we received from dependency
|
||||
logger.info(f"proxy_ollama: Received {len(servers)} server(s) from get_active_servers dependency: {[s.name for s in servers]}")
|
||||
|
||||
candidate_servers = servers
|
||||
@@ -667,15 +925,11 @@ async def proxy_ollama(
|
||||
candidate_servers = servers_with_model
|
||||
logger.info(f"Smart routing: Found {len(servers_with_model)} server(s) with model '{model_name}': {[s.name for s in servers_with_model]}")
|
||||
else:
|
||||
# Model not found in any server's catalog, or catalogs not fetched yet
|
||||
# Fall back to all active servers
|
||||
logger.warning(
|
||||
f"Model '{model_name}' not found in any server's catalog. "
|
||||
f"Falling back to round-robin across all {len(servers)} active server(s). "
|
||||
f"Make sure to refresh model lists for accurate routing."
|
||||
f"Falling back to round-robin across all {len(servers)} active server(s)."
|
||||
)
|
||||
|
||||
# DEFENSIVE: Double-check we have servers before calling _reverse_proxy
|
||||
if not candidate_servers:
|
||||
logger.error(f"proxy_ollama: No candidate servers available for model '{model_name}'")
|
||||
raise HTTPException(
|
||||
@@ -683,15 +937,43 @@ async def proxy_ollama(
|
||||
detail=f"No servers available for model '{model_name}'. Please check server status and model availability."
|
||||
)
|
||||
|
||||
# Create initial usage log entry (without tokens - will be updated later for streaming)
|
||||
is_token_trackable_endpoint = path in ("generate", "chat", "embeddings")
|
||||
|
||||
log_id = None
|
||||
if is_token_trackable_endpoint:
|
||||
log_id = await _async_log_usage(
|
||||
db, api_key.id, f"/api/{path}", 200, None, model_name,
|
||||
None, None, None
|
||||
)
|
||||
|
||||
# Proxy to one of the candidate servers
|
||||
response, chosen_server = await _reverse_proxy(request, path, candidate_servers, body_bytes)
|
||||
response, chosen_server = await _reverse_proxy(
|
||||
request, path, candidate_servers, body_bytes,
|
||||
api_key_id=api_key.id, log_id=log_id
|
||||
)
|
||||
|
||||
# OPTIMIZATION: Fire-and-forget logging to avoid blocking
|
||||
try:
|
||||
asyncio.create_task(_async_log_usage(
|
||||
db, api_key.id, f"/api/{path}", response.status_code, chosen_server.id, model_name
|
||||
))
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to queue usage log: {e}")
|
||||
# Update log with server_id if we have a log entry
|
||||
if log_id and chosen_server:
|
||||
try:
|
||||
from app.database.session import AsyncSessionLocal
|
||||
async with AsyncSessionLocal() as async_db:
|
||||
from sqlalchemy import update
|
||||
from app.database.models import UsageLog
|
||||
await async_db.execute(
|
||||
update(UsageLog).where(UsageLog.id == log_id).values(server_id=chosen_server.id)
|
||||
)
|
||||
await async_db.commit()
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to update server_id for log {log_id}: {e}")
|
||||
|
||||
# For non-streaming, non-tracked endpoints, log without tokens
|
||||
if not is_token_trackable_endpoint:
|
||||
try:
|
||||
asyncio.create_task(_async_log_usage(
|
||||
db, api_key.id, f"/api/{path}", response.status_code, chosen_server.id, model_name
|
||||
))
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to queue usage log: {e}")
|
||||
|
||||
return response
|
||||
|
||||
@@ -3,9 +3,18 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import func, select, text, Date
|
||||
from app.database.models import UsageLog, APIKey, User, OllamaServer
|
||||
import datetime
|
||||
from typing import Optional, List
|
||||
|
||||
async def create_usage_log(
|
||||
db: AsyncSession, *, api_key_id: int, endpoint: str, status_code: int, server_id: int | None, model: str | None = None
|
||||
db: AsyncSession, *,
|
||||
api_key_id: int,
|
||||
endpoint: str,
|
||||
status_code: int,
|
||||
server_id: Optional[int] = None,
|
||||
model: Optional[str] = None,
|
||||
prompt_tokens: Optional[int] = None,
|
||||
completion_tokens: Optional[int] = None,
|
||||
total_tokens: Optional[int] = None
|
||||
) -> UsageLog:
|
||||
# Validate inputs to prevent injection
|
||||
if not isinstance(api_key_id, int) or api_key_id <= 0:
|
||||
@@ -16,22 +25,41 @@ async def create_usage_log(
|
||||
raise ValueError("Invalid status_code")
|
||||
if model is not None and (not isinstance(model, str) or len(model) > 256):
|
||||
raise ValueError("Invalid model name")
|
||||
|
||||
# Validate token counts
|
||||
if prompt_tokens is not None:
|
||||
prompt_tokens = max(0, int(prompt_tokens))
|
||||
if completion_tokens is not None:
|
||||
completion_tokens = max(0, int(completion_tokens))
|
||||
if total_tokens is not None:
|
||||
total_tokens = max(0, int(total_tokens))
|
||||
elif prompt_tokens is not None and completion_tokens is not None:
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
db_log = UsageLog(
|
||||
api_key_id=api_key_id,
|
||||
endpoint=endpoint[:512], # Limit length
|
||||
status_code=status_code,
|
||||
server_id=server_id,
|
||||
model=model[:256] if model else None
|
||||
model=model[:256] if model else None,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens
|
||||
)
|
||||
db.add(db_log)
|
||||
await db.commit()
|
||||
await db.refresh(db_log)
|
||||
return db_log
|
||||
|
||||
async def get_usage_statistics(db: AsyncSession, sort_by: str = "request_count", sort_order: str = "desc"):
|
||||
|
||||
async def get_usage_statistics(
|
||||
db: AsyncSession,
|
||||
sort_by: str = "request_count",
|
||||
sort_order: str = "desc"
|
||||
):
|
||||
"""
|
||||
Returns aggregated usage statistics for all API keys, with sorting.
|
||||
Includes token usage totals.
|
||||
"""
|
||||
# Whitelist allowed sort columns to prevent injection
|
||||
allowed_sort_columns = {
|
||||
@@ -39,6 +67,7 @@ async def get_usage_statistics(db: AsyncSession, sort_by: str = "request_count",
|
||||
"key_name": APIKey.key_name,
|
||||
"key_prefix": APIKey.key_prefix,
|
||||
"request_count": func.count(UsageLog.id),
|
||||
"total_tokens": func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
}
|
||||
|
||||
# Default to request_count if invalid column provided
|
||||
@@ -61,6 +90,9 @@ async def get_usage_statistics(db: AsyncSession, sort_by: str = "request_count",
|
||||
APIKey.key_prefix,
|
||||
APIKey.is_revoked,
|
||||
func.count(UsageLog.id).label("request_count"),
|
||||
func.coalesce(func.sum(UsageLog.prompt_tokens), 0).label("total_prompt_tokens"),
|
||||
func.coalesce(func.sum(UsageLog.completion_tokens), 0).label("total_completion_tokens"),
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total_tokens"),
|
||||
)
|
||||
.select_from(APIKey)
|
||||
.join(User, APIKey.user_id == User.id)
|
||||
@@ -71,10 +103,11 @@ async def get_usage_statistics(db: AsyncSession, sort_by: str = "request_count",
|
||||
result = await db.execute(stmt)
|
||||
return result.all()
|
||||
|
||||
|
||||
# --- NEW STATISTICS FUNCTIONS ---
|
||||
|
||||
async def get_daily_usage_stats(db: AsyncSession, days: int = 30):
|
||||
"""Returns total requests per day for the last N days."""
|
||||
"""Returns total requests and tokens per day for the last N days."""
|
||||
# Validate days parameter
|
||||
try:
|
||||
days = int(days)
|
||||
@@ -92,7 +125,10 @@ async def get_daily_usage_stats(db: AsyncSession, days: int = 30):
|
||||
stmt = (
|
||||
select(
|
||||
date_column,
|
||||
func.count(UsageLog.id).label("request_count")
|
||||
func.count(UsageLog.id).label("request_count"),
|
||||
func.coalesce(func.sum(UsageLog.prompt_tokens), 0).label("total_prompt_tokens"),
|
||||
func.coalesce(func.sum(UsageLog.completion_tokens), 0).label("total_completion_tokens"),
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total_tokens"),
|
||||
)
|
||||
.filter(UsageLog.request_timestamp >= start_date)
|
||||
.group_by(date_column)
|
||||
@@ -101,30 +137,49 @@ async def get_daily_usage_stats(db: AsyncSession, days: int = 30):
|
||||
result = await db.execute(stmt)
|
||||
return result.all()
|
||||
|
||||
|
||||
async def get_hourly_usage_stats(db: AsyncSession):
|
||||
"""Returns total requests aggregated by the hour of the day (UTC)."""
|
||||
"""Returns total requests and tokens aggregated by the hour of the day (UTC)."""
|
||||
# Use safe SQLAlchemy constructs only
|
||||
hour_extract = func.strftime('%H', UsageLog.request_timestamp)
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
hour_extract.label("hour"),
|
||||
func.count(UsageLog.id).label("request_count")
|
||||
func.count(UsageLog.id).label("request_count"),
|
||||
func.coalesce(func.sum(UsageLog.prompt_tokens), 0).label("total_prompt_tokens"),
|
||||
func.coalesce(func.sum(UsageLog.completion_tokens), 0).label("total_completion_tokens"),
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total_tokens"),
|
||||
)
|
||||
.group_by("hour")
|
||||
.order_by("hour")
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
# Ensure all 24 hours are present
|
||||
stats_dict = {row.hour: row.request_count for row in result.all()}
|
||||
return [{"hour": f"{h:02d}:00", "request_count": stats_dict.get(f"{h:02d}", 0)} for h in range(24)]
|
||||
stats_dict = {row.hour: {
|
||||
"request_count": row.request_count,
|
||||
"total_prompt_tokens": row.total_prompt_tokens,
|
||||
"total_completion_tokens": row.total_completion_tokens,
|
||||
"total_tokens": row.total_tokens,
|
||||
} for row in result.all()}
|
||||
|
||||
return [{"hour": f"{h:02d}:00", **stats_dict.get(f"{h:02d}", {
|
||||
"request_count": 0,
|
||||
"total_prompt_tokens": 0,
|
||||
"total_completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
})} for h in range(24)]
|
||||
|
||||
|
||||
async def get_server_load_stats(db: AsyncSession):
|
||||
"""Returns total requests per backend server."""
|
||||
"""Returns total requests and tokens per backend server."""
|
||||
stmt = (
|
||||
select(
|
||||
OllamaServer.name.label("server_name"),
|
||||
func.count(UsageLog.id).label("request_count")
|
||||
func.count(UsageLog.id).label("request_count"),
|
||||
func.coalesce(func.sum(UsageLog.prompt_tokens), 0).label("total_prompt_tokens"),
|
||||
func.coalesce(func.sum(UsageLog.completion_tokens), 0).label("total_completion_tokens"),
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total_tokens"),
|
||||
)
|
||||
.select_from(OllamaServer)
|
||||
.outerjoin(UsageLog, OllamaServer.id == UsageLog.server_id)
|
||||
@@ -134,12 +189,16 @@ async def get_server_load_stats(db: AsyncSession):
|
||||
result = await db.execute(stmt)
|
||||
return result.all()
|
||||
|
||||
|
||||
async def get_model_usage_stats(db: AsyncSession):
|
||||
"""Returns total requests per model."""
|
||||
"""Returns total requests and tokens per model."""
|
||||
stmt = (
|
||||
select(
|
||||
UsageLog.model.label("model_name"),
|
||||
func.count(UsageLog.id).label("request_count")
|
||||
func.count(UsageLog.id).label("request_count"),
|
||||
func.coalesce(func.sum(UsageLog.prompt_tokens), 0).label("total_prompt_tokens"),
|
||||
func.coalesce(func.sum(UsageLog.completion_tokens), 0).label("total_completion_tokens"),
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total_tokens"),
|
||||
)
|
||||
.filter(UsageLog.model.isnot(None))
|
||||
.group_by(UsageLog.model)
|
||||
@@ -148,10 +207,11 @@ async def get_model_usage_stats(db: AsyncSession):
|
||||
result = await db.execute(stmt)
|
||||
return result.all()
|
||||
|
||||
|
||||
# --- NEW USER-SPECIFIC STATISTICS FUNCTIONS ---
|
||||
|
||||
async def get_daily_usage_stats_for_user(db: AsyncSession, user_id: int, days: int = 30):
|
||||
"""Returns total requests per day for the last N days for a specific user."""
|
||||
"""Returns total requests and tokens per day for the last N days for a specific user."""
|
||||
# Validate inputs
|
||||
try:
|
||||
user_id = int(user_id)
|
||||
@@ -175,7 +235,10 @@ async def get_daily_usage_stats_for_user(db: AsyncSession, user_id: int, days: i
|
||||
stmt = (
|
||||
select(
|
||||
date_column,
|
||||
func.count(UsageLog.id).label("request_count")
|
||||
func.count(UsageLog.id).label("request_count"),
|
||||
func.coalesce(func.sum(UsageLog.prompt_tokens), 0).label("total_prompt_tokens"),
|
||||
func.coalesce(func.sum(UsageLog.completion_tokens), 0).label("total_completion_tokens"),
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total_tokens"),
|
||||
)
|
||||
.join(APIKey, UsageLog.api_key_id == APIKey.id)
|
||||
.filter(APIKey.user_id == user_id)
|
||||
@@ -186,8 +249,9 @@ async def get_daily_usage_stats_for_user(db: AsyncSession, user_id: int, days: i
|
||||
result = await db.execute(stmt)
|
||||
return result.all()
|
||||
|
||||
|
||||
async def get_hourly_usage_stats_for_user(db: AsyncSession, user_id: int):
|
||||
"""Returns total requests aggregated by the hour for a specific user."""
|
||||
"""Returns total requests and tokens aggregated by the hour for a specific user."""
|
||||
# Validate user_id
|
||||
try:
|
||||
user_id = int(user_id)
|
||||
@@ -201,7 +265,10 @@ async def get_hourly_usage_stats_for_user(db: AsyncSession, user_id: int):
|
||||
stmt = (
|
||||
select(
|
||||
hour_extract.label("hour"),
|
||||
func.count(UsageLog.id).label("request_count")
|
||||
func.count(UsageLog.id).label("request_count"),
|
||||
func.coalesce(func.sum(UsageLog.prompt_tokens), 0).label("total_prompt_tokens"),
|
||||
func.coalesce(func.sum(UsageLog.completion_tokens), 0).label("total_completion_tokens"),
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total_tokens"),
|
||||
)
|
||||
.join(APIKey, UsageLog.api_key_id == APIKey.id)
|
||||
.filter(APIKey.user_id == user_id)
|
||||
@@ -209,11 +276,23 @@ async def get_hourly_usage_stats_for_user(db: AsyncSession, user_id: int):
|
||||
.order_by("hour")
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
stats_dict = {row.hour: row.request_count for row in result.all()}
|
||||
return [{"hour": f"{h:02d}:00", "request_count": stats_dict.get(f"{h:02d}", 0)} for h in range(24)]
|
||||
stats_dict = {row.hour: {
|
||||
"request_count": row.request_count,
|
||||
"total_prompt_tokens": row.total_prompt_tokens,
|
||||
"total_completion_tokens": row.total_completion_tokens,
|
||||
"total_tokens": row.total_tokens,
|
||||
} for row in result.all()}
|
||||
|
||||
return [{"hour": f"{h:02d}:00", **stats_dict.get(f"{h:02d}", {
|
||||
"request_count": 0,
|
||||
"total_prompt_tokens": 0,
|
||||
"total_completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
})} for h in range(24)]
|
||||
|
||||
|
||||
async def get_server_load_stats_for_user(db: AsyncSession, user_id: int):
|
||||
"""Returns total requests per backend server for a specific user."""
|
||||
"""Returns total requests and tokens per backend server for a specific user."""
|
||||
# Validate user_id
|
||||
try:
|
||||
user_id = int(user_id)
|
||||
@@ -225,7 +304,10 @@ async def get_server_load_stats_for_user(db: AsyncSession, user_id: int):
|
||||
stmt = (
|
||||
select(
|
||||
OllamaServer.name.label("server_name"),
|
||||
func.count(UsageLog.id).label("request_count")
|
||||
func.count(UsageLog.id).label("request_count"),
|
||||
func.coalesce(func.sum(UsageLog.prompt_tokens), 0).label("total_prompt_tokens"),
|
||||
func.coalesce(func.sum(UsageLog.completion_tokens), 0).label("total_completion_tokens"),
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total_tokens"),
|
||||
)
|
||||
.select_from(UsageLog)
|
||||
.join(APIKey, UsageLog.api_key_id == APIKey.id)
|
||||
@@ -237,8 +319,9 @@ async def get_server_load_stats_for_user(db: AsyncSession, user_id: int):
|
||||
result = await db.execute(stmt)
|
||||
return result.all()
|
||||
|
||||
|
||||
async def get_model_usage_stats_for_user(db: AsyncSession, user_id: int):
|
||||
"""Returns total requests per model for a specific user."""
|
||||
"""Returns total requests and tokens per model for a specific user."""
|
||||
# Validate user_id
|
||||
try:
|
||||
user_id = int(user_id)
|
||||
@@ -250,7 +333,10 @@ async def get_model_usage_stats_for_user(db: AsyncSession, user_id: int):
|
||||
stmt = (
|
||||
select(
|
||||
UsageLog.model.label("model_name"),
|
||||
func.count(UsageLog.id).label("request_count")
|
||||
func.count(UsageLog.id).label("request_count"),
|
||||
func.coalesce(func.sum(UsageLog.prompt_tokens), 0).label("total_prompt_tokens"),
|
||||
func.coalesce(func.sum(UsageLog.completion_tokens), 0).label("total_completion_tokens"),
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total_tokens"),
|
||||
)
|
||||
.join(APIKey, UsageLog.api_key_id == APIKey.id)
|
||||
.filter(APIKey.user_id == user_id)
|
||||
@@ -260,3 +346,40 @@ async def get_model_usage_stats_for_user(db: AsyncSession, user_id: int):
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return result.all()
|
||||
|
||||
|
||||
async def update_usage_log_with_tokens(
|
||||
db: AsyncSession,
|
||||
log_id: int,
|
||||
prompt_tokens: Optional[int] = None,
|
||||
completion_tokens: Optional[int] = None,
|
||||
total_tokens: Optional[int] = None
|
||||
) -> Optional[UsageLog]:
|
||||
"""Updates an existing usage log entry with token counts."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UsageLog).filter(UsageLog.id == log_id)
|
||||
)
|
||||
log_entry = result.scalars().first()
|
||||
|
||||
if not log_entry:
|
||||
logger.warning(f"Usage log entry {log_id} not found for token update")
|
||||
return None
|
||||
|
||||
# Validate and update token counts
|
||||
if prompt_tokens is not None:
|
||||
log_entry.prompt_tokens = max(0, int(prompt_tokens))
|
||||
if completion_tokens is not None:
|
||||
log_entry.completion_tokens = max(0, int(completion_tokens))
|
||||
if total_tokens is not None:
|
||||
log_entry.total_tokens = max(0, int(total_tokens))
|
||||
elif prompt_tokens is not None and completion_tokens is not None:
|
||||
log_entry.total_tokens = log_entry.prompt_tokens + log_entry.completion_tokens
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(log_entry)
|
||||
return log_entry
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update usage log {log_id} with tokens: {e}")
|
||||
return None
|
||||
|
||||
@@ -20,6 +20,7 @@ async def get_user_by_id(db: AsyncSession, user_id: int) -> User | None:
|
||||
async def get_users(db: AsyncSession, skip: int = 0, limit: int = 100, sort_by: str = "username", sort_order: str = "asc") -> list:
|
||||
"""
|
||||
Retrieves a list of users along with their statistics, with sorting.
|
||||
Includes token usage totals.
|
||||
"""
|
||||
# Subquery to find the last usage time for each user
|
||||
last_used_subq = (
|
||||
@@ -40,6 +41,7 @@ async def get_users(db: AsyncSession, skip: int = 0, limit: int = 100, sort_by:
|
||||
User.is_admin,
|
||||
func.count(func.distinct(APIKey.id)).label("key_count"),
|
||||
func.count(UsageLog.id).label("request_count"),
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total_tokens"),
|
||||
last_used_subq.c.last_used
|
||||
)
|
||||
.outerjoin(APIKey, User.id == APIKey.user_id)
|
||||
@@ -58,6 +60,7 @@ async def get_users(db: AsyncSession, skip: int = 0, limit: int = 100, sort_by:
|
||||
"username": User.username,
|
||||
"key_count": func.count(func.distinct(APIKey.id)),
|
||||
"request_count": func.count(UsageLog.id),
|
||||
"total_tokens": func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
"last_used": last_used_subq.c.last_used
|
||||
}
|
||||
sort_column = sort_column_map.get(sort_by, User.username)
|
||||
@@ -106,4 +109,4 @@ async def delete_user(db: AsyncSession, user_id: int) -> User | None:
|
||||
if user:
|
||||
await db.delete(user)
|
||||
await db.commit()
|
||||
return user
|
||||
return user
|
||||
|
||||
@@ -297,6 +297,26 @@ async def migrate_usage_logs_table(engine: AsyncEngine) -> None:
|
||||
"VARCHAR"
|
||||
)
|
||||
|
||||
# Add token usage columns if missing
|
||||
await add_column_if_missing(
|
||||
engine,
|
||||
"usage_logs",
|
||||
"prompt_tokens",
|
||||
"INTEGER"
|
||||
)
|
||||
await add_column_if_missing(
|
||||
engine,
|
||||
"usage_logs",
|
||||
"completion_tokens",
|
||||
"INTEGER"
|
||||
)
|
||||
await add_column_if_missing(
|
||||
engine,
|
||||
"usage_logs",
|
||||
"total_tokens",
|
||||
"INTEGER"
|
||||
)
|
||||
|
||||
# Create index on model column if it doesn't exist
|
||||
# Note: SQLite will silently ignore if index already exists
|
||||
async with engine.begin() as conn:
|
||||
@@ -347,9 +367,9 @@ async def migrate_app_settings_data(engine: AsyncEngine) -> None:
|
||||
|
||||
# Default values for new retry settings
|
||||
default_retry_settings = {
|
||||
"max_retries": 5,
|
||||
"retry_total_timeout_seconds": 2.0,
|
||||
"retry_base_delay_ms": 50
|
||||
"max_retries": 2,
|
||||
"retry_total_timeout_seconds": 1.0,
|
||||
"retry_base_delay_ms": 10
|
||||
}
|
||||
|
||||
# Add missing fields
|
||||
@@ -527,6 +547,9 @@ async def run_all_migrations(engine: AsyncEngine) -> None:
|
||||
},
|
||||
"usage_logs": {
|
||||
"model": "VARCHAR",
|
||||
"prompt_tokens": "INTEGER",
|
||||
"completion_tokens": "INTEGER",
|
||||
"total_tokens": "INTEGER",
|
||||
"server_id": "INTEGER",
|
||||
},
|
||||
"model_metadata": {
|
||||
|
||||
@@ -59,6 +59,11 @@ class UsageLog(Base):
|
||||
request_timestamp = Column(DateTime, default=datetime.datetime.utcnow)
|
||||
server_id = Column(Integer, ForeignKey("ollama_servers.id"), nullable=True)
|
||||
model = Column(String, nullable=True, index=True)
|
||||
|
||||
# Token usage tracking
|
||||
prompt_tokens = Column(Integer, nullable=True)
|
||||
completion_tokens = Column(Integer, nullable=True)
|
||||
total_tokens = Column(Integer, nullable=True)
|
||||
|
||||
api_key = relationship("APIKey", back_populates="usage_logs")
|
||||
server = relationship("OllamaServer")
|
||||
|
||||
@@ -19,6 +19,32 @@
|
||||
|
||||
{% block content %}
|
||||
<div class="space-y-8">
|
||||
|
||||
<!-- Token Usage Summary Card -->
|
||||
<div class="card-style">
|
||||
<h2 class="card-header text-2xl font-bold mb-4 pb-2">Token Usage Summary</h2>
|
||||
<div class="grid grid-cols-1 md:grid-cols-3 gap-6">
|
||||
<div class="p-4 rounded-lg text-center" style="background-color: rgba(128,128,128,0.1);">
|
||||
<h3 class="text-sm font-medium uppercase text-gray-400 mb-2">Total Prompt Tokens</h3>
|
||||
<div class="text-3xl font-bold text-[var(--color-primary-500)]">
|
||||
{{ "{:,}".format(total_prompt_tokens|default(0)) }}
|
||||
</div>
|
||||
</div>
|
||||
<div class="p-4 rounded-lg text-center" style="background-color: rgba(128,128,128,0.1);">
|
||||
<h3 class="text-sm font-medium uppercase text-gray-400 mb-2">Total Completion Tokens</h3>
|
||||
<div class="text-3xl font-bold text-[var(--color-primary-500)]">
|
||||
{{ "{:,}".format(total_completion_tokens|default(0)) }}
|
||||
</div>
|
||||
</div>
|
||||
<div class="p-4 rounded-lg text-center" style="background-color: rgba(128,128,128,0.1);">
|
||||
<h3 class="text-sm font-medium uppercase text-gray-400 mb-2">Total Tokens</h3>
|
||||
<div class="text-3xl font-bold text-[var(--color-primary-500)]">
|
||||
{{ "{:,}".format(total_tokens|default(0)) }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Top Row: Line and Bar Charts -->
|
||||
<div class="grid grid-cols-1 lg:grid-cols-2 gap-8">
|
||||
<div class="card-style">
|
||||
@@ -67,6 +93,18 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Token Usage by Model Chart -->
|
||||
<div class="card-style">
|
||||
<h2 class="card-header flex justify-between items-center text-xl font-bold mb-4 pb-2">
|
||||
<span>Token Usage by Model</span>
|
||||
<div class="space-x-2">
|
||||
<button onclick="exportChartToPNG('tokenUsageChart', 'token-usage-by-model.png')" class="text-sm text-[var(--color-primary-500)] hover:underline">PNG</button>
|
||||
<button onclick="exportDataToCSV(['Model Name', 'Prompt Tokens', 'Completion Tokens', 'Total Tokens'], {{ model_labels | tojson }}, {{ model_prompt_tokens | tojson }}, {{ model_completion_tokens | tojson }}, {{ model_total_tokens | tojson }}, 'token-usage-by-model.csv')" class="text-sm text-[var(--color-primary-500)] hover:underline">CSV</button>
|
||||
</div>
|
||||
</h2>
|
||||
{% if model_data and model_tokens_total %}<canvas id="tokenUsageChart"></canvas>{% else %}<div class="flex items-center justify-center h-64"><p>No token usage data available.</p></div>{% endif %}
|
||||
</div>
|
||||
|
||||
<!-- Bottom Row: Key Usage Table -->
|
||||
<div class="card-style">
|
||||
<h2 class="card-header flex justify-between items-center text-xl font-bold mb-4 pb-2">
|
||||
@@ -81,6 +119,9 @@
|
||||
{{ sortable_header('key_name', 'Key Name') }}
|
||||
{{ sortable_header('key_prefix', 'Key Prefix') }}
|
||||
{{ sortable_header('request_count', 'Requests', 'text-right') }}
|
||||
<th scope="col" class="px-6 py-3 text-right text-xs font-medium text-gray-400 uppercase tracking-wider">Prompt Tokens</th>
|
||||
<th scope="col" class="px-6 py-3 text-right text-xs font-medium text-gray-400 uppercase tracking-wider">Completion Tokens</th>
|
||||
<th scope="col" class="px-6 py-3 text-right text-xs font-medium text-gray-400 uppercase tracking-wider">Total Tokens</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody class="divide-y divide-white/10">
|
||||
@@ -90,9 +131,12 @@
|
||||
<td class="px-6 py-4 whitespace-nowrap text-sm font-medium text-current">{{ stat.key_name }}</td>
|
||||
<td class="px-6 py-4 whitespace-nowrap text-sm font-mono">{{ stat.key_prefix }}</td>
|
||||
<td class="px-6 py-4 whitespace-nowrap text-right text-sm font-medium text-current">{{ stat.request_count }}</td>
|
||||
<td class="px-6 py-4 whitespace-nowrap text-right text-sm">{{ "{:,}".format(stat.total_prompt_tokens|default(0)) }}</td>
|
||||
<td class="px-6 py-4 whitespace-nowrap text-right text-sm">{{ "{:,}".format(stat.total_completion_tokens|default(0)) }}</td>
|
||||
<td class="px-6 py-4 whitespace-nowrap text-right text-sm font-medium">{{ "{:,}".format(stat.total_tokens|default(0)) }}</td>
|
||||
</tr>
|
||||
{% else %}
|
||||
<tr><td colspan="4" class="px-6 py-4 text-center">No usage data available.</td></tr>
|
||||
<tr><td colspan="7" class="px-6 py-4 text-center">No usage data available.</td></tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
@@ -231,6 +275,40 @@
|
||||
options: { responsive: true }
|
||||
});
|
||||
}
|
||||
|
||||
// 5. Token Usage by Model Chart (Stacked Bar)
|
||||
const tokenCtx = document.getElementById('tokenUsageChart')?.getContext('2d');
|
||||
if (tokenCtx && {{ model_prompt_tokens | tojson }} && {{ model_completion_tokens | tojson }}) {
|
||||
new Chart(tokenCtx, {
|
||||
type: 'bar',
|
||||
data: {
|
||||
labels: {{ model_labels | tojson }},
|
||||
datasets: [
|
||||
{
|
||||
label: 'Prompt Tokens',
|
||||
data: {{ model_prompt_tokens | tojson }},
|
||||
backgroundColor: primaryColorTransparent,
|
||||
borderColor: primaryColor,
|
||||
borderWidth: 1
|
||||
},
|
||||
{
|
||||
label: 'Completion Tokens',
|
||||
data: {{ model_completion_tokens | tojson }},
|
||||
backgroundColor: '#f97316',
|
||||
borderColor: '#ea580c',
|
||||
borderWidth: 1
|
||||
}
|
||||
]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
scales: {
|
||||
y: { beginAtZero: true, stacked: true },
|
||||
x: { stacked: true }
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
</script>
|
||||
{% endblock %}
|
||||
|
||||
Reference in New Issue
Block a user