Files
ezkl/src/graph/postgres.rs
2024-05-13 12:47:36 +09:00

494 lines
16 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use log::{debug, error, info};
use std::fmt::Debug;
use std::net::IpAddr;
#[cfg(unix)]
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use std::{fmt, pin::Pin};
use tokio::task::JoinHandle;
#[doc(inline)]
pub use tokio_postgres::config::{
ChannelBinding, Host, LoadBalanceHosts, SslMode, TargetSessionAttrs,
};
use tokio_postgres::tls::NoTlsStream;
use tokio_postgres::NoTls;
use tokio_postgres::{error::DbError, types::ToSql, Error, Row, Socket, ToStatement};
/// Connection configuration.
///
/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats:
///
///
#[derive(Clone)]
pub struct Config {
config: tokio_postgres::Config,
notice_callback: Arc<dyn Fn(DbError) + Send + Sync>,
}
impl fmt::Debug for Config {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Config")
.field("config", &self.config)
.finish()
}
}
impl Default for Config {
fn default() -> Config {
Config::new()
}
}
impl Config {
/// Creates a new configuration.
pub fn new() -> Config {
tokio_postgres::Config::new().into()
}
/// Sets the user to authenticate with.
///
/// If the user is not set, then this defaults to the user executing this process.
pub fn user(&mut self, user: &str) -> &mut Config {
self.config.user(user);
self
}
/// Gets the user to authenticate with, if one has been configured with
/// the `user` method.
pub fn get_user(&self) -> Option<&str> {
self.config.get_user()
}
/// Sets the password to authenticate with.
pub fn password<T>(&mut self, password: T) -> &mut Config
where
T: AsRef<[u8]>,
{
self.config.password(password);
self
}
/// Gets the password to authenticate with, if one has been configured with
/// the `password` method.
pub fn get_password(&self) -> Option<&[u8]> {
self.config.get_password()
}
/// Sets the name of the database to connect to.
///
/// Defaults to the user.
pub fn dbname(&mut self, dbname: &str) -> &mut Config {
self.config.dbname(dbname);
self
}
/// Gets the name of the database to connect to, if one has been configured
/// with the `dbname` method.
pub fn get_dbname(&self) -> Option<&str> {
self.config.get_dbname()
}
/// Sets command line options used to configure the server.
pub fn options(&mut self, options: &str) -> &mut Config {
self.config.options(options);
self
}
/// Gets the command line options used to configure the server, if the
/// options have been set with the `options` method.
pub fn get_options(&self) -> Option<&str> {
self.config.get_options()
}
/// Sets the value of the `application_name` runtime parameter.
pub fn application_name(&mut self, application_name: &str) -> &mut Config {
self.config.application_name(application_name);
self
}
/// Gets the value of the `application_name` runtime parameter, if it has
/// been set with the `application_name` method.
pub fn get_application_name(&self) -> Option<&str> {
self.config.get_application_name()
}
/// Sets the SSL configuration.
///
/// Defaults to `prefer`.
pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
self.config.ssl_mode(ssl_mode);
self
}
/// Gets the SSL configuration.
pub fn get_ssl_mode(&self) -> SslMode {
self.config.get_ssl_mode()
}
/// Adds a host to the configuration.
///
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix
/// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets.
/// There must be either no hosts, or the same number of hosts as hostaddrs.
pub fn host(&mut self, host: &str) -> &mut Config {
self.config.host(host);
self
}
/// Gets the hosts that have been added to the configuration with `host`.
pub fn get_hosts(&self) -> &[Host] {
self.config.get_hosts()
}
/// Gets the hostaddrs that have been added to the configuration with `hostaddr`.
pub fn get_hostaddrs(&self) -> &[IpAddr] {
self.config.get_hostaddrs()
}
/// Adds a Unix socket host to the configuration.
///
/// Unlike `host`, this method allows non-UTF8 paths.
#[cfg(unix)]
pub fn host_path<T>(&mut self, host: T) -> &mut Config
where
T: AsRef<Path>,
{
self.config.host_path(host);
self
}
/// Adds a hostaddr to the configuration.
///
/// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order.
/// There must be either no hostaddrs, or the same number of hostaddrs as hosts.
pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config {
self.config.hostaddr(hostaddr);
self
}
/// Adds a port to the configuration.
///
/// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which
/// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports
/// as hosts.
pub fn port(&mut self, port: u16) -> &mut Config {
self.config.port(port);
self
}
/// Gets the ports that have been added to the configuration with `port`.
pub fn get_ports(&self) -> &[u16] {
self.config.get_ports()
}
/// Sets the timeout applied to socket-level connection attempts.
///
/// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each
/// host separately. Defaults to no limit.
pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
self.config.connect_timeout(connect_timeout);
self
}
/// Gets the connection timeout, if one has been set with the
/// `connect_timeout` method.
pub fn get_connect_timeout(&self) -> Option<&Duration> {
self.config.get_connect_timeout()
}
/// Sets the TCP user timeout.
///
/// This is ignored for Unix domain socket connections. It is only supported on systems where
/// TCP_USER_TIMEOUT is available and will default to the system default if omitted or set to 0;
/// on other systems, it has no effect.
pub fn tcp_user_timeout(&mut self, tcp_user_timeout: Duration) -> &mut Config {
self.config.tcp_user_timeout(tcp_user_timeout);
self
}
/// Gets the TCP user timeout, if one has been set with the
/// `user_timeout` method.
pub fn get_tcp_user_timeout(&self) -> Option<&Duration> {
self.config.get_tcp_user_timeout()
}
/// Controls the use of TCP keepalive.
///
/// This is ignored for Unix domain socket connections. Defaults to `true`.
pub fn keepalives(&mut self, keepalives: bool) -> &mut Config {
self.config.keepalives(keepalives);
self
}
/// Reports whether TCP keepalives will be used.
pub fn get_keepalives(&self) -> bool {
self.config.get_keepalives()
}
/// Sets the amount of idle time before a keepalive packet is sent on the connection.
///
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. Defaults to 2 hours.
pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config {
self.config.keepalives_idle(keepalives_idle);
self
}
/// Gets the configured amount of idle time before a keepalive packet will
/// be sent on the connection.
pub fn get_keepalives_idle(&self) -> Duration {
self.config.get_keepalives_idle()
}
/// Sets the time interval between TCP keepalive probes.
/// On Windows, this sets the value of the tcp_keepalive structs keepaliveinterval field.
///
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config {
self.config.keepalives_interval(keepalives_interval);
self
}
/// Gets the time interval between TCP keepalive probes.
pub fn get_keepalives_interval(&self) -> Option<Duration> {
self.config.get_keepalives_interval()
}
/// Sets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
///
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config {
self.config.keepalives_retries(keepalives_retries);
self
}
/// Gets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
pub fn get_keepalives_retries(&self) -> Option<u32> {
self.config.get_keepalives_retries()
}
/// Sets the requirements of the session.
///
/// This can be used to connect to the primary server in a clustered database rather than one of the read-only
/// secondary servers. Defaults to `Any`.
pub fn target_session_attrs(
&mut self,
target_session_attrs: TargetSessionAttrs,
) -> &mut Config {
self.config.target_session_attrs(target_session_attrs);
self
}
/// Gets the requirements of the session.
pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
self.config.get_target_session_attrs()
}
/// Sets the channel binding behavior.
///
/// Defaults to `prefer`.
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
self.config.channel_binding(channel_binding);
self
}
/// Gets the channel binding behavior.
pub fn get_channel_binding(&self) -> ChannelBinding {
self.config.get_channel_binding()
}
/// Sets the host load balancing behavior.
///
/// Defaults to `disable`.
pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config {
self.config.load_balance_hosts(load_balance_hosts);
self
}
/// Gets the host load balancing behavior.
pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts {
self.config.get_load_balance_hosts()
}
/// Sets the notice callback.
///
/// This callback will be invoked with the contents of every
/// [`AsyncMessage::Notice`] that is received by the connection. Notices use
/// the same structure as errors, but they are not "errors" per-se.
///
/// Notices are distinct from notifications, which are instead accessible
/// via the [`Notifications`] API.
///
/// [`AsyncMessage::Notice`]: tokio_postgres::AsyncMessage::Notice
/// [`Notifications`]: crate::Notifications
pub fn notice_callback<F>(&mut self, f: F) -> &mut Config
where
F: Fn(DbError) + Send + Sync + 'static,
{
self.notice_callback = Arc::new(f);
self
}
/// Opens a connection to a PostgreSQL database.
pub async fn connect(&self) -> Result<Client, Error> {
let (client, connection) = self.config.connect(NoTls).await?;
let connection = Connection::new(connection);
Ok(Client::new(client, connection))
}
}
impl FromStr for Config {
type Err = Error;
fn from_str(s: &str) -> Result<Config, Error> {
s.parse::<tokio_postgres::Config>().map(Config::from)
}
}
impl From<tokio_postgres::Config> for Config {
fn from(config: tokio_postgres::Config) -> Config {
Config {
config,
notice_callback: Arc::new(|notice| {
info!("{}: {}", notice.severity(), notice.message())
}),
}
}
}
#[allow(missing_debug_implementations, dead_code)]
/// An asynchronous PostgreSQL connection. We use this to keep the connection alive / keep it pinned so that it doesn't
/// get dropped.
pub struct Connection {
/// The underlying connection stream.
connection: Pin<Box<tokio_postgres::Connection<Socket, NoTlsStream>>>,
}
impl Connection {
/// Creates a new connection.
pub fn new(connection: tokio_postgres::Connection<Socket, NoTlsStream>) -> Self {
Connection {
connection: Box::pin(connection),
}
}
/// start the connection
pub async fn start(self) {
if let Err(e) = self.connection.await {
error!("connection error: {}", e);
}
}
}
#[allow(missing_debug_implementations, dead_code)]
/// An asynchronous PostgreSQL client.
pub struct Client {
connection: JoinHandle<()>,
client: tokio_postgres::Client,
}
impl Drop for Client {
fn drop(&mut self) {
let _ = self.close_inner();
}
}
impl Client {
pub(crate) fn new(client: tokio_postgres::Client, connection: Connection) -> Client {
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
let thread = tokio::spawn(async move {
connection.start().await;
});
Client {
client,
connection: thread,
}
}
/// A convenience function which parses a configuration string into a `Config` and then connects to the database.
///
/// See the documentation for [`Config`] for information about the connection syntax.
///
/// [`Config`]: config/struct.Config.html
pub async fn connect(params: &str) -> Result<Client, Error> {
debug!("Connecting to database with params: {}", params);
params.parse::<Config>()?.connect().await
}
/// Returns a new `Config` object which can be used to configure and connect to a database.
pub fn configure() -> Config {
Config::new()
}
/// Executes a statement, returning the number of rows modified.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned.
///
/// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
pub async fn execute<T>(
&mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<u64, Error>
where
T: ?Sized + ToStatement + Debug,
{
debug!("Executing query: {:?}", query);
self.client.execute(query, params).await
}
/// Executes a statement, returning the resulting rows.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// # Examples
///
pub async fn query<T>(
&mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Vec<Row>, Error>
where
T: ?Sized + ToStatement + Debug,
{
debug!("Executing query: {:?}", query);
self.client.query(query, params).await
}
/// Determines if the client's connection has already closed.
///
/// If this returns `true`, the client is no longer usable.
pub fn is_closed(&self) -> bool {
self.client.is_closed()
}
/// Closes the client's connection to the server.
///
/// This is equivalent to `Client`'s `Drop` implementation, except that it returns any error encountered to the
/// caller.
pub fn close(mut self) -> Result<(), Error> {
self.close_inner()
}
fn close_inner(&mut self) -> Result<(), Error> {
self.client.__private_api_close();
Ok(())
}
}