Compare commits

...

2 Commits

Author SHA1 Message Date
github-actions[bot]
530a504fa4 ci: update version string in docs 2025-02-07 21:06:20 +00:00
dante
f7f04415fa chore!: add model input/output types to settings (#933)
BREAKING CHANGE: compiled model serialization is not backwards compatible
2025-02-07 16:05:59 -05:00
8 changed files with 91 additions and 12 deletions

View File

@@ -23,8 +23,6 @@ use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
const L: usize = 10;
#[derive(Clone, Debug)]
struct MyCircuit {
image: ValTensor<Fr>,
@@ -40,7 +38,7 @@ impl Circuit<Fr> for MyCircuit {
}
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, 10>::configure(cs, ())
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::configure(cs, ())
}
fn synthesize(
@@ -48,7 +46,7 @@ impl Circuit<Fr> for MyCircuit {
config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), Error> {
let chip: PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L> =
let chip: PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE> =
PoseidonChip::new(config);
chip.layout(&mut layouter, &[self.image.clone()], 0, &mut HashMap::new())?;
Ok(())
@@ -59,7 +57,7 @@ fn runposeidon(c: &mut Criterion) {
let mut group = c.benchmark_group("poseidon");
for size in [64, 784, 2352, 12288].iter() {
let k = (PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L>::num_rows(*size)
let k = (PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::num_rows(*size)
as f32)
.log2()
.ceil() as u32;
@@ -67,7 +65,7 @@ fn runposeidon(c: &mut Criterion) {
let message = (0..*size).map(|_| Fr::random(OsRng)).collect::<Vec<_>>();
let _output =
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L>::run(message.to_vec())
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.to_vec())
.unwrap();
let mut image = Tensor::from(message.into_iter().map(Value::known));

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '0.0.0'
release = '20.0.0'
version = release

View File

@@ -337,6 +337,8 @@ enum PyInputType {
Int,
///
TDim,
///
Unknown,
}
impl From<InputType> for PyInputType {
@@ -348,6 +350,7 @@ impl From<InputType> for PyInputType {
InputType::F64 => PyInputType::F64,
InputType::Int => PyInputType::Int,
InputType::TDim => PyInputType::TDim,
InputType::Unknown => PyInputType::Unknown,
}
}
}
@@ -361,6 +364,7 @@ impl From<PyInputType> for InputType {
PyInputType::F64 => InputType::F64,
PyInputType::Int => InputType::Int,
PyInputType::TDim => InputType::TDim,
PyInputType::Unknown => InputType::Unknown,
}
}
}
@@ -375,6 +379,7 @@ impl FromStr for PyInputType {
"f64" => Ok(PyInputType::F64),
"int" => Ok(PyInputType::Int),
"tdim" => Ok(PyInputType::TDim),
"unknown" => Ok(PyInputType::Unknown),
_ => Err("Invalid value for InputType".to_string()),
}
}

View File

@@ -1,6 +1,8 @@
use std::any::Any;
use serde::{Deserialize, Serialize};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::prelude::DatumType;
use crate::{
graph::quantize_tensor,
@@ -96,6 +98,8 @@ pub enum InputType {
Int,
///
TDim,
///
Unknown,
}
impl InputType {
@@ -132,6 +136,7 @@ impl InputType {
let int_input = input.clone().to_i64().unwrap();
*input = T::from_i64(int_input).unwrap();
}
InputType::Unknown => {}
}
}
}
@@ -152,6 +157,28 @@ impl std::str::FromStr for InputType {
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl From<DatumType> for InputType {
fn from(datum_type: DatumType) -> Self {
match datum_type {
DatumType::Bool => InputType::Bool,
DatumType::F16 => InputType::F16,
DatumType::F32 => InputType::F32,
DatumType::F64 => InputType::F64,
DatumType::I8 => InputType::Int,
DatumType::I16 => InputType::Int,
DatumType::I32 => InputType::Int,
DatumType::I64 => InputType::Int,
DatumType::U8 => InputType::Int,
DatumType::U16 => InputType::Int,
DatumType::U32 => InputType::Int,
DatumType::U64 => InputType::Int,
DatumType::TDim => InputType::TDim,
_ => unimplemented!(),
}
}
}
///
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Input {

View File

@@ -455,6 +455,10 @@ pub struct GraphSettings {
pub num_blinding_factors: Option<usize>,
/// unix time timestamp
pub timestamp: Option<u128>,
/// Model inputs types (if any)
pub input_types: Option<Vec<InputType>>,
/// Model outputs types (if any)
pub output_types: Option<Vec<InputType>>,
}
impl GraphSettings {

View File

@@ -379,9 +379,15 @@ pub struct ParsedNodes {
pub nodes: BTreeMap<usize, NodeType>,
inputs: Vec<usize>,
outputs: Vec<Outlet>,
output_types: Vec<InputType>,
}
impl ParsedNodes {
/// Returns the output types of the computational graph.
pub fn get_output_types(&self) -> Vec<InputType> {
self.output_types.clone()
}
/// Returns the number of the computational graph's inputs
pub fn num_inputs(&self) -> usize {
self.inputs.len()
@@ -491,6 +497,16 @@ impl Model {
Ok(om)
}
/// Gets the input types from the parsed nodes
pub fn get_input_types(&self) -> Result<Vec<InputType>, GraphError> {
self.graph.get_input_types()
}
/// Gets the output types from the parsed nodes
pub fn get_output_types(&self) -> Vec<InputType> {
self.graph.get_output_types()
}
///
pub fn save(&self, path: PathBuf) -> Result<(), GraphError> {
let f = std::fs::File::create(&path).map_err(|e| {
@@ -574,6 +590,11 @@ impl Model {
required_range_checks: res.range_checks.into_iter().collect(),
model_output_scales: self.graph.get_output_scales()?,
model_input_scales: self.graph.get_input_scales(),
input_types: match self.get_input_types() {
Ok(x) => Some(x),
Err(_) => None,
},
output_types: Some(self.get_output_types()),
num_dynamic_lookups: res.num_dynamic_lookups,
total_dynamic_col_size: res.dynamic_lookup_col_coord,
num_shuffles: res.num_shuffles,
@@ -704,6 +725,11 @@ impl Model {
nodes,
inputs: model.inputs.iter().map(|o| o.node).collect(),
outputs: model.outputs.iter().map(|o| (o.node, o.slot)).collect(),
output_types: model
.outputs
.iter()
.map(|o| Ok::<InputType, GraphError>(model.outlet_fact(*o)?.datum_type.into()))
.collect::<Result<Vec<_>, GraphError>>()?,
};
let duration = start_time.elapsed();
@@ -862,6 +888,15 @@ impl Model {
nodes: subgraph_nodes,
inputs: model.inputs.iter().map(|o| o.node).collect(),
outputs: model.outputs.iter().map(|o| (o.node, o.slot)).collect(),
output_types: model
.outputs
.iter()
.map(|o| {
Ok::<InputType, GraphError>(
model.outlet_fact(*o)?.datum_type.into(),
)
})
.collect::<Result<Vec<_>, GraphError>>()?,
};
let om = Model {
@@ -1579,4 +1614,16 @@ impl Model {
}
Ok(instance_shapes)
}
/// Input types of the computational graph's public inputs (if any)
pub fn instance_types(&self) -> Result<Vec<InputType>, GraphError> {
let mut instance_types = vec![];
if self.visibility.input.is_public() {
instance_types.extend(self.graph.get_input_types()?);
}
if self.visibility.output.is_public() {
instance_types.extend(self.graph.get_output_types());
}
Ok(instance_types)
}
}

View File

@@ -387,7 +387,7 @@ pub fn add<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sy
) -> Result<Tensor<T>, TensorError> {
if t.len() == 1 {
return Ok(t[0].clone());
} else if t.len() == 0 {
} else if t.is_empty() {
return Err(TensorError::DimMismatch("add".to_string()));
}
@@ -441,7 +441,7 @@ pub fn sub<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sy
) -> Result<Tensor<T>, TensorError> {
if t.len() == 1 {
return Ok(t[0].clone());
} else if t.len() == 0 {
} else if t.is_empty() {
return Err(TensorError::DimMismatch("sub".to_string()));
}
// calculate value of output
@@ -492,7 +492,7 @@ pub fn mult<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::S
) -> Result<Tensor<T>, TensorError> {
if t.len() == 1 {
return Ok(t[0].clone());
} else if t.len() == 0 {
} else if t.is_empty() {
return Err(TensorError::DimMismatch("mult".to_string()));
}
// calculate value of output
@@ -1326,7 +1326,6 @@ pub fn pad<T: TensorType>(
///
/// # Errors
/// Returns a TensorError if the tensors in `inputs` have incompatible dimensions for concatenation along the specified `axis`.
pub fn concat<T: TensorType + Send + Sync>(
inputs: &[&Tensor<T>],
axis: usize,
@@ -2102,7 +2101,6 @@ pub mod nonlinearities {
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 25, 8, 1, 1, 0]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn tanh(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| {
let kix = (a_i as f64) / scale_input;

Binary file not shown.