mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-10 06:48:01 -05:00
chore: keras/tf example notebook (#370)
This commit is contained in:
2
.github/workflows/rust.yml
vendored
2
.github/workflows/rust.yml
vendored
@@ -465,7 +465,7 @@ jobs:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.7"
|
||||
python-version: "3.9"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
|
||||
@@ -51,7 +51,7 @@ source .env/bin/activate
|
||||
pip install -r requirements.txt
|
||||
maturin develop --release --features python-bindings
|
||||
# dependencies specific to tutorials
|
||||
pip install torch pandas numpy seaborn jupyter onnx kaggle py-solc-x web3 librosa
|
||||
pip install torch pandas numpy seaborn jupyter onnx kaggle py-solc-x web3 librosa tensorflow keras tf2onnx
|
||||
```
|
||||
|
||||
|
||||
|
||||
@@ -141,7 +141,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x = 0.1*torch.rand(1,*[3, 8, 8], requires_grad=True)\n",
|
||||
"x = torch.rand(1,*[3, 8, 8], requires_grad=True)\n",
|
||||
"\n",
|
||||
"# Flips the neural net into inference mode\n",
|
||||
"circuit.eval()\n",
|
||||
@@ -250,7 +250,7 @@
|
||||
"source": [
|
||||
"# generate a bunch of dummy calibration data\n",
|
||||
"cal_data = {\n",
|
||||
" \"input_data\": [(0.1*torch.rand(40, *[3, 8, 8])).flatten().tolist()],\n",
|
||||
" \"input_data\": [torch.cat((x, torch.rand(10, *[3, 8, 8]))).flatten().tolist()],\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"cal_path = os.path.join('val_data.json')\n",
|
||||
|
||||
File diff suppressed because one or more lines are too long
249
examples/notebooks/keras_simple_demo.ipynb
Normal file
249
examples/notebooks/keras_simple_demo.ipynb
Normal file
@@ -0,0 +1,249 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## EZKL Jupyter Notebook Demo with Keras\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e4e073ab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install keras tensorflow tf2onnx numpy==1.23"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a27b0cd9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# make sure you have the dependencies required here already installed\n",
|
||||
"import ezkl\n",
|
||||
"import os\n",
|
||||
"import json\n",
|
||||
"from keras.models import Sequential\n",
|
||||
"from keras.layers import Dense, Dropout, Activation, Flatten\n",
|
||||
"from keras.layers import Convolution2D, MaxPooling2D\n",
|
||||
"import logging\n",
|
||||
"\n",
|
||||
"# uncomment for more descriptive logging \n",
|
||||
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
|
||||
"logging.basicConfig(format=FORMAT)\n",
|
||||
"logging.getLogger().setLevel(logging.DEBUG)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "95613ee9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"# Defines the model\n",
|
||||
"# we got convs, we got relu, we got linear layers, max pooling layers etc... \n",
|
||||
"# What else could one want ????\n",
|
||||
"\n",
|
||||
"model = Sequential()\n",
|
||||
"model.add(Convolution2D(32, (3,3), activation='relu', input_shape=(28,28,1)))\n",
|
||||
"model.add(Convolution2D(32, (3,3), activation='relu'))\n",
|
||||
"model.add(MaxPooling2D(pool_size=(2,2)))\n",
|
||||
"model.add(Dropout(0.25))\n",
|
||||
"model.add(Flatten())\n",
|
||||
"model.add(Dense(128, activation='relu'))\n",
|
||||
"model.add(Dropout(0.5))\n",
|
||||
"model.add(Dense(10, activation='softmax'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Train the model as you like here (skipped for brevity)\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b37637c4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_path = os.path.join('network.onnx')\n",
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"settings_path = os.path.join('settings.json')\n",
|
||||
"srs_path = os.path.join('kzg.srs')\n",
|
||||
"witness_path = os.path.join('witness.json')\n",
|
||||
"data_path = os.path.join('input.json')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "82db373a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import tf2onnx\n",
|
||||
"import tensorflow as tf\n",
|
||||
"\n",
|
||||
"# After training, export to onnx (network.onnx) and create a data file (input.json)\n",
|
||||
"x = 0.1*np.random.rand(1,*[1, 28, 28])\n",
|
||||
"\n",
|
||||
"spec = tf.TensorSpec([1, 28, 28, 1], tf.float32, name='input_0')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tf2onnx.convert.from_keras(model, input_signature=[spec], inputs_as_nchw=['input_0'], opset=12, output_path=model_path)\n",
|
||||
"\n",
|
||||
"data_array = x.reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_data = [data_array])\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump( data, open(data_path, 'w' ))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d5e374a2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!RUST_LOG=trace\n",
|
||||
"# TODO: Dictionary outputs\n",
|
||||
"res = ezkl.gen_settings(model_path, settings_path)\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8b74dcee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.gen_srs(srs_path, 17)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "18c8b7c7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, model_path, witness_path, settings_path = settings_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b1c561a8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
|
||||
"# WE GOT KEYS\n",
|
||||
"# WE GOT CIRCUIT PARAMETERS\n",
|
||||
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
|
||||
"\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" model_path,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path,\n",
|
||||
" srs_path,\n",
|
||||
" settings_path,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"assert os.path.isfile(vk_path)\n",
|
||||
"assert os.path.isfile(pk_path)\n",
|
||||
"assert os.path.isfile(settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c384cbc8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# GENERATE A PROOF\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"proof_path = os.path.join('test.pf')\n",
|
||||
"\n",
|
||||
"res = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" srs_path,\n",
|
||||
" \"evm\",\n",
|
||||
" \"single\",\n",
|
||||
" settings_path,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
"assert os.path.isfile(proof_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "76f00d41",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# VERIFY IT\n",
|
||||
"res = ezkl.verify(\n",
|
||||
" proof_path,\n",
|
||||
" settings_path,\n",
|
||||
" vk_path,\n",
|
||||
" srs_path,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"print(\"verified\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1309,6 +1309,17 @@ pub fn reshape<F: PrimeField + TensorType + PartialOrd>(
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
/// Dummy (no contraints) move_axis layout
|
||||
pub fn move_axis<F: PrimeField + TensorType + PartialOrd>(
|
||||
values: &[ValTensor<F>; 1],
|
||||
source: usize,
|
||||
destination: usize,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let mut t = values[0].clone();
|
||||
t.move_axis(source, destination)?;
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
/// resize layout
|
||||
pub fn resize<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
|
||||
@@ -44,6 +44,10 @@ pub enum PolyOp<F: PrimeField + TensorType + PartialOrd> {
|
||||
},
|
||||
Identity,
|
||||
Reshape(Vec<usize>),
|
||||
MoveAxis {
|
||||
source: usize,
|
||||
destination: usize,
|
||||
},
|
||||
Gather {
|
||||
dim: usize,
|
||||
index: Tensor<usize>,
|
||||
@@ -78,6 +82,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
|
||||
}
|
||||
fn as_string(&self) -> String {
|
||||
let name = match &self {
|
||||
PolyOp::MoveAxis { .. } => "MOVEAXIS",
|
||||
PolyOp::Downsample { .. } => "DOWNSAMPLE",
|
||||
PolyOp::Resize { .. } => "RESIZE",
|
||||
PolyOp::Iff => "IFF",
|
||||
@@ -122,6 +127,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
|
||||
t.reshape(new_dims);
|
||||
Ok(t)
|
||||
}
|
||||
PolyOp::MoveAxis {
|
||||
source,
|
||||
destination,
|
||||
} => inputs[0].move_axis(*source, *destination),
|
||||
PolyOp::Flatten(new_dims) => {
|
||||
let mut t = inputs[0].clone();
|
||||
t.reshape(new_dims);
|
||||
@@ -225,6 +234,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
|
||||
let mut values = values.to_vec();
|
||||
|
||||
Ok(Some(match self {
|
||||
PolyOp::MoveAxis {
|
||||
source,
|
||||
destination,
|
||||
} => layouts::move_axis(values[..].try_into()?, *source, *destination)?,
|
||||
PolyOp::Downsample {
|
||||
axis,
|
||||
stride,
|
||||
@@ -328,6 +341,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<u32>, _g: u32) -> u32 {
|
||||
match self {
|
||||
PolyOp::MoveAxis { .. } => in_scales[0],
|
||||
PolyOp::Downsample { .. } => in_scales[0],
|
||||
PolyOp::Resize { .. } => in_scales[0],
|
||||
PolyOp::Iff => in_scales[1],
|
||||
|
||||
@@ -696,7 +696,10 @@ pub(crate) async fn calibrate(
|
||||
res.push(task);
|
||||
}
|
||||
}
|
||||
if let Some(best) = res.into_iter().max_by_key(|p| p.run_args.logrows) {
|
||||
if let Some(best) = res
|
||||
.into_iter()
|
||||
.max_by_key(|p| (p.run_args.bits, p.run_args.scale))
|
||||
{
|
||||
// pick the one with the largest logrows
|
||||
found_params.push(best);
|
||||
}
|
||||
|
||||
@@ -11,13 +11,12 @@ use log::{debug, warn};
|
||||
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
|
||||
use tract_onnx::tract_core::ops::array::Gather;
|
||||
use tract_onnx::tract_core::ops::array::Slice;
|
||||
use tract_onnx::tract_core::ops::change_axes::AxisOp;
|
||||
use tract_onnx::tract_core::ops::cnn::DeconvUnary;
|
||||
use tract_onnx::tract_core::ops::einsum::EinSum;
|
||||
use tract_onnx::tract_core::ops::Downsample;
|
||||
|
||||
use tract_onnx::tract_core::ops::element_wise::ElementWiseOp;
|
||||
|
||||
use tract_onnx::tract_core::ops::nn::{LeakyRelu, Reduce, Softmax};
|
||||
use tract_onnx::tract_core::ops::Downsample;
|
||||
use tract_onnx::tract_hir::internal::DimLike;
|
||||
use tract_onnx::tract_hir::ops::cnn::ConvUnary;
|
||||
use tract_onnx::tract_hir::ops::konst::Const;
|
||||
@@ -136,6 +135,23 @@ fn load_gather_op(
|
||||
Ok(op.clone())
|
||||
}
|
||||
|
||||
///
|
||||
fn load_axis_op(
|
||||
op: &dyn tract_onnx::prelude::Op,
|
||||
idx: usize,
|
||||
name: String,
|
||||
) -> Result<AxisOp, Box<dyn std::error::Error>> {
|
||||
// Extract the slope layer hyperparams
|
||||
let op: &AxisOp = match op.downcast_ref::<AxisOp>() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, name)));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(op.clone())
|
||||
}
|
||||
|
||||
/// Extracts an axis op from an onnx node.
|
||||
fn load_const(
|
||||
op: &dyn tract_onnx::prelude::Op,
|
||||
@@ -256,6 +272,20 @@ pub fn new_op_from_onnx(
|
||||
|
||||
Box::new(crate::circuit::ops::poly::PolyOp::Gather { dim: axis, index })
|
||||
}
|
||||
"MoveAxis" => {
|
||||
let op = load_axis_op(node.op(), idx, node.op().name().to_string())?;
|
||||
match op {
|
||||
AxisOp::Move(from, to) => {
|
||||
let source = from.to_usize()?;
|
||||
let destination = to.to_usize()?;
|
||||
Box::new(crate::circuit::ops::poly::PolyOp::MoveAxis {
|
||||
source,
|
||||
destination,
|
||||
})
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
"Concat" | "InferenceConcat" => {
|
||||
let op = load_concat_op(node.op(), idx, node.op().name().to_string())?;
|
||||
let axis = op.axis;
|
||||
@@ -746,7 +776,7 @@ pub fn new_op_from_onnx(
|
||||
);
|
||||
Box::new(PolyOp::Pad(padding_h, padding_w))
|
||||
}
|
||||
"RmAxis" | "Reshape" => {
|
||||
"RmAxis" | "Reshape" | "AddAxis" => {
|
||||
// Extract the slope layer hyperparams
|
||||
let shapes = node_output_shapes(&node)?;
|
||||
let output_shape = shapes[0].as_ref().unwrap().clone();
|
||||
|
||||
@@ -630,6 +630,62 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
self.dims = Vec::from(new_dims);
|
||||
}
|
||||
|
||||
/// Move axis of the tensor
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<f32>::new(None, &[3, 3, 3]).unwrap();
|
||||
/// let b = a.move_axis(0, 2).unwrap();
|
||||
/// assert_eq!(b.dims(), &[3, 3, 3]);
|
||||
///
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6]), &[1, 2, 3]).unwrap();
|
||||
/// let b = a.move_axis(0, 2).unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
///
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let b = a.move_axis(1, 2).unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
/// ```
|
||||
pub fn move_axis(&mut self, source: usize, destination: usize) -> Result<Self, TensorError> {
|
||||
assert!(source < self.dims.len());
|
||||
assert!(destination < self.dims.len());
|
||||
let mut new_dims = self.dims.clone();
|
||||
new_dims.remove(source);
|
||||
new_dims.insert(destination, self.dims[source]);
|
||||
|
||||
// now reconfigure the elements appropriately in the new array
|
||||
// eg. if we have a 3x3x3 array and we want to move the 0th axis to the 2nd position
|
||||
// we need to move the elements at 0, 1, 2, 3, 4, 5, 6, 7, 8 to 0, 3, 6, 1, 4, 7, 2, 5, 8
|
||||
// so we need to move the elements at 0, 1, 2 to 0, 3, 6
|
||||
// and the elements at 3, 4, 5 to 1, 4, 7
|
||||
// and the elements at 6, 7, 8 to 2, 5, 8
|
||||
let cartesian_coords = new_dims
|
||||
.iter()
|
||||
.map(|d| 0..*d)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<Vec<usize>>>();
|
||||
|
||||
let mut output = Tensor::new(None, &new_dims)?;
|
||||
|
||||
for coord in cartesian_coords {
|
||||
let mut old_coord = vec![0; self.dims.len()];
|
||||
// now fetch the old index
|
||||
for (i, c) in coord.iter().enumerate() {
|
||||
if i == destination {
|
||||
old_coord[source] = *c;
|
||||
} else if i < source {
|
||||
old_coord[i] = *c;
|
||||
} else if i >= source {
|
||||
old_coord[i + 1] = *c;
|
||||
}
|
||||
}
|
||||
output.set(&coord, self.get(&old_coord));
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Broadcasts the tensor to a given shape
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
|
||||
@@ -312,6 +312,22 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Calls `move_axis` on the inner tensor.
|
||||
pub fn move_axis(&mut self, source: usize, destination: usize) -> Result<(), Box<dyn Error>> {
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
} => {
|
||||
*v = v.move_axis(source, destination)?;
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(Box::new(TensorError::WrongMethod));
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sets the [ValTensor]'s shape.
|
||||
pub fn reshape(&mut self, new_dims: &[usize]) -> Result<(), Box<dyn Error>> {
|
||||
match self {
|
||||
|
||||
@@ -69,9 +69,18 @@ mod py_tests {
|
||||
"py-solc-x",
|
||||
"web3",
|
||||
"librosa",
|
||||
"keras",
|
||||
"tensorflow",
|
||||
"tf2onnx",
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
|
||||
let status = Command::new("pip")
|
||||
.args(["install", "numpy==1.23"])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
|
||||
assert!(status.success());
|
||||
});
|
||||
}
|
||||
@@ -99,7 +108,8 @@ mod py_tests {
|
||||
}
|
||||
}
|
||||
|
||||
const TESTS: [&str; 5] = [
|
||||
const TESTS: [&str; 6] = [
|
||||
"keras_simple_demo.ipynb",
|
||||
"encrypted_vis.ipynb",
|
||||
"hashed_vis.ipynb",
|
||||
"simple_demo.ipynb",
|
||||
@@ -117,7 +127,7 @@ mod py_tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
seq!(N in 0..=4 {
|
||||
seq!(N in 0..=5 {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn run_notebook_(test: &str) {
|
||||
|
||||
Reference in New Issue
Block a user