feat: optional rebase scale override (#993)

This commit is contained in:
dante
2025-07-27 16:05:54 -04:00
committed by GitHub
parent edd4d7f5b8
commit 2f1a3f430e
6 changed files with 25 additions and 3 deletions

View File

@@ -200,7 +200,7 @@ jobs:
- name: Build release binary (asm)
if: matrix.build == 'linux-gnu'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm,mimalloc
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features mimalloc
- name: Build release binary (metal)
if: matrix.build == 'macos-aarch64'

View File

@@ -142,6 +142,9 @@ struct PyRunArgs {
#[pyo3(get, set)]
/// int: The denominator in the fixed point representation used when quantizing parameters
pub param_scale: crate::Scale,
/// int: The scale to rebase to (optional). If None, we rebase to the max of input_scale and param_scale
/// This is an advanced parameter that should be used with caution
pub rebase_scale: Option<crate::Scale>,
#[pyo3(get, set)]
/// int: If the scale is ever > scale_rebase_multiplier * input_scale then the scale is rebased to input_scale (this a more advanced parameter, use with caution)
pub scale_rebase_multiplier: u32,
@@ -208,6 +211,7 @@ impl From<PyRunArgs> for RunArgs {
bounded_log_lookup: py_run_args.bounded_log_lookup,
input_scale: py_run_args.input_scale,
param_scale: py_run_args.param_scale,
rebase_scale: py_run_args.rebase_scale,
num_inner_cols: py_run_args.num_inner_cols,
scale_rebase_multiplier: py_run_args.scale_rebase_multiplier,
lookup_range: py_run_args.lookup_range,
@@ -234,6 +238,7 @@ impl Into<PyRunArgs> for RunArgs {
bounded_log_lookup: self.bounded_log_lookup,
input_scale: self.input_scale,
param_scale: self.param_scale,
rebase_scale: self.rebase_scale,
num_inner_cols: self.num_inner_cols,
scale_rebase_multiplier: self.scale_rebase_multiplier,
lookup_range: self.lookup_range,

View File

@@ -695,8 +695,8 @@ impl Node {
opkind = opkind.homogenous_rescale(in_scales.clone())?.into();
let mut out_scale = opkind.out_scale(in_scales.clone())?;
// rescale the inputs if necessary to get consistent fixed points, we select the largest scale (highest precision)
let global_scale = scales.get_max();
opkind = RebaseScale::rebase(opkind, global_scale, out_scale, scales.rebase_multiplier);
let rebase_scale = scales.get_rebase_scale();
opkind = RebaseScale::rebase(opkind, rebase_scale, out_scale, scales.rebase_multiplier);
out_scale = opkind.out_scale(in_scales)?;

View File

@@ -250,6 +250,8 @@ pub struct VarScales {
pub params: crate::Scale,
/// Multiplier for scale rebasing
pub rebase_multiplier: u32,
/// rebase scale factor (optional). if None, we rebase to the max of input_scale and param_scale
pub rebase_scale: Option<crate::Scale>,
}
impl std::fmt::Display for VarScales {
@@ -269,11 +271,21 @@ impl VarScales {
std::cmp::min(self.input, self.params)
}
/// Returns the scale to rebase to, if specified
pub fn get_rebase_scale(&self) -> crate::Scale {
if let Some(rebase_scale) = self.rebase_scale {
rebase_scale
} else {
self.get_max()
}
}
/// Creates VarScales from runtime arguments
pub fn from_args(args: &RunArgs) -> Self {
Self {
input: args.input_scale,
params: args.param_scale,
rebase_scale: args.rebase_scale,
rebase_multiplier: args.scale_rebase_multiplier,
}
}

View File

@@ -288,6 +288,10 @@ pub struct RunArgs {
/// Higher values provide more precision but increase circuit complexity
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "7", value_hint = clap::ValueHint::Other))]
pub param_scale: Scale,
/// Scale to rebase to when the input scale exceeds rebase_scale * multiplier. If None we rebase to the max of input_scale and param_scale
/// This is an advanced parameter that should be used with caution
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, required = false, value_hint = clap::ValueHint::Other))]
pub rebase_scale: Option<Scale>,
/// Scale rebase threshold multiplier
/// When scale exceeds input_scale * multiplier, it is rebased to input_scale
/// Advanced parameter that should be used with caution
@@ -378,6 +382,7 @@ impl Default for RunArgs {
bounded_log_lookup: false,
input_scale: 7,
param_scale: 7,
rebase_scale: None,
scale_rebase_multiplier: 1,
lookup_range: (-32768, 32768),
logrows: 17,

Binary file not shown.