Compare commits

...

1 Commits

Author SHA1 Message Date
dante
ddbcc1d2d8 fix: calibration should only consider local scales (#691) 2024-01-18 23:28:49 +00:00
2 changed files with 18 additions and 9 deletions

View File

@@ -643,10 +643,10 @@ jobs:
# # now dump the contents of the file into a file called kaggle.json
# echo $KAGGLE_API_KEY > /home/ubuntu/.kaggle/kaggle.json
# chmod 600 /home/ubuntu/.kaggle/kaggle.json
- name: NBEATS tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
- name: Voice tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
- name: NBEATS tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
- name: All notebooks
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
- name: Tictactoe tutorials

View File

@@ -834,7 +834,6 @@ pub(crate) fn calibrate(
Ok(r) => Some(r),
Err(_) => None,
};
let key = (input_scale, param_scale, scale_rebase_multiplier);
forward_pass_res.insert(key, vec![]);
@@ -847,7 +846,15 @@ pub(crate) fn calibrate(
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
Ok(c) => c,
Err(_) => return Err(format!("failed to create circuit from run args").into()),
Err(e) => {
// drop the gag
#[cfg(unix)]
std::mem::drop(_r);
#[cfg(unix)]
std::mem::drop(_q);
debug!("circuit creation from run args failed: {:?}", e);
continue;
}
};
chunks
@@ -874,16 +881,18 @@ pub(crate) fn calibrate(
.collect::<Result<Vec<()>, String>>()?;
let min_lookup_range = forward_pass_res
.get(&key)
.unwrap()
.iter()
.map(|x| x.1.iter().map(|x| x.min_lookup_inputs))
.flatten()
.map(|x| x.min_lookup_inputs)
.min()
.unwrap_or(0);
let max_lookup_range = forward_pass_res
.get(&key)
.unwrap()
.iter()
.map(|x| x.1.iter().map(|x| x.max_lookup_inputs))
.flatten()
.map(|x| x.max_lookup_inputs)
.max()
.unwrap_or(0);
@@ -930,7 +939,7 @@ pub(crate) fn calibrate(
found_settings.as_json()?.to_colored_json_auto()?
);
} else {
debug!("calibration failed");
debug!("calibration failed {}", res.err().unwrap());
}
pb.inc(1);