feat: adds rust hub bindings (#551)

---------

Co-authored-by: Alexander Camuto <45801863+alexander-camuto@users.noreply.github.com>
This commit is contained in:
samtvlabs
2023-10-17 03:54:01 +04:00
committed by GitHub
parent 5073eb906f
commit efe65f06a3
12 changed files with 858 additions and 23 deletions

View File

@@ -551,6 +551,8 @@ jobs:
# # now dump the contents of the file into a file called kaggle.json
# echo $KAGGLE_API_KEY > /home/ubuntu/.kaggle/kaggle.json
# chmod 600 /home/ubuntu/.kaggle/kaggle.json
- name: Simple hub demo
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_25_expects
- name: Hashed DA tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_24_expects
- name: Little transformer tutorial

58
Cargo.lock generated
View File

@@ -1826,6 +1826,7 @@ dependencies = [
"test-case",
"thiserror",
"tokio",
"tokio-util",
"tract-onnx",
"unzip-n",
"wasm-bindgen",
@@ -3611,9 +3612,9 @@ dependencies = [
[[package]]
name = "pin-project-lite"
version = "0.2.9"
version = "0.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116"
checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58"
[[package]]
name = "pin-utils"
@@ -4175,9 +4176,9 @@ dependencies = [
[[package]]
name = "reqwest"
version = "0.11.18"
version = "0.11.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55"
checksum = "046cd98826c46c2ac8ddecae268eb5c2e58628688a5fc7a2643704a73faba95b"
dependencies = [
"base64 0.21.2",
"bytes",
@@ -4193,6 +4194,7 @@ dependencies = [
"js-sys",
"log",
"mime",
"mime_guess",
"native-tls",
"once_cell",
"percent-encoding",
@@ -4200,12 +4202,15 @@ dependencies = [
"serde",
"serde_json",
"serde_urlencoded",
"system-configuration",
"tokio",
"tokio-native-tls",
"tokio-util",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
"winreg",
]
@@ -4971,6 +4976,27 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "system-configuration"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
dependencies = [
"bitflags 1.3.2",
"core-foundation",
"system-configuration-sys",
]
[[package]]
name = "system-configuration-sys"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "tabbycat"
version = "0.1.2"
@@ -5262,9 +5288,9 @@ dependencies = [
[[package]]
name = "tokio-util"
version = "0.7.8"
version = "0.7.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "806fe8c2c87eccc8b3267cbae29ed3ab2d0bd37fca70ab622e46aaa9375ddb7d"
checksum = "1d68074620f57a0b21594d9735eb2e98ab38b17f80d3fcb189fca266771ca60d"
dependencies = [
"bytes",
"futures-core",
@@ -5804,6 +5830,19 @@ dependencies = [
"quote",
]
[[package]]
name = "wasm-streams"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4609d447824375f43e1ffbc051b50ad8f4b3ae8219680c94452ea05eb240ac7"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "web-sys"
version = "0.3.64"
@@ -6018,11 +6057,12 @@ dependencies = [
[[package]]
name = "winreg"
version = "0.10.1"
version = "0.50.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d"
checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1"
dependencies = [
"winapi",
"cfg-if",
"windows-sys 0.48.0",
]
[[package]]

View File

@@ -39,7 +39,7 @@ ethers = { version = "2.0.7", default_features = false, features = ["ethers-solc
indicatif = {version = "0.17.5", features = ["rayon"]}
gag = { version = "1.0.0", default_features = false}
instant = { version = "0.1" }
reqwest = { version = "0.11.14", default-features = false, features = ["default-tls"] }
reqwest = { version = "0.11.14", default-features = false, features = ["default-tls", "multipart", "stream"] }
openssl = { version = "0.10.55", features = ["vendored"] }
postgres = "0.19.5"
pg_bigdecimal = "0.1.5"
@@ -48,6 +48,7 @@ colored_json = { version = "3.0.1", default_features = false, optional = true}
plotters = { version = "0.3.0", default_features = false, optional = true }
regex = { version = "1", default_features = false }
tokio = { version = "1.26.0", default_features = false, features = ["macros", "rt"] }
tokio-util = { version = "0.7.9", features = ["codec"] }
pyo3 = { version = "0.18.3", features = ["extension-module", "abi3-py37", "macros"], default_features = false, optional = true }
pyo3-asyncio = { version = "0.18.0", features = ["attributes", "tokio-runtime"], default_features = false, optional = true }
pyo3-log = { version = "0.8.1", default_features = false, optional = true }

View File

@@ -0,0 +1,236 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
"metadata": {},
"source": [
"## EZKL HUB Jupyter Notebook Demo \n",
"\n",
"Here we demonstrate the use of the EZKL hub in a Jupyter notebook whereby all components of the circuit are public or pre-committed to. This is the simplest case of using EZKL (proof of computation).\n",
"\n",
"This will be accomplished in 3 steps. \n",
"\n",
"1. Train the model. \n",
"2. Define our visibility settings. \n",
"3. Upload the model to the hub. \n",
"\n",
"That's it !"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95613ee9",
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"\n",
"# here we create and (potentially train a model)\n",
"\n",
"# make sure you have the dependencies required here already installed\n",
"from torch import nn\n",
"import ezkl\n",
"import os\n",
"import json\n",
"import torch\n",
"\n",
"\n",
"# Defines the model\n",
"# we got convs, we got relu, we got linear layers\n",
"# What else could one want ????\n",
"\n",
"class MyModel(nn.Module):\n",
" def __init__(self):\n",
" super(MyModel, self).__init__()\n",
"\n",
" self.conv1 = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=5, stride=2)\n",
" self.conv2 = nn.Conv2d(in_channels=2, out_channels=3, kernel_size=5, stride=2)\n",
"\n",
" self.relu = nn.ReLU()\n",
"\n",
" self.d1 = nn.Linear(48, 48)\n",
" self.d2 = nn.Linear(48, 10)\n",
"\n",
" def forward(self, x):\n",
" # 32x1x28x28 => 32x32x26x26\n",
" x = self.conv1(x)\n",
" x = self.relu(x)\n",
" x = self.conv2(x)\n",
" x = self.relu(x)\n",
"\n",
" # flatten => 32 x (32*26*26)\n",
" x = x.flatten(start_dim = 1)\n",
"\n",
" # 32 x (32*26*26) => 32x128\n",
" x = self.d1(x)\n",
" x = self.relu(x)\n",
"\n",
" # logits => 32x10\n",
" logits = self.d2(x)\n",
"\n",
" return logits\n",
"\n",
"\n",
"circuit = MyModel()\n",
"\n",
"# Train the model as you like here (skipped for brevity)\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b37637c4",
"metadata": {},
"outputs": [],
"source": [
"model_path = os.path.join('network.onnx')\n",
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82db373a",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"# After training, export to onnx (network.onnx) and create a data file (input.json)\n",
"x = 0.1*torch.rand(1,*[1, 28, 28], requires_grad=True)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
" # Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" model_path, # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
" # Serialize data into file:\n",
"json.dump( data, open(data_path, 'w' ))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5e374a2",
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"test_hub_name = \"samtvlabs\" #we've set this up for you, but you can create your own hub name and use that instead\n",
"\n",
"py_run_args = ezkl.PyRunArgs()\n",
"py_run_args.input_visibility = \"public\"\n",
"py_run_args.output_visibility = \"public\"\n",
"py_run_args.param_visibility = \"private\" # private by default\n",
"\n",
"\n",
"organization = ezkl.get_hub_credentials(test_hub_name)['organizations'][0]\n",
"\n",
"print(\"organization: \" + str(organization))\n",
"\n",
"deployed_model = ezkl.create_hub_artifact(model_path, data_path, \"my first model\", organization['id'], target=\"resources\", py_run_args=py_run_args)\n",
"\n",
"print(\"deployed model: \" + str(deployed_model))\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81201b32",
"metadata": {},
"outputs": [],
"source": [
"# sleep for a bit to make sure the model is deployed\n",
"time.sleep(20)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fcc44717",
"metadata": {},
"outputs": [],
"source": [
"\n",
"proof_id = ezkl.prove_hub(deployed_model['id'], data_path)\n",
"print(\"proof id: \" + str(proof_id))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5aa6a580",
"metadata": {},
"outputs": [],
"source": [
"# give prove_hub some time to finish\n",
"time.sleep(5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b9e2f32f",
"metadata": {},
"outputs": [],
"source": [
"proof = ezkl.get_hub_proof(proof_id['id'])\n",
"\n",
"print(\"proof: \" + str(proof))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -326,7 +326,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
{
cs.lookup("", |cs| {
let mut res = vec![];
let sel = cs.query_selector(multi_col_selector.clone());
let sel = cs.query_selector(multi_col_selector);
let synthetic_sel = match len {
1 => Expression::Constant(F::from(1)),
@@ -377,12 +377,12 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
(
col_expr.clone() * input_query.clone()
+ not_expr.clone() * Expression::Constant(default_x),
input_col.clone(),
*input_col,
),
(
col_expr.clone() * output_query.clone()
+ not_expr.clone() * Expression::Constant(default_y),
output_col.clone(),
*output_col,
),
]);

View File

@@ -71,6 +71,20 @@ impl Default for CalibrationTarget {
}
}
impl ToString for CalibrationTarget {
fn to_string(&self) -> String {
match self {
CalibrationTarget::Resources { col_overflow: true } => {
"resources/col-overflow".to_string()
}
CalibrationTarget::Resources {
col_overflow: false,
} => "resources".to_string(),
CalibrationTarget::Accuracy => "accuracy".to_string(),
}
}
}
impl From<&str> for CalibrationTarget {
fn from(s: &str) -> Self {
match s {
@@ -625,4 +639,70 @@ pub enum Commands {
#[arg(long)]
proof_path: PathBuf,
},
/// Gets credentials from the hub
#[command(name = "get-hub-credentials", arg_required_else_help = true)]
#[cfg(not(target_arch = "wasm32"))]
GetHubCredentials {
/// The path to the model file
#[arg(short = 'N', long)]
username: String,
/// The path to the input json file
#[arg(short = 'U', long)]
url: Option<String>,
},
/// Create artifacts and deploys them on the hub
#[command(name = "create-hub-artifact", arg_required_else_help = true)]
#[cfg(not(target_arch = "wasm32"))]
CreateHubArtifact {
/// The path to the model file
#[arg(short = 'M', long)]
uncompiled_circuit: PathBuf,
/// The path to the input json file
#[arg(short = 'D', long)]
data: PathBuf,
/// the hub's url
#[arg(short = 'O', long)]
organization_id: String,
///artifact name
#[arg(short = 'A', long)]
artifact_name: String,
/// the hub's url
#[arg(short = 'U', long)]
url: Option<String>,
/// proving arguments
#[clap(flatten)]
args: RunArgs,
/// calibration target
#[arg(long, default_value = "resources")]
target: CalibrationTarget,
},
/// Create artifacts and deploys them on the hub
#[command(name = "prove-hub", arg_required_else_help = true)]
#[cfg(not(target_arch = "wasm32"))]
ProveHub {
/// The path to the model file
#[arg(short = 'A', long)]
artifact_id: String,
/// The path to the input json file
#[arg(short = 'D', long)]
data: PathBuf,
#[arg(short = 'U', long)]
url: Option<String>,
#[arg(short = 'T', long)]
transcript_type: Option<String>,
},
/// Create artifacts and deploys them on the hub
#[command(name = "get-hub-proof", arg_required_else_help = true)]
#[cfg(not(target_arch = "wasm32"))]
GetHubProof {
/// The path to the model file
#[arg(short = 'A', long)]
artifact_id: String,
#[arg(short = 'U', long)]
url: Option<String>,
},
}

View File

@@ -65,6 +65,8 @@ use std::sync::OnceLock;
#[cfg(not(target_arch = "wasm32"))]
use std::time::Duration;
use thiserror::Error;
#[cfg(not(target_arch = "wasm32"))]
use tokio_util::codec::{BytesCodec, FramedRead};
#[cfg(not(target_arch = "wasm32"))]
static _SOLC_REQUIREMENT: OnceLock<bool> = OnceLock::new();
@@ -335,6 +337,51 @@ pub async fn run(cli: Cli) -> Result<(), Box<dyn Error>> {
addr_da,
} => verify_evm(proof_path, addr_verifier, rpc_url, addr_da).await,
Commands::PrintProofHex { proof_path } => print_proof_hex(proof_path),
#[cfg(not(target_arch = "wasm32"))]
Commands::GetHubCredentials { username, url } => {
get_hub_credentials(url.as_deref(), &username)
.await
.map(|_| ())
}
#[cfg(not(target_arch = "wasm32"))]
Commands::CreateHubArtifact {
uncompiled_circuit,
data,
organization_id,
artifact_name,
url,
args,
target,
} => deploy_model(
url.as_deref(),
&uncompiled_circuit,
&data,
&artifact_name,
&organization_id,
&args,
&target,
)
.await
.map(|_| ()),
#[cfg(not(target_arch = "wasm32"))]
Commands::GetHubProof { artifact_id, url } => get_hub_proof(url.as_deref(), &artifact_id)
.await
.map(|_| ()),
#[cfg(not(target_arch = "wasm32"))]
Commands::ProveHub {
artifact_id,
data,
transcript_type,
url,
} => prove_hub(
url.as_deref(),
&artifact_id,
&data,
transcript_type.as_deref(),
)
.await
.map(|_| ()),
}
}
@@ -1608,6 +1655,214 @@ pub(crate) fn verify_aggr(
Ok(())
}
/// Retrieves the user's credentials from the hub
pub(crate) async fn get_hub_credentials(
url: Option<&str>,
username: &str,
) -> Result<crate::hub::Organizations, Box<dyn Error>> {
let client = reqwest::Client::new();
let request_body = serde_json::json!({
"query": r#"
query GetOrganizationId($username: String!) {
organizations(name: $username) {
id
name
}
}
"#,
"variables": {
"username": username,
}
});
let url = url.unwrap_or("https://hub-staging.ezkl.xyz/graphql");
let response = client.post(url).json(&request_body).send().await?;
let response_body = response.json::<serde_json::Value>().await?;
let organizations: crate::hub::Organizations =
serde_json::from_value(response_body["data"].clone())?;
log::info!(
"Organization ID : {}",
organizations.as_json()?.to_colored_json_auto()?
);
Ok(organizations)
}
/// Deploy a model
pub(crate) async fn deploy_model(
url: Option<&str>,
model: &PathBuf,
input: &PathBuf,
name: &str,
organization_id: &str,
args: &RunArgs,
target: &CalibrationTarget,
) -> Result<crate::hub::Artifact, Box<dyn Error>> {
let model_file = tokio::fs::File::open(model.canonicalize()?).await?;
// read file body stream
let stream = FramedRead::new(model_file, BytesCodec::new());
let model_file_body = reqwest::Body::wrap_stream(stream);
let model_file = reqwest::multipart::Part::stream(model_file_body).file_name("uncompiledModel");
let input_file = tokio::fs::File::open(input.canonicalize()?).await?;
// read file body stream
let stream = FramedRead::new(input_file, BytesCodec::new());
let input_file_body = reqwest::Body::wrap_stream(stream);
//make form part of file
let input_file = reqwest::multipart::Part::stream(input_file_body).file_name("input");
// the graphql request map
let map = r#"{
"uncompiledModel": [
"variables.uncompiledModel"
],
"input": [
"variables.input"
]
}"#;
let operations = serde_json::json!({
"query": "mutation($uncompiledModel: Upload!, $input: Upload!, $organizationId: String!, $name: String!, $calibrationTarget: String!, $tolerance: Float!, $inputVisibility: String!, $outputVisibility: String!, $paramVisibility: String!) {
generateArtifact(
name: $name,
description: $name,
uncompiledModel: $uncompiledModel,
input: $input,
organizationId: $organizationId,
calibrationTarget: $calibrationTarget,
tolerance: $tolerance,
inputVisibility: $inputVisibility,
outputVisibility: $outputVisibility,
paramVisibility: $paramVisibility,
) {
artifact {
id
}
}
}",
"variables": {
"name": name,
"uncompiledModel": null,
"input": null,
"organizationId": organization_id,
"calibrationTarget": target.to_string(),
"tolerance": args.tolerance.val,
"inputVisibility": args.input_visibility.to_string(),
"outputVisibility": args.output_visibility.to_string(),
"paramVisibility": args.param_visibility.to_string(),
}
})
.to_string();
// now the form data
let mut form = reqwest::multipart::Form::new();
form = form
.text("operations", operations)
.text("map", map)
.part("uncompiledModel", model_file)
.part("input", input_file);
let client = reqwest::Client::new();
let url = url.unwrap_or("https://hub-staging.ezkl.xyz/graphql");
//send request
let response = client.post(url).multipart(form).send().await?;
let response_body = response.json::<serde_json::Value>().await?;
println!("{}", response_body.to_string());
let artifact_id: crate::hub::Artifact =
serde_json::from_value(response_body["data"]["generateArtifact"]["artifact"].clone())?;
log::info!(
"Artifact ID : {}",
artifact_id.as_json()?.to_colored_json_auto()?
);
Ok(artifact_id)
}
/// Generates proofs on the hub
pub async fn prove_hub(
url: Option<&str>,
id: &str,
input: &PathBuf,
transcript_type: Option<&str>,
) -> Result<crate::hub::Proof, Box<dyn std::error::Error>> {
let input_file = tokio::fs::File::open(input.canonicalize()?).await?;
let stream = FramedRead::new(input_file, BytesCodec::new());
let input_file_body = reqwest::Body::wrap_stream(stream);
let input_file = reqwest::multipart::Part::stream(input_file_body).file_name("input");
let map = r#"{
"input": [
"variables.input"
]
}"#;
let operations = serde_json::json!({
"query": r#"
mutation($input: Upload!, $id: String!, $transcriptType: String) {
initiateProof(input: $input, id: $id, transcriptType: $transcriptType) {
id
}
}
"#,
"variables": {
"input": null,
"id": id,
"transcriptType": transcript_type.unwrap_or("evm"),
}
})
.to_string();
let mut form = reqwest::multipart::Form::new();
form = form
.text("operations", operations)
.text("map", map)
.part("input", input_file);
let url = url.unwrap_or("https://hub-staging.ezkl.xyz/graphql");
let client = reqwest::Client::new();
let response = client.post(url).multipart(form).send().await?;
let response_body = response.json::<serde_json::Value>().await?;
let proof_id: crate::hub::Proof =
serde_json::from_value(response_body["data"]["initiateProof"].clone())?;
log::info!("Proof ID : {}", proof_id.as_json()?.to_colored_json_auto()?);
Ok(proof_id)
}
/// Fetches proofs from the hub
pub(crate) async fn get_hub_proof(
url: Option<&str>,
id: &str,
) -> Result<crate::hub::Proof, Box<dyn Error>> {
let client = reqwest::Client::new();
let request_body = serde_json::json!({
"query": format!(r#"
query {{
getProof(id: "{}") {{
id
artifact {{ id name }}
status
proof
instances
transcriptType
strategy
}}
}}
"#, id),
});
let url = url.unwrap_or("https://hub-staging.ezkl.xyz/graphql");
let response = client.post(url).json(&request_body).send().await?;
let response_body = response.json::<serde_json::Value>().await?;
let proof: crate::hub::Proof =
serde_json::from_value(response_body["data"]["getProof"].clone())?;
log::info!("Proof : {}", proof.as_json()?.to_colored_json_auto()?);
Ok(proof)
}
/// helper function for load_params
pub(crate) fn load_params_cmd(
srs_path: PathBuf,

143
src/hub.rs Normal file
View File

@@ -0,0 +1,143 @@
use serde::{Deserialize, Serialize};
/// Stores users organizations
#[derive(Serialize, Deserialize, Debug)]
pub struct Organization {
/// The organization id
pub id: String,
/// The users username
pub name: String,
}
impl Organization {
/// Export the organization as json
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
let serialized = match serde_json::to_string(&self) {
Ok(s) => s,
Err(e) => {
return Err(Box::new(e));
}
};
Ok(serialized)
}
}
/// Stores Organization
#[derive(Serialize, Deserialize, Debug)]
pub struct Organizations {
/// An Array of Organizations
pub organizations: Vec<Organization>,
}
impl Organizations {
/// Export the organizations as json
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
let serialized = match serde_json::to_string(&self) {
Ok(s) => s,
Err(e) => {
return Err(Box::new(e));
}
};
Ok(serialized)
}
}
/// Stores the Proof Response
#[derive(Debug, Deserialize, Serialize)]
pub struct Proof {
/// stores the artifact
pub artifact: Option<Artifact>,
/// stores the Proof Id
pub id: String,
/// stores the instances
pub instances: Option<Vec<String>>,
/// stores the proofs
pub proof: Option<String>,
/// stores the status
pub status: Option<String>,
///stores the strategy
pub strategy: Option<String>,
/// stores the transcript type
#[serde(rename = "transcriptType")]
pub transcript_type: Option<String>,
}
impl Proof {
/// Export the proof as json
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
let serialized = match serde_json::to_string(&self) {
Ok(s) => s,
Err(e) => {
return Err(Box::new(e));
}
};
Ok(serialized)
}
}
/// Stores the Artifacts
#[derive(Debug, Deserialize, Serialize)]
pub struct Artifact {
///stores the aritfact id
pub id: Option<String>,
/// stores the name of the artifact
pub name: Option<String>,
}
impl Artifact {
/// Export the artifact as json
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
let serialized = match serde_json::to_string(&self) {
Ok(s) => s,
Err(e) => {
return Err(Box::new(e));
}
};
Ok(serialized)
}
}
#[cfg(feature = "python-bindings")]
impl pyo3::ToPyObject for Artifact {
fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject {
let dict = pyo3::types::PyDict::new(py);
dict.set_item("id", &self.id).unwrap();
dict.set_item("name", &self.name).unwrap();
dict.into()
}
}
#[cfg(feature = "python-bindings")]
impl pyo3::ToPyObject for Proof {
fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject {
let dict = pyo3::types::PyDict::new(py);
dict.set_item("artifact", &self.artifact).unwrap();
dict.set_item("id", &self.id).unwrap();
dict.set_item("instances", &self.instances).unwrap();
dict.set_item("proof", &self.proof).unwrap();
dict.set_item("status", &self.status).unwrap();
dict.set_item("strategy", &self.strategy).unwrap();
dict.set_item("transcript_type", &self.transcript_type)
.unwrap();
dict.into()
}
}
#[cfg(feature = "python-bindings")]
impl pyo3::ToPyObject for Organizations {
fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject {
let dict = pyo3::types::PyDict::new(py);
dict.set_item("organizations", &self.organizations).unwrap();
dict.into()
}
}
#[cfg(feature = "python-bindings")]
impl pyo3::ToPyObject for Organization {
fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject {
let dict = pyo3::types::PyDict::new(py);
dict.set_item("id", &self.id).unwrap();
dict.set_item("name", &self.name).unwrap();
dict.into()
}
}

View File

@@ -51,6 +51,9 @@ pub mod fieldutils;
/// a Halo2 circuit.
#[cfg(feature = "onnx")]
pub mod graph;
/// Methods for deploying and interacting with the ezkl hub
#[cfg(not(target_arch = "wasm32"))]
pub mod hub;
/// beautiful logging
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
pub mod logger;

View File

@@ -1077,16 +1077,86 @@ fn print_proof_hex(proof_path: PathBuf) -> Result<String, PyErr> {
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)
.map_err(|_| PyIOError::new_err("Failed to load proof"))?;
// let mut return_string: String = "";
// for instance in proof.instances {
// return_string.push_str(instance + "\n");
// }
// return_string = hex::encode(proof.proof);
// return proof for now
Ok(hex::encode(proof.proof))
}
/// deploys a model to the hub
#[pyfunction(signature = (model, input, name, organization_id, target=None, py_run_args=None, url=None))]
fn create_hub_artifact(
model: PathBuf,
input: PathBuf,
name: String,
organization_id: String,
target: Option<CalibrationTarget>,
py_run_args: Option<PyRunArgs>,
url: Option<&str>,
) -> PyResult<PyObject> {
let run_args: RunArgs = py_run_args.unwrap_or_else(PyRunArgs::new).into();
let target = target.unwrap_or(CalibrationTarget::Resources {
col_overflow: false,
});
let output = Runtime::new()
.unwrap()
.block_on(crate::execute::deploy_model(
url,
&model,
&input,
&name,
&organization_id,
&run_args,
&target,
))
.map_err(|e| {
let err_str = format!("Failed to deploy model to hub: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Python::with_gil(|py| Ok(output.to_object(py)))
}
/// Generate a proof on the hub.
#[pyfunction(signature = (id, input, url=None, transcript_type=None))]
fn prove_hub(
id: &str,
input: PathBuf,
url: Option<&str>,
transcript_type: Option<&str>,
) -> PyResult<PyObject> {
let output = Runtime::new()
.unwrap()
.block_on(crate::execute::prove_hub(url, id, &input, transcript_type))
.map_err(|e| {
let err_str = format!("Failed to generate proof on hub: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Python::with_gil(|py| Ok(output.to_object(py)))
}
/// Fetches proof from hub
#[pyfunction(signature = (id, url=None))]
fn get_hub_proof(id: &str, url: Option<&str>) -> PyResult<PyObject> {
let output = Runtime::new()
.unwrap()
.block_on(crate::execute::get_hub_proof(url, id))
.map_err(|e| {
let err_str = format!("Failed to get proof from hub: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Python::with_gil(|py| Ok(output.to_object(py)))
}
/// Gets hub credentials
#[pyfunction(signature = (username, url=None))]
fn get_hub_credentials(username: &str, url: Option<&str>) -> PyResult<PyObject> {
let output = Runtime::new()
.unwrap()
.block_on(crate::execute::get_hub_credentials(url, username))
.map_err(|e| {
let err_str = format!("Failed to get hub credentials: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Python::with_gil(|py| Ok(output.to_object(py)))
}
// Python Module
#[pymodule]
fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
@@ -1130,6 +1200,10 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(print_proof_hex, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_verifier_aggr, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_data_attestation, m)?)?;
m.add_function(wrap_pyfunction!(create_hub_artifact, m)?)?;
m.add_function(wrap_pyfunction!(prove_hub, m)?)?;
m.add_function(wrap_pyfunction!(get_hub_proof, m)?)?;
m.add_function(wrap_pyfunction!(get_hub_credentials, m)?)?;
Ok(())
}

View File

@@ -450,7 +450,7 @@ impl<T: Clone + TensorType> Tensor<T> {
pub fn new(values: Option<&[T]>, dims: &[usize]) -> Result<Self, TensorError> {
let total_dims: usize = if !dims.is_empty() {
dims.iter().product()
} else if let Some(_) = values {
} else if values.is_some() {
1
} else {
0

View File

@@ -115,7 +115,7 @@ mod py_tests {
}
}
const TESTS: [&str; 25] = [
const TESTS: [&str; 26] = [
"mnist_gan.ipynb",
// "mnist_vae.ipynb",
"keras_simple_demo.ipynb",
@@ -142,6 +142,7 @@ mod py_tests {
"linear_regression.ipynb",
"stacked_regression.ipynb",
"data_attest_hashed.ipynb",
"simple_hub_demo.ipynb",
];
macro_rules! test_func {
@@ -154,7 +155,7 @@ mod py_tests {
use super::*;
seq!(N in 0..=24 {
seq!(N in 0..=25 {
#(#[test_case(TESTS[N])])*
fn run_notebook_(test: &str) {