Compare commits

...

5 Commits

Author SHA1 Message Date
dante
ea535e2ecd refactor: use linear index constraints for gather and scatter (#735) 2024-03-09 18:00:21 +00:00
Alexander Camuto
f8aa91ed08 fix: windows compile 2024-03-06 11:40:44 +00:00
dante
a59e3780b2 chore: rm recip_int helper (#733) 2024-03-05 21:51:14 +00:00
dante
345fb5672a chore: cleanup unused args (#732) 2024-03-05 13:43:29 +00:00
dante
70daaff2e4 chore: cleanup calibrate (#731) 2024-03-04 17:52:11 +00:00
9 changed files with 422 additions and 674 deletions

View File

@@ -17,7 +17,6 @@ pub enum BaseOp {
Sub,
SumInit,
Sum,
IsZero,
IsBoolean,
}
@@ -35,7 +34,6 @@ impl BaseOp {
BaseOp::Add => a + b,
BaseOp::Sub => a - b,
BaseOp::Mult => a * b,
BaseOp::IsZero => b,
BaseOp::IsBoolean => b,
_ => panic!("nonaccum_f called on accumulating operation"),
}
@@ -76,7 +74,6 @@ impl BaseOp {
BaseOp::Mult => "MULT",
BaseOp::Sum => "SUM",
BaseOp::SumInit => "SUMINIT",
BaseOp::IsZero => "ISZERO",
BaseOp::IsBoolean => "ISBOOLEAN",
}
}
@@ -93,7 +90,6 @@ impl BaseOp {
BaseOp::Mult => (0, 1),
BaseOp::Sum => (-1, 2),
BaseOp::SumInit => (0, 1),
BaseOp::IsZero => (0, 1),
BaseOp::IsBoolean => (0, 1),
}
}
@@ -110,7 +106,6 @@ impl BaseOp {
BaseOp::Mult => 2,
BaseOp::Sum => 1,
BaseOp::SumInit => 1,
BaseOp::IsZero => 0,
BaseOp::IsBoolean => 0,
}
}
@@ -127,7 +122,6 @@ impl BaseOp {
BaseOp::SumInit => 0,
BaseOp::CumProd => 1,
BaseOp::CumProdInit => 0,
BaseOp::IsZero => 0,
BaseOp::IsBoolean => 0,
}
}

View File

@@ -387,7 +387,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
nonaccum_selectors.insert((BaseOp::Add, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Sub, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Mult, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::IsZero, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::IsBoolean, i, j), meta.selector());
}
}
@@ -432,12 +431,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
}
BaseOp::IsZero => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
.expect("non accum: output query failed");
vec![expected_output[base_op.constraint_idx()].clone()]
}
_ => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)

File diff suppressed because it is too large Load Diff

View File

@@ -60,8 +60,6 @@ pub enum PolyOp {
len_prod: usize,
},
Pow(u32),
Pack(u32, u32),
GlobalSumPool,
Concat {
axis: usize,
},
@@ -110,8 +108,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::Sum { .. } => "SUM".into(),
PolyOp::Prod { .. } => "PROD".into(),
PolyOp::Pow(_) => "POW".into(),
PolyOp::Pack(_, _) => "PACK".into(),
PolyOp::GlobalSumPool => "GLOBALSUMPOOL".into(),
PolyOp::Conv { .. } => "CONV".into(),
PolyOp::DeConv { .. } => "DECONV".into(),
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
@@ -181,13 +177,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
output_padding,
stride,
} => tensor::ops::deconv(&inputs, *padding, *output_padding, *stride),
PolyOp::Pack(base, scale) => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("pack inputs".to_string()));
}
tensor::ops::pack(&inputs[0], F::from(*base as u64), *scale)
}
PolyOp::Pow(u) => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("pow inputs".to_string()));
@@ -206,7 +195,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
}
tensor::ops::prod_axes(&inputs[0], axes)
}
PolyOp::GlobalSumPool => unreachable!(),
PolyOp::Concat { axis } => {
tensor::ops::concat(&inputs.iter().collect::<Vec<_>>(), *axis)
}
@@ -288,7 +276,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
if let Some(idx) = constant_idx {
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
} else {
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?.0
}
}
PolyOp::ScatterElements { dim, constant_idx } => {
@@ -334,10 +322,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
input
}
PolyOp::Pow(exp) => layouts::pow(config, region, values[..].try_into()?, *exp)?,
PolyOp::Pack(base, scale) => {
layouts::pack(config, region, values[..].try_into()?, *base, *scale)?
}
PolyOp::GlobalSumPool => unreachable!(),
PolyOp::Concat { axis } => layouts::concat(values[..].try_into()?, axis)?,
PolyOp::Slice { axis, start, end } => {
layouts::slice(config, region, values[..].try_into()?, axis, start, end)?

View File

@@ -2243,75 +2243,6 @@ mod pow {
}
}
#[cfg(test)]
mod pack {
use super::*;
const K: usize = 8;
const LEN: usize = 4;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 1],
_marker: PhantomData<F>,
}
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 1, LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN);
let output = VarTensor::new_advice(cs, K, 1, LEN);
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
config
.layout(
&mut region,
&self.inputs.clone(),
Box::new(PolyOp::Pack(2, 1)),
)
.map_err(|_| Error::Synthesis)
},
)
.unwrap();
Ok(())
}
}
#[test]
fn packcircuit() {
// parameters
let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1))));
let circuit = MyCircuit::<F> {
inputs: [ValTensor::from(a)],
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied();
}
}
#[cfg(test)]
mod matmul_relu {
use super::*;

View File

@@ -402,9 +402,6 @@ pub enum Commands {
/// Number of logrows to use for srs. Overrides settings_path if specified.
#[arg(long, default_value = None)]
logrows: Option<u32>,
/// Check mode for SRS. Verifies downloaded srs is valid. Set to unsafe for speed.
#[arg(long, default_value = DEFAULT_CHECKMODE)]
check: CheckMode,
},
/// Loads model and input and runs mock prover (for testing)
Mock {

View File

@@ -159,8 +159,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
srs_path,
settings_path,
logrows,
check,
} => get_srs_cmd(srs_path, settings_path, logrows, check).await,
} => get_srs_cmd(srs_path, settings_path, logrows).await,
Commands::Table { model, args } => table(model, args),
#[cfg(feature = "render")]
Commands::RenderCircuit {
@@ -492,23 +491,28 @@ async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
}
#[cfg(not(target_arch = "wasm32"))]
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
pub(crate) fn get_file_hash(path: &PathBuf) -> Result<String, Box<dyn Error>> {
use std::io::Read;
let path = get_srs_path(logrows, srs_path);
let file = std::fs::File::open(path.clone())?;
let file = std::fs::File::open(path)?;
let mut reader = std::io::BufReader::new(file);
let mut buffer = vec![];
let mut reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
let bytes_read = reader.read_to_end(&mut buffer)?;
info!(
"read {} bytes from SRS file (vector of len = {})",
"read {} bytes from file (vector of len = {})",
bytes_read,
buffer.len()
);
let hash = sha256::digest(buffer);
info!("SRS hash: {}", hash);
info!("file hash: {}", hash);
Ok(hash)
}
#[cfg(not(target_arch = "wasm32"))]
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
let path = get_srs_path(logrows, srs_path);
let hash = get_file_hash(&path)?;
let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } {
Some(h) => h,
@@ -532,7 +536,6 @@ pub(crate) async fn get_srs_cmd(
srs_path: Option<PathBuf>,
settings_path: Option<PathBuf>,
logrows: Option<u32>,
check_mode: CheckMode,
) -> Result<String, Box<dyn Error>> {
// logrows overrides settings
@@ -560,21 +563,20 @@ pub(crate) async fn get_srs_cmd(
let srs_uri = format!("{}{}", PUBLIC_SRS_URL, k);
let mut reader = Cursor::new(fetch_srs(&srs_uri).await?);
// check the SRS
if matches!(check_mode, CheckMode::SAFE) {
#[cfg(not(target_arch = "wasm32"))]
let pb = init_spinner();
#[cfg(not(target_arch = "wasm32"))]
pb.set_message("Validating SRS (this may take a while) ...");
ParamsKZG::<Bn256>::read(&mut reader)?;
#[cfg(not(target_arch = "wasm32"))]
pb.finish_with_message("SRS validated");
}
#[cfg(not(target_arch = "wasm32"))]
let pb = init_spinner();
#[cfg(not(target_arch = "wasm32"))]
pb.set_message("Validating SRS (this may take a while) ...");
let params = ParamsKZG::<Bn256>::read(&mut reader)?;
#[cfg(not(target_arch = "wasm32"))]
pb.finish_with_message("SRS validated.");
info!("Saving SRS to disk...");
let mut file = std::fs::File::create(get_srs_path(k, srs_path.clone()))?;
let mut buffer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, &mut file);
buffer.write_all(reader.get_ref())?;
buffer.flush()?;
params.write(&mut buffer)?;
info!("Saved SRS to disk.");
info!("SRS downloaded");
} else {
@@ -902,7 +904,12 @@ pub(crate) fn calibrate(
input_scale, param_scale, scale_rebase_multiplier, div_rebasing
));
let key = (input_scale, param_scale, scale_rebase_multiplier);
let key = (
input_scale,
param_scale,
scale_rebase_multiplier,
div_rebasing,
);
forward_pass_res.insert(key, vec![]);
let local_run_args = RunArgs {
@@ -971,29 +978,21 @@ pub(crate) fn calibrate(
#[cfg(unix)]
drop(_g);
let min_lookup_range = forward_pass_res
.get(&key)
.unwrap()
let result = forward_pass_res.get(&key).ok_or("key not found")?;
let min_lookup_range = result
.iter()
.map(|x| x.min_lookup_inputs)
.min()
.unwrap_or(0);
let max_lookup_range = forward_pass_res
.get(&key)
.unwrap()
let max_lookup_range = result
.iter()
.map(|x| x.max_lookup_inputs)
.max()
.unwrap_or(0);
let max_range_size = forward_pass_res
.get(&key)
.unwrap()
.iter()
.map(|x| x.max_range_size)
.max()
.unwrap_or(0);
let max_range_size = result.iter().map(|x| x.max_range_size).max().unwrap_or(0);
let res = circuit.calc_min_logrows(
(min_lookup_range, max_lookup_range),
@@ -1114,6 +1113,7 @@ pub(crate) fn calibrate(
best_params.run_args.input_scale,
best_params.run_args.param_scale,
best_params.run_args.scale_rebase_multiplier,
best_params.run_args.div_rebasing,
))
.ok_or("no params found")?
.iter()
@@ -1697,7 +1697,7 @@ pub(crate) fn fuzz(
let logrows = circuit.settings().run_args.logrows;
info!("setting up tests");
#[cfg(unix)]
let _r = Gag::stdout()?;
let params = gen_srs::<KZGCommitmentScheme<Bn256>>(logrows);
@@ -1715,6 +1715,7 @@ pub(crate) fn fuzz(
let public_inputs = circuit.prepare_public_inputs(&data)?;
let strategy = KZGSingleStrategy::new(&params);
#[cfg(unix)]
std::mem::drop(_r);
info!("starting fuzzing");
@@ -1905,6 +1906,7 @@ pub(crate) fn run_fuzz_fn(
passed: &AtomicBool,
) {
let num_failures = AtomicI64::new(0);
#[cfg(unix)]
let _r = Gag::stdout().unwrap();
let pb = init_bar(num_runs as u64);
@@ -1918,6 +1920,7 @@ pub(crate) fn run_fuzz_fn(
pb.inc(1);
});
pb.finish_with_message("Done.");
#[cfg(unix)]
std::mem::drop(_r);
info!(
"num failures: {} out of {}",

View File

@@ -484,7 +484,6 @@ fn get_srs(
srs_path,
settings_path,
logrows,
CheckMode::SAFE,
))
.map_err(|e| {
let err_str = format!("Failed to get srs: {}", e);

View File

@@ -4,6 +4,37 @@ use super::{
};
use halo2_proofs::{arithmetic::Field, plonk::Instance};
pub(crate) fn create_constant_tensor<
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
>(
val: F,
len: usize,
) -> ValTensor<F> {
let mut constant = Tensor::from(vec![ValType::Constant(val); len].into_iter());
constant.set_visibility(&crate::graph::Visibility::Fixed);
ValTensor::from(constant)
}
pub(crate) fn create_unit_tensor<
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
>(
len: usize,
) -> ValTensor<F> {
let mut unit = Tensor::from(vec![ValType::Constant(F::ONE); len].into_iter());
unit.set_visibility(&crate::graph::Visibility::Fixed);
ValTensor::from(unit)
}
pub(crate) fn create_zero_tensor<
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
>(
len: usize,
) -> ValTensor<F> {
let mut zero = Tensor::from(vec![ValType::Constant(F::ZERO); len].into_iter());
zero.set_visibility(&crate::graph::Visibility::Fixed);
ValTensor::from(zero)
}
#[derive(Debug, Clone)]
/// A [ValType] is a wrapper around Halo2 value(s).
pub enum ValType<F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd> {
@@ -463,7 +494,12 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
_ => return Err(Box::new(TensorError::WrongMethod)),
};
Ok(integer_evals.into_iter().into())
let mut tensor: Tensor<i128> = integer_evals.into_iter().into();
match tensor.reshape(self.dims()) {
_ => {}
};
Ok(tensor)
}
/// Calls `pad_to_zero_rem` on the inner tensor.