mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
feat(hpu): Now the mockup takes into account the field position from the regmap toml to generate its register read and write answers.
This commit is contained in:
@@ -148,10 +148,10 @@ pub struct HpuPcParameters {
|
|||||||
pub ksk_pc: usize,
|
pub ksk_pc: usize,
|
||||||
pub ksk_bytes_w: usize,
|
pub ksk_bytes_w: usize,
|
||||||
pub bsk_pc: usize,
|
pub bsk_pc: usize,
|
||||||
|
pub glwe_pc: usize, // Currently hardcoded to 1
|
||||||
pub bsk_bytes_w: usize,
|
pub bsk_bytes_w: usize,
|
||||||
pub pem_pc: usize,
|
pub pem_pc: usize,
|
||||||
pub pem_bytes_w: usize,
|
pub pem_bytes_w: usize,
|
||||||
// pub glwe_pc: usize, // Currently hardcoded to 1
|
|
||||||
pub glwe_bytes_w: usize,
|
pub glwe_bytes_w: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -186,6 +186,7 @@ impl FromRtl for HpuPcParameters {
|
|||||||
let ksk_pc = *hbm_pc_fields.get("ksk_pc").expect("Unknown field") as usize;
|
let ksk_pc = *hbm_pc_fields.get("ksk_pc").expect("Unknown field") as usize;
|
||||||
let bsk_pc = *hbm_pc_fields.get("bsk_pc").expect("Unknown field") as usize;
|
let bsk_pc = *hbm_pc_fields.get("bsk_pc").expect("Unknown field") as usize;
|
||||||
let pem_pc = *hbm_pc_fields.get("pem_pc").expect("Unknown field") as usize;
|
let pem_pc = *hbm_pc_fields.get("pem_pc").expect("Unknown field") as usize;
|
||||||
|
let glwe_pc = *hbm_pc_fields.get("glwe_pc").expect("Unknown field") as usize;
|
||||||
|
|
||||||
// Extract bus width for each channel
|
// Extract bus width for each channel
|
||||||
let ksk_bytes_w = {
|
let ksk_bytes_w = {
|
||||||
@@ -229,6 +230,7 @@ impl FromRtl for HpuPcParameters {
|
|||||||
ksk_pc,
|
ksk_pc,
|
||||||
bsk_pc,
|
bsk_pc,
|
||||||
pem_pc,
|
pem_pc,
|
||||||
|
glwe_pc,
|
||||||
ksk_bytes_w,
|
ksk_bytes_w,
|
||||||
bsk_bytes_w,
|
bsk_bytes_w,
|
||||||
pem_bytes_w,
|
pem_bytes_w,
|
||||||
|
|||||||
@@ -37,6 +37,7 @@
|
|||||||
bsk_bytes_w= 64
|
bsk_bytes_w= 64
|
||||||
pem_pc= 2
|
pem_pc= 2
|
||||||
pem_bytes_w= 64
|
pem_bytes_w= 64
|
||||||
|
glwe_pc= 1
|
||||||
glwe_bytes_w= 64
|
glwe_bytes_w= 64
|
||||||
[regf_params]
|
[regf_params]
|
||||||
reg_nb= 64
|
reg_nb= 64
|
||||||
|
|||||||
@@ -37,6 +37,7 @@
|
|||||||
bsk_bytes_w= 64
|
bsk_bytes_w= 64
|
||||||
pem_pc= 2
|
pem_pc= 2
|
||||||
pem_bytes_w= 64
|
pem_bytes_w= 64
|
||||||
|
glwe_pc= 1
|
||||||
glwe_bytes_w= 64
|
glwe_bytes_w= 64
|
||||||
[regf_params]
|
[regf_params]
|
||||||
reg_nb= 64
|
reg_nb= 64
|
||||||
|
|||||||
@@ -39,6 +39,7 @@
|
|||||||
bsk_bytes_w= 32
|
bsk_bytes_w= 32
|
||||||
pem_pc= 2
|
pem_pc= 2
|
||||||
pem_bytes_w= 32
|
pem_bytes_w= 32
|
||||||
|
glwe_pc= 1
|
||||||
glwe_bytes_w= 32
|
glwe_bytes_w= 32
|
||||||
|
|
||||||
[regf_params]
|
[regf_params]
|
||||||
|
|||||||
@@ -37,6 +37,7 @@
|
|||||||
bsk_bytes_w= 32
|
bsk_bytes_w= 32
|
||||||
pem_pc= 2
|
pem_pc= 2
|
||||||
pem_bytes_w= 32
|
pem_bytes_w= 32
|
||||||
|
glwe_pc= 1
|
||||||
glwe_bytes_w= 32
|
glwe_bytes_w= 32
|
||||||
[regf_params]
|
[regf_params]
|
||||||
reg_nb= 64
|
reg_nb= 64
|
||||||
|
|||||||
@@ -39,6 +39,7 @@
|
|||||||
bsk_bytes_w= 32
|
bsk_bytes_w= 32
|
||||||
pem_pc= 2
|
pem_pc= 2
|
||||||
pem_bytes_w= 32
|
pem_bytes_w= 32
|
||||||
|
glwe_pc= 1
|
||||||
glwe_bytes_w= 32
|
glwe_bytes_w= 32
|
||||||
|
|
||||||
[regf_params]
|
[regf_params]
|
||||||
|
|||||||
@@ -40,6 +40,7 @@
|
|||||||
bsk_bytes_w= 32
|
bsk_bytes_w= 32
|
||||||
pem_pc= 2
|
pem_pc= 2
|
||||||
pem_bytes_w= 32
|
pem_bytes_w= 32
|
||||||
|
glwe_pc= 1
|
||||||
glwe_bytes_w= 32
|
glwe_bytes_w= 32
|
||||||
|
|
||||||
[regf_params]
|
[regf_params]
|
||||||
|
|||||||
@@ -37,6 +37,7 @@
|
|||||||
bsk_bytes_w= 32
|
bsk_bytes_w= 32
|
||||||
pem_pc= 2
|
pem_pc= 2
|
||||||
pem_bytes_w= 32
|
pem_bytes_w= 32
|
||||||
|
glwe_pc= 1
|
||||||
glwe_bytes_w= 32
|
glwe_bytes_w= 32
|
||||||
[regf_params]
|
[regf_params]
|
||||||
reg_nb= 64
|
reg_nb= 64
|
||||||
|
|||||||
@@ -40,6 +40,7 @@
|
|||||||
bsk_bytes_w= 32
|
bsk_bytes_w= 32
|
||||||
pem_pc= 2
|
pem_pc= 2
|
||||||
pem_bytes_w= 32
|
pem_bytes_w= 32
|
||||||
|
glwe_pc= 1
|
||||||
glwe_bytes_w= 32
|
glwe_bytes_w= 32
|
||||||
|
|
||||||
[regf_params]
|
[regf_params]
|
||||||
|
|||||||
@@ -40,6 +40,7 @@
|
|||||||
bsk_bytes_w= 32
|
bsk_bytes_w= 32
|
||||||
pem_pc= 2
|
pem_pc= 2
|
||||||
pem_bytes_w= 32
|
pem_bytes_w= 32
|
||||||
|
glwe_pc= 1
|
||||||
glwe_bytes_w= 32
|
glwe_bytes_w= 32
|
||||||
|
|
||||||
[regf_params]
|
[regf_params]
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use std::collections::VecDeque;
|
use std::collections::{HashMap, VecDeque};
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -101,10 +101,21 @@ impl RegisterMap {
|
|||||||
/// Return register value from parameter value
|
/// Return register value from parameter value
|
||||||
pub fn read_reg(&mut self, addr: u64) -> u32 {
|
pub fn read_reg(&mut self, addr: u64) -> u32 {
|
||||||
let register_name = self.get_register_name(addr);
|
let register_name = self.get_register_name(addr);
|
||||||
|
let cur_reg = self
|
||||||
|
.regmap
|
||||||
|
.register()
|
||||||
|
.get(register_name)
|
||||||
|
.expect("Unknown register, check regmap definition");
|
||||||
|
|
||||||
match register_name {
|
match register_name {
|
||||||
"info::ntt_structure" => {
|
"info::ntt_structure" => {
|
||||||
let ntt_p = &self.rtl_params.ntt_params;
|
let ntt_p = &self.rtl_params.ntt_params;
|
||||||
(ntt_p.radix + (ntt_p.psi << 8) /*+(ntt_p.div << 16)*/ + (ntt_p.delta << 24)) as u32
|
cur_reg.from_field(HashMap::from([
|
||||||
|
("radix", ntt_p.radix as u32),
|
||||||
|
("psi", ntt_p.psi as u32),
|
||||||
|
("delta", ntt_p.delta as u32),
|
||||||
|
//("div", ntt_p.div as u32),
|
||||||
|
]))
|
||||||
}
|
}
|
||||||
"info::ntt_rdx_cut" => {
|
"info::ntt_rdx_cut" => {
|
||||||
let ntt_p = &self.rtl_params.ntt_params;
|
let ntt_p = &self.rtl_params.ntt_params;
|
||||||
@@ -112,10 +123,11 @@ impl RegisterMap {
|
|||||||
HpuNttCoreArch::GF64(cut_w) => cut_w,
|
HpuNttCoreArch::GF64(cut_w) => cut_w,
|
||||||
_ => &vec![ntt_p.delta as u8],
|
_ => &vec![ntt_p.delta as u8],
|
||||||
};
|
};
|
||||||
cut_w
|
let mut hash_cut: HashMap<&str, u32> = HashMap::new();
|
||||||
.iter()
|
for (f, val) in cur_reg.field().iter().zip(cut_w) {
|
||||||
.enumerate()
|
hash_cut.insert(f.name(), *val as u32);
|
||||||
.fold(0, |acc, (id, val)| acc + ((*val as u32) << (id * 4)))
|
}
|
||||||
|
cur_reg.from_field(hash_cut)
|
||||||
}
|
}
|
||||||
"info::ntt_architecture" => match self.rtl_params.ntt_params.core_arch {
|
"info::ntt_architecture" => match self.rtl_params.ntt_params.core_arch {
|
||||||
HpuNttCoreArch::WmmCompactPcg => NTT_CORE_ARCH_OFS + 4,
|
HpuNttCoreArch::WmmCompactPcg => NTT_CORE_ARCH_OFS + 4,
|
||||||
@@ -124,7 +136,10 @@ impl RegisterMap {
|
|||||||
},
|
},
|
||||||
"info::ntt_pbs" => {
|
"info::ntt_pbs" => {
|
||||||
let ntt_p = &self.rtl_params.ntt_params;
|
let ntt_p = &self.rtl_params.ntt_params;
|
||||||
(ntt_p.batch_pbs_nb + (ntt_p.total_pbs_nb << 8)) as u32
|
cur_reg.from_field(HashMap::from([
|
||||||
|
("batch_pbs_nb", ntt_p.batch_pbs_nb as u32),
|
||||||
|
("total_pbs_nb", ntt_p.total_pbs_nb as u32),
|
||||||
|
]))
|
||||||
}
|
}
|
||||||
"info::ntt_modulo" => {
|
"info::ntt_modulo" => {
|
||||||
MOD_NTT_NAME_OFS + (self.rtl_params.ntt_params.prime_modulus as u8) as u32
|
MOD_NTT_NAME_OFS + (self.rtl_params.ntt_params.prime_modulus as u8) as u32
|
||||||
@@ -159,17 +174,29 @@ impl RegisterMap {
|
|||||||
}
|
}
|
||||||
"info::ks_structure" => {
|
"info::ks_structure" => {
|
||||||
let ks_p = &self.rtl_params.ks_params;
|
let ks_p = &self.rtl_params.ks_params;
|
||||||
(ks_p.lbx + (ks_p.lby << 8) + (ks_p.lbz << 16)) as u32
|
cur_reg.from_field(HashMap::from([
|
||||||
|
("x", ks_p.lbx as u32),
|
||||||
|
("y", ks_p.lby as u32),
|
||||||
|
("z", ks_p.lbz as u32),
|
||||||
|
]))
|
||||||
}
|
}
|
||||||
"info::ks_crypto_param" => {
|
"info::ks_crypto_param" => {
|
||||||
let ks_p = &self.rtl_params.ks_params;
|
let ks_p = &self.rtl_params.ks_params;
|
||||||
let pbs_p = &self.rtl_params.pbs_params;
|
let pbs_p = &self.rtl_params.pbs_params;
|
||||||
(ks_p.width + (pbs_p.ks_level << 8) + (pbs_p.ks_base_log << 16)) as u32
|
cur_reg.from_field(HashMap::from([
|
||||||
|
("mod_ksk_w", ks_p.width as u32),
|
||||||
|
("ks_l", pbs_p.ks_level as u32),
|
||||||
|
("ks_b", pbs_p.ks_base_log as u32),
|
||||||
|
]))
|
||||||
}
|
}
|
||||||
"info::hbm_axi4_nb" => {
|
"info::hbm_axi4_nb" => {
|
||||||
let pc_p = &self.rtl_params.pc_params;
|
let pc_p = &self.rtl_params.pc_params;
|
||||||
// TODO: Cut number currently not reverted
|
cur_reg.from_field(HashMap::from([
|
||||||
(pc_p.bsk_pc + (pc_p.ksk_pc << 8) + (pc_p.pem_pc << 16)) as u32
|
("ksk_pc", pc_p.ksk_pc as u32),
|
||||||
|
("bsk_pc", pc_p.bsk_pc as u32),
|
||||||
|
("pem_pc", pc_p.pem_pc as u32),
|
||||||
|
("glwe_pc", pc_p.glwe_pc as u32),
|
||||||
|
]))
|
||||||
}
|
}
|
||||||
"info::hbm_axi4_dataw_ksk" => {
|
"info::hbm_axi4_dataw_ksk" => {
|
||||||
let bytes_w = &self.rtl_params.pc_params.ksk_bytes_w;
|
let bytes_w = &self.rtl_params.pc_params.ksk_bytes_w;
|
||||||
@@ -190,11 +217,17 @@ impl RegisterMap {
|
|||||||
|
|
||||||
"info::regf_structure" => {
|
"info::regf_structure" => {
|
||||||
let regf_p = &self.rtl_params.regf_params;
|
let regf_p = &self.rtl_params.regf_params;
|
||||||
(regf_p.reg_nb + (regf_p.coef_nb << 8)) as u32
|
cur_reg.from_field(HashMap::from([
|
||||||
|
("reg_nb", regf_p.reg_nb as u32),
|
||||||
|
("coef_nb", regf_p.coef_nb as u32),
|
||||||
|
]))
|
||||||
}
|
}
|
||||||
"info::isc_structure" => {
|
"info::isc_structure" => {
|
||||||
let isc_p = &self.rtl_params.isc_params;
|
let isc_p = &self.rtl_params.isc_params;
|
||||||
(isc_p.depth + (isc_p.min_iop_size << 8)) as u32
|
cur_reg.from_field(HashMap::from([
|
||||||
|
("depth", isc_p.depth as u32),
|
||||||
|
("min_iop_size", isc_p.min_iop_size as u32),
|
||||||
|
]))
|
||||||
}
|
}
|
||||||
|
|
||||||
"bsk_avail::avail" => self.bsk.avail.load(Ordering::SeqCst) as u32,
|
"bsk_avail::avail" => self.bsk.avail.load(Ordering::SeqCst) as u32,
|
||||||
@@ -217,9 +250,10 @@ impl RegisterMap {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Bpip configuration registers
|
// Bpip configuration registers
|
||||||
"bpip::use" => {
|
"bpip::use" => cur_reg.from_field(HashMap::from([
|
||||||
((self.bpip.used as u8) + ((self.bpip.use_opportunism as u8) << 1)) as u32
|
("use_bpip", self.bpip.used as u32),
|
||||||
}
|
("use_opportunism", self.bpip.use_opportunism as u32),
|
||||||
|
])),
|
||||||
"bpip::timeout" => self.bpip.timeout,
|
"bpip::timeout" => self.bpip.timeout,
|
||||||
|
|
||||||
// Add offset configuration registers
|
// Add offset configuration registers
|
||||||
@@ -318,13 +352,23 @@ impl RegisterMap {
|
|||||||
|
|
||||||
pub fn write_reg(&mut self, addr: u64, value: u32) -> RegisterEvent {
|
pub fn write_reg(&mut self, addr: u64, value: u32) -> RegisterEvent {
|
||||||
let register_name = self.get_register_name(addr);
|
let register_name = self.get_register_name(addr);
|
||||||
|
let cur_reg = self
|
||||||
|
.regmap
|
||||||
|
.register()
|
||||||
|
.get(register_name)
|
||||||
|
.expect("Unknown register, check regmap definition");
|
||||||
|
let hash_val = cur_reg.as_field(value);
|
||||||
|
|
||||||
match register_name {
|
match register_name {
|
||||||
"bsk_avail::avail" => {
|
"bsk_avail::avail" => {
|
||||||
self.bsk.avail.store((value & 0x1) == 0x1, Ordering::SeqCst);
|
let avail = hash_val.get("avail");
|
||||||
|
|
||||||
|
self.bsk.avail.store(avail == Some(&0x1), Ordering::SeqCst);
|
||||||
RegisterEvent::None
|
RegisterEvent::None
|
||||||
}
|
}
|
||||||
"bsk_avail::reset" => {
|
"bsk_avail::reset" => {
|
||||||
if (value & 0x1) == 0x1 {
|
let req = hash_val.get("request");
|
||||||
|
if req == Some(&0x1) {
|
||||||
self.bsk.rst_pdg.store(true, Ordering::SeqCst);
|
self.bsk.rst_pdg.store(true, Ordering::SeqCst);
|
||||||
self.bsk.avail.store(false, Ordering::SeqCst);
|
self.bsk.avail.store(false, Ordering::SeqCst);
|
||||||
RegisterEvent::KeyReset
|
RegisterEvent::KeyReset
|
||||||
@@ -333,11 +377,14 @@ impl RegisterMap {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
"ksk_avail::avail" => {
|
"ksk_avail::avail" => {
|
||||||
self.ksk.avail.store((value & 0x1) == 0x1, Ordering::SeqCst);
|
let avail = hash_val.get("avail");
|
||||||
|
|
||||||
|
self.ksk.avail.store(avail == Some(&0x1), Ordering::SeqCst);
|
||||||
RegisterEvent::None
|
RegisterEvent::None
|
||||||
}
|
}
|
||||||
"ksk_avail::reset" => {
|
"ksk_avail::reset" => {
|
||||||
if (value & 0x1) == 0x1 {
|
let req = hash_val.get("request");
|
||||||
|
if req == Some(&0x1) {
|
||||||
self.ksk.rst_pdg.store(true, Ordering::SeqCst);
|
self.ksk.rst_pdg.store(true, Ordering::SeqCst);
|
||||||
self.ksk.avail.store(false, Ordering::SeqCst);
|
self.ksk.avail.store(false, Ordering::SeqCst);
|
||||||
RegisterEvent::KeyReset
|
RegisterEvent::KeyReset
|
||||||
@@ -348,8 +395,10 @@ impl RegisterMap {
|
|||||||
|
|
||||||
// Bpip configuration registers
|
// Bpip configuration registers
|
||||||
"bpip::use" => {
|
"bpip::use" => {
|
||||||
self.bpip.used = (value & 0x1) == 0x1;
|
let use_bpip = hash_val.get("use_bpip");
|
||||||
self.bpip.use_opportunism = (value & 0x2) == 0x2;
|
let use_opportunism = hash_val.get("use_opportunism");
|
||||||
|
self.bpip.used = use_bpip == Some(&0x1);
|
||||||
|
self.bpip.use_opportunism = use_opportunism == Some(&0x1);
|
||||||
RegisterEvent::None
|
RegisterEvent::None
|
||||||
}
|
}
|
||||||
"bpip::timeout" => {
|
"bpip::timeout" => {
|
||||||
|
|||||||
Reference in New Issue
Block a user