mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-04-28 03:01:21 -04:00
Compare commits
3 Commits
feat/princ
...
hw-team/pg
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc6509714c | ||
|
|
a2bf556f4c | ||
|
|
77131ea678 |
@@ -22,6 +22,7 @@ cxx-build = "1.0"
|
||||
|
||||
[dependencies]
|
||||
cxx = "1.0"
|
||||
libc = "0.2"
|
||||
hw_regmap = "0.2.1"
|
||||
|
||||
strum = { version = "0.26.2", features = ["derive"] }
|
||||
@@ -52,7 +53,7 @@ ipc-channel = "0.18.3"
|
||||
num-traits = { version = "0.2", optional = true }
|
||||
clap = { version = "4.4.4", features = ["derive"], optional = true }
|
||||
clap-num = { version = "1.1.1", optional = true }
|
||||
nix = { version = "0.29.0", features = ["ioctl", "uio", "fs"] }
|
||||
nix = { version = "0.29.0", features = ["mman", "ioctl", "uio", "fs"] }
|
||||
|
||||
# Dependencies used for rtl_graph features
|
||||
dot2 = { version = "1.0", optional = true }
|
||||
|
||||
@@ -228,6 +228,11 @@ impl HpuHw {
|
||||
pub fn iop_ack_rd(&mut self) -> u32 {
|
||||
self.0.ami.iop_ackq_rd()
|
||||
}
|
||||
|
||||
#[cfg(feature = "hw-v80")]
|
||||
pub fn map_bar_reg(&mut self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
self.0.ami.map_bar_reg()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MemZone(
|
||||
|
||||
@@ -3,14 +3,19 @@
|
||||
//! AMI driver is used to issue gcq command to the RPU
|
||||
//! Those command are used for configuration and register R/W
|
||||
use lazy_static::lazy_static;
|
||||
use nix::sys::mman::{mmap, munmap, MapFlags, ProtFlags};
|
||||
use std::error::Error;
|
||||
use std::fs::{File, OpenOptions};
|
||||
use std::io::{BufRead, BufReader, Read};
|
||||
use std::num::NonZero;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
use std::ptr::NonNull;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::time::Duration;
|
||||
|
||||
const AMI_VERSION_FILE: &str = "/sys/module/ami/version";
|
||||
const AMI_VERSION_PATTERN: &str = r"3\.1\.\d+-zama";
|
||||
const AMI_VERSION_PATTERN: &str = r"3\.2\.\d+-zama";
|
||||
|
||||
const AMI_ID_FILE: &str = "/sys/bus/pci/drivers/ami/devices";
|
||||
const AMI_ID_PATTERN: &str = r"(?<bus>[[:xdigit:]]{2}):(?<dev>[[:xdigit:]]{2})\.(?<func>[[:xdigit:]])\s(?<devn>\d+)\s(?<hwmon>\d+)";
|
||||
@@ -78,7 +83,8 @@ impl AmiInfo {
|
||||
|
||||
pub struct AmiDriver {
|
||||
ami_dev: File,
|
||||
ami_info: AmiInfo,
|
||||
bar_reg_ptr: Option<NonNull<u8>>,
|
||||
iop_ack_atomic_ptr: NonNull<AtomicU32>,
|
||||
retry_rate: Duration,
|
||||
}
|
||||
|
||||
@@ -97,15 +103,67 @@ impl AmiDriver {
|
||||
.read(true)
|
||||
.write(true)
|
||||
.create(false)
|
||||
.custom_flags(libc::O_SYNC)
|
||||
.open(ami_path)?;
|
||||
|
||||
let ami_proc_path = format!("/proc/ami_iop_ack_{}", ami_info.devn);
|
||||
let ami_proc = OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
.create(false)
|
||||
.open(&ami_proc_path)
|
||||
.unwrap();
|
||||
|
||||
let addr = unsafe {
|
||||
mmap(
|
||||
None,
|
||||
NonZero::new(4096 as usize).unwrap(),
|
||||
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
|
||||
MapFlags::MAP_SHARED,
|
||||
&ami_proc,
|
||||
0,
|
||||
)?
|
||||
};
|
||||
|
||||
let iop_ack_atomic_ptr: NonNull<AtomicU32> = addr.cast();
|
||||
|
||||
Ok(Self {
|
||||
ami_dev,
|
||||
ami_info,
|
||||
bar_reg_ptr: None,
|
||||
iop_ack_atomic_ptr,
|
||||
retry_rate,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn map_bar_reg(&mut self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let length: usize = 0x140000;
|
||||
|
||||
let map_addr = unsafe {
|
||||
mmap(
|
||||
None,
|
||||
NonZero::new(length).unwrap(),
|
||||
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE, // Read & Write
|
||||
MapFlags::MAP_SHARED,
|
||||
&self.ami_dev,
|
||||
0, // Offset in BAR0
|
||||
)?
|
||||
};
|
||||
tracing::info!("mapping HPU BAR0 at address -> {:p}", map_addr);
|
||||
|
||||
let bar_addr: NonNull<u8> = map_addr.cast();
|
||||
self.bar_reg_ptr = Some(bar_addr);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn munmap_cnt(&self) -> Result<(), Box<dyn Error>> {
|
||||
let cnt_addr = self.iop_ack_atomic_ptr.cast();
|
||||
unsafe {
|
||||
munmap(cnt_addr, 4096)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read currently loaded UUID in BAR
|
||||
pub fn uuid(&self) -> String {
|
||||
let ami_fd = self.ami_dev.as_raw_fd();
|
||||
@@ -234,24 +292,32 @@ impl AmiDriver {
|
||||
let data = Box::<u32>::new(0xdeadc0de);
|
||||
let data_ptr = Box::into_raw(data);
|
||||
|
||||
// Populate payload
|
||||
let payload = AmiPeakPokePayload {
|
||||
data_ptr,
|
||||
len: 0x1,
|
||||
offset: addr as u32,
|
||||
};
|
||||
if let Some(base) = self.bar_reg_ptr {
|
||||
unsafe {
|
||||
let raw_base = base.as_ptr();
|
||||
let reg_ptr = raw_base.add((addr + 0x100000).try_into().unwrap()) as *const u32;
|
||||
*data_ptr = std::ptr::read_volatile(reg_ptr);
|
||||
}
|
||||
} else {
|
||||
// Populate payload
|
||||
let payload = AmiPeakPokePayload {
|
||||
data_ptr,
|
||||
len: 0x1,
|
||||
offset: addr as u32,
|
||||
};
|
||||
|
||||
tracing::trace!("AMI: Read request with following payload {payload:x?}");
|
||||
loop {
|
||||
let ret = unsafe { ami_peak(ami_fd, &payload) };
|
||||
match ret {
|
||||
Err(err) => {
|
||||
tracing::debug!("AMI: Read failed -> {err:?}");
|
||||
std::thread::sleep(self.retry_rate);
|
||||
}
|
||||
Ok(val) => {
|
||||
tracing::trace!("AMI: Read ack received {payload:x?} -> {val:?}");
|
||||
break;
|
||||
tracing::trace!("AMI: Read request with following payload {payload:x?}");
|
||||
loop {
|
||||
let ret = unsafe { ami_peak(ami_fd, &payload) };
|
||||
match ret {
|
||||
Err(err) => {
|
||||
tracing::debug!("AMI: Read failed -> {err:?}");
|
||||
std::thread::sleep(self.retry_rate);
|
||||
}
|
||||
Ok(val) => {
|
||||
tracing::trace!("AMI: Read ack received {payload:x?} -> {val:?}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -265,24 +331,32 @@ impl AmiDriver {
|
||||
let data = Box::<u32>::new(value);
|
||||
let data_ptr = Box::into_raw(data);
|
||||
|
||||
// Populate payload
|
||||
let payload = AmiPeakPokePayload {
|
||||
data_ptr,
|
||||
len: 0x1,
|
||||
offset: addr as u32,
|
||||
};
|
||||
if let Some(base) = self.bar_reg_ptr {
|
||||
unsafe {
|
||||
let raw_base = base.as_ptr();
|
||||
let reg_ptr = raw_base.add((addr + 0x100000).try_into().unwrap()) as *mut u32;
|
||||
std::ptr::write_volatile(reg_ptr, value);
|
||||
}
|
||||
} else {
|
||||
// Populate payload
|
||||
let payload = AmiPeakPokePayload {
|
||||
data_ptr,
|
||||
len: 0x1,
|
||||
offset: addr as u32,
|
||||
};
|
||||
|
||||
tracing::trace!("AMI: Write request with following payload {payload:x?}");
|
||||
loop {
|
||||
let ret = unsafe { ami_poke(ami_fd, &payload) };
|
||||
match ret {
|
||||
Err(err) => {
|
||||
tracing::debug!("AMI: Write failed -> {err:?}");
|
||||
std::thread::sleep(self.retry_rate);
|
||||
}
|
||||
Ok(val) => {
|
||||
tracing::trace!("AMI: Write ack received {payload:x?} -> {val:?}");
|
||||
break;
|
||||
tracing::trace!("AMI: Write request with following payload {payload:x?}");
|
||||
loop {
|
||||
let ret = unsafe { ami_poke(ami_fd, &payload) };
|
||||
match ret {
|
||||
Err(err) => {
|
||||
tracing::debug!("AMI: Write failed -> {err:?}");
|
||||
std::thread::sleep(self.retry_rate);
|
||||
}
|
||||
Ok(val) => {
|
||||
tracing::trace!("AMI: Write ack received {payload:x?} -> {val:?}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -359,32 +433,9 @@ impl AmiDriver {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO ugly quick patch
|
||||
// Clean this when driver interface is specified
|
||||
// read shared atomic counter of iop acknowledge
|
||||
pub fn iop_ackq_rd(&self) -> u32 {
|
||||
let ami_devn = self.ami_info.devn;
|
||||
let ami_proc_path = format!("/proc/ami_iop_ack_{}", ami_devn);
|
||||
let mut iop_ack_f = OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
.create(false)
|
||||
.open(&ami_proc_path)
|
||||
.unwrap();
|
||||
|
||||
// Read a line and extract a 32b integer
|
||||
let mut ack_str = String::new();
|
||||
iop_ack_f.read_to_string(&mut ack_str).unwrap();
|
||||
if ack_str.is_empty() {
|
||||
0
|
||||
} else {
|
||||
let ack_nb = ack_str
|
||||
.as_str()
|
||||
.lines()
|
||||
.map(|line| line.trim_ascii().parse::<u32>().unwrap())
|
||||
.sum();
|
||||
tracing::trace!("Get value {ack_str} from {ami_proc_path} => {ack_nb}",);
|
||||
ack_nb
|
||||
}
|
||||
unsafe { self.iop_ack_atomic_ptr.as_ref().swap(0, Ordering::SeqCst) }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -103,6 +103,7 @@ impl HpuHw {
|
||||
tracing::info!("Current pdi -> [\n{uuid}]");
|
||||
Ok(hw)
|
||||
} else {
|
||||
hw.ami.munmap_cnt().unwrap();
|
||||
Err(format!(
|
||||
"UUID mismatch loaded {:?} expected {:?}",
|
||||
current_uuid,
|
||||
|
||||
@@ -24,6 +24,10 @@ pub struct HpuBackend {
|
||||
|
||||
// Extracted parameters
|
||||
pub(crate) params: HpuParameters,
|
||||
#[cfg(feature = "hw-v80")]
|
||||
hpu_version_major: u32,
|
||||
#[cfg(feature = "hw-v80")]
|
||||
hpu_version_minor: u32,
|
||||
// Prevent to parse regmap at each polling iteration
|
||||
#[cfg(not(feature = "hw-v80"))]
|
||||
workq_addr: u64,
|
||||
@@ -109,6 +113,20 @@ impl HpuBackend {
|
||||
let regmap = hw_regmap::FlatRegmap::from_file(®map_str);
|
||||
let mut params = HpuParameters::from_rtl(&mut hpu_hw, ®map);
|
||||
|
||||
#[cfg(feature = "hw-v80")]
|
||||
let (hpu_version_major, hpu_version_minor) = {
|
||||
let version_reg = regmap
|
||||
.register()
|
||||
.get("info::version")
|
||||
.expect("Unknown register, check regmap definition");
|
||||
let hpu_version_val = hpu_hw.read_reg(*version_reg.offset() as u64);
|
||||
let hpu_version_fields = version_reg.as_field(hpu_version_val);
|
||||
(
|
||||
*hpu_version_fields.get("major").expect("Unknown field"),
|
||||
*hpu_version_fields.get("minor").expect("Unknown field"),
|
||||
)
|
||||
};
|
||||
|
||||
// In case this is not filled by from_rtl()
|
||||
if params.ntt_params.min_pbs_nb.is_none() {
|
||||
params.ntt_params.min_pbs_nb = Some(config.firmware.min_batch_size);
|
||||
@@ -282,6 +300,10 @@ impl HpuBackend {
|
||||
hpu_hw,
|
||||
regmap,
|
||||
params,
|
||||
#[cfg(feature = "hw-v80")]
|
||||
hpu_version_major,
|
||||
#[cfg(feature = "hw-v80")]
|
||||
hpu_version_minor,
|
||||
#[cfg(not(feature = "hw-v80"))]
|
||||
workq_addr,
|
||||
#[cfg(not(feature = "hw-v80"))]
|
||||
@@ -916,6 +938,16 @@ impl HpuBackend {
|
||||
while self.poll_ack_q()? {}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "hw-v80")]
|
||||
pub(crate) fn get_hpu_version(&self) -> (u32, u32) {
|
||||
(self.hpu_version_major, self.hpu_version_minor)
|
||||
}
|
||||
|
||||
#[cfg(feature = "hw-v80")]
|
||||
pub(crate) fn map_bar_reg(&mut self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
self.hpu_hw.map_bar_reg()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for HpuBackend {
|
||||
|
||||
@@ -94,6 +94,18 @@ impl HpuDevice {
|
||||
) where
|
||||
F: Fn(HpuParameters, &crate::asm::Pbs) -> HpuGlweLookuptableOwned<u64>,
|
||||
{
|
||||
// print HPU version
|
||||
#[cfg(feature = "hw-v80")]
|
||||
{
|
||||
let mut backend = self.backend.lock().unwrap();
|
||||
let (major, minor) = backend.get_hpu_version();
|
||||
tracing::info!("HPU version -> {}.{}", major, minor);
|
||||
|
||||
if major >= 2 && minor >= 3 {
|
||||
backend.map_bar_reg().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
// Properly reset keys
|
||||
self.bsk_unset();
|
||||
self.ksk_unset();
|
||||
|
||||
Reference in New Issue
Block a user