from __future__ import annotations import base64 import enum import logging from collections import defaultdict from datetime import datetime, timezone from json import JSONDecodeError from typing import ( TYPE_CHECKING, Annotated, Any, Callable, ClassVar, Generic, Literal, Optional, TypeVar, cast, get_args, ) from uuid import uuid4 from prisma.enums import CreditTransactionType, OnboardingStep from pydantic import ( BaseModel, ConfigDict, Field, GetCoreSchemaHandler, SecretStr, field_serializer, ) from pydantic_core import ( CoreSchema, PydanticUndefined, PydanticUndefinedType, ValidationError, core_schema, ) from typing_extensions import TypedDict from backend.integrations.providers import ProviderName from backend.util.json import loads as json_loads from backend.util.request import parse_url from backend.util.settings import Secrets # Type alias for any provider name (including custom ones) AnyProviderName = str # Will be validated as ProviderName at runtime USER_TIMEZONE_NOT_SET = "not-set" class User(BaseModel): """Application-layer User model with snake_case convention.""" model_config = ConfigDict( extra="forbid", str_strip_whitespace=True, ) id: str = Field(..., description="User ID") email: str = Field(..., description="User email address") email_verified: bool = Field(default=True, description="Whether email is verified") name: Optional[str] = Field(None, description="User display name") created_at: datetime = Field(..., description="When user was created") updated_at: datetime = Field(..., description="When user was last updated") metadata: dict[str, Any] = Field( default_factory=dict, description="User metadata as dict" ) integrations: str = Field(default="", description="Encrypted integrations data") stripe_customer_id: Optional[str] = Field(None, description="Stripe customer ID") top_up_config: Optional["AutoTopUpConfig"] = Field( None, description="Top up configuration" ) # Notification preferences max_emails_per_day: int = Field(default=3, description="Maximum emails per day") notify_on_agent_run: bool = Field(default=True, description="Notify on agent run") notify_on_zero_balance: bool = Field( default=True, description="Notify on zero balance" ) notify_on_low_balance: bool = Field( default=True, description="Notify on low balance" ) notify_on_block_execution_failed: bool = Field( default=True, description="Notify on block execution failure" ) notify_on_continuous_agent_error: bool = Field( default=True, description="Notify on continuous agent error" ) notify_on_daily_summary: bool = Field( default=True, description="Notify on daily summary" ) notify_on_weekly_summary: bool = Field( default=True, description="Notify on weekly summary" ) notify_on_monthly_summary: bool = Field( default=True, description="Notify on monthly summary" ) # User timezone for scheduling and time display timezone: str = Field( default=USER_TIMEZONE_NOT_SET, description="User timezone (IANA timezone identifier or 'not-set')", ) @classmethod def from_db(cls, prisma_user: "PrismaUser") -> "User": """Convert a database User object to application User model.""" # Handle metadata field - convert from JSON string or dict to dict metadata = {} if prisma_user.metadata: if isinstance(prisma_user.metadata, str): try: metadata = json_loads(prisma_user.metadata) except (JSONDecodeError, TypeError): metadata = {} elif isinstance(prisma_user.metadata, dict): metadata = prisma_user.metadata # Handle topUpConfig field top_up_config = None if prisma_user.topUpConfig: if isinstance(prisma_user.topUpConfig, str): try: config_dict = json_loads(prisma_user.topUpConfig) top_up_config = AutoTopUpConfig.model_validate(config_dict) except (JSONDecodeError, TypeError, ValueError): top_up_config = None elif isinstance(prisma_user.topUpConfig, dict): try: top_up_config = AutoTopUpConfig.model_validate( prisma_user.topUpConfig ) except ValueError: top_up_config = None return cls( id=prisma_user.id, email=prisma_user.email, email_verified=prisma_user.emailVerified or True, name=prisma_user.name, created_at=prisma_user.createdAt, updated_at=prisma_user.updatedAt, metadata=metadata, integrations=prisma_user.integrations or "", stripe_customer_id=prisma_user.stripeCustomerId, top_up_config=top_up_config, max_emails_per_day=prisma_user.maxEmailsPerDay or 3, notify_on_agent_run=prisma_user.notifyOnAgentRun or True, notify_on_zero_balance=prisma_user.notifyOnZeroBalance or True, notify_on_low_balance=prisma_user.notifyOnLowBalance or True, notify_on_block_execution_failed=prisma_user.notifyOnBlockExecutionFailed or True, notify_on_continuous_agent_error=prisma_user.notifyOnContinuousAgentError or True, notify_on_daily_summary=prisma_user.notifyOnDailySummary or True, notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True, notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True, timezone=prisma_user.timezone or USER_TIMEZONE_NOT_SET, ) if TYPE_CHECKING: from prisma.models import User as PrismaUser T = TypeVar("T") logger = logging.getLogger(__name__) class BlockSecret: def __init__(self, key: Optional[str] = None, value: Optional[str] = None): if value is not None: trimmed_value = value.strip() if value != trimmed_value: logger.debug(BlockSecret.TRIMMING_VALUE_MSG) self._value = trimmed_value return self._value = self.__get_secret(key) if self._value is None: raise ValueError(f"Secret {key} not found.") trimmed_value = self._value.strip() if self._value != trimmed_value: logger.debug(BlockSecret.TRIMMING_VALUE_MSG) self._value = trimmed_value TRIMMING_VALUE_MSG: ClassVar[str] = "Provided secret value got trimmed." STR: ClassVar[str] = "" SECRETS: ClassVar[Secrets] = Secrets() def __repr__(self): return BlockSecret.STR def __str__(self): return BlockSecret.STR @staticmethod def __get_secret(key: str | None): if not key or not hasattr(BlockSecret.SECRETS, key): return None return getattr(BlockSecret.SECRETS, key) def get_secret_value(self): trimmed_value = str(self._value).strip() if self._value != trimmed_value: logger.info(BlockSecret.TRIMMING_VALUE_MSG) return trimmed_value @classmethod def parse_value(cls, value: Any) -> BlockSecret: if isinstance(value, BlockSecret): return value return BlockSecret(value=value) @classmethod def __get_pydantic_json_schema__( cls, source_type: Any, handler: GetCoreSchemaHandler ) -> dict[str, Any]: return { "type": "string", } @classmethod def __get_pydantic_core_schema__( cls, source_type: Any, handler: GetCoreSchemaHandler ) -> CoreSchema: validate_fun = core_schema.no_info_plain_validator_function(cls.parse_value) return core_schema.json_or_python_schema( json_schema=validate_fun, python_schema=validate_fun, serialization=core_schema.plain_serializer_function_ser_schema( lambda val: BlockSecret.STR ), ) def SecretField( value: Optional[str] = None, key: Optional[str] = None, title: Optional[str] = None, description: Optional[str] = None, placeholder: Optional[str] = None, **kwargs, ) -> BlockSecret: return SchemaField( BlockSecret(key=key, value=value), title=title, description=description, placeholder=placeholder, secret=True, **kwargs, ) def SchemaField( default: T | PydanticUndefinedType = PydanticUndefined, *args, default_factory: Optional[Callable[[], T]] = None, title: Optional[str] = None, description: Optional[str] = None, placeholder: Optional[str] = None, advanced: Optional[bool] = None, secret: bool = False, exclude: bool = False, hidden: Optional[bool] = None, depends_on: Optional[list[str]] = None, ge: Optional[float] = None, le: Optional[float] = None, min_length: Optional[int] = None, max_length: Optional[int] = None, discriminator: Optional[str] = None, format: Optional[str] = None, json_schema_extra: Optional[dict[str, Any]] = None, ) -> T: if default is PydanticUndefined and default_factory is None: advanced = False elif advanced is None: advanced = True json_schema_extra = { k: v for k, v in { "placeholder": placeholder, "secret": secret, "advanced": advanced, "hidden": hidden, "depends_on": depends_on, "format": format, **(json_schema_extra or {}), }.items() if v is not None } return Field( default, *args, default_factory=default_factory, title=title, description=description, exclude=exclude, ge=ge, le=le, min_length=min_length, max_length=max_length, discriminator=discriminator, json_schema_extra=json_schema_extra, ) # type: ignore class _BaseCredentials(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) provider: str title: Optional[str] = None @field_serializer("*") def dump_secret_strings(value: Any, _info): if isinstance(value, SecretStr): return value.get_secret_value() return value class OAuth2Credentials(_BaseCredentials): type: Literal["oauth2"] = "oauth2" username: Optional[str] = None """Username of the third-party service user that these credentials belong to""" access_token: SecretStr access_token_expires_at: Optional[int] = None """Unix timestamp (seconds) indicating when the access token expires (if at all)""" refresh_token: Optional[SecretStr] = None refresh_token_expires_at: Optional[int] = None """Unix timestamp (seconds) indicating when the refresh token expires (if at all)""" scopes: list[str] metadata: dict[str, Any] = Field(default_factory=dict) def auth_header(self) -> str: return f"Bearer {self.access_token.get_secret_value()}" class APIKeyCredentials(_BaseCredentials): type: Literal["api_key"] = "api_key" api_key: SecretStr expires_at: Optional[int] = Field( default=None, description="Unix timestamp (seconds) indicating when the API key expires (if at all)", ) """Unix timestamp (seconds) indicating when the API key expires (if at all)""" def auth_header(self) -> str: # Linear API keys should not have Bearer prefix if self.provider == "linear": return self.api_key.get_secret_value() return f"Bearer {self.api_key.get_secret_value()}" class UserPasswordCredentials(_BaseCredentials): type: Literal["user_password"] = "user_password" username: SecretStr password: SecretStr def auth_header(self) -> str: # Converting the string to bytes using encode() # Base64 encoding it with base64.b64encode() # Converting the resulting bytes back to a string with decode() return f"Basic {base64.b64encode(f'{self.username.get_secret_value()}:{self.password.get_secret_value()}'.encode()).decode()}" class HostScopedCredentials(_BaseCredentials): type: Literal["host_scoped"] = "host_scoped" host: str = Field(description="The host/URI pattern to match against request URLs") headers: dict[str, SecretStr] = Field( description="Key-value header map to add to matching requests", default_factory=dict, ) def _extract_headers(self, headers: dict[str, SecretStr]) -> dict[str, str]: """Helper to extract secret values from headers.""" return {key: value.get_secret_value() for key, value in headers.items()} @field_serializer("headers") def serialize_headers(self, headers: dict[str, SecretStr]) -> dict[str, str]: """Serialize headers by extracting secret values.""" return self._extract_headers(headers) def get_headers_dict(self) -> dict[str, str]: """Get headers with secret values extracted.""" return self._extract_headers(self.headers) def auth_header(self) -> str: """Get authorization header for backward compatibility.""" auth_headers = self.get_headers_dict() if "Authorization" in auth_headers: return auth_headers["Authorization"] return "" def matches_url(self, url: str) -> bool: """Check if this credential should be applied to the given URL.""" request_host, request_port = _extract_host_from_url(url) cred_scope_host, cred_scope_port = _extract_host_from_url(self.host) if not request_host: return False # If a port is specified in credential host, the request host port must match if cred_scope_port is not None and request_port != cred_scope_port: return False # Non-standard ports are only allowed if explicitly specified in credential host elif cred_scope_port is None and request_port not in (80, 443, None): return False # Simple host matching if cred_scope_host == request_host: return True # Support wildcard matching (e.g., "*.example.com" matches "api.example.com") if cred_scope_host.startswith("*."): domain = cred_scope_host[2:] # Remove "*." return request_host.endswith(f".{domain}") or request_host == domain return False Credentials = Annotated[ OAuth2Credentials | APIKeyCredentials | UserPasswordCredentials | HostScopedCredentials, Field(discriminator="type"), ] CredentialsType = Literal["api_key", "oauth2", "user_password", "host_scoped"] class OAuthState(BaseModel): token: str provider: str expires_at: int code_verifier: Optional[str] = None """Unix timestamp (seconds) indicating when this OAuth state expires""" scopes: list[str] # Fields for external API OAuth flows callback_url: Optional[str] = None """External app's callback URL for OAuth redirect""" state_metadata: dict[str, Any] = Field(default_factory=dict) """Metadata to echo back to external app on completion""" initiated_by_api_key_id: Optional[str] = None """ID of the API key that initiated this OAuth flow""" @property def is_external(self) -> bool: """Whether this OAuth flow was initiated via external API.""" return self.callback_url is not None class UserMetadata(BaseModel): integration_credentials: list[Credentials] = Field(default_factory=list) """⚠️ Deprecated; use `UserIntegrations.credentials` instead""" integration_oauth_states: list[OAuthState] = Field(default_factory=list) """⚠️ Deprecated; use `UserIntegrations.oauth_states` instead""" class UserMetadataRaw(TypedDict, total=False): integration_credentials: list[dict] """⚠️ Deprecated; use `UserIntegrations.credentials` instead""" integration_oauth_states: list[dict] """⚠️ Deprecated; use `UserIntegrations.oauth_states` instead""" class UserIntegrations(BaseModel): class ManagedCredentials(BaseModel): """Integration credentials managed by us, rather than by the user""" ayrshare_profile_key: Optional[SecretStr] = None @field_serializer("*") def dump_secret_strings(value: Any, _info): if isinstance(value, SecretStr): return value.get_secret_value() return value managed_credentials: ManagedCredentials = Field(default_factory=ManagedCredentials) credentials: list[Credentials] = Field(default_factory=list) oauth_states: list[OAuthState] = Field(default_factory=list) CP = TypeVar("CP", bound=ProviderName) CT = TypeVar("CT", bound=CredentialsType) def is_credentials_field_name(field_name: str) -> bool: return field_name == "credentials" or field_name.endswith("_credentials") class CredentialsMetaInput(BaseModel, Generic[CP, CT]): id: str title: Optional[str] = None provider: CP type: CT @classmethod def allowed_providers(cls) -> tuple[ProviderName, ...] | None: return get_args(cls.model_fields["provider"].annotation) @classmethod def allowed_cred_types(cls) -> tuple[CredentialsType, ...]: return get_args(cls.model_fields["type"].annotation) @staticmethod def validate_credentials_field_schema( field_schema: dict[str, Any], field_name: str ): """Validates the schema of a credentials input field""" try: field_info = CredentialsFieldInfo[CP, CT].model_validate(field_schema) except ValidationError as e: if "Field required [type=missing" not in str(e): raise raise TypeError( "Field 'credentials' JSON schema lacks required extra items: " f"{field_schema}" ) from e providers = field_info.provider if ( providers is not None and len(providers) > 1 and not field_info.discriminator ): raise TypeError( f"Multi-provider CredentialsField '{field_name}' " "requires discriminator!" ) @staticmethod def _add_json_schema_extra(schema: dict, model_class: type): # Use model_class for allowed_providers/cred_types if hasattr(model_class, "allowed_providers") and hasattr( model_class, "allowed_cred_types" ): allowed_providers = model_class.allowed_providers() # If no specific providers (None), allow any string if allowed_providers is None: schema["credentials_provider"] = ["string"] # Allow any string provider else: schema["credentials_provider"] = allowed_providers schema["credentials_types"] = model_class.allowed_cred_types() # Do not return anything, just mutate schema in place model_config = ConfigDict( json_schema_extra=_add_json_schema_extra, # type: ignore ) def _extract_host_from_url(url: str) -> tuple[str, int | None]: """Extract host and port from URL for grouping host-scoped credentials.""" try: parsed = parse_url(url) return parsed.hostname or url, parsed.port except Exception: return "", None class CredentialsFieldInfo(BaseModel, Generic[CP, CT]): # TODO: move discrimination mechanism out of CredentialsField (frontend + backend) provider: frozenset[CP] = Field(..., alias="credentials_provider") supported_types: frozenset[CT] = Field(..., alias="credentials_types") required_scopes: Optional[frozenset[str]] = Field(None, alias="credentials_scopes") discriminator: Optional[str] = None discriminator_mapping: Optional[dict[str, CP]] = None discriminator_values: set[Any] = Field(default_factory=set) @classmethod def combine( cls, *fields: tuple[CredentialsFieldInfo[CP, CT], T] ) -> dict[str, tuple[CredentialsFieldInfo[CP, CT], set[T]]]: """ Combines multiple CredentialsFieldInfo objects into as few as possible. Rules: - Items can only be combined if they have the same supported credentials types and the same supported providers. - When combining items, the `required_scopes` of the result is a join of the `required_scopes` of the original items. Params: *fields: (CredentialsFieldInfo, key) objects to group and combine Returns: A sequence of tuples containing combined CredentialsFieldInfo objects and the set of keys of the respective original items that were grouped together. """ if not fields: return {} # Group fields by their provider and supported_types # For HTTP host-scoped credentials, also group by host grouped_fields: defaultdict[ tuple[frozenset[CP], frozenset[CT]], list[tuple[T, CredentialsFieldInfo[CP, CT]]], ] = defaultdict(list) for field, key in fields: if field.provider == frozenset([ProviderName.HTTP]): # HTTP host-scoped credentials can have different hosts that reqires different credential sets. # Group by host extracted from the URL providers = frozenset( [cast(CP, "http")] + [ cast(CP, parse_url(str(value)).netloc) for value in field.discriminator_values ] ) else: providers = frozenset(field.provider) group_key = (providers, frozenset(field.supported_types)) grouped_fields[group_key].append((key, field)) # Combine fields within each group result: dict[str, tuple[CredentialsFieldInfo[CP, CT], set[T]]] = {} for key, group in grouped_fields.items(): # Start with the first field in the group _, combined = group[0] # Track the keys that were combined combined_keys = {key for key, _ in group} # Combine required_scopes from all fields in the group all_scopes = set() for _, field in group: if field.required_scopes: all_scopes.update(field.required_scopes) # Combine discriminator_values from all fields in the group (removing duplicates) all_discriminator_values = [] for _, field in group: for value in field.discriminator_values: if value not in all_discriminator_values: all_discriminator_values.append(value) # Generate the key for the combined result providers_key, supported_types_key = key group_key = ( "-".join(sorted(providers_key)) + "_" + "-".join(sorted(supported_types_key)) + "_credentials" ) result[group_key] = ( CredentialsFieldInfo[CP, CT]( credentials_provider=combined.provider, credentials_types=combined.supported_types, credentials_scopes=frozenset(all_scopes) or None, discriminator=combined.discriminator, discriminator_mapping=combined.discriminator_mapping, discriminator_values=set(all_discriminator_values), ), combined_keys, ) return result def discriminate(self, discriminator_value: Any) -> CredentialsFieldInfo: if not (self.discriminator and self.discriminator_mapping): return self try: provider = self.discriminator_mapping[discriminator_value] except KeyError: raise ValueError( f"Model '{discriminator_value}' is not supported. " "It may have been deprecated. Please update your agent configuration." ) return CredentialsFieldInfo( credentials_provider=frozenset([provider]), credentials_types=self.supported_types, credentials_scopes=self.required_scopes, discriminator=self.discriminator, discriminator_mapping=self.discriminator_mapping, discriminator_values=self.discriminator_values, ) def CredentialsField( required_scopes: set[str] = set(), *, discriminator: Optional[str] = None, discriminator_mapping: Optional[dict[str, Any]] = None, discriminator_values: Optional[set[Any]] = None, title: Optional[str] = None, description: Optional[str] = None, **kwargs, ) -> CredentialsMetaInput: """ `CredentialsField` must and can only be used on fields named `credentials`. This is enforced by the `BlockSchema` base class. """ field_schema_extra = { k: v for k, v in { "credentials_scopes": list(required_scopes) or None, "discriminator": discriminator, "discriminator_mapping": discriminator_mapping, "discriminator_values": discriminator_values, }.items() if v is not None } # Merge any json_schema_extra passed in kwargs if "json_schema_extra" in kwargs: extra_schema = kwargs.pop("json_schema_extra") field_schema_extra.update(extra_schema) return Field( title=title, description=description, json_schema_extra=field_schema_extra, # validated on BlockSchema init **kwargs, ) class ContributorDetails(BaseModel): name: str = Field(title="Name", description="The name of the contributor.") class TopUpType(enum.Enum): AUTO = "AUTO" MANUAL = "MANUAL" UNCATEGORIZED = "UNCATEGORIZED" class AutoTopUpConfig(BaseModel): amount: int """Amount of credits to top up.""" threshold: int """Threshold to trigger auto top up.""" class UserTransaction(BaseModel): transaction_key: str = "" transaction_time: datetime = datetime.min.replace(tzinfo=timezone.utc) transaction_type: CreditTransactionType = CreditTransactionType.USAGE amount: int = 0 running_balance: int = 0 current_balance: int = 0 description: str | None = None usage_graph_id: str | None = None usage_execution_id: str | None = None usage_node_count: int = 0 usage_start_time: datetime = datetime.max.replace(tzinfo=timezone.utc) user_id: str user_email: str | None = None reason: str | None = None admin_email: str | None = None extra_data: str | None = None class TransactionHistory(BaseModel): transactions: list[UserTransaction] next_transaction_time: datetime | None class RefundRequest(BaseModel): id: str user_id: str transaction_key: str amount: int reason: str result: str | None = None status: str created_at: datetime updated_at: datetime class NodeExecutionStats(BaseModel): """Execution statistics for a node execution.""" model_config = ConfigDict( extra="allow", arbitrary_types_allowed=True, ) error: Optional[BaseException | str] = None walltime: float = 0 cputime: float = 0 input_size: int = 0 output_size: int = 0 llm_call_count: int = 0 llm_retry_count: int = 0 input_token_count: int = 0 output_token_count: int = 0 extra_cost: int = 0 extra_steps: int = 0 # Moderation fields cleared_inputs: Optional[dict[str, list[str]]] = None cleared_outputs: Optional[dict[str, list[str]]] = None def __iadd__(self, other: "NodeExecutionStats") -> "NodeExecutionStats": """Mutate this instance by adding another NodeExecutionStats.""" if not isinstance(other, NodeExecutionStats): return NotImplemented stats_dict = other.model_dump() current_stats = self.model_dump() for key, value in stats_dict.items(): if key not in current_stats: # Field doesn't exist yet, just set it setattr(self, key, value) elif isinstance(value, dict) and isinstance(current_stats[key], dict): current_stats[key].update(value) setattr(self, key, current_stats[key]) elif isinstance(value, (int, float)) and isinstance( current_stats[key], (int, float) ): setattr(self, key, current_stats[key] + value) elif isinstance(value, list) and isinstance(current_stats[key], list): current_stats[key].extend(value) setattr(self, key, current_stats[key]) else: setattr(self, key, value) return self class GraphExecutionStats(BaseModel): """Execution statistics for a graph execution.""" model_config = ConfigDict( extra="allow", arbitrary_types_allowed=True, ) error: Optional[Exception | str] = None walltime: float = Field( default=0, description="Time between start and end of run (seconds)" ) cputime: float = 0 nodes_walltime: float = Field( default=0, description="Total node execution time (seconds)" ) nodes_cputime: float = 0 node_count: int = Field(default=0, description="Total number of node executions") node_error_count: int = Field( default=0, description="Total number of errors generated" ) cost: int = Field(default=0, description="Total execution cost (cents)") activity_status: Optional[str] = Field( default=None, description="AI-generated summary of what the agent did" ) correctness_score: Optional[float] = Field( default=None, description="AI-generated score (0.0-1.0) indicating how well the execution achieved its intended purpose", ) class UserExecutionSummaryStats(BaseModel): """Summary of user statistics for a specific user.""" model_config = ConfigDict( extra="allow", arbitrary_types_allowed=True, ) total_credits_used: float = Field(default=0) total_executions: int = Field(default=0) successful_runs: int = Field(default=0) failed_runs: int = Field(default=0) most_used_agent: str = Field(default="") total_execution_time: float = Field(default=0) average_execution_time: float = Field(default=0) cost_breakdown: dict[str, float] = Field(default_factory=dict) class UserOnboarding(BaseModel): userId: str completedSteps: list[OnboardingStep] walletShown: bool notified: list[OnboardingStep] rewardedFor: list[OnboardingStep] usageReason: Optional[str] integrations: list[str] otherIntegrations: Optional[str] selectedStoreListingVersionId: Optional[str] agentInput: Optional[dict[str, Any]] onboardingAgentExecutionId: Optional[str] agentRuns: int lastRunAt: Optional[datetime] consecutiveRunDays: int