Compare commits

...

3 Commits

Author SHA1 Message Date
github-actions[bot]
d16fe8e928 ci: update version string in docs 2025-04-29 14:58:55 +00:00
dante
839030ce10 chore: rm halo2proofs patches (#976) 2025-04-29 10:58:35 -04:00
dante
cfccc5460c refactor: rm postgres (#977) 2025-04-29 08:59:14 -04:00
10 changed files with 1554 additions and 2139 deletions

View File

@@ -798,8 +798,6 @@ jobs:
run: source .env/bin/activate; cargo nextest run py_tests::tests::neural_bag_of_words_ --no-capture
- name: Felt conversion
run: source .env/bin/activate; cargo nextest run py_tests::tests::felt_conversion_test_ --no-capture
- name: Postgres tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --no-capture
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --test-threads 1
# - name: authenticate-kaggle-cli

2552
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,7 @@ cargo-features = ["profile-rustflags"]
[package]
name = "ezkl"
version = "0.0.0"
edition = "2024"
edition = "2021"
default-run = "ezkl"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -69,20 +69,18 @@ reqwest = { version = "0.12.4", default-features = false, features = [
"stream",
], optional = true }
openssl = { version = "0.10.55", features = ["vendored"], optional = true }
tokio-postgres = { version = "0.7.10", optional = true }
pg_bigdecimal = { version = "0.1.5", optional = true }
lazy_static = { version = "1.4.0", optional = true }
colored_json = { version = "3.0.1", default-features = false, optional = true }
tokio = { version = "1.35.0", default-features = false, features = [
"macros",
"rt-multi-thread",
], optional = true }
pyo3 = { version = "0.23.2", features = [
pyo3 = { version = "0.24.2", features = [
"extension-module",
"abi3-py37",
"macros",
], default-features = false, optional = true }
pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", version = "0.23.0", features = [
pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", version = "0.24.0", features = [
"attributes",
"tokio-runtime",
], default-features = false, optional = true }
@@ -242,8 +240,6 @@ ezkl = [
"dep:indicatif",
"dep:gag",
"dep:reqwest",
"dep:tokio-postgres",
"dep:pg_bigdecimal",
"dep:lazy_static",
"dep:tokio",
"dep:openssl",
@@ -275,12 +271,6 @@ no-update = []
macos-metal = ["halo2_proofs/macos"]
ios-metal = ["halo2_proofs/ios"]
[patch.'https://github.com/zkonduit/halo2']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d", package = "halo2_proofs" }
[patch.'https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d", package = "halo2_proofs" }
[patch.crates-io]
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
@@ -289,7 +279,7 @@ uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "fea
rustflags = ["-C", "relocation-model=pic"]
lto = "fat"
codegen-units = 1
#panic = "abort"
# panic = "abort"
[profile.test-runs]

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '0.0.0'
release = '22.0.3'
version = release

View File

@@ -1,462 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Mean of ERC20 transfer amounts\n",
"\n",
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
"\n",
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import getpass\n",
"import json\n",
"import time\n",
"import subprocess\n",
"\n",
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
"os.system(\"chmod +x shovel\")\n",
"\n",
"\n",
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
"\n",
"# create a config.json file with the following contents\n",
"config = {\n",
" \"pg_url\": \"$PG_URL\",\n",
" \"eth_sources\": [\n",
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
" ],\n",
" \"integrations\": [{\n",
" \"name\": \"usdc_transfer\",\n",
" \"enabled\": True,\n",
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
" \"table\": {\n",
" \"name\": \"usdc\",\n",
" \"columns\": [\n",
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
" ]\n",
" },\n",
" \"block\": [\n",
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
" {\n",
" \"name\": \"log_addr\",\n",
" \"column\": \"log_addr\",\n",
" \"filter_op\": \"contains\",\n",
" \"filter_arg\": [\n",
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
" ]\n",
" }\n",
" ],\n",
" \"event\": {\n",
" \"name\": \"Transfer\",\n",
" \"type\": \"event\",\n",
" \"anonymous\": False,\n",
" \"inputs\": [\n",
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
" ]\n",
" }\n",
" }]\n",
"}\n",
"\n",
"# write the config to a file\n",
"with open(\"config.json\", \"w\") as f:\n",
" f.write(json.dumps(config))\n",
"\n",
"\n",
"# print the two env variables\n",
"os.system(\"echo $PG_URL\")\n",
"\n",
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
"\n",
"os.system(\"echo shovel is now installed. starting:\")\n",
"\n",
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
"proc = subprocess.Popen(command)\n",
"\n",
"os.system(\"echo shovel started.\")\n",
"\n",
"time.sleep(10)\n",
"\n",
"# after we've fetched some data -- kill the process\n",
"proc.terminate()\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2wIAHwqH2_mo"
},
"source": [
"**Import Dependencies**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9Byiv2Nc2MsK"
},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import ezkl\n",
"import torch\n",
"import datetime\n",
"import pandas as pd\n",
"import requests\n",
"import json\n",
"import os\n",
"\n",
"import logging\n",
"# # uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n",
"\n",
"print(\"ezkl version: \", ezkl.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "osjj-0Ta3E8O"
},
"source": [
"**Create Computational Graph**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "x1vl9ZXF3EEW",
"outputId": "bda21d02-fe5f-4fb2-8106-f51a8e2e67aa"
},
"outputs": [],
"source": [
"from torch import nn\n",
"import torch\n",
"\n",
"\n",
"class Model(nn.Module):\n",
" def __init__(self):\n",
" super(Model, self).__init__()\n",
"\n",
" # x is a time series \n",
" def forward(self, x):\n",
" return [torch.mean(x)]\n",
"\n",
"\n",
"\n",
"\n",
"circuit = Model()\n",
"\n",
"\n",
"\n",
"\n",
"x = 0.1*torch.rand(1,*[1,5], requires_grad=True)\n",
"\n",
"# # print(torch.__version__)\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"print(device)\n",
"\n",
"circuit.to(device)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
"# Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" \"lol.onnx\", # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=11, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"# export(circuit, input_shape=[1, 20])\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E3qCeX-X5xqd"
},
"source": [
"**Set Data Source and Get Data**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6RAMplxk5xPk",
"outputId": "bd2158fe-0c00-44fd-e632-6a3f70cdb7c9"
},
"outputs": [],
"source": [
"import getpass\n",
"# make an input.json file from the df above\n",
"input_filename = os.path.join('input.json')\n",
"\n",
"pg_input_file = dict(input_data = {\n",
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
"json_formatted_str = json.dumps(pg_input_file, indent=2)\n",
"print(json_formatted_str)\n",
"\n",
"\n",
" # Serialize data into file:\n",
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# this corresponds to 4 batches\n",
"calibration_filename = os.path.join('calibration.json')\n",
"\n",
"pg_cal_file = dict(input_data = {\n",
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
" # Serialize data into file:\n",
"json.dump( pg_cal_file, open(calibration_filename, 'w' ))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eLJ7oirQ_HQR"
},
"source": [
"**EZKL Workflow**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rNw0C9QL6W88"
},
"outputs": [],
"source": [
"import subprocess\n",
"import os\n",
"\n",
"onnx_filename = os.path.join('lol.onnx')\n",
"compiled_filename = os.path.join('lol.compiled')\n",
"settings_filename = os.path.join('settings.json')\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.decomp_legs = 4\n",
"\n",
"# Generate settings using ezkl\n",
"res = ezkl.gen_settings(onnx_filename, settings_filename, py_run_args=run_args)\n",
"\n",
"assert res == True\n",
"\n",
"res = await ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
"\n",
"assert res == True\n",
"\n",
"await ezkl.get_srs(settings_filename)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4MmE9SX66_Il",
"outputId": "16403639-66a4-4280-ac7f-6966b75de5a3"
},
"outputs": [],
"source": [
"# generate settings\n",
"\n",
"\n",
"# show the settings.json\n",
"with open(\"settings.json\") as f:\n",
" data = json.load(f)\n",
" json_formatted_str = json.dumps(data, indent=2)\n",
"\n",
" print(json_formatted_str)\n",
"\n",
"assert os.path.exists(\"settings.json\")\n",
"assert os.path.exists(\"input.json\")\n",
"assert os.path.exists(\"lol.onnx\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fULvvnK7_CMb"
},
"outputs": [],
"source": [
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"\n",
"\n",
"# setup the proof\n",
"res = ezkl.setup(\n",
" compiled_filename,\n",
" vk_path,\n",
" pk_path\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_filename)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"witness_path = \"witness.json\"\n",
"\n",
"# generate the witness\n",
"res = await ezkl.gen_witness(\n",
" input_filename,\n",
" compiled_filename,\n",
" witness_path\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Oog3j6Kd-Wed",
"outputId": "5839d0c1-5b43-476e-c2f8-6707de562260"
},
"outputs": [],
"source": [
"# prove the zk circuit\n",
"# GENERATE A PROOF\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"\n",
"proof = ezkl.prove(\n",
" witness_path,\n",
" compiled_filename,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\"\n",
" )\n",
"\n",
"\n",
"print(\"proved\")\n",
"\n",
"assert os.path.isfile(proof_path)\n",
"\n"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@@ -98,8 +98,6 @@ pub enum GraphError {
feature = "ezkl",
not(all(target_arch = "wasm32", target_os = "unknown"))
))]
#[error("[tokio postgres] {0}")]
TokioPostgresError(#[from] tokio_postgres::Error),
/// Eth error
#[cfg(all(
feature = "ezkl",
@@ -141,7 +139,9 @@ pub enum GraphError {
#[error("range check {0} is too large")]
RangeCheckTooLarge(usize),
///Cannot use on-chain data source as private data
#[error("cannot use on-chain data source as 1) output for on-chain test 2) as private data 3) as input when using wasm.")]
#[error(
"cannot use on-chain data source as 1) output for on-chain test 2) as private data 3) as input when using wasm."
)]
OnChainDataSource,
/// Missing data source
#[error("missing data source")]

View File

@@ -3,8 +3,6 @@ use super::quantize_float;
use crate::EZKL_BUF_CAPACITY;
use crate::circuit::InputType;
use crate::fieldutils::integer_rep_to_felt;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::graph::postgres::Client;
use halo2curves::bn256::Fr as Fp;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
@@ -23,9 +21,6 @@ use tract_onnx::tract_core::{
value::TValue,
};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_hir::tract_num_traits::ToPrimitive;
type Decimals = u8;
type Call = String;
type RPCUrl = String;
@@ -260,9 +255,6 @@ pub enum DataSource {
File(FileSource),
/// Data fetched from blockchain contracts
OnChain(OnChainSource),
/// Data from a PostgreSQL database
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
DB(PostgresSource),
}
impl Default for DataSource {
@@ -323,15 +315,6 @@ impl<'de> Deserialize<'de> for DataSource {
return Ok(DataSource::OnChain(t));
}
// Try deserializing as PostgresSource if feature enabled
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
{
let third_try: Result<PostgresSource, _> = serde_json::from_str(this_json.get());
if let Ok(t) = third_try {
return Ok(DataSource::DB(t));
}
}
Err(serde::de::Error::custom("failed to deserialize DataSource"))
}
}
@@ -517,11 +500,6 @@ impl GraphData {
"on-chain data cannot be split into batches".to_string(),
));
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
GraphData {
input_data: DataSource::DB(data),
output_data: _,
} => data.fetch_and_format_as_file().await?,
};
// Process each input tensor according to its shape
@@ -591,28 +569,6 @@ impl GraphData {
mod tests {
use super::*;
#[test]
fn test_postgres_source_new() {
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
{
let source = PostgresSource::new(
"localhost".to_string(),
"5432".to_string(),
"user".to_string(),
"SELECT * FROM table".to_string(),
"database".to_string(),
"password".to_string(),
);
assert_eq!(source.host, "localhost");
assert_eq!(source.port, "5432");
assert_eq!(source.user, "user");
assert_eq!(source.query, "SELECT * FROM table");
assert_eq!(source.dbname, "database");
assert_eq!(source.password, "password");
}
}
#[test]
fn test_data_source_serialization_round_trip() {
// Test backwards compatibility with old format
@@ -655,95 +611,6 @@ mod tests {
}
}
/// Source data from a PostgreSQL database
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
pub struct PostgresSource {
/// Database host address
pub host: RPCUrl,
/// Database user name
pub user: String,
/// Database password
pub password: String,
/// SQL query to execute
pub query: String,
/// Database name
pub dbname: String,
/// Database port
pub port: String,
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl PostgresSource {
/// Creates a new PostgreSQL data source
pub fn new(
host: RPCUrl,
port: String,
user: String,
query: String,
dbname: String,
password: String,
) -> Self {
PostgresSource {
host,
user,
password,
query,
dbname,
port,
}
}
/// Fetches data from the PostgreSQL database
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
// Configuration string
let config = if self.password.is_empty() {
format!(
"host={} user={} dbname={} port={}",
self.host, self.user, self.dbname, self.port
)
} else {
format!(
"host={} user={} dbname={} port={} password={}",
self.host, self.user, self.dbname, self.port, self.password
)
};
let mut client = Client::connect(&config).await?;
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
// Extract rows from query
for row in client.query(&self.query, &[]).await? {
for i in 0..row.len() {
res.push(row.get(i));
}
}
Ok(vec![res])
}
/// Fetches and formats data as FileSource
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
Ok(self
.fetch()
.await?
.iter()
.map(|d| {
d.iter()
.map(|d| {
FileSourceInner::Float(
d.n.as_ref()
.unwrap()
.to_f64()
.ok_or("could not convert decimal to f64")
.unwrap(),
)
})
.collect()
})
.collect())
}
}
#[cfg(feature = "python-bindings")]
impl ToPyObject for CallToAccount {
fn to_object(&self, py: Python) -> PyObject {
@@ -767,14 +634,6 @@ impl ToPyObject for DataSource {
.unwrap();
dict.to_object(py)
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
DataSource::DB(source) => {
let dict = PyDict::new(py);
dict.set_item("host", &source.host).unwrap();
dict.set_item("user", &source.user).unwrap();
dict.set_item("query", &source.query).unwrap();
dict.to_object(py)
}
}
}
}

View File

@@ -6,9 +6,6 @@ pub mod model;
pub mod modules;
/// Inner elements of a computational graph that represent a single operation / constraints.
pub mod node;
/// postgres helper functions
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub mod postgres;
/// Helper functions
pub mod utilities;
/// Representations of a computational graph's variables.
@@ -1011,10 +1008,6 @@ impl GraphCircuit {
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)
}
DataSource::DB(pg) => {
let data = pg.fetch_and_format_as_file().await?;
self.load_file_data(&data, &shapes, scales, input_types)
}
}
}

View File

@@ -1,493 +0,0 @@
use log::{debug, error, info};
use std::fmt::Debug;
use std::net::IpAddr;
#[cfg(all(not(not(feature = "ezkl")), unix))]
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use std::{fmt, pin::Pin};
use tokio::task::JoinHandle;
#[doc(inline)]
pub use tokio_postgres::config::{
ChannelBinding, Host, LoadBalanceHosts, SslMode, TargetSessionAttrs,
};
use tokio_postgres::tls::NoTlsStream;
use tokio_postgres::NoTls;
use tokio_postgres::{error::DbError, types::ToSql, Error, Row, Socket, ToStatement};
/// Connection configuration.
///
/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats:
///
///
#[derive(Clone)]
pub struct Config {
config: tokio_postgres::Config,
notice_callback: Arc<dyn Fn(DbError) + Send + Sync>,
}
impl fmt::Debug for Config {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Config")
.field("config", &self.config)
.finish()
}
}
impl Default for Config {
fn default() -> Config {
Config::new()
}
}
impl Config {
/// Creates a new configuration.
pub fn new() -> Config {
tokio_postgres::Config::new().into()
}
/// Sets the user to authenticate with.
///
/// If the user is not set, then this defaults to the user executing this process.
pub fn user(&mut self, user: &str) -> &mut Config {
self.config.user(user);
self
}
/// Gets the user to authenticate with, if one has been configured with
/// the `user` method.
pub fn get_user(&self) -> Option<&str> {
self.config.get_user()
}
/// Sets the password to authenticate with.
pub fn password<T>(&mut self, password: T) -> &mut Config
where
T: AsRef<[u8]>,
{
self.config.password(password);
self
}
/// Gets the password to authenticate with, if one has been configured with
/// the `password` method.
pub fn get_password(&self) -> Option<&[u8]> {
self.config.get_password()
}
/// Sets the name of the database to connect to.
///
/// Defaults to the user.
pub fn dbname(&mut self, dbname: &str) -> &mut Config {
self.config.dbname(dbname);
self
}
/// Gets the name of the database to connect to, if one has been configured
/// with the `dbname` method.
pub fn get_dbname(&self) -> Option<&str> {
self.config.get_dbname()
}
/// Sets command line options used to configure the server.
pub fn options(&mut self, options: &str) -> &mut Config {
self.config.options(options);
self
}
/// Gets the command line options used to configure the server, if the
/// options have been set with the `options` method.
pub fn get_options(&self) -> Option<&str> {
self.config.get_options()
}
/// Sets the value of the `application_name` runtime parameter.
pub fn application_name(&mut self, application_name: &str) -> &mut Config {
self.config.application_name(application_name);
self
}
/// Gets the value of the `application_name` runtime parameter, if it has
/// been set with the `application_name` method.
pub fn get_application_name(&self) -> Option<&str> {
self.config.get_application_name()
}
/// Sets the SSL configuration.
///
/// Defaults to `prefer`.
pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
self.config.ssl_mode(ssl_mode);
self
}
/// Gets the SSL configuration.
pub fn get_ssl_mode(&self) -> SslMode {
self.config.get_ssl_mode()
}
/// Adds a host to the configuration.
///
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix
/// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets.
/// There must be either no hosts, or the same number of hosts as hostaddrs.
pub fn host(&mut self, host: &str) -> &mut Config {
self.config.host(host);
self
}
/// Gets the hosts that have been added to the configuration with `host`.
pub fn get_hosts(&self) -> &[Host] {
self.config.get_hosts()
}
/// Gets the hostaddrs that have been added to the configuration with `hostaddr`.
pub fn get_hostaddrs(&self) -> &[IpAddr] {
self.config.get_hostaddrs()
}
/// Adds a Unix socket host to the configuration.
///
/// Unlike `host`, this method allows non-UTF8 paths.
#[cfg(all(not(not(feature = "ezkl")), unix))]
pub fn host_path<T>(&mut self, host: T) -> &mut Config
where
T: AsRef<Path>,
{
self.config.host_path(host);
self
}
/// Adds a hostaddr to the configuration.
///
/// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order.
/// There must be either no hostaddrs, or the same number of hostaddrs as hosts.
pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config {
self.config.hostaddr(hostaddr);
self
}
/// Adds a port to the configuration.
///
/// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which
/// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports
/// as hosts.
pub fn port(&mut self, port: u16) -> &mut Config {
self.config.port(port);
self
}
/// Gets the ports that have been added to the configuration with `port`.
pub fn get_ports(&self) -> &[u16] {
self.config.get_ports()
}
/// Sets the timeout applied to socket-level connection attempts.
///
/// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each
/// host separately. Defaults to no limit.
pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
self.config.connect_timeout(connect_timeout);
self
}
/// Gets the connection timeout, if one has been set with the
/// `connect_timeout` method.
pub fn get_connect_timeout(&self) -> Option<&Duration> {
self.config.get_connect_timeout()
}
/// Sets the TCP user timeout.
///
/// This is ignored for Unix domain socket connections. It is only supported on systems where
/// TCP_USER_TIMEOUT is available and will default to the system default if omitted or set to 0;
/// on other systems, it has no effect.
pub fn tcp_user_timeout(&mut self, tcp_user_timeout: Duration) -> &mut Config {
self.config.tcp_user_timeout(tcp_user_timeout);
self
}
/// Gets the TCP user timeout, if one has been set with the
/// `user_timeout` method.
pub fn get_tcp_user_timeout(&self) -> Option<&Duration> {
self.config.get_tcp_user_timeout()
}
/// Controls the use of TCP keepalive.
///
/// This is ignored for Unix domain socket connections. Defaults to `true`.
pub fn keepalives(&mut self, keepalives: bool) -> &mut Config {
self.config.keepalives(keepalives);
self
}
/// Reports whether TCP keepalives will be used.
pub fn get_keepalives(&self) -> bool {
self.config.get_keepalives()
}
/// Sets the amount of idle time before a keepalive packet is sent on the connection.
///
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. Defaults to 2 hours.
pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config {
self.config.keepalives_idle(keepalives_idle);
self
}
/// Gets the configured amount of idle time before a keepalive packet will
/// be sent on the connection.
pub fn get_keepalives_idle(&self) -> Duration {
self.config.get_keepalives_idle()
}
/// Sets the time interval between TCP keepalive probes.
/// On Windows, this sets the value of the tcp_keepalive structs keepaliveinterval field.
///
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config {
self.config.keepalives_interval(keepalives_interval);
self
}
/// Gets the time interval between TCP keepalive probes.
pub fn get_keepalives_interval(&self) -> Option<Duration> {
self.config.get_keepalives_interval()
}
/// Sets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
///
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config {
self.config.keepalives_retries(keepalives_retries);
self
}
/// Gets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
pub fn get_keepalives_retries(&self) -> Option<u32> {
self.config.get_keepalives_retries()
}
/// Sets the requirements of the session.
///
/// This can be used to connect to the primary server in a clustered database rather than one of the read-only
/// secondary servers. Defaults to `Any`.
pub fn target_session_attrs(
&mut self,
target_session_attrs: TargetSessionAttrs,
) -> &mut Config {
self.config.target_session_attrs(target_session_attrs);
self
}
/// Gets the requirements of the session.
pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
self.config.get_target_session_attrs()
}
/// Sets the channel binding behavior.
///
/// Defaults to `prefer`.
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
self.config.channel_binding(channel_binding);
self
}
/// Gets the channel binding behavior.
pub fn get_channel_binding(&self) -> ChannelBinding {
self.config.get_channel_binding()
}
/// Sets the host load balancing behavior.
///
/// Defaults to `disable`.
pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config {
self.config.load_balance_hosts(load_balance_hosts);
self
}
/// Gets the host load balancing behavior.
pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts {
self.config.get_load_balance_hosts()
}
/// Sets the notice callback.
///
/// This callback will be invoked with the contents of every
/// [`AsyncMessage::Notice`] that is received by the connection. Notices use
/// the same structure as errors, but they are not "errors" per-se.
///
/// Notices are distinct from notifications, which are instead accessible
/// via the [`Notifications`] API.
///
/// [`AsyncMessage::Notice`]: tokio_postgres::AsyncMessage::Notice
/// [`Notifications`]: crate::Notifications
pub fn notice_callback<F>(&mut self, f: F) -> &mut Config
where
F: Fn(DbError) + Send + Sync + 'static,
{
self.notice_callback = Arc::new(f);
self
}
/// Opens a connection to a PostgreSQL database.
pub async fn connect(&self) -> Result<Client, Error> {
let (client, connection) = self.config.connect(NoTls).await?;
let connection = Connection::new(connection);
Ok(Client::new(client, connection))
}
}
impl FromStr for Config {
type Err = Error;
fn from_str(s: &str) -> Result<Config, Error> {
s.parse::<tokio_postgres::Config>().map(Config::from)
}
}
impl From<tokio_postgres::Config> for Config {
fn from(config: tokio_postgres::Config) -> Config {
Config {
config,
notice_callback: Arc::new(|notice| {
info!("{}: {}", notice.severity(), notice.message())
}),
}
}
}
#[allow(missing_debug_implementations, dead_code)]
/// An asynchronous PostgreSQL connection. We use this to keep the connection alive / keep it pinned so that it doesn't
/// get dropped.
pub struct Connection {
/// The underlying connection stream.
connection: Pin<Box<tokio_postgres::Connection<Socket, NoTlsStream>>>,
}
impl Connection {
/// Creates a new connection.
pub fn new(connection: tokio_postgres::Connection<Socket, NoTlsStream>) -> Self {
Connection {
connection: Box::pin(connection),
}
}
/// start the connection
pub async fn start(self) {
if let Err(e) = self.connection.await {
error!("connection error: {}", e);
}
}
}
#[allow(missing_debug_implementations, dead_code)]
/// An asynchronous PostgreSQL client.
pub struct Client {
connection: JoinHandle<()>,
client: tokio_postgres::Client,
}
impl Drop for Client {
fn drop(&mut self) {
let _ = self.close_inner();
}
}
impl Client {
pub(crate) fn new(client: tokio_postgres::Client, connection: Connection) -> Client {
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
let thread = tokio::spawn(async move {
connection.start().await;
});
Client {
client,
connection: thread,
}
}
/// A convenience function which parses a configuration string into a `Config` and then connects to the database.
///
/// See the documentation for [`Config`] for information about the connection syntax.
///
/// [`Config`]: config/struct.Config.html
pub async fn connect(params: &str) -> Result<Client, Error> {
debug!("Connecting to database with params: {}", params);
params.parse::<Config>()?.connect().await
}
/// Returns a new `Config` object which can be used to configure and connect to a database.
pub fn configure() -> Config {
Config::new()
}
/// Executes a statement, returning the number of rows modified.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned.
///
/// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
pub async fn execute<T>(
&mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<u64, Error>
where
T: ?Sized + ToStatement + Debug,
{
debug!("Executing query: {:?}", query);
self.client.execute(query, params).await
}
/// Executes a statement, returning the resulting rows.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// # Examples
///
pub async fn query<T>(
&mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Vec<Row>, Error>
where
T: ?Sized + ToStatement + Debug,
{
debug!("Executing query: {:?}", query);
self.client.query(query, params).await
}
/// Determines if the client's connection has already closed.
///
/// If this returns `true`, the client is no longer usable.
pub fn is_closed(&self) -> bool {
self.client.is_closed()
}
/// Closes the client's connection to the server.
///
/// This is equivalent to `Client`'s `Drop` implementation, except that it returns any error encountered to the
/// caller.
pub fn close(mut self) -> Result<(), Error> {
self.close_inner()
}
fn close_inner(&mut self) -> Result<(), Error> {
self.client.__private_api_close();
Ok(())
}
}

View File

@@ -272,16 +272,6 @@ mod py_tests {
anvil_child.kill().unwrap();
}
#[test]
fn postgres_notebook_() {
crate::py_tests::init_binary();
let test_dir: TempDir = TempDir::new("mean_postgres").unwrap();
let path = test_dir.path().to_str().unwrap();
crate::py_tests::mv_test_(path, "mean_postgres.ipynb");
run_notebook(path, "mean_postgres.ipynb");
test_dir.close().unwrap();
}
#[test]
fn tictactoe_autoencoder_notebook_() {
crate::py_tests::init_binary();