mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-10 23:08:03 -05:00
feat: adds rust hub bindings (#551)
--------- Co-authored-by: Alexander Camuto <45801863+alexander-camuto@users.noreply.github.com>
This commit is contained in:
2
.github/workflows/rust.yml
vendored
2
.github/workflows/rust.yml
vendored
@@ -551,6 +551,8 @@ jobs:
|
||||
# # now dump the contents of the file into a file called kaggle.json
|
||||
# echo $KAGGLE_API_KEY > /home/ubuntu/.kaggle/kaggle.json
|
||||
# chmod 600 /home/ubuntu/.kaggle/kaggle.json
|
||||
- name: Simple hub demo
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_25_expects
|
||||
- name: Hashed DA tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_24_expects
|
||||
- name: Little transformer tutorial
|
||||
|
||||
58
Cargo.lock
generated
58
Cargo.lock
generated
@@ -1826,6 +1826,7 @@ dependencies = [
|
||||
"test-case",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tract-onnx",
|
||||
"unzip-n",
|
||||
"wasm-bindgen",
|
||||
@@ -3611,9 +3612,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.2.9"
|
||||
version = "0.2.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116"
|
||||
checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58"
|
||||
|
||||
[[package]]
|
||||
name = "pin-utils"
|
||||
@@ -4175,9 +4176,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.11.18"
|
||||
version = "0.11.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55"
|
||||
checksum = "046cd98826c46c2ac8ddecae268eb5c2e58628688a5fc7a2643704a73faba95b"
|
||||
dependencies = [
|
||||
"base64 0.21.2",
|
||||
"bytes",
|
||||
@@ -4193,6 +4194,7 @@ dependencies = [
|
||||
"js-sys",
|
||||
"log",
|
||||
"mime",
|
||||
"mime_guess",
|
||||
"native-tls",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
@@ -4200,12 +4202,15 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"system-configuration",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-util",
|
||||
"tower-service",
|
||||
"url",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"wasm-streams",
|
||||
"web-sys",
|
||||
"winreg",
|
||||
]
|
||||
@@ -4971,6 +4976,27 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "system-configuration"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"core-foundation",
|
||||
"system-configuration-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "system-configuration-sys"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tabbycat"
|
||||
version = "0.1.2"
|
||||
@@ -5262,9 +5288,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.7.8"
|
||||
version = "0.7.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "806fe8c2c87eccc8b3267cbae29ed3ab2d0bd37fca70ab622e46aaa9375ddb7d"
|
||||
checksum = "1d68074620f57a0b21594d9735eb2e98ab38b17f80d3fcb189fca266771ca60d"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
@@ -5804,6 +5830,19 @@ dependencies = [
|
||||
"quote",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-streams"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b4609d447824375f43e1ffbc051b50ad8f4b3ae8219680c94452ea05eb240ac7"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "web-sys"
|
||||
version = "0.3.64"
|
||||
@@ -6018,11 +6057,12 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "winreg"
|
||||
version = "0.10.1"
|
||||
version = "0.50.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d"
|
||||
checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
"cfg-if",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -39,7 +39,7 @@ ethers = { version = "2.0.7", default_features = false, features = ["ethers-solc
|
||||
indicatif = {version = "0.17.5", features = ["rayon"]}
|
||||
gag = { version = "1.0.0", default_features = false}
|
||||
instant = { version = "0.1" }
|
||||
reqwest = { version = "0.11.14", default-features = false, features = ["default-tls"] }
|
||||
reqwest = { version = "0.11.14", default-features = false, features = ["default-tls", "multipart", "stream"] }
|
||||
openssl = { version = "0.10.55", features = ["vendored"] }
|
||||
postgres = "0.19.5"
|
||||
pg_bigdecimal = "0.1.5"
|
||||
@@ -48,6 +48,7 @@ colored_json = { version = "3.0.1", default_features = false, optional = true}
|
||||
plotters = { version = "0.3.0", default_features = false, optional = true }
|
||||
regex = { version = "1", default_features = false }
|
||||
tokio = { version = "1.26.0", default_features = false, features = ["macros", "rt"] }
|
||||
tokio-util = { version = "0.7.9", features = ["codec"] }
|
||||
pyo3 = { version = "0.18.3", features = ["extension-module", "abi3-py37", "macros"], default_features = false, optional = true }
|
||||
pyo3-asyncio = { version = "0.18.0", features = ["attributes", "tokio-runtime"], default_features = false, optional = true }
|
||||
pyo3-log = { version = "0.8.1", default_features = false, optional = true }
|
||||
|
||||
236
examples/notebooks/simple_hub_demo.ipynb
Normal file
236
examples/notebooks/simple_hub_demo.ipynb
Normal file
@@ -0,0 +1,236 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## EZKL HUB Jupyter Notebook Demo \n",
|
||||
"\n",
|
||||
"Here we demonstrate the use of the EZKL hub in a Jupyter notebook whereby all components of the circuit are public or pre-committed to. This is the simplest case of using EZKL (proof of computation).\n",
|
||||
"\n",
|
||||
"This will be accomplished in 3 steps. \n",
|
||||
"\n",
|
||||
"1. Train the model. \n",
|
||||
"2. Define our visibility settings. \n",
|
||||
"3. Upload the model to the hub. \n",
|
||||
"\n",
|
||||
"That's it !"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "95613ee9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if notebook is in colab\n",
|
||||
"try:\n",
|
||||
" # install ezkl\n",
|
||||
" import google.colab\n",
|
||||
" import subprocess\n",
|
||||
" import sys\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
||||
"\n",
|
||||
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
||||
"except:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# here we create and (potentially train a model)\n",
|
||||
"\n",
|
||||
"# make sure you have the dependencies required here already installed\n",
|
||||
"from torch import nn\n",
|
||||
"import ezkl\n",
|
||||
"import os\n",
|
||||
"import json\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Defines the model\n",
|
||||
"# we got convs, we got relu, we got linear layers\n",
|
||||
"# What else could one want ????\n",
|
||||
"\n",
|
||||
"class MyModel(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(MyModel, self).__init__()\n",
|
||||
"\n",
|
||||
" self.conv1 = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=5, stride=2)\n",
|
||||
" self.conv2 = nn.Conv2d(in_channels=2, out_channels=3, kernel_size=5, stride=2)\n",
|
||||
"\n",
|
||||
" self.relu = nn.ReLU()\n",
|
||||
"\n",
|
||||
" self.d1 = nn.Linear(48, 48)\n",
|
||||
" self.d2 = nn.Linear(48, 10)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" # 32x1x28x28 => 32x32x26x26\n",
|
||||
" x = self.conv1(x)\n",
|
||||
" x = self.relu(x)\n",
|
||||
" x = self.conv2(x)\n",
|
||||
" x = self.relu(x)\n",
|
||||
"\n",
|
||||
" # flatten => 32 x (32*26*26)\n",
|
||||
" x = x.flatten(start_dim = 1)\n",
|
||||
"\n",
|
||||
" # 32 x (32*26*26) => 32x128\n",
|
||||
" x = self.d1(x)\n",
|
||||
" x = self.relu(x)\n",
|
||||
"\n",
|
||||
" # logits => 32x10\n",
|
||||
" logits = self.d2(x)\n",
|
||||
"\n",
|
||||
" return logits\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"circuit = MyModel()\n",
|
||||
"\n",
|
||||
"# Train the model as you like here (skipped for brevity)\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b37637c4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_path = os.path.join('network.onnx')\n",
|
||||
"data_path = os.path.join('input.json')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "82db373a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"# After training, export to onnx (network.onnx) and create a data file (input.json)\n",
|
||||
"x = 0.1*torch.rand(1,*[1, 28, 28], requires_grad=True)\n",
|
||||
"\n",
|
||||
"# Flips the neural net into inference mode\n",
|
||||
"circuit.eval()\n",
|
||||
"\n",
|
||||
" # Export the model\n",
|
||||
"torch.onnx.export(circuit, # model being run\n",
|
||||
" x, # model input (or a tuple for multiple inputs)\n",
|
||||
" model_path, # where to save the model (can be a file or file-like object)\n",
|
||||
" export_params=True, # store the trained parameter weights inside the model file\n",
|
||||
" opset_version=10, # the ONNX version to export the model to\n",
|
||||
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
||||
" input_names = ['input'], # the model's input names\n",
|
||||
" output_names = ['output'], # the model's output names\n",
|
||||
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
|
||||
" 'output' : {0 : 'batch_size'}})\n",
|
||||
"\n",
|
||||
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_data = [data_array])\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump( data, open(data_path, 'w' ))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d5e374a2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import time\n",
|
||||
"test_hub_name = \"samtvlabs\" #we've set this up for you, but you can create your own hub name and use that instead\n",
|
||||
"\n",
|
||||
"py_run_args = ezkl.PyRunArgs()\n",
|
||||
"py_run_args.input_visibility = \"public\"\n",
|
||||
"py_run_args.output_visibility = \"public\"\n",
|
||||
"py_run_args.param_visibility = \"private\" # private by default\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"organization = ezkl.get_hub_credentials(test_hub_name)['organizations'][0]\n",
|
||||
"\n",
|
||||
"print(\"organization: \" + str(organization))\n",
|
||||
"\n",
|
||||
"deployed_model = ezkl.create_hub_artifact(model_path, data_path, \"my first model\", organization['id'], target=\"resources\", py_run_args=py_run_args)\n",
|
||||
"\n",
|
||||
"print(\"deployed model: \" + str(deployed_model))\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "81201b32",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# sleep for a bit to make sure the model is deployed\n",
|
||||
"time.sleep(20)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fcc44717",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"proof_id = ezkl.prove_hub(deployed_model['id'], data_path)\n",
|
||||
"print(\"proof id: \" + str(proof_id))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5aa6a580",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# give prove_hub some time to finish\n",
|
||||
"time.sleep(5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b9e2f32f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"proof = ezkl.get_hub_proof(proof_id['id'])\n",
|
||||
"\n",
|
||||
"print(\"proof: \" + str(proof))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -326,7 +326,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
{
|
||||
cs.lookup("", |cs| {
|
||||
let mut res = vec![];
|
||||
let sel = cs.query_selector(multi_col_selector.clone());
|
||||
let sel = cs.query_selector(multi_col_selector);
|
||||
|
||||
let synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(1)),
|
||||
@@ -377,12 +377,12 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
(
|
||||
col_expr.clone() * input_query.clone()
|
||||
+ not_expr.clone() * Expression::Constant(default_x),
|
||||
input_col.clone(),
|
||||
*input_col,
|
||||
),
|
||||
(
|
||||
col_expr.clone() * output_query.clone()
|
||||
+ not_expr.clone() * Expression::Constant(default_y),
|
||||
output_col.clone(),
|
||||
*output_col,
|
||||
),
|
||||
]);
|
||||
|
||||
|
||||
@@ -71,6 +71,20 @@ impl Default for CalibrationTarget {
|
||||
}
|
||||
}
|
||||
|
||||
impl ToString for CalibrationTarget {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
CalibrationTarget::Resources { col_overflow: true } => {
|
||||
"resources/col-overflow".to_string()
|
||||
}
|
||||
CalibrationTarget::Resources {
|
||||
col_overflow: false,
|
||||
} => "resources".to_string(),
|
||||
CalibrationTarget::Accuracy => "accuracy".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for CalibrationTarget {
|
||||
fn from(s: &str) -> Self {
|
||||
match s {
|
||||
@@ -625,4 +639,70 @@ pub enum Commands {
|
||||
#[arg(long)]
|
||||
proof_path: PathBuf,
|
||||
},
|
||||
|
||||
/// Gets credentials from the hub
|
||||
#[command(name = "get-hub-credentials", arg_required_else_help = true)]
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
GetHubCredentials {
|
||||
/// The path to the model file
|
||||
#[arg(short = 'N', long)]
|
||||
username: String,
|
||||
/// The path to the input json file
|
||||
#[arg(short = 'U', long)]
|
||||
url: Option<String>,
|
||||
},
|
||||
|
||||
/// Create artifacts and deploys them on the hub
|
||||
#[command(name = "create-hub-artifact", arg_required_else_help = true)]
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
CreateHubArtifact {
|
||||
/// The path to the model file
|
||||
#[arg(short = 'M', long)]
|
||||
uncompiled_circuit: PathBuf,
|
||||
/// The path to the input json file
|
||||
#[arg(short = 'D', long)]
|
||||
data: PathBuf,
|
||||
/// the hub's url
|
||||
#[arg(short = 'O', long)]
|
||||
organization_id: String,
|
||||
///artifact name
|
||||
#[arg(short = 'A', long)]
|
||||
artifact_name: String,
|
||||
/// the hub's url
|
||||
#[arg(short = 'U', long)]
|
||||
url: Option<String>,
|
||||
/// proving arguments
|
||||
#[clap(flatten)]
|
||||
args: RunArgs,
|
||||
/// calibration target
|
||||
#[arg(long, default_value = "resources")]
|
||||
target: CalibrationTarget,
|
||||
},
|
||||
|
||||
/// Create artifacts and deploys them on the hub
|
||||
#[command(name = "prove-hub", arg_required_else_help = true)]
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
ProveHub {
|
||||
/// The path to the model file
|
||||
#[arg(short = 'A', long)]
|
||||
artifact_id: String,
|
||||
/// The path to the input json file
|
||||
#[arg(short = 'D', long)]
|
||||
data: PathBuf,
|
||||
#[arg(short = 'U', long)]
|
||||
url: Option<String>,
|
||||
#[arg(short = 'T', long)]
|
||||
transcript_type: Option<String>,
|
||||
},
|
||||
|
||||
/// Create artifacts and deploys them on the hub
|
||||
#[command(name = "get-hub-proof", arg_required_else_help = true)]
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
GetHubProof {
|
||||
/// The path to the model file
|
||||
#[arg(short = 'A', long)]
|
||||
artifact_id: String,
|
||||
#[arg(short = 'U', long)]
|
||||
url: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
255
src/execute.rs
255
src/execute.rs
@@ -65,6 +65,8 @@ use std::sync::OnceLock;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use tokio_util::codec::{BytesCodec, FramedRead};
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
static _SOLC_REQUIREMENT: OnceLock<bool> = OnceLock::new();
|
||||
@@ -335,6 +337,51 @@ pub async fn run(cli: Cli) -> Result<(), Box<dyn Error>> {
|
||||
addr_da,
|
||||
} => verify_evm(proof_path, addr_verifier, rpc_url, addr_da).await,
|
||||
Commands::PrintProofHex { proof_path } => print_proof_hex(proof_path),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::GetHubCredentials { username, url } => {
|
||||
get_hub_credentials(url.as_deref(), &username)
|
||||
.await
|
||||
.map(|_| ())
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::CreateHubArtifact {
|
||||
uncompiled_circuit,
|
||||
data,
|
||||
organization_id,
|
||||
artifact_name,
|
||||
url,
|
||||
args,
|
||||
target,
|
||||
} => deploy_model(
|
||||
url.as_deref(),
|
||||
&uncompiled_circuit,
|
||||
&data,
|
||||
&artifact_name,
|
||||
&organization_id,
|
||||
&args,
|
||||
&target,
|
||||
)
|
||||
.await
|
||||
.map(|_| ()),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::GetHubProof { artifact_id, url } => get_hub_proof(url.as_deref(), &artifact_id)
|
||||
.await
|
||||
.map(|_| ()),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::ProveHub {
|
||||
artifact_id,
|
||||
data,
|
||||
transcript_type,
|
||||
url,
|
||||
} => prove_hub(
|
||||
url.as_deref(),
|
||||
&artifact_id,
|
||||
&data,
|
||||
transcript_type.as_deref(),
|
||||
)
|
||||
.await
|
||||
.map(|_| ()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1608,6 +1655,214 @@ pub(crate) fn verify_aggr(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Retrieves the user's credentials from the hub
|
||||
pub(crate) async fn get_hub_credentials(
|
||||
url: Option<&str>,
|
||||
username: &str,
|
||||
) -> Result<crate::hub::Organizations, Box<dyn Error>> {
|
||||
let client = reqwest::Client::new();
|
||||
let request_body = serde_json::json!({
|
||||
"query": r#"
|
||||
query GetOrganizationId($username: String!) {
|
||||
organizations(name: $username) {
|
||||
id
|
||||
name
|
||||
}
|
||||
}
|
||||
"#,
|
||||
"variables": {
|
||||
"username": username,
|
||||
}
|
||||
});
|
||||
let url = url.unwrap_or("https://hub-staging.ezkl.xyz/graphql");
|
||||
|
||||
let response = client.post(url).json(&request_body).send().await?;
|
||||
let response_body = response.json::<serde_json::Value>().await?;
|
||||
|
||||
let organizations: crate::hub::Organizations =
|
||||
serde_json::from_value(response_body["data"].clone())?;
|
||||
|
||||
log::info!(
|
||||
"Organization ID : {}",
|
||||
organizations.as_json()?.to_colored_json_auto()?
|
||||
);
|
||||
Ok(organizations)
|
||||
}
|
||||
|
||||
/// Deploy a model
|
||||
pub(crate) async fn deploy_model(
|
||||
url: Option<&str>,
|
||||
model: &PathBuf,
|
||||
input: &PathBuf,
|
||||
name: &str,
|
||||
organization_id: &str,
|
||||
args: &RunArgs,
|
||||
target: &CalibrationTarget,
|
||||
) -> Result<crate::hub::Artifact, Box<dyn Error>> {
|
||||
let model_file = tokio::fs::File::open(model.canonicalize()?).await?;
|
||||
// read file body stream
|
||||
let stream = FramedRead::new(model_file, BytesCodec::new());
|
||||
let model_file_body = reqwest::Body::wrap_stream(stream);
|
||||
|
||||
let model_file = reqwest::multipart::Part::stream(model_file_body).file_name("uncompiledModel");
|
||||
|
||||
let input_file = tokio::fs::File::open(input.canonicalize()?).await?;
|
||||
// read file body stream
|
||||
let stream = FramedRead::new(input_file, BytesCodec::new());
|
||||
let input_file_body = reqwest::Body::wrap_stream(stream);
|
||||
|
||||
//make form part of file
|
||||
let input_file = reqwest::multipart::Part::stream(input_file_body).file_name("input");
|
||||
|
||||
// the graphql request map
|
||||
let map = r#"{
|
||||
"uncompiledModel": [
|
||||
"variables.uncompiledModel"
|
||||
],
|
||||
"input": [
|
||||
"variables.input"
|
||||
]
|
||||
}"#;
|
||||
|
||||
let operations = serde_json::json!({
|
||||
"query": "mutation($uncompiledModel: Upload!, $input: Upload!, $organizationId: String!, $name: String!, $calibrationTarget: String!, $tolerance: Float!, $inputVisibility: String!, $outputVisibility: String!, $paramVisibility: String!) {
|
||||
generateArtifact(
|
||||
name: $name,
|
||||
description: $name,
|
||||
uncompiledModel: $uncompiledModel,
|
||||
input: $input,
|
||||
organizationId: $organizationId,
|
||||
calibrationTarget: $calibrationTarget,
|
||||
tolerance: $tolerance,
|
||||
inputVisibility: $inputVisibility,
|
||||
outputVisibility: $outputVisibility,
|
||||
paramVisibility: $paramVisibility,
|
||||
) {
|
||||
artifact {
|
||||
id
|
||||
}
|
||||
}
|
||||
}",
|
||||
"variables": {
|
||||
"name": name,
|
||||
"uncompiledModel": null,
|
||||
"input": null,
|
||||
"organizationId": organization_id,
|
||||
"calibrationTarget": target.to_string(),
|
||||
"tolerance": args.tolerance.val,
|
||||
"inputVisibility": args.input_visibility.to_string(),
|
||||
"outputVisibility": args.output_visibility.to_string(),
|
||||
"paramVisibility": args.param_visibility.to_string(),
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
// now the form data
|
||||
let mut form = reqwest::multipart::Form::new();
|
||||
form = form
|
||||
.text("operations", operations)
|
||||
.text("map", map)
|
||||
.part("uncompiledModel", model_file)
|
||||
.part("input", input_file);
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let url = url.unwrap_or("https://hub-staging.ezkl.xyz/graphql");
|
||||
//send request
|
||||
let response = client.post(url).multipart(form).send().await?;
|
||||
let response_body = response.json::<serde_json::Value>().await?;
|
||||
println!("{}", response_body.to_string());
|
||||
let artifact_id: crate::hub::Artifact =
|
||||
serde_json::from_value(response_body["data"]["generateArtifact"]["artifact"].clone())?;
|
||||
log::info!(
|
||||
"Artifact ID : {}",
|
||||
artifact_id.as_json()?.to_colored_json_auto()?
|
||||
);
|
||||
Ok(artifact_id)
|
||||
}
|
||||
|
||||
/// Generates proofs on the hub
|
||||
pub async fn prove_hub(
|
||||
url: Option<&str>,
|
||||
id: &str,
|
||||
input: &PathBuf,
|
||||
transcript_type: Option<&str>,
|
||||
) -> Result<crate::hub::Proof, Box<dyn std::error::Error>> {
|
||||
let input_file = tokio::fs::File::open(input.canonicalize()?).await?;
|
||||
let stream = FramedRead::new(input_file, BytesCodec::new());
|
||||
let input_file_body = reqwest::Body::wrap_stream(stream);
|
||||
|
||||
let input_file = reqwest::multipart::Part::stream(input_file_body).file_name("input");
|
||||
|
||||
let map = r#"{
|
||||
"input": [
|
||||
"variables.input"
|
||||
]
|
||||
}"#;
|
||||
|
||||
let operations = serde_json::json!({
|
||||
"query": r#"
|
||||
mutation($input: Upload!, $id: String!, $transcriptType: String) {
|
||||
initiateProof(input: $input, id: $id, transcriptType: $transcriptType) {
|
||||
id
|
||||
}
|
||||
}
|
||||
"#,
|
||||
"variables": {
|
||||
"input": null,
|
||||
"id": id,
|
||||
"transcriptType": transcript_type.unwrap_or("evm"),
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let mut form = reqwest::multipart::Form::new();
|
||||
form = form
|
||||
.text("operations", operations)
|
||||
.text("map", map)
|
||||
.part("input", input_file);
|
||||
let url = url.unwrap_or("https://hub-staging.ezkl.xyz/graphql");
|
||||
let client = reqwest::Client::new();
|
||||
let response = client.post(url).multipart(form).send().await?;
|
||||
let response_body = response.json::<serde_json::Value>().await?;
|
||||
let proof_id: crate::hub::Proof =
|
||||
serde_json::from_value(response_body["data"]["initiateProof"].clone())?;
|
||||
log::info!("Proof ID : {}", proof_id.as_json()?.to_colored_json_auto()?);
|
||||
Ok(proof_id)
|
||||
}
|
||||
|
||||
/// Fetches proofs from the hub
|
||||
pub(crate) async fn get_hub_proof(
|
||||
url: Option<&str>,
|
||||
id: &str,
|
||||
) -> Result<crate::hub::Proof, Box<dyn Error>> {
|
||||
let client = reqwest::Client::new();
|
||||
let request_body = serde_json::json!({
|
||||
"query": format!(r#"
|
||||
query {{
|
||||
getProof(id: "{}") {{
|
||||
id
|
||||
artifact {{ id name }}
|
||||
status
|
||||
proof
|
||||
instances
|
||||
transcriptType
|
||||
strategy
|
||||
}}
|
||||
}}
|
||||
"#, id),
|
||||
});
|
||||
let url = url.unwrap_or("https://hub-staging.ezkl.xyz/graphql");
|
||||
|
||||
let response = client.post(url).json(&request_body).send().await?;
|
||||
let response_body = response.json::<serde_json::Value>().await?;
|
||||
|
||||
let proof: crate::hub::Proof =
|
||||
serde_json::from_value(response_body["data"]["getProof"].clone())?;
|
||||
|
||||
log::info!("Proof : {}", proof.as_json()?.to_colored_json_auto()?);
|
||||
Ok(proof)
|
||||
}
|
||||
|
||||
/// helper function for load_params
|
||||
pub(crate) fn load_params_cmd(
|
||||
srs_path: PathBuf,
|
||||
|
||||
143
src/hub.rs
Normal file
143
src/hub.rs
Normal file
@@ -0,0 +1,143 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Stores users organizations
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Organization {
|
||||
/// The organization id
|
||||
pub id: String,
|
||||
/// The users username
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
impl Organization {
|
||||
/// Export the organization as json
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
};
|
||||
Ok(serialized)
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores Organization
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Organizations {
|
||||
/// An Array of Organizations
|
||||
pub organizations: Vec<Organization>,
|
||||
}
|
||||
|
||||
impl Organizations {
|
||||
/// Export the organizations as json
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
};
|
||||
Ok(serialized)
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores the Proof Response
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct Proof {
|
||||
/// stores the artifact
|
||||
pub artifact: Option<Artifact>,
|
||||
/// stores the Proof Id
|
||||
pub id: String,
|
||||
/// stores the instances
|
||||
pub instances: Option<Vec<String>>,
|
||||
/// stores the proofs
|
||||
pub proof: Option<String>,
|
||||
/// stores the status
|
||||
pub status: Option<String>,
|
||||
///stores the strategy
|
||||
pub strategy: Option<String>,
|
||||
/// stores the transcript type
|
||||
#[serde(rename = "transcriptType")]
|
||||
pub transcript_type: Option<String>,
|
||||
}
|
||||
|
||||
impl Proof {
|
||||
/// Export the proof as json
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
};
|
||||
Ok(serialized)
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores the Artifacts
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct Artifact {
|
||||
///stores the aritfact id
|
||||
pub id: Option<String>,
|
||||
/// stores the name of the artifact
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
impl Artifact {
|
||||
/// Export the artifact as json
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
};
|
||||
Ok(serialized)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl pyo3::ToPyObject for Artifact {
|
||||
fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject {
|
||||
let dict = pyo3::types::PyDict::new(py);
|
||||
dict.set_item("id", &self.id).unwrap();
|
||||
dict.set_item("name", &self.name).unwrap();
|
||||
dict.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl pyo3::ToPyObject for Proof {
|
||||
fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject {
|
||||
let dict = pyo3::types::PyDict::new(py);
|
||||
dict.set_item("artifact", &self.artifact).unwrap();
|
||||
dict.set_item("id", &self.id).unwrap();
|
||||
dict.set_item("instances", &self.instances).unwrap();
|
||||
dict.set_item("proof", &self.proof).unwrap();
|
||||
dict.set_item("status", &self.status).unwrap();
|
||||
dict.set_item("strategy", &self.strategy).unwrap();
|
||||
dict.set_item("transcript_type", &self.transcript_type)
|
||||
.unwrap();
|
||||
dict.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl pyo3::ToPyObject for Organizations {
|
||||
fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject {
|
||||
let dict = pyo3::types::PyDict::new(py);
|
||||
dict.set_item("organizations", &self.organizations).unwrap();
|
||||
dict.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl pyo3::ToPyObject for Organization {
|
||||
fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject {
|
||||
let dict = pyo3::types::PyDict::new(py);
|
||||
dict.set_item("id", &self.id).unwrap();
|
||||
dict.set_item("name", &self.name).unwrap();
|
||||
dict.into()
|
||||
}
|
||||
}
|
||||
@@ -51,6 +51,9 @@ pub mod fieldutils;
|
||||
/// a Halo2 circuit.
|
||||
#[cfg(feature = "onnx")]
|
||||
pub mod graph;
|
||||
/// Methods for deploying and interacting with the ezkl hub
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub mod hub;
|
||||
/// beautiful logging
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
pub mod logger;
|
||||
|
||||
@@ -1077,16 +1077,86 @@ fn print_proof_hex(proof_path: PathBuf) -> Result<String, PyErr> {
|
||||
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)
|
||||
.map_err(|_| PyIOError::new_err("Failed to load proof"))?;
|
||||
|
||||
// let mut return_string: String = "";
|
||||
// for instance in proof.instances {
|
||||
// return_string.push_str(instance + "\n");
|
||||
// }
|
||||
// return_string = hex::encode(proof.proof);
|
||||
|
||||
// return proof for now
|
||||
Ok(hex::encode(proof.proof))
|
||||
}
|
||||
|
||||
/// deploys a model to the hub
|
||||
#[pyfunction(signature = (model, input, name, organization_id, target=None, py_run_args=None, url=None))]
|
||||
fn create_hub_artifact(
|
||||
model: PathBuf,
|
||||
input: PathBuf,
|
||||
name: String,
|
||||
organization_id: String,
|
||||
target: Option<CalibrationTarget>,
|
||||
py_run_args: Option<PyRunArgs>,
|
||||
url: Option<&str>,
|
||||
) -> PyResult<PyObject> {
|
||||
let run_args: RunArgs = py_run_args.unwrap_or_else(PyRunArgs::new).into();
|
||||
let target = target.unwrap_or(CalibrationTarget::Resources {
|
||||
col_overflow: false,
|
||||
});
|
||||
let output = Runtime::new()
|
||||
.unwrap()
|
||||
.block_on(crate::execute::deploy_model(
|
||||
url,
|
||||
&model,
|
||||
&input,
|
||||
&name,
|
||||
&organization_id,
|
||||
&run_args,
|
||||
&target,
|
||||
))
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to deploy model to hub: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
Python::with_gil(|py| Ok(output.to_object(py)))
|
||||
}
|
||||
|
||||
/// Generate a proof on the hub.
|
||||
#[pyfunction(signature = (id, input, url=None, transcript_type=None))]
|
||||
fn prove_hub(
|
||||
id: &str,
|
||||
input: PathBuf,
|
||||
url: Option<&str>,
|
||||
transcript_type: Option<&str>,
|
||||
) -> PyResult<PyObject> {
|
||||
let output = Runtime::new()
|
||||
.unwrap()
|
||||
.block_on(crate::execute::prove_hub(url, id, &input, transcript_type))
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to generate proof on hub: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
Python::with_gil(|py| Ok(output.to_object(py)))
|
||||
}
|
||||
|
||||
/// Fetches proof from hub
|
||||
#[pyfunction(signature = (id, url=None))]
|
||||
fn get_hub_proof(id: &str, url: Option<&str>) -> PyResult<PyObject> {
|
||||
let output = Runtime::new()
|
||||
.unwrap()
|
||||
.block_on(crate::execute::get_hub_proof(url, id))
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to get proof from hub: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
Python::with_gil(|py| Ok(output.to_object(py)))
|
||||
}
|
||||
|
||||
/// Gets hub credentials
|
||||
#[pyfunction(signature = (username, url=None))]
|
||||
fn get_hub_credentials(username: &str, url: Option<&str>) -> PyResult<PyObject> {
|
||||
let output = Runtime::new()
|
||||
.unwrap()
|
||||
.block_on(crate::execute::get_hub_credentials(url, username))
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to get hub credentials: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
Python::with_gil(|py| Ok(output.to_object(py)))
|
||||
}
|
||||
|
||||
// Python Module
|
||||
#[pymodule]
|
||||
fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
@@ -1130,6 +1200,10 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(print_proof_hex, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_verifier_aggr, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_data_attestation, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_hub_artifact, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(prove_hub, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(get_hub_proof, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(get_hub_credentials, m)?)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -450,7 +450,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
pub fn new(values: Option<&[T]>, dims: &[usize]) -> Result<Self, TensorError> {
|
||||
let total_dims: usize = if !dims.is_empty() {
|
||||
dims.iter().product()
|
||||
} else if let Some(_) = values {
|
||||
} else if values.is_some() {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
|
||||
@@ -115,7 +115,7 @@ mod py_tests {
|
||||
}
|
||||
}
|
||||
|
||||
const TESTS: [&str; 25] = [
|
||||
const TESTS: [&str; 26] = [
|
||||
"mnist_gan.ipynb",
|
||||
// "mnist_vae.ipynb",
|
||||
"keras_simple_demo.ipynb",
|
||||
@@ -142,6 +142,7 @@ mod py_tests {
|
||||
"linear_regression.ipynb",
|
||||
"stacked_regression.ipynb",
|
||||
"data_attest_hashed.ipynb",
|
||||
"simple_hub_demo.ipynb",
|
||||
];
|
||||
|
||||
macro_rules! test_func {
|
||||
@@ -154,7 +155,7 @@ mod py_tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
seq!(N in 0..=24 {
|
||||
seq!(N in 0..=25 {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn run_notebook_(test: &str) {
|
||||
|
||||
Reference in New Issue
Block a user