mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-06 21:34:05 -05:00
feat(hpu): Now the mockup takes into account the field position from the regmap toml to generate its register read and write answers.
This commit is contained in:
@@ -148,10 +148,10 @@ pub struct HpuPcParameters {
|
||||
pub ksk_pc: usize,
|
||||
pub ksk_bytes_w: usize,
|
||||
pub bsk_pc: usize,
|
||||
pub glwe_pc: usize, // Currently hardcoded to 1
|
||||
pub bsk_bytes_w: usize,
|
||||
pub pem_pc: usize,
|
||||
pub pem_bytes_w: usize,
|
||||
// pub glwe_pc: usize, // Currently hardcoded to 1
|
||||
pub glwe_bytes_w: usize,
|
||||
}
|
||||
|
||||
|
||||
@@ -186,6 +186,7 @@ impl FromRtl for HpuPcParameters {
|
||||
let ksk_pc = *hbm_pc_fields.get("ksk_pc").expect("Unknown field") as usize;
|
||||
let bsk_pc = *hbm_pc_fields.get("bsk_pc").expect("Unknown field") as usize;
|
||||
let pem_pc = *hbm_pc_fields.get("pem_pc").expect("Unknown field") as usize;
|
||||
let glwe_pc = *hbm_pc_fields.get("glwe_pc").expect("Unknown field") as usize;
|
||||
|
||||
// Extract bus width for each channel
|
||||
let ksk_bytes_w = {
|
||||
@@ -229,6 +230,7 @@ impl FromRtl for HpuPcParameters {
|
||||
ksk_pc,
|
||||
bsk_pc,
|
||||
pem_pc,
|
||||
glwe_pc,
|
||||
ksk_bytes_w,
|
||||
bsk_bytes_w,
|
||||
pem_bytes_w,
|
||||
|
||||
@@ -37,6 +37,7 @@
|
||||
bsk_bytes_w= 64
|
||||
pem_pc= 2
|
||||
pem_bytes_w= 64
|
||||
glwe_pc= 1
|
||||
glwe_bytes_w= 64
|
||||
[regf_params]
|
||||
reg_nb= 64
|
||||
|
||||
@@ -37,6 +37,7 @@
|
||||
bsk_bytes_w= 64
|
||||
pem_pc= 2
|
||||
pem_bytes_w= 64
|
||||
glwe_pc= 1
|
||||
glwe_bytes_w= 64
|
||||
[regf_params]
|
||||
reg_nb= 64
|
||||
|
||||
@@ -39,6 +39,7 @@
|
||||
bsk_bytes_w= 32
|
||||
pem_pc= 2
|
||||
pem_bytes_w= 32
|
||||
glwe_pc= 1
|
||||
glwe_bytes_w= 32
|
||||
|
||||
[regf_params]
|
||||
|
||||
@@ -37,6 +37,7 @@
|
||||
bsk_bytes_w= 32
|
||||
pem_pc= 2
|
||||
pem_bytes_w= 32
|
||||
glwe_pc= 1
|
||||
glwe_bytes_w= 32
|
||||
[regf_params]
|
||||
reg_nb= 64
|
||||
|
||||
@@ -39,6 +39,7 @@
|
||||
bsk_bytes_w= 32
|
||||
pem_pc= 2
|
||||
pem_bytes_w= 32
|
||||
glwe_pc= 1
|
||||
glwe_bytes_w= 32
|
||||
|
||||
[regf_params]
|
||||
|
||||
@@ -40,6 +40,7 @@
|
||||
bsk_bytes_w= 32
|
||||
pem_pc= 2
|
||||
pem_bytes_w= 32
|
||||
glwe_pc= 1
|
||||
glwe_bytes_w= 32
|
||||
|
||||
[regf_params]
|
||||
|
||||
@@ -37,6 +37,7 @@
|
||||
bsk_bytes_w= 32
|
||||
pem_pc= 2
|
||||
pem_bytes_w= 32
|
||||
glwe_pc= 1
|
||||
glwe_bytes_w= 32
|
||||
[regf_params]
|
||||
reg_nb= 64
|
||||
|
||||
@@ -40,6 +40,7 @@
|
||||
bsk_bytes_w= 32
|
||||
pem_pc= 2
|
||||
pem_bytes_w= 32
|
||||
glwe_pc= 1
|
||||
glwe_bytes_w= 32
|
||||
|
||||
[regf_params]
|
||||
|
||||
@@ -40,6 +40,7 @@
|
||||
bsk_bytes_w= 32
|
||||
pem_pc= 2
|
||||
pem_bytes_w= 32
|
||||
glwe_pc= 1
|
||||
glwe_bytes_w= 32
|
||||
|
||||
[regf_params]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use super::*;
|
||||
@@ -101,10 +101,21 @@ impl RegisterMap {
|
||||
/// Return register value from parameter value
|
||||
pub fn read_reg(&mut self, addr: u64) -> u32 {
|
||||
let register_name = self.get_register_name(addr);
|
||||
let cur_reg = self
|
||||
.regmap
|
||||
.register()
|
||||
.get(register_name)
|
||||
.expect("Unknown register, check regmap definition");
|
||||
|
||||
match register_name {
|
||||
"info::ntt_structure" => {
|
||||
let ntt_p = &self.rtl_params.ntt_params;
|
||||
(ntt_p.radix + (ntt_p.psi << 8) /*+(ntt_p.div << 16)*/ + (ntt_p.delta << 24)) as u32
|
||||
cur_reg.from_field(HashMap::from([
|
||||
("radix", ntt_p.radix as u32),
|
||||
("psi", ntt_p.psi as u32),
|
||||
("delta", ntt_p.delta as u32),
|
||||
//("div", ntt_p.div as u32),
|
||||
]))
|
||||
}
|
||||
"info::ntt_rdx_cut" => {
|
||||
let ntt_p = &self.rtl_params.ntt_params;
|
||||
@@ -112,10 +123,11 @@ impl RegisterMap {
|
||||
HpuNttCoreArch::GF64(cut_w) => cut_w,
|
||||
_ => &vec![ntt_p.delta as u8],
|
||||
};
|
||||
cut_w
|
||||
.iter()
|
||||
.enumerate()
|
||||
.fold(0, |acc, (id, val)| acc + ((*val as u32) << (id * 4)))
|
||||
let mut hash_cut: HashMap<&str, u32> = HashMap::new();
|
||||
for (f, val) in cur_reg.field().iter().zip(cut_w) {
|
||||
hash_cut.insert(f.name(), *val as u32);
|
||||
}
|
||||
cur_reg.from_field(hash_cut)
|
||||
}
|
||||
"info::ntt_architecture" => match self.rtl_params.ntt_params.core_arch {
|
||||
HpuNttCoreArch::WmmCompactPcg => NTT_CORE_ARCH_OFS + 4,
|
||||
@@ -124,7 +136,10 @@ impl RegisterMap {
|
||||
},
|
||||
"info::ntt_pbs" => {
|
||||
let ntt_p = &self.rtl_params.ntt_params;
|
||||
(ntt_p.batch_pbs_nb + (ntt_p.total_pbs_nb << 8)) as u32
|
||||
cur_reg.from_field(HashMap::from([
|
||||
("batch_pbs_nb", ntt_p.batch_pbs_nb as u32),
|
||||
("total_pbs_nb", ntt_p.total_pbs_nb as u32),
|
||||
]))
|
||||
}
|
||||
"info::ntt_modulo" => {
|
||||
MOD_NTT_NAME_OFS + (self.rtl_params.ntt_params.prime_modulus as u8) as u32
|
||||
@@ -159,17 +174,29 @@ impl RegisterMap {
|
||||
}
|
||||
"info::ks_structure" => {
|
||||
let ks_p = &self.rtl_params.ks_params;
|
||||
(ks_p.lbx + (ks_p.lby << 8) + (ks_p.lbz << 16)) as u32
|
||||
cur_reg.from_field(HashMap::from([
|
||||
("x", ks_p.lbx as u32),
|
||||
("y", ks_p.lby as u32),
|
||||
("z", ks_p.lbz as u32),
|
||||
]))
|
||||
}
|
||||
"info::ks_crypto_param" => {
|
||||
let ks_p = &self.rtl_params.ks_params;
|
||||
let pbs_p = &self.rtl_params.pbs_params;
|
||||
(ks_p.width + (pbs_p.ks_level << 8) + (pbs_p.ks_base_log << 16)) as u32
|
||||
cur_reg.from_field(HashMap::from([
|
||||
("mod_ksk_w", ks_p.width as u32),
|
||||
("ks_l", pbs_p.ks_level as u32),
|
||||
("ks_b", pbs_p.ks_base_log as u32),
|
||||
]))
|
||||
}
|
||||
"info::hbm_axi4_nb" => {
|
||||
let pc_p = &self.rtl_params.pc_params;
|
||||
// TODO: Cut number currently not reverted
|
||||
(pc_p.bsk_pc + (pc_p.ksk_pc << 8) + (pc_p.pem_pc << 16)) as u32
|
||||
cur_reg.from_field(HashMap::from([
|
||||
("ksk_pc", pc_p.ksk_pc as u32),
|
||||
("bsk_pc", pc_p.bsk_pc as u32),
|
||||
("pem_pc", pc_p.pem_pc as u32),
|
||||
("glwe_pc", pc_p.glwe_pc as u32),
|
||||
]))
|
||||
}
|
||||
"info::hbm_axi4_dataw_ksk" => {
|
||||
let bytes_w = &self.rtl_params.pc_params.ksk_bytes_w;
|
||||
@@ -190,11 +217,17 @@ impl RegisterMap {
|
||||
|
||||
"info::regf_structure" => {
|
||||
let regf_p = &self.rtl_params.regf_params;
|
||||
(regf_p.reg_nb + (regf_p.coef_nb << 8)) as u32
|
||||
cur_reg.from_field(HashMap::from([
|
||||
("reg_nb", regf_p.reg_nb as u32),
|
||||
("coef_nb", regf_p.coef_nb as u32),
|
||||
]))
|
||||
}
|
||||
"info::isc_structure" => {
|
||||
let isc_p = &self.rtl_params.isc_params;
|
||||
(isc_p.depth + (isc_p.min_iop_size << 8)) as u32
|
||||
cur_reg.from_field(HashMap::from([
|
||||
("depth", isc_p.depth as u32),
|
||||
("min_iop_size", isc_p.min_iop_size as u32),
|
||||
]))
|
||||
}
|
||||
|
||||
"bsk_avail::avail" => self.bsk.avail.load(Ordering::SeqCst) as u32,
|
||||
@@ -217,9 +250,10 @@ impl RegisterMap {
|
||||
}
|
||||
|
||||
// Bpip configuration registers
|
||||
"bpip::use" => {
|
||||
((self.bpip.used as u8) + ((self.bpip.use_opportunism as u8) << 1)) as u32
|
||||
}
|
||||
"bpip::use" => cur_reg.from_field(HashMap::from([
|
||||
("use_bpip", self.bpip.used as u32),
|
||||
("use_opportunism", self.bpip.use_opportunism as u32),
|
||||
])),
|
||||
"bpip::timeout" => self.bpip.timeout,
|
||||
|
||||
// Add offset configuration registers
|
||||
@@ -318,13 +352,23 @@ impl RegisterMap {
|
||||
|
||||
pub fn write_reg(&mut self, addr: u64, value: u32) -> RegisterEvent {
|
||||
let register_name = self.get_register_name(addr);
|
||||
let cur_reg = self
|
||||
.regmap
|
||||
.register()
|
||||
.get(register_name)
|
||||
.expect("Unknown register, check regmap definition");
|
||||
let hash_val = cur_reg.as_field(value);
|
||||
|
||||
match register_name {
|
||||
"bsk_avail::avail" => {
|
||||
self.bsk.avail.store((value & 0x1) == 0x1, Ordering::SeqCst);
|
||||
let avail = hash_val.get("avail");
|
||||
|
||||
self.bsk.avail.store(avail == Some(&0x1), Ordering::SeqCst);
|
||||
RegisterEvent::None
|
||||
}
|
||||
"bsk_avail::reset" => {
|
||||
if (value & 0x1) == 0x1 {
|
||||
let req = hash_val.get("request");
|
||||
if req == Some(&0x1) {
|
||||
self.bsk.rst_pdg.store(true, Ordering::SeqCst);
|
||||
self.bsk.avail.store(false, Ordering::SeqCst);
|
||||
RegisterEvent::KeyReset
|
||||
@@ -333,11 +377,14 @@ impl RegisterMap {
|
||||
}
|
||||
}
|
||||
"ksk_avail::avail" => {
|
||||
self.ksk.avail.store((value & 0x1) == 0x1, Ordering::SeqCst);
|
||||
let avail = hash_val.get("avail");
|
||||
|
||||
self.ksk.avail.store(avail == Some(&0x1), Ordering::SeqCst);
|
||||
RegisterEvent::None
|
||||
}
|
||||
"ksk_avail::reset" => {
|
||||
if (value & 0x1) == 0x1 {
|
||||
let req = hash_val.get("request");
|
||||
if req == Some(&0x1) {
|
||||
self.ksk.rst_pdg.store(true, Ordering::SeqCst);
|
||||
self.ksk.avail.store(false, Ordering::SeqCst);
|
||||
RegisterEvent::KeyReset
|
||||
@@ -348,8 +395,10 @@ impl RegisterMap {
|
||||
|
||||
// Bpip configuration registers
|
||||
"bpip::use" => {
|
||||
self.bpip.used = (value & 0x1) == 0x1;
|
||||
self.bpip.use_opportunism = (value & 0x2) == 0x2;
|
||||
let use_bpip = hash_val.get("use_bpip");
|
||||
let use_opportunism = hash_val.get("use_opportunism");
|
||||
self.bpip.used = use_bpip == Some(&0x1);
|
||||
self.bpip.use_opportunism = use_opportunism == Some(&0x1);
|
||||
RegisterEvent::None
|
||||
}
|
||||
"bpip::timeout" => {
|
||||
|
||||
Reference in New Issue
Block a user