mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-10 06:48:01 -05:00
feat: optional rebase scale override (#993)
This commit is contained in:
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -200,7 +200,7 @@ jobs:
|
||||
|
||||
- name: Build release binary (asm)
|
||||
if: matrix.build == 'linux-gnu'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm,mimalloc
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features mimalloc
|
||||
|
||||
- name: Build release binary (metal)
|
||||
if: matrix.build == 'macos-aarch64'
|
||||
|
||||
@@ -142,6 +142,9 @@ struct PyRunArgs {
|
||||
#[pyo3(get, set)]
|
||||
/// int: The denominator in the fixed point representation used when quantizing parameters
|
||||
pub param_scale: crate::Scale,
|
||||
/// int: The scale to rebase to (optional). If None, we rebase to the max of input_scale and param_scale
|
||||
/// This is an advanced parameter that should be used with caution
|
||||
pub rebase_scale: Option<crate::Scale>,
|
||||
#[pyo3(get, set)]
|
||||
/// int: If the scale is ever > scale_rebase_multiplier * input_scale then the scale is rebased to input_scale (this a more advanced parameter, use with caution)
|
||||
pub scale_rebase_multiplier: u32,
|
||||
@@ -208,6 +211,7 @@ impl From<PyRunArgs> for RunArgs {
|
||||
bounded_log_lookup: py_run_args.bounded_log_lookup,
|
||||
input_scale: py_run_args.input_scale,
|
||||
param_scale: py_run_args.param_scale,
|
||||
rebase_scale: py_run_args.rebase_scale,
|
||||
num_inner_cols: py_run_args.num_inner_cols,
|
||||
scale_rebase_multiplier: py_run_args.scale_rebase_multiplier,
|
||||
lookup_range: py_run_args.lookup_range,
|
||||
@@ -234,6 +238,7 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
bounded_log_lookup: self.bounded_log_lookup,
|
||||
input_scale: self.input_scale,
|
||||
param_scale: self.param_scale,
|
||||
rebase_scale: self.rebase_scale,
|
||||
num_inner_cols: self.num_inner_cols,
|
||||
scale_rebase_multiplier: self.scale_rebase_multiplier,
|
||||
lookup_range: self.lookup_range,
|
||||
|
||||
@@ -695,8 +695,8 @@ impl Node {
|
||||
opkind = opkind.homogenous_rescale(in_scales.clone())?.into();
|
||||
let mut out_scale = opkind.out_scale(in_scales.clone())?;
|
||||
// rescale the inputs if necessary to get consistent fixed points, we select the largest scale (highest precision)
|
||||
let global_scale = scales.get_max();
|
||||
opkind = RebaseScale::rebase(opkind, global_scale, out_scale, scales.rebase_multiplier);
|
||||
let rebase_scale = scales.get_rebase_scale();
|
||||
opkind = RebaseScale::rebase(opkind, rebase_scale, out_scale, scales.rebase_multiplier);
|
||||
|
||||
out_scale = opkind.out_scale(in_scales)?;
|
||||
|
||||
|
||||
@@ -250,6 +250,8 @@ pub struct VarScales {
|
||||
pub params: crate::Scale,
|
||||
/// Multiplier for scale rebasing
|
||||
pub rebase_multiplier: u32,
|
||||
/// rebase scale factor (optional). if None, we rebase to the max of input_scale and param_scale
|
||||
pub rebase_scale: Option<crate::Scale>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for VarScales {
|
||||
@@ -269,11 +271,21 @@ impl VarScales {
|
||||
std::cmp::min(self.input, self.params)
|
||||
}
|
||||
|
||||
/// Returns the scale to rebase to, if specified
|
||||
pub fn get_rebase_scale(&self) -> crate::Scale {
|
||||
if let Some(rebase_scale) = self.rebase_scale {
|
||||
rebase_scale
|
||||
} else {
|
||||
self.get_max()
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates VarScales from runtime arguments
|
||||
pub fn from_args(args: &RunArgs) -> Self {
|
||||
Self {
|
||||
input: args.input_scale,
|
||||
params: args.param_scale,
|
||||
rebase_scale: args.rebase_scale,
|
||||
rebase_multiplier: args.scale_rebase_multiplier,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -288,6 +288,10 @@ pub struct RunArgs {
|
||||
/// Higher values provide more precision but increase circuit complexity
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "7", value_hint = clap::ValueHint::Other))]
|
||||
pub param_scale: Scale,
|
||||
/// Scale to rebase to when the input scale exceeds rebase_scale * multiplier. If None we rebase to the max of input_scale and param_scale
|
||||
/// This is an advanced parameter that should be used with caution
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, required = false, value_hint = clap::ValueHint::Other))]
|
||||
pub rebase_scale: Option<Scale>,
|
||||
/// Scale rebase threshold multiplier
|
||||
/// When scale exceeds input_scale * multiplier, it is rebased to input_scale
|
||||
/// Advanced parameter that should be used with caution
|
||||
@@ -378,6 +382,7 @@ impl Default for RunArgs {
|
||||
bounded_log_lookup: false,
|
||||
input_scale: 7,
|
||||
param_scale: 7,
|
||||
rebase_scale: None,
|
||||
scale_rebase_multiplier: 1,
|
||||
lookup_range: (-32768, 32768),
|
||||
logrows: 17,
|
||||
|
||||
Binary file not shown.
Reference in New Issue
Block a user