diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 4096f623..66fcc2e9 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -551,6 +551,8 @@ jobs: # # now dump the contents of the file into a file called kaggle.json # echo $KAGGLE_API_KEY > /home/ubuntu/.kaggle/kaggle.json # chmod 600 /home/ubuntu/.kaggle/kaggle.json + - name: Simple hub demo + run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_25_expects - name: Hashed DA tutorial run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_24_expects - name: Little transformer tutorial diff --git a/Cargo.lock b/Cargo.lock index 5a3d8ded..f5955596 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1826,6 +1826,7 @@ dependencies = [ "test-case", "thiserror", "tokio", + "tokio-util", "tract-onnx", "unzip-n", "wasm-bindgen", @@ -3611,9 +3612,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.9" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" [[package]] name = "pin-utils" @@ -4175,9 +4176,9 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.11.18" +version = "0.11.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55" +checksum = "046cd98826c46c2ac8ddecae268eb5c2e58628688a5fc7a2643704a73faba95b" dependencies = [ "base64 0.21.2", "bytes", @@ -4193,6 +4194,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "native-tls", "once_cell", "percent-encoding", @@ -4200,12 +4202,15 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", + "system-configuration", "tokio", "tokio-native-tls", + "tokio-util", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "winreg", ] @@ -4971,6 +4976,27 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tabbycat" version = "0.1.2" @@ -5262,9 +5288,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.8" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "806fe8c2c87eccc8b3267cbae29ed3ab2d0bd37fca70ab622e46aaa9375ddb7d" +checksum = "1d68074620f57a0b21594d9735eb2e98ab38b17f80d3fcb189fca266771ca60d" dependencies = [ "bytes", "futures-core", @@ -5804,6 +5830,19 @@ dependencies = [ "quote", ] +[[package]] +name = "wasm-streams" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4609d447824375f43e1ffbc051b50ad8f4b3ae8219680c94452ea05eb240ac7" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.64" @@ -6018,11 +6057,12 @@ dependencies = [ [[package]] name = "winreg" -version = "0.10.1" +version = "0.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" +checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" dependencies = [ - "winapi", + "cfg-if", + "windows-sys 0.48.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index e9a36bc3..f6a0f8ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,7 @@ ethers = { version = "2.0.7", default_features = false, features = ["ethers-solc indicatif = {version = "0.17.5", features = ["rayon"]} gag = { version = "1.0.0", default_features = false} instant = { version = "0.1" } -reqwest = { version = "0.11.14", default-features = false, features = ["default-tls"] } +reqwest = { version = "0.11.14", default-features = false, features = ["default-tls", "multipart", "stream"] } openssl = { version = "0.10.55", features = ["vendored"] } postgres = "0.19.5" pg_bigdecimal = "0.1.5" @@ -48,6 +48,7 @@ colored_json = { version = "3.0.1", default_features = false, optional = true} plotters = { version = "0.3.0", default_features = false, optional = true } regex = { version = "1", default_features = false } tokio = { version = "1.26.0", default_features = false, features = ["macros", "rt"] } +tokio-util = { version = "0.7.9", features = ["codec"] } pyo3 = { version = "0.18.3", features = ["extension-module", "abi3-py37", "macros"], default_features = false, optional = true } pyo3-asyncio = { version = "0.18.0", features = ["attributes", "tokio-runtime"], default_features = false, optional = true } pyo3-log = { version = "0.8.1", default_features = false, optional = true } diff --git a/examples/notebooks/simple_hub_demo.ipynb b/examples/notebooks/simple_hub_demo.ipynb new file mode 100644 index 00000000..9b634a51 --- /dev/null +++ b/examples/notebooks/simple_hub_demo.ipynb @@ -0,0 +1,236 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67", + "metadata": {}, + "source": [ + "## EZKL HUB Jupyter Notebook Demo \n", + "\n", + "Here we demonstrate the use of the EZKL hub in a Jupyter notebook whereby all components of the circuit are public or pre-committed to. This is the simplest case of using EZKL (proof of computation).\n", + "\n", + "This will be accomplished in 3 steps. \n", + "\n", + "1. Train the model. \n", + "2. Define our visibility settings. \n", + "3. Upload the model to the hub. \n", + "\n", + "That's it !" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "95613ee9", + "metadata": {}, + "outputs": [], + "source": [ + "# check if notebook is in colab\n", + "try:\n", + " # install ezkl\n", + " import google.colab\n", + " import subprocess\n", + " import sys\n", + " subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n", + " subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n", + "\n", + "# rely on local installation of ezkl if the notebook is not in colab\n", + "except:\n", + " pass\n", + "\n", + "\n", + "# here we create and (potentially train a model)\n", + "\n", + "# make sure you have the dependencies required here already installed\n", + "from torch import nn\n", + "import ezkl\n", + "import os\n", + "import json\n", + "import torch\n", + "\n", + "\n", + "# Defines the model\n", + "# we got convs, we got relu, we got linear layers\n", + "# What else could one want ????\n", + "\n", + "class MyModel(nn.Module):\n", + " def __init__(self):\n", + " super(MyModel, self).__init__()\n", + "\n", + " self.conv1 = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=5, stride=2)\n", + " self.conv2 = nn.Conv2d(in_channels=2, out_channels=3, kernel_size=5, stride=2)\n", + "\n", + " self.relu = nn.ReLU()\n", + "\n", + " self.d1 = nn.Linear(48, 48)\n", + " self.d2 = nn.Linear(48, 10)\n", + "\n", + " def forward(self, x):\n", + " # 32x1x28x28 => 32x32x26x26\n", + " x = self.conv1(x)\n", + " x = self.relu(x)\n", + " x = self.conv2(x)\n", + " x = self.relu(x)\n", + "\n", + " # flatten => 32 x (32*26*26)\n", + " x = x.flatten(start_dim = 1)\n", + "\n", + " # 32 x (32*26*26) => 32x128\n", + " x = self.d1(x)\n", + " x = self.relu(x)\n", + "\n", + " # logits => 32x10\n", + " logits = self.d2(x)\n", + "\n", + " return logits\n", + "\n", + "\n", + "circuit = MyModel()\n", + "\n", + "# Train the model as you like here (skipped for brevity)\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b37637c4", + "metadata": {}, + "outputs": [], + "source": [ + "model_path = os.path.join('network.onnx')\n", + "data_path = os.path.join('input.json')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "82db373a", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "# After training, export to onnx (network.onnx) and create a data file (input.json)\n", + "x = 0.1*torch.rand(1,*[1, 28, 28], requires_grad=True)\n", + "\n", + "# Flips the neural net into inference mode\n", + "circuit.eval()\n", + "\n", + " # Export the model\n", + "torch.onnx.export(circuit, # model being run\n", + " x, # model input (or a tuple for multiple inputs)\n", + " model_path, # where to save the model (can be a file or file-like object)\n", + " export_params=True, # store the trained parameter weights inside the model file\n", + " opset_version=10, # the ONNX version to export the model to\n", + " do_constant_folding=True, # whether to execute constant folding for optimization\n", + " input_names = ['input'], # the model's input names\n", + " output_names = ['output'], # the model's output names\n", + " dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n", + " 'output' : {0 : 'batch_size'}})\n", + "\n", + "data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_data = [data_array])\n", + "\n", + " # Serialize data into file:\n", + "json.dump( data, open(data_path, 'w' ))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5e374a2", + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "test_hub_name = \"samtvlabs\" #we've set this up for you, but you can create your own hub name and use that instead\n", + "\n", + "py_run_args = ezkl.PyRunArgs()\n", + "py_run_args.input_visibility = \"public\"\n", + "py_run_args.output_visibility = \"public\"\n", + "py_run_args.param_visibility = \"private\" # private by default\n", + "\n", + "\n", + "organization = ezkl.get_hub_credentials(test_hub_name)['organizations'][0]\n", + "\n", + "print(\"organization: \" + str(organization))\n", + "\n", + "deployed_model = ezkl.create_hub_artifact(model_path, data_path, \"my first model\", organization['id'], target=\"resources\", py_run_args=py_run_args)\n", + "\n", + "print(\"deployed model: \" + str(deployed_model))\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "81201b32", + "metadata": {}, + "outputs": [], + "source": [ + "# sleep for a bit to make sure the model is deployed\n", + "time.sleep(20)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fcc44717", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "proof_id = ezkl.prove_hub(deployed_model['id'], data_path)\n", + "print(\"proof id: \" + str(proof_id))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5aa6a580", + "metadata": {}, + "outputs": [], + "source": [ + "# give prove_hub some time to finish\n", + "time.sleep(5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b9e2f32f", + "metadata": {}, + "outputs": [], + "source": [ + "proof = ezkl.get_hub_proof(proof_id['id'])\n", + "\n", + "print(\"proof: \" + str(proof))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/circuit/ops/chip.rs b/src/circuit/ops/chip.rs index cc410ba1..307d2422 100644 --- a/src/circuit/ops/chip.rs +++ b/src/circuit/ops/chip.rs @@ -326,7 +326,7 @@ impl BaseConfig { { cs.lookup("", |cs| { let mut res = vec![]; - let sel = cs.query_selector(multi_col_selector.clone()); + let sel = cs.query_selector(multi_col_selector); let synthetic_sel = match len { 1 => Expression::Constant(F::from(1)), @@ -377,12 +377,12 @@ impl BaseConfig { ( col_expr.clone() * input_query.clone() + not_expr.clone() * Expression::Constant(default_x), - input_col.clone(), + *input_col, ), ( col_expr.clone() * output_query.clone() + not_expr.clone() * Expression::Constant(default_y), - output_col.clone(), + *output_col, ), ]); diff --git a/src/commands.rs b/src/commands.rs index 6efb028d..a987239f 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -71,6 +71,20 @@ impl Default for CalibrationTarget { } } +impl ToString for CalibrationTarget { + fn to_string(&self) -> String { + match self { + CalibrationTarget::Resources { col_overflow: true } => { + "resources/col-overflow".to_string() + } + CalibrationTarget::Resources { + col_overflow: false, + } => "resources".to_string(), + CalibrationTarget::Accuracy => "accuracy".to_string(), + } + } +} + impl From<&str> for CalibrationTarget { fn from(s: &str) -> Self { match s { @@ -625,4 +639,70 @@ pub enum Commands { #[arg(long)] proof_path: PathBuf, }, + + /// Gets credentials from the hub + #[command(name = "get-hub-credentials", arg_required_else_help = true)] + #[cfg(not(target_arch = "wasm32"))] + GetHubCredentials { + /// The path to the model file + #[arg(short = 'N', long)] + username: String, + /// The path to the input json file + #[arg(short = 'U', long)] + url: Option, + }, + + /// Create artifacts and deploys them on the hub + #[command(name = "create-hub-artifact", arg_required_else_help = true)] + #[cfg(not(target_arch = "wasm32"))] + CreateHubArtifact { + /// The path to the model file + #[arg(short = 'M', long)] + uncompiled_circuit: PathBuf, + /// The path to the input json file + #[arg(short = 'D', long)] + data: PathBuf, + /// the hub's url + #[arg(short = 'O', long)] + organization_id: String, + ///artifact name + #[arg(short = 'A', long)] + artifact_name: String, + /// the hub's url + #[arg(short = 'U', long)] + url: Option, + /// proving arguments + #[clap(flatten)] + args: RunArgs, + /// calibration target + #[arg(long, default_value = "resources")] + target: CalibrationTarget, + }, + + /// Create artifacts and deploys them on the hub + #[command(name = "prove-hub", arg_required_else_help = true)] + #[cfg(not(target_arch = "wasm32"))] + ProveHub { + /// The path to the model file + #[arg(short = 'A', long)] + artifact_id: String, + /// The path to the input json file + #[arg(short = 'D', long)] + data: PathBuf, + #[arg(short = 'U', long)] + url: Option, + #[arg(short = 'T', long)] + transcript_type: Option, + }, + + /// Create artifacts and deploys them on the hub + #[command(name = "get-hub-proof", arg_required_else_help = true)] + #[cfg(not(target_arch = "wasm32"))] + GetHubProof { + /// The path to the model file + #[arg(short = 'A', long)] + artifact_id: String, + #[arg(short = 'U', long)] + url: Option, + }, } diff --git a/src/execute.rs b/src/execute.rs index 5b67f96c..a916e7a2 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -65,6 +65,8 @@ use std::sync::OnceLock; #[cfg(not(target_arch = "wasm32"))] use std::time::Duration; use thiserror::Error; +#[cfg(not(target_arch = "wasm32"))] +use tokio_util::codec::{BytesCodec, FramedRead}; #[cfg(not(target_arch = "wasm32"))] static _SOLC_REQUIREMENT: OnceLock = OnceLock::new(); @@ -335,6 +337,51 @@ pub async fn run(cli: Cli) -> Result<(), Box> { addr_da, } => verify_evm(proof_path, addr_verifier, rpc_url, addr_da).await, Commands::PrintProofHex { proof_path } => print_proof_hex(proof_path), + #[cfg(not(target_arch = "wasm32"))] + Commands::GetHubCredentials { username, url } => { + get_hub_credentials(url.as_deref(), &username) + .await + .map(|_| ()) + } + + #[cfg(not(target_arch = "wasm32"))] + Commands::CreateHubArtifact { + uncompiled_circuit, + data, + organization_id, + artifact_name, + url, + args, + target, + } => deploy_model( + url.as_deref(), + &uncompiled_circuit, + &data, + &artifact_name, + &organization_id, + &args, + &target, + ) + .await + .map(|_| ()), + #[cfg(not(target_arch = "wasm32"))] + Commands::GetHubProof { artifact_id, url } => get_hub_proof(url.as_deref(), &artifact_id) + .await + .map(|_| ()), + #[cfg(not(target_arch = "wasm32"))] + Commands::ProveHub { + artifact_id, + data, + transcript_type, + url, + } => prove_hub( + url.as_deref(), + &artifact_id, + &data, + transcript_type.as_deref(), + ) + .await + .map(|_| ()), } } @@ -1608,6 +1655,214 @@ pub(crate) fn verify_aggr( Ok(()) } +/// Retrieves the user's credentials from the hub +pub(crate) async fn get_hub_credentials( + url: Option<&str>, + username: &str, +) -> Result> { + let client = reqwest::Client::new(); + let request_body = serde_json::json!({ + "query": r#" + query GetOrganizationId($username: String!) { + organizations(name: $username) { + id + name + } + } + "#, + "variables": { + "username": username, + } + }); + let url = url.unwrap_or("https://hub-staging.ezkl.xyz/graphql"); + + let response = client.post(url).json(&request_body).send().await?; + let response_body = response.json::().await?; + + let organizations: crate::hub::Organizations = + serde_json::from_value(response_body["data"].clone())?; + + log::info!( + "Organization ID : {}", + organizations.as_json()?.to_colored_json_auto()? + ); + Ok(organizations) +} + +/// Deploy a model +pub(crate) async fn deploy_model( + url: Option<&str>, + model: &PathBuf, + input: &PathBuf, + name: &str, + organization_id: &str, + args: &RunArgs, + target: &CalibrationTarget, +) -> Result> { + let model_file = tokio::fs::File::open(model.canonicalize()?).await?; + // read file body stream + let stream = FramedRead::new(model_file, BytesCodec::new()); + let model_file_body = reqwest::Body::wrap_stream(stream); + + let model_file = reqwest::multipart::Part::stream(model_file_body).file_name("uncompiledModel"); + + let input_file = tokio::fs::File::open(input.canonicalize()?).await?; + // read file body stream + let stream = FramedRead::new(input_file, BytesCodec::new()); + let input_file_body = reqwest::Body::wrap_stream(stream); + + //make form part of file + let input_file = reqwest::multipart::Part::stream(input_file_body).file_name("input"); + + // the graphql request map + let map = r#"{ + "uncompiledModel": [ + "variables.uncompiledModel" + ], + "input": [ + "variables.input" + ] + }"#; + + let operations = serde_json::json!({ + "query": "mutation($uncompiledModel: Upload!, $input: Upload!, $organizationId: String!, $name: String!, $calibrationTarget: String!, $tolerance: Float!, $inputVisibility: String!, $outputVisibility: String!, $paramVisibility: String!) { + generateArtifact( + name: $name, + description: $name, + uncompiledModel: $uncompiledModel, + input: $input, + organizationId: $organizationId, + calibrationTarget: $calibrationTarget, + tolerance: $tolerance, + inputVisibility: $inputVisibility, + outputVisibility: $outputVisibility, + paramVisibility: $paramVisibility, + ) { + artifact { + id + } + } + }", + "variables": { + "name": name, + "uncompiledModel": null, + "input": null, + "organizationId": organization_id, + "calibrationTarget": target.to_string(), + "tolerance": args.tolerance.val, + "inputVisibility": args.input_visibility.to_string(), + "outputVisibility": args.output_visibility.to_string(), + "paramVisibility": args.param_visibility.to_string(), + } + }) + .to_string(); + + // now the form data + let mut form = reqwest::multipart::Form::new(); + form = form + .text("operations", operations) + .text("map", map) + .part("uncompiledModel", model_file) + .part("input", input_file); + + let client = reqwest::Client::new(); + let url = url.unwrap_or("https://hub-staging.ezkl.xyz/graphql"); + //send request + let response = client.post(url).multipart(form).send().await?; + let response_body = response.json::().await?; + println!("{}", response_body.to_string()); + let artifact_id: crate::hub::Artifact = + serde_json::from_value(response_body["data"]["generateArtifact"]["artifact"].clone())?; + log::info!( + "Artifact ID : {}", + artifact_id.as_json()?.to_colored_json_auto()? + ); + Ok(artifact_id) +} + +/// Generates proofs on the hub +pub async fn prove_hub( + url: Option<&str>, + id: &str, + input: &PathBuf, + transcript_type: Option<&str>, +) -> Result> { + let input_file = tokio::fs::File::open(input.canonicalize()?).await?; + let stream = FramedRead::new(input_file, BytesCodec::new()); + let input_file_body = reqwest::Body::wrap_stream(stream); + + let input_file = reqwest::multipart::Part::stream(input_file_body).file_name("input"); + + let map = r#"{ + "input": [ + "variables.input" + ] + }"#; + + let operations = serde_json::json!({ + "query": r#" + mutation($input: Upload!, $id: String!, $transcriptType: String) { + initiateProof(input: $input, id: $id, transcriptType: $transcriptType) { + id + } + } + "#, + "variables": { + "input": null, + "id": id, + "transcriptType": transcript_type.unwrap_or("evm"), + } + }) + .to_string(); + + let mut form = reqwest::multipart::Form::new(); + form = form + .text("operations", operations) + .text("map", map) + .part("input", input_file); + let url = url.unwrap_or("https://hub-staging.ezkl.xyz/graphql"); + let client = reqwest::Client::new(); + let response = client.post(url).multipart(form).send().await?; + let response_body = response.json::().await?; + let proof_id: crate::hub::Proof = + serde_json::from_value(response_body["data"]["initiateProof"].clone())?; + log::info!("Proof ID : {}", proof_id.as_json()?.to_colored_json_auto()?); + Ok(proof_id) +} + +/// Fetches proofs from the hub +pub(crate) async fn get_hub_proof( + url: Option<&str>, + id: &str, +) -> Result> { + let client = reqwest::Client::new(); + let request_body = serde_json::json!({ + "query": format!(r#" + query {{ + getProof(id: "{}") {{ + id + artifact {{ id name }} + status + proof + instances + transcriptType + strategy + }} + }} + "#, id), + }); + let url = url.unwrap_or("https://hub-staging.ezkl.xyz/graphql"); + + let response = client.post(url).json(&request_body).send().await?; + let response_body = response.json::().await?; + + let proof: crate::hub::Proof = + serde_json::from_value(response_body["data"]["getProof"].clone())?; + + log::info!("Proof : {}", proof.as_json()?.to_colored_json_auto()?); + Ok(proof) +} + /// helper function for load_params pub(crate) fn load_params_cmd( srs_path: PathBuf, diff --git a/src/hub.rs b/src/hub.rs new file mode 100644 index 00000000..75efb40a --- /dev/null +++ b/src/hub.rs @@ -0,0 +1,143 @@ +use serde::{Deserialize, Serialize}; + +/// Stores users organizations +#[derive(Serialize, Deserialize, Debug)] +pub struct Organization { + /// The organization id + pub id: String, + /// The users username + pub name: String, +} + +impl Organization { + /// Export the organization as json + pub fn as_json(&self) -> Result> { + let serialized = match serde_json::to_string(&self) { + Ok(s) => s, + Err(e) => { + return Err(Box::new(e)); + } + }; + Ok(serialized) + } +} + +/// Stores Organization +#[derive(Serialize, Deserialize, Debug)] +pub struct Organizations { + /// An Array of Organizations + pub organizations: Vec, +} + +impl Organizations { + /// Export the organizations as json + pub fn as_json(&self) -> Result> { + let serialized = match serde_json::to_string(&self) { + Ok(s) => s, + Err(e) => { + return Err(Box::new(e)); + } + }; + Ok(serialized) + } +} + +/// Stores the Proof Response +#[derive(Debug, Deserialize, Serialize)] +pub struct Proof { + /// stores the artifact + pub artifact: Option, + /// stores the Proof Id + pub id: String, + /// stores the instances + pub instances: Option>, + /// stores the proofs + pub proof: Option, + /// stores the status + pub status: Option, + ///stores the strategy + pub strategy: Option, + /// stores the transcript type + #[serde(rename = "transcriptType")] + pub transcript_type: Option, +} + +impl Proof { + /// Export the proof as json + pub fn as_json(&self) -> Result> { + let serialized = match serde_json::to_string(&self) { + Ok(s) => s, + Err(e) => { + return Err(Box::new(e)); + } + }; + Ok(serialized) + } +} + +/// Stores the Artifacts +#[derive(Debug, Deserialize, Serialize)] +pub struct Artifact { + ///stores the aritfact id + pub id: Option, + /// stores the name of the artifact + pub name: Option, +} + +impl Artifact { + /// Export the artifact as json + pub fn as_json(&self) -> Result> { + let serialized = match serde_json::to_string(&self) { + Ok(s) => s, + Err(e) => { + return Err(Box::new(e)); + } + }; + Ok(serialized) + } +} + +#[cfg(feature = "python-bindings")] +impl pyo3::ToPyObject for Artifact { + fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject { + let dict = pyo3::types::PyDict::new(py); + dict.set_item("id", &self.id).unwrap(); + dict.set_item("name", &self.name).unwrap(); + dict.into() + } +} + +#[cfg(feature = "python-bindings")] +impl pyo3::ToPyObject for Proof { + fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject { + let dict = pyo3::types::PyDict::new(py); + dict.set_item("artifact", &self.artifact).unwrap(); + dict.set_item("id", &self.id).unwrap(); + dict.set_item("instances", &self.instances).unwrap(); + dict.set_item("proof", &self.proof).unwrap(); + dict.set_item("status", &self.status).unwrap(); + dict.set_item("strategy", &self.strategy).unwrap(); + dict.set_item("transcript_type", &self.transcript_type) + .unwrap(); + dict.into() + } +} + +#[cfg(feature = "python-bindings")] +impl pyo3::ToPyObject for Organizations { + fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject { + let dict = pyo3::types::PyDict::new(py); + dict.set_item("organizations", &self.organizations).unwrap(); + dict.into() + } +} + +#[cfg(feature = "python-bindings")] +impl pyo3::ToPyObject for Organization { + fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject { + let dict = pyo3::types::PyDict::new(py); + dict.set_item("id", &self.id).unwrap(); + dict.set_item("name", &self.name).unwrap(); + dict.into() + } +} diff --git a/src/lib.rs b/src/lib.rs index 54721fa2..121d9091 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,6 +51,9 @@ pub mod fieldutils; /// a Halo2 circuit. #[cfg(feature = "onnx")] pub mod graph; +/// Methods for deploying and interacting with the ezkl hub +#[cfg(not(target_arch = "wasm32"))] +pub mod hub; /// beautiful logging #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub mod logger; diff --git a/src/python.rs b/src/python.rs index 46ed7148..f752cfac 100644 --- a/src/python.rs +++ b/src/python.rs @@ -1077,16 +1077,86 @@ fn print_proof_hex(proof_path: PathBuf) -> Result { let proof = Snark::load::>(&proof_path) .map_err(|_| PyIOError::new_err("Failed to load proof"))?; - // let mut return_string: String = ""; - // for instance in proof.instances { - // return_string.push_str(instance + "\n"); - // } - // return_string = hex::encode(proof.proof); - - // return proof for now Ok(hex::encode(proof.proof)) } +/// deploys a model to the hub +#[pyfunction(signature = (model, input, name, organization_id, target=None, py_run_args=None, url=None))] +fn create_hub_artifact( + model: PathBuf, + input: PathBuf, + name: String, + organization_id: String, + target: Option, + py_run_args: Option, + url: Option<&str>, +) -> PyResult { + let run_args: RunArgs = py_run_args.unwrap_or_else(PyRunArgs::new).into(); + let target = target.unwrap_or(CalibrationTarget::Resources { + col_overflow: false, + }); + let output = Runtime::new() + .unwrap() + .block_on(crate::execute::deploy_model( + url, + &model, + &input, + &name, + &organization_id, + &run_args, + &target, + )) + .map_err(|e| { + let err_str = format!("Failed to deploy model to hub: {}", e); + PyRuntimeError::new_err(err_str) + })?; + Python::with_gil(|py| Ok(output.to_object(py))) +} + +/// Generate a proof on the hub. +#[pyfunction(signature = (id, input, url=None, transcript_type=None))] +fn prove_hub( + id: &str, + input: PathBuf, + url: Option<&str>, + transcript_type: Option<&str>, +) -> PyResult { + let output = Runtime::new() + .unwrap() + .block_on(crate::execute::prove_hub(url, id, &input, transcript_type)) + .map_err(|e| { + let err_str = format!("Failed to generate proof on hub: {}", e); + PyRuntimeError::new_err(err_str) + })?; + Python::with_gil(|py| Ok(output.to_object(py))) +} + +/// Fetches proof from hub +#[pyfunction(signature = (id, url=None))] +fn get_hub_proof(id: &str, url: Option<&str>) -> PyResult { + let output = Runtime::new() + .unwrap() + .block_on(crate::execute::get_hub_proof(url, id)) + .map_err(|e| { + let err_str = format!("Failed to get proof from hub: {}", e); + PyRuntimeError::new_err(err_str) + })?; + Python::with_gil(|py| Ok(output.to_object(py))) +} + +/// Gets hub credentials +#[pyfunction(signature = (username, url=None))] +fn get_hub_credentials(username: &str, url: Option<&str>) -> PyResult { + let output = Runtime::new() + .unwrap() + .block_on(crate::execute::get_hub_credentials(url, username)) + .map_err(|e| { + let err_str = format!("Failed to get hub credentials: {}", e); + PyRuntimeError::new_err(err_str) + })?; + Python::with_gil(|py| Ok(output.to_object(py))) +} + // Python Module #[pymodule] fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> { @@ -1130,6 +1200,10 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(print_proof_hex, m)?)?; m.add_function(wrap_pyfunction!(create_evm_verifier_aggr, m)?)?; m.add_function(wrap_pyfunction!(create_evm_data_attestation, m)?)?; + m.add_function(wrap_pyfunction!(create_hub_artifact, m)?)?; + m.add_function(wrap_pyfunction!(prove_hub, m)?)?; + m.add_function(wrap_pyfunction!(get_hub_proof, m)?)?; + m.add_function(wrap_pyfunction!(get_hub_credentials, m)?)?; Ok(()) } diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 26754550..4bfcb2a2 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -450,7 +450,7 @@ impl Tensor { pub fn new(values: Option<&[T]>, dims: &[usize]) -> Result { let total_dims: usize = if !dims.is_empty() { dims.iter().product() - } else if let Some(_) = values { + } else if values.is_some() { 1 } else { 0 diff --git a/tests/py_integration_tests.rs b/tests/py_integration_tests.rs index cae8e91a..4f413ce9 100644 --- a/tests/py_integration_tests.rs +++ b/tests/py_integration_tests.rs @@ -115,7 +115,7 @@ mod py_tests { } } - const TESTS: [&str; 25] = [ + const TESTS: [&str; 26] = [ "mnist_gan.ipynb", // "mnist_vae.ipynb", "keras_simple_demo.ipynb", @@ -142,6 +142,7 @@ mod py_tests { "linear_regression.ipynb", "stacked_regression.ipynb", "data_attest_hashed.ipynb", + "simple_hub_demo.ipynb", ]; macro_rules! test_func { @@ -154,7 +155,7 @@ mod py_tests { use super::*; - seq!(N in 0..=24 { + seq!(N in 0..=25 { #(#[test_case(TESTS[N])])* fn run_notebook_(test: &str) {