mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
upstream remu (#9921)
This commit is contained in:
9
.github/actions/setup-tinygrad/action.yml
vendored
9
.github/actions/setup-tinygrad/action.yml
vendored
@@ -121,9 +121,8 @@ runs:
|
||||
echo -e 'Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600' | sudo tee /etc/apt/preferences.d/rocm-pin-600
|
||||
sudo apt update || true
|
||||
sudo apt install --no-install-recommends --allow-unauthenticated -y hsa-rocr comgr hsa-rocr-dev liburing-dev libc6-dev
|
||||
curl -s https://api.github.com/repos/Qazalin/remu/releases/latest | \
|
||||
jq -r '.assets[] | select(.name == "libremu.so").browser_download_url' | \
|
||||
sudo xargs curl -L -o /usr/local/lib/libremu.so
|
||||
cargo build --release --manifest-path ./extra/remu/Cargo.toml
|
||||
sudo ln -sf ${{ github.workspace }}/extra/remu/target/release/libremu.so /usr/local/lib/libremu.so
|
||||
sudo tee --append /etc/ld.so.conf.d/rocm.conf <<'EOF'
|
||||
/opt/rocm/lib
|
||||
/opt/rocm/lib64
|
||||
@@ -137,9 +136,7 @@ runs:
|
||||
curl -s -H "Authorization: token $GH_TOKEN" curl -s https://api.github.com/repos/nimlgen/amdcomgr_dylib/releases/latest | \
|
||||
jq -r '.assets[] | select(.name == "libamd_comgr.dylib").browser_download_url' | \
|
||||
sudo xargs curl -L -o /usr/local/lib/libamd_comgr.dylib
|
||||
curl -s -H "Authorization: token $GH_TOKEN" curl -s https://api.github.com/repos/Qazalin/remu/releases/latest | \
|
||||
jq -r '.assets[] | select(.name == "libremu.dylib").browser_download_url' | \
|
||||
sudo xargs curl -L -o /usr/local/lib/libremu.dylib
|
||||
cargo build --release --manifest-path ./extra/remu/Cargo.toml
|
||||
|
||||
# **** CUDA ****
|
||||
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -60,3 +60,4 @@ comgr_*
|
||||
site/
|
||||
profile_stats
|
||||
*.log
|
||||
target
|
||||
|
||||
66
extra/remu/Cargo.lock
generated
Normal file
66
extra/remu/Cargo.lock
generated
Normal file
@@ -0,0 +1,66 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 4
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "crunchy"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
|
||||
|
||||
[[package]]
|
||||
name = "float-cmp"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "98de4bbd547a563b716d8dfa9aad1cb19bfab00f4fa09a6a4ed21dbcf44ce9c4"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "2.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crunchy",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"libm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "remu"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"float-cmp",
|
||||
"half",
|
||||
"num-traits",
|
||||
]
|
||||
15
extra/remu/Cargo.toml
Normal file
15
extra/remu/Cargo.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
[package]
|
||||
name = "remu"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
rust-version = "1.80.0"
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
num-traits = "0.2.17"
|
||||
|
||||
[dev-dependencies]
|
||||
float-cmp = "0.9.0"
|
||||
80
extra/remu/README.md
Normal file
80
extra/remu/README.md
Normal file
@@ -0,0 +1,80 @@
|
||||
## Intro
|
||||
|
||||
Remu is an RDNA3 emulator built to test correctness of RDNA3 code. It is used in [tinygrad's AMD CI](https://github.com/tinygrad/tinygrad).
|
||||
|
||||
Most of the common instructions are implemented, but some formats like IMG are not supported.
|
||||
|
||||
Remu is only for testing correctness of program output, it is not a cycle accurate simulator.
|
||||
|
||||
## Build Locally
|
||||
|
||||
Remu is written in Rust. Make sure you have [Cargo](https://doc.rust-lang.org/cargo/getting-started/installation.html).
|
||||
|
||||
To build the project, run:
|
||||
|
||||
```bash
|
||||
cargo build --release --manifest-path ./extra/remu/Cargo.toml
|
||||
```
|
||||
|
||||
This will produce a binary in the `extra/remu/target/release` directory.
|
||||
|
||||
## Usage with tinygrad
|
||||
|
||||
The latest binaries are released in https://github.com/Qazalin/remu/releases. Alternatively, you can [build locally](#build-locally).
|
||||
|
||||
Tinygrad does not yet output RDNA3 kernels directly. You can either install comgr or use `AMD_LLVM=1` if you have [LLVM@19](https://github.com/tinygrad/tinygrad/blob/e2ed673c946c8f1774d816c75e52a994c2dd8a88/.github/actions/setup-tinygrad/action.yml#L208).
|
||||
|
||||
`PYTHONPATH="." MOCKGPU=1 AMD=1 python test/test_tiny.py TestTiny.test_plus` runs an emulated RDNA3 kernel with Remu.
|
||||
|
||||
Add `DEBUG=6` to see Remu's logs.
|
||||
|
||||
### DEBUG output
|
||||
|
||||
Remu runs each thread one at a time in a nested for loop, see lib.rs. The DEBUG output prints information about the current thread.
|
||||
|
||||
The DEBUG output has 3 sections:
|
||||
|
||||
```
|
||||
<------------ 1 ----------> <--- 2 ---> <--------------------------------------- 3 ------------------------------------------>
|
||||
[0 0 0 ] [0 0 0 ] 0 F4080100 SMEM sbase=0 sdata=4 op=2 offset=0 soffset=0
|
||||
```
|
||||
|
||||
#### Section 1: Grid info
|
||||
|
||||
`[gid.x, gid.y, gid.z], [lid.x, lid.y, lid.z]` of the current thread.
|
||||
|
||||
#### Section 2: Wave info
|
||||
|
||||
`<lane> <instruction hex>`
|
||||
|
||||
RDNA3 divides threads into chunks of 32. Each thread is assigned to a "lane" from 0-31.
|
||||
|
||||
In Remu, even though all threads run one at a time, each 32 thread chunk (a wave) shares state like SGPR, VGPR, LDS, EXEC mask, etc.
|
||||
Remu can simulate up to one wave sync instruction.
|
||||
For more details, see work_group.rs.
|
||||
|
||||
Section 2 can have a green or gray color.
|
||||
|
||||
Green = The thread is actively executing the instruction.
|
||||
|
||||
Gray = The thread has been "turned off" by the EXEC mask, it skips execution of some instructions. (refer to "EXECute Mask" on [page 23](https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/instruction-set-architectures/rdna3-shader-instruction-set-architecture-feb-2023_0.pdf#page=23) of ISA docs for more details.)
|
||||
|
||||
To see the colors in action, try running `DEBUG=6 PYTHONPATH="." MOCKGPU=1 AMD=1 python test/test_ops.py TestOps.test_arange_big`. See how only lane 0 writes to global memory:
|
||||
```
|
||||
[255 0 0 ] [0 0 0 ] 0 DC6A0000 GLOBAL offset=0 op=26 addr=8 data=0 saddr=0 vdst=0
|
||||
[255 0 0 ] [1 0 0 ] 1 DC6A0000
|
||||
[255 0 0 ] [2 0 0 ] 2 DC6A0000
|
||||
[255 0 0 ] [3 0 0 ] 3 DC6A0000
|
||||
[255 0 0 ] [4 0 0 ] 4 DC6A0000
|
||||
```
|
||||
|
||||
#### Section 3: Decoded Instruction
|
||||
|
||||
This prints the instruction type and all the parsed bitfields.
|
||||
|
||||
Remu output vs llvm-objdump:
|
||||
|
||||
```
|
||||
s_load_b64 s[0:1], s[0:1], 0x10 // 00000000160C: F4040000 F8000010
|
||||
SMEM sbase=0 sdata=0 op=1 offset=16 soffset=0
|
||||
```
|
||||
1
extra/remu/rustfmt.toml
Normal file
1
extra/remu/rustfmt.toml
Normal file
@@ -0,0 +1 @@
|
||||
max_width = 150
|
||||
170
extra/remu/src/helpers.rs
Normal file
170
extra/remu/src/helpers.rs
Normal file
@@ -0,0 +1,170 @@
|
||||
use half::f16;
|
||||
use num_traits::float::FloatCore;
|
||||
|
||||
pub fn nth(val: u32, pos: usize) -> u32 {
|
||||
(val >> (31 - pos as u32)) & 1
|
||||
}
|
||||
pub fn f16_lo(val: u32) -> f16 {
|
||||
f16::from_bits((val & 0xffff) as u16)
|
||||
}
|
||||
pub fn f16_hi(val: u32) -> f16 {
|
||||
f16::from_bits(((val >> 16) & 0xffff) as u16)
|
||||
}
|
||||
|
||||
pub fn sign_ext(num: u64, bits: usize) -> i64 {
|
||||
let mut value = num;
|
||||
let is_negative = (value >> (bits - 1)) & 1 != 0;
|
||||
if is_negative {
|
||||
value |= !0 << bits;
|
||||
}
|
||||
value as i64
|
||||
}
|
||||
|
||||
pub trait IEEEClass<T> {
|
||||
fn exponent(&self) -> T;
|
||||
}
|
||||
impl IEEEClass<u32> for f32 {
|
||||
fn exponent(&self) -> u32 {
|
||||
(self.to_bits() & 0b01111111100000000000000000000000) >> 23
|
||||
}
|
||||
}
|
||||
impl IEEEClass<u16> for f16 {
|
||||
fn exponent(&self) -> u16 {
|
||||
(self.to_bits() & 0b0111110000000000) >> 10
|
||||
}
|
||||
}
|
||||
impl IEEEClass<u64> for f64 {
|
||||
fn exponent(&self) -> u64 {
|
||||
(self.to_bits() & 0b0111111111110000000000000000000000000000000000000000000000000000) >> 52
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VOPModifier<T> {
|
||||
fn negate(&self, pos: usize, modifier: usize) -> T;
|
||||
fn absolute(&self, pos: usize, modifier: usize) -> T;
|
||||
}
|
||||
impl<T> VOPModifier<T> for T
|
||||
where
|
||||
T: FloatCore,
|
||||
{
|
||||
fn negate(&self, pos: usize, modifier: usize) -> T {
|
||||
match (modifier >> pos) & 1 {
|
||||
1 => match self.is_zero() {
|
||||
true => *self,
|
||||
false => -*self,
|
||||
},
|
||||
_ => *self,
|
||||
}
|
||||
}
|
||||
fn absolute(&self, pos: usize, modifier: usize) -> T {
|
||||
match (modifier >> pos) & 1 {
|
||||
1 => self.abs(),
|
||||
_ => *self,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_mantissa(x: f64) -> f64 {
|
||||
if x.is_infinite() || x.is_nan() {
|
||||
return x;
|
||||
}
|
||||
let bits = x.to_bits();
|
||||
let mantissa_mask: u64 = 0x000FFFFFFFFFFFFF;
|
||||
let bias: u64 = 1023;
|
||||
let normalized_mantissa_bits = (bits & mantissa_mask) | ((bias - 1) << 52);
|
||||
return f64::from_bits(normalized_mantissa_bits);
|
||||
}
|
||||
pub fn ldexp(x: f64, exp: i32) -> f64 {
|
||||
x * 2f64.powi(exp)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
#[test]
|
||||
fn test_extract_mantissa() {
|
||||
assert_eq!(extract_mantissa(2.0f64), 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_exponent() {
|
||||
assert_eq!(2.5f32.exponent(), 128);
|
||||
assert_eq!(1.17549435e-38f32.exponent(), 1);
|
||||
assert_eq!(f32::INFINITY.exponent(), 255);
|
||||
assert_eq!(f32::NEG_INFINITY.exponent(), 255);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_denormal_exponent() {
|
||||
assert_eq!(1.0e-40f32.exponent(), 0);
|
||||
assert_eq!(1.0e-42f32.exponent(), 0);
|
||||
assert_eq!(1.0e-44f32.exponent(), 0);
|
||||
assert_eq!((1.17549435e-38f32 / 2.0).exponent(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_exponent_f16() {
|
||||
assert_eq!(f16::from_f32(3.14f32).exponent(), 16);
|
||||
assert_eq!(f16::NEG_INFINITY.exponent(), 31);
|
||||
assert_eq!(f16::INFINITY.exponent(), 31);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_neg() {
|
||||
assert_eq!(0.3_f32.negate(0, 0b001), -0.3_f32);
|
||||
assert_eq!(0.3_f32.negate(1, 0b010), -0.3_f32);
|
||||
assert_eq!(0.3_f32.negate(2, 0b100), -0.3_f32);
|
||||
assert_eq!(0.3_f32.negate(0, 0b110), 0.3_f32);
|
||||
assert_eq!(0.3_f32.negate(1, 0b010), -0.3_f32);
|
||||
assert_eq!(0.0_f32.negate(0, 0b001).to_bits(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sign_ext() {
|
||||
assert_eq!(sign_ext(0b000000000000000101000, 21), 40);
|
||||
assert_eq!(sign_ext(0b111111111111111011000, 21), -40);
|
||||
assert_eq!(sign_ext(0b000000000000000000000, 21), 0);
|
||||
assert_eq!(sign_ext(0b111111111111111111111, 21), -1);
|
||||
assert_eq!(sign_ext(0b111000000000000000000, 21), -262144);
|
||||
assert_eq!(sign_ext(0b000111111111111111111, 21), 262143);
|
||||
assert_eq!(sign_ext(7608, 13), -584);
|
||||
}
|
||||
}
|
||||
|
||||
use std::sync::LazyLock;
|
||||
pub static DEBUG: LazyLock<bool> = LazyLock::new(|| std::env::var("DEBUG").map(|v| v.parse::<usize>().unwrap_or(0) >= 6).unwrap_or(false));
|
||||
|
||||
pub trait Colorize {
|
||||
fn color(self, color: &str) -> String;
|
||||
}
|
||||
impl<'a> Colorize for &'a str {
|
||||
fn color(self, color: &str) -> String {
|
||||
let ansi_code = match color {
|
||||
"blue" => format!("\x1b[{};2;112;184;255m", 38),
|
||||
"green" => format!("\x1b[{};2;39;176;139m", 38),
|
||||
"gray" => format!("\x1b[{};2;169;169;169m", 38),
|
||||
_ => format!("\x1b[{};2;255;255;255m", 38),
|
||||
};
|
||||
format!("{}{}{}", ansi_code, self, "\x1b[0m")
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! todo_instr {
|
||||
($x:expr) => {{
|
||||
println!("{:08X}", $x);
|
||||
Err(1)
|
||||
}};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! print_instr {
|
||||
($name:expr $(, $arg:ident)* $(,)?) => {
|
||||
if *DEBUG {
|
||||
print!("{}", format!("{:<8}", $name).color("blue"));
|
||||
$(
|
||||
print!(" {:<16}", format!("{}={:?}", stringify!($arg), $arg));
|
||||
)*
|
||||
}
|
||||
};
|
||||
}
|
||||
31
extra/remu/src/lib.rs
Normal file
31
extra/remu/src/lib.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
use crate::work_group::WorkGroup;
|
||||
use std::os::raw::c_char;
|
||||
use std::slice;
|
||||
mod helpers;
|
||||
mod state;
|
||||
mod thread;
|
||||
mod work_group;
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn run_asm(lib: *const c_char, lib_sz: u32, gx: u32, gy: u32, gz: u32, lx: u32, ly: u32, lz: u32, args_ptr: *const u64) -> i32 {
|
||||
if lib.is_null() || (lib_sz % 4) != 0 {
|
||||
panic!("Pointer is null or length is not properly aligned to 4 bytes");
|
||||
}
|
||||
let kernel = unsafe { slice::from_raw_parts(lib as *const u32, (lib_sz / 4) as usize).to_vec() };
|
||||
let dispatch_dim = match (gy != 1, gz != 1) {
|
||||
(true, true) => 3,
|
||||
(true, false) => 2,
|
||||
_ => 1,
|
||||
};
|
||||
for gx in 0..gx {
|
||||
for gy in 0..gy {
|
||||
for gz in 0..gz {
|
||||
let mut wg = WorkGroup::new(dispatch_dim, [gx, gy, gz], [lx, ly, lz], &kernel, args_ptr);
|
||||
if let Err(err) = wg.exec_waves() {
|
||||
return err;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
0
|
||||
}
|
||||
256
extra/remu/src/state.rs
Normal file
256
extra/remu/src/state.rs
Normal file
@@ -0,0 +1,256 @@
|
||||
use std::ops::{Index, IndexMut};
|
||||
|
||||
pub trait Register {
|
||||
fn read64(&self, idx: usize) -> u64;
|
||||
fn write64(&mut self, idx: usize, addr: u64);
|
||||
}
|
||||
impl<T> Register for T
|
||||
where
|
||||
T: Index<usize, Output = u32> + IndexMut<usize>,
|
||||
{
|
||||
fn read64(&self, idx: usize) -> u64 {
|
||||
let addr_lsb = self[idx];
|
||||
let addr_msb = self[idx + 1];
|
||||
((addr_msb as u64) << 32) | addr_lsb as u64
|
||||
}
|
||||
fn write64(&mut self, idx: usize, addr: u64) {
|
||||
self[idx] = (addr & 0xffffffff) as u32;
|
||||
self[idx + 1] = ((addr & (0xffffffff << 32)) >> 32) as u32;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct VGPR {
|
||||
values: [[u32; 256]; 32],
|
||||
pub default_lane: Option<usize>,
|
||||
}
|
||||
impl Index<usize> for VGPR {
|
||||
type Output = u32;
|
||||
fn index(&self, index: usize) -> &Self::Output {
|
||||
&self.values[self.default_lane.unwrap()][index]
|
||||
}
|
||||
}
|
||||
impl IndexMut<usize> for VGPR {
|
||||
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
|
||||
&mut self.values[self.default_lane.unwrap()][index]
|
||||
}
|
||||
}
|
||||
impl VGPR {
|
||||
pub fn new() -> Self {
|
||||
VGPR {
|
||||
values: [[0; 256]; 32],
|
||||
default_lane: None,
|
||||
}
|
||||
}
|
||||
pub fn get_lane(&self, lane: usize) -> [u32; 256] {
|
||||
*self.values.get(lane).unwrap()
|
||||
}
|
||||
pub fn get_lane_mut(&mut self, lane: usize) -> &mut [u32; 256] {
|
||||
self.values.get_mut(lane).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Value {
|
||||
fn mut_hi16(&mut self, val: u16);
|
||||
fn mut_lo16(&mut self, val: u16);
|
||||
}
|
||||
impl Value for u32 {
|
||||
fn mut_hi16(&mut self, val: u16) {
|
||||
*self = ((val as u32) << 16) | (*self as u16 as u32);
|
||||
}
|
||||
fn mut_lo16(&mut self, val: u16) {
|
||||
*self = ((((*self & (0xffff << 16)) >> 16) as u32) << 16) | val as u32;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct WaveValue {
|
||||
pub value: u32,
|
||||
pub warp_size: usize,
|
||||
pub default_lane: Option<usize>,
|
||||
pub mutations: Option<[bool; 32]>,
|
||||
}
|
||||
impl WaveValue {
|
||||
pub fn new(value: u32, warp_size: usize) -> Self {
|
||||
Self {
|
||||
value,
|
||||
warp_size,
|
||||
default_lane: None,
|
||||
mutations: None,
|
||||
}
|
||||
}
|
||||
pub fn read(&self) -> bool {
|
||||
(self.value >> self.default_lane.unwrap()) & 1 == 1
|
||||
}
|
||||
pub fn set_lane(&mut self, value: bool) {
|
||||
if self.mutations.is_none() {
|
||||
self.mutations = Some([false; 32])
|
||||
}
|
||||
self.mutations.as_mut().unwrap()[self.default_lane.unwrap()] = value;
|
||||
}
|
||||
pub fn apply_muts(&mut self) {
|
||||
self.value = 0;
|
||||
for lane in 0..self.warp_size {
|
||||
if self.mutations.unwrap()[lane] {
|
||||
self.value |= 1 << lane;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct VecDataStore {
|
||||
pub data: Vec<u8>,
|
||||
}
|
||||
|
||||
impl VecDataStore {
|
||||
pub fn new() -> Self {
|
||||
Self { data: Vec::new() }
|
||||
}
|
||||
pub fn write(&mut self, addr: usize, val: u32) {
|
||||
if addr + 4 >= self.data.len() {
|
||||
self.data.resize(self.data.len() + addr + 5, 0);
|
||||
}
|
||||
self.data[addr..addr + 4].iter_mut().enumerate().for_each(|(i, x)| {
|
||||
*x = val.to_le_bytes()[i];
|
||||
});
|
||||
}
|
||||
pub fn write64(&mut self, addr: usize, val: u64) {
|
||||
self.write(addr, (val & 0xffffffff) as u32);
|
||||
self.write(addr + 4, ((val & (0xffffffff << 32)) >> 32) as u32);
|
||||
}
|
||||
pub fn read(&self, addr: usize) -> u32 {
|
||||
let mut bytes: [u8; 4] = [0; 4];
|
||||
bytes.copy_from_slice(&self.data[addr + 0..addr + 4]);
|
||||
u32::from_le_bytes(bytes)
|
||||
}
|
||||
pub fn read64(&mut self, addr: usize) -> u64 {
|
||||
let lsb = self.read(addr);
|
||||
let msb = self.read(addr + 4);
|
||||
((msb as u64) << 32) | lsb as u64
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_state {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_wave_value() {
|
||||
let mut val = WaveValue::new(0b11000000000000011111111111101110, 32);
|
||||
val.default_lane = Some(0);
|
||||
assert!(!val.read());
|
||||
val.default_lane = Some(31);
|
||||
assert!(val.read());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wave_value_small() {
|
||||
let mut val = WaveValue::new(0, 1);
|
||||
val.default_lane = Some(0);
|
||||
assert!(!val.read());
|
||||
assert_eq!(val.value, 0);
|
||||
val.set_lane(true);
|
||||
val.apply_muts();
|
||||
assert!(val.read());
|
||||
assert_eq!(val.value, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wave_value_small_alt() {
|
||||
let mut val = WaveValue::new(0, 2);
|
||||
val.default_lane = Some(0);
|
||||
assert!(!val.read());
|
||||
assert_eq!(val.value, 0);
|
||||
val.set_lane(true);
|
||||
val.apply_muts();
|
||||
assert!(val.read());
|
||||
assert_eq!(val.value, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wave_value_exec() {
|
||||
let warp_size = 32;
|
||||
let val = WaveValue::new(u32::MAX, warp_size);
|
||||
assert_eq!(val.value, u32::MAX);
|
||||
let warp_size = 3;
|
||||
let val = WaveValue::new((1 << warp_size) - 1, warp_size);
|
||||
assert_eq!(val.value, 7)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wave_value_toggle_one() {
|
||||
let warp_size = 2;
|
||||
let mut val = WaveValue::new(0b11, warp_size);
|
||||
// 0
|
||||
val.default_lane = Some(0);
|
||||
val.set_lane(false);
|
||||
// 1
|
||||
val.default_lane = Some(1);
|
||||
val.set_lane(true);
|
||||
val.apply_muts();
|
||||
assert_eq!(val.value, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wave_value_mutate_small() {
|
||||
let mut val = WaveValue::new(0, 2);
|
||||
val.default_lane = Some(0);
|
||||
assert!(!val.read());
|
||||
assert_eq!(val.value, 0);
|
||||
val.set_lane(true);
|
||||
val.apply_muts();
|
||||
assert!(val.read());
|
||||
assert_eq!(val.value, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wave_value_mutations() {
|
||||
let mut val = WaveValue::new(0b10001, 32);
|
||||
val.default_lane = Some(0);
|
||||
val.set_lane(false);
|
||||
assert!(val.mutations.unwrap().iter().all(|x| !x));
|
||||
val.default_lane = Some(1);
|
||||
val.set_lane(true);
|
||||
assert_eq!(val.value, 0b10001);
|
||||
assert_eq!(
|
||||
val.mutations,
|
||||
Some([
|
||||
false, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false,
|
||||
false, false, false, false, false, false, false, false, false, false, false, false, false,
|
||||
])
|
||||
);
|
||||
|
||||
val.apply_muts();
|
||||
assert_eq!(val.value, 0b10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write16() {
|
||||
let mut vgpr = VGPR::new();
|
||||
vgpr.default_lane = Some(0);
|
||||
vgpr[0] = 0b11100000000000001111111111111111;
|
||||
vgpr[0].mut_lo16(0b1011101111111110);
|
||||
assert_eq!(vgpr[0], 0b11100000000000001011101111111110);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write16hi() {
|
||||
let mut vgpr = VGPR::new();
|
||||
vgpr.default_lane = Some(0);
|
||||
vgpr[0] = 0b11100000000000001111111111111111;
|
||||
vgpr[0].mut_hi16(0b1011101111111110);
|
||||
assert_eq!(vgpr[0], 0b10111011111111101111111111111111);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vgpr() {
|
||||
let mut vgpr = VGPR::new();
|
||||
vgpr.default_lane = Some(0);
|
||||
vgpr[0] = 42;
|
||||
vgpr.default_lane = Some(10);
|
||||
vgpr[0] = 10;
|
||||
assert_eq!(vgpr.get_lane(0)[0], 42);
|
||||
assert_eq!(vgpr.get_lane(10)[0], 10);
|
||||
}
|
||||
}
|
||||
3716
extra/remu/src/thread.rs
Normal file
3716
extra/remu/src/thread.rs
Normal file
File diff suppressed because it is too large
Load Diff
268
extra/remu/src/work_group.rs
Normal file
268
extra/remu/src/work_group.rs
Normal file
@@ -0,0 +1,268 @@
|
||||
use crate::helpers::{Colorize, DEBUG};
|
||||
use crate::state::{Register, VecDataStore, WaveValue, VGPR};
|
||||
use crate::thread::{Thread, END_PRG};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub struct WorkGroup<'a> {
|
||||
dispatch_dim: u32,
|
||||
id: [u32; 3],
|
||||
lds: VecDataStore,
|
||||
kernel: &'a Vec<u32>,
|
||||
kernel_args: *const u64,
|
||||
launch_bounds: [u32; 3],
|
||||
wave_state: HashMap<usize, WaveState>,
|
||||
}
|
||||
|
||||
struct WaveState {
|
||||
scalar_reg: Vec<u32>,
|
||||
scc: u32,
|
||||
vec_reg: VGPR,
|
||||
vcc: WaveValue,
|
||||
exec: WaveValue,
|
||||
pc: usize,
|
||||
sds: HashMap<usize, VecDataStore>,
|
||||
}
|
||||
|
||||
const SYNCS: [u32; 5] = [0xBF89FC07, 0xBFBD0000, 0xBC7C0000, 0xBF890007, 0xbFB60003];
|
||||
const BARRIERS: [[u32; 2]; 5] = [
|
||||
[SYNCS[0], SYNCS[0]],
|
||||
[SYNCS[0], SYNCS[1]],
|
||||
[SYNCS[0], SYNCS[2]],
|
||||
[SYNCS[3], SYNCS[1]],
|
||||
[SYNCS[1], SYNCS[0]],
|
||||
];
|
||||
impl<'a> WorkGroup<'a> {
|
||||
pub fn new(dispatch_dim: u32, id: [u32; 3], launch_bounds: [u32; 3], kernel: &'a Vec<u32>, kernel_args: *const u64) -> Self {
|
||||
return Self {
|
||||
dispatch_dim,
|
||||
id,
|
||||
kernel,
|
||||
launch_bounds,
|
||||
kernel_args,
|
||||
lds: VecDataStore::new(),
|
||||
wave_state: HashMap::new(),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn exec_waves(&mut self) -> Result<(), i32> {
|
||||
let mut blocks = vec![];
|
||||
for z in 0..self.launch_bounds[2] {
|
||||
for y in 0..self.launch_bounds[1] {
|
||||
for x in 0..self.launch_bounds[0] {
|
||||
blocks.push([x, y, z])
|
||||
}
|
||||
}
|
||||
}
|
||||
let waves = blocks.chunks(32).map(|w| w.to_vec()).collect::<Vec<_>>();
|
||||
|
||||
let mut sync = false;
|
||||
for (i, x) in self.kernel.iter().enumerate() {
|
||||
if i != 0 && BARRIERS.contains(&[*x, self.kernel[i - 1]]) {
|
||||
sync = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
for _ in 0..=(sync as usize) {
|
||||
for w in waves.iter().enumerate() {
|
||||
self.exec_wave(w)?
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn exec_wave(&mut self, (wave_id, threads): (usize, &Vec<[u32; 3]>)) -> Result<(), i32> {
|
||||
let wave_state = self.wave_state.get(&wave_id);
|
||||
let mut sds = match wave_state {
|
||||
Some(val) => val.sds.clone(),
|
||||
None => {
|
||||
let mut sds = HashMap::new();
|
||||
for i in 0..=31 {
|
||||
sds.insert(i, VecDataStore::new());
|
||||
}
|
||||
sds
|
||||
}
|
||||
};
|
||||
let (mut scalar_reg, mut scc, mut pc) = match wave_state {
|
||||
Some(val) => (val.scalar_reg.to_vec(), val.scc, val.pc),
|
||||
None => {
|
||||
let mut scalar_reg = vec![0; 256];
|
||||
scalar_reg.write64(0, self.kernel_args as u64);
|
||||
let [gx, gy, gz] = self.id;
|
||||
match self.dispatch_dim {
|
||||
3 => (scalar_reg[13], scalar_reg[14], scalar_reg[15]) = (gx, gy, gz),
|
||||
2 => (scalar_reg[14], scalar_reg[15]) = (gx, gy),
|
||||
_ => scalar_reg[15] = gx,
|
||||
}
|
||||
(scalar_reg, 0, 0)
|
||||
}
|
||||
};
|
||||
let (mut vec_reg, mut vcc) = match wave_state {
|
||||
Some(val) => (val.vec_reg.clone(), val.vcc.clone()),
|
||||
None => (VGPR::new(), WaveValue::new(0, threads.len())),
|
||||
};
|
||||
let mut exec = match wave_state {
|
||||
Some(val) => val.exec.clone(),
|
||||
None => {
|
||||
let active = match threads.len() == 32 {
|
||||
true => u32::MAX,
|
||||
false => (1 << threads.len()) - 1,
|
||||
};
|
||||
WaveValue::new(active, threads.len())
|
||||
}
|
||||
};
|
||||
|
||||
let mut seeded_lanes = vec![];
|
||||
loop {
|
||||
if self.kernel[pc] == END_PRG {
|
||||
break Ok(());
|
||||
}
|
||||
if BARRIERS.contains(&[self.kernel[pc], self.kernel[pc + 1]]) && wave_state.is_none() {
|
||||
self.wave_state.insert(
|
||||
wave_id,
|
||||
WaveState {
|
||||
scalar_reg,
|
||||
scc,
|
||||
vec_reg,
|
||||
vcc,
|
||||
exec,
|
||||
pc,
|
||||
sds,
|
||||
},
|
||||
);
|
||||
break Ok(());
|
||||
}
|
||||
if SYNCS.contains(&self.kernel[pc]) || self.kernel[pc] >> 20 == 0xbf8 || self.kernel[pc] == 0x7E000000 {
|
||||
pc += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut sgpr_co = None;
|
||||
for (lane_id, [x, y, z]) in threads.iter().enumerate() {
|
||||
vec_reg.default_lane = Some(lane_id);
|
||||
vcc.default_lane = Some(lane_id);
|
||||
exec.default_lane = Some(lane_id);
|
||||
if *DEBUG {
|
||||
let lane = format!("{:<2} {:08X} ", lane_id, self.kernel[pc]);
|
||||
let state = match exec.read() {
|
||||
true => "green",
|
||||
false => "gray",
|
||||
};
|
||||
let [id0, id1, id2] = self.id;
|
||||
print!("[{id0:<3} {id1:<3} {id2:<3}] [{x:<3} {y:<3} {z:<3}] {}", lane.color(state));
|
||||
}
|
||||
if !seeded_lanes.contains(&lane_id) && self.wave_state.get(&wave_id).is_none() {
|
||||
match (self.launch_bounds[1] != 1, self.launch_bounds[2] != 1) {
|
||||
(false, false) => vec_reg[0] = *x,
|
||||
_ => vec_reg[0] = (z << 20) | (y << 10) | x,
|
||||
}
|
||||
seeded_lanes.push(lane_id);
|
||||
}
|
||||
let mut thread = Thread {
|
||||
scalar_reg: &mut scalar_reg,
|
||||
scc: &mut scc,
|
||||
vec_reg: &mut vec_reg,
|
||||
vcc: &mut vcc,
|
||||
exec: &mut exec,
|
||||
lds: &mut self.lds,
|
||||
sds: &mut sds.get_mut(&lane_id).unwrap(),
|
||||
pc_offset: 0,
|
||||
stream: self.kernel[pc..self.kernel.len()].to_vec(),
|
||||
scalar: false,
|
||||
simm: None,
|
||||
warp_size: threads.len(),
|
||||
sgpr_co: &mut sgpr_co,
|
||||
};
|
||||
thread.interpret()?;
|
||||
if *DEBUG {
|
||||
println!();
|
||||
}
|
||||
if thread.scalar {
|
||||
pc = ((pc as isize) + 1 + (thread.pc_offset as isize)) as usize;
|
||||
break;
|
||||
}
|
||||
if lane_id == threads.len() - 1 {
|
||||
pc = ((pc as isize) + 1 + (thread.pc_offset as isize)) as usize;
|
||||
}
|
||||
}
|
||||
|
||||
if vcc.mutations.is_some() {
|
||||
vcc.apply_muts();
|
||||
vcc.mutations = None;
|
||||
}
|
||||
if exec.mutations.is_some() {
|
||||
exec.apply_muts();
|
||||
exec.mutations = None;
|
||||
}
|
||||
if let Some((idx, mut wv)) = sgpr_co.take() {
|
||||
wv.apply_muts();
|
||||
scalar_reg[idx] = wv.value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_workgroup {
|
||||
use super::*;
|
||||
|
||||
// TODO: make this generic by adding the assembler
|
||||
fn global_store_sgpr(addr: u64, instructions: Vec<u32>, src: u32) -> Vec<u32> {
|
||||
[
|
||||
instructions,
|
||||
vec![
|
||||
0x7E020200 + src,
|
||||
0x7E0402FF,
|
||||
addr as u32,
|
||||
0x7E0602FF,
|
||||
(addr >> 32) as u32,
|
||||
0xDC6A0000,
|
||||
0x007C0102,
|
||||
],
|
||||
vec![END_PRG],
|
||||
]
|
||||
.concat()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wave_value_state_vcc() {
|
||||
let mut ret: u32 = 0;
|
||||
let kernel = vec![
|
||||
0xBEEA00FF,
|
||||
0b11111111111111111111111111111111, // initial vcc state
|
||||
0x7E140282,
|
||||
0x7C94010A, // cmp blockDim.x == 2
|
||||
];
|
||||
let addr = (&mut ret as *mut u32) as u64;
|
||||
let kernel = global_store_sgpr(addr, kernel, 106);
|
||||
let mut wg = WorkGroup::new(1, [0, 0, 0], [3, 1, 1], &kernel, [addr].as_ptr());
|
||||
wg.exec_waves().unwrap();
|
||||
assert_eq!(ret, 0b100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wave_value_state_exec() {
|
||||
let mut ret: u32 = 0;
|
||||
let kernel = vec![
|
||||
0xBEFE00FF,
|
||||
0b11111111111111111111111111111111,
|
||||
0x7E140282,
|
||||
0x7D9C010A, // cmpx blockDim.x <= 2
|
||||
];
|
||||
let addr = (&mut ret as *mut u32) as u64;
|
||||
let kernel = global_store_sgpr(addr, kernel, 126);
|
||||
let mut wg = WorkGroup::new(1, [0, 0, 0], [4, 1, 1], &kernel, [addr].as_ptr());
|
||||
wg.exec_waves().unwrap();
|
||||
assert_eq!(ret, 0b0111);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wave_value_sgpr_co() {
|
||||
let mut ret: u32 = 0;
|
||||
let kernel = vec![0xBE8D00FF, 0x7FFFFFFF, 0x7E1402FF, u32::MAX, 0xD700000A, 0x0002010A];
|
||||
let addr = (&mut ret as *mut u32) as u64;
|
||||
let kernel = global_store_sgpr(addr, kernel, 0);
|
||||
let mut wg = WorkGroup::new(1, [0, 0, 0], [5, 1, 1], &kernel, [addr].as_ptr());
|
||||
wg.exec_waves().unwrap();
|
||||
assert_eq!(ret, 0b11110);
|
||||
}
|
||||
}
|
||||
@@ -18,7 +18,8 @@ WAIT_REG_MEM_FUNCTION_ALWAYS = 0
|
||||
WAIT_REG_MEM_FUNCTION_EQ = 3 # ==
|
||||
WAIT_REG_MEM_FUNCTION_GEQ = 5 # >=
|
||||
|
||||
REMU_PATHS = ["libremu.so", "/usr/local/lib/libremu.so", "libremu.dylib", "/usr/local/lib/libremu.dylib", "/opt/homebrew/lib/libremu.dylib"]
|
||||
REMU_PATHS = ["extra/remu/target/release/libremu.so", "libremu.so", "/usr/local/lib/libremu.so",
|
||||
"extra/remu/target/release/libremu.dylib", "libremu.dylib", "/usr/local/lib/libremu.dylib", "/opt/homebrew/lib/libremu.dylib"]
|
||||
def _try_dlopen_remu():
|
||||
for path in REMU_PATHS:
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user