mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
remu refactors (#10028)
* remu refactors * scc is sgpr 253 * remove that * rename to vcc_lo * run cargo test in CI * llvm-mc * meh * work * work_group work 1 * seeded_lanes is dumb * better than seeded_lanes * does not need to be address * 128 sgpr per wave * scc is sgpr, we don't know which one * null_src once more * derive clone, wave init is cleaner * init comes first
This commit is contained in:
@@ -36,7 +36,7 @@ The DEBUG output has 3 sections:
|
||||
|
||||
```
|
||||
<------------ 1 ----------> <--- 2 ---> <--------------------------------------- 3 ------------------------------------------>
|
||||
[0 0 0 ] [0 0 0 ] 0 F4080100 SMEM sbase=0 sdata=4 op=2 offset=0 soffset=0
|
||||
[0 0 0 ] [0 0 0 ] 0 F4080100 SMEM { op: 2, sdata: 4, sbase: 0, offset: 0, soffset: 124, glc: false, dlc: false }
|
||||
```
|
||||
|
||||
#### Section 1: Grid info
|
||||
@@ -61,11 +61,11 @@ Gray = The thread has been "turned off" by the EXEC mask, it skips execution of
|
||||
|
||||
To see the colors in action, try running `DEBUG=6 PYTHONPATH="." MOCKGPU=1 AMD=1 python test/test_ops.py TestOps.test_arange_big`. See how only lane 0 writes to global memory:
|
||||
```
|
||||
[255 0 0 ] [0 0 0 ] 0 DC6A0000 GLOBAL offset=0 op=26 addr=8 data=0 saddr=0 vdst=0
|
||||
[255 0 0 ] [0 0 0 ] 0 DC6A0000 FLAT { op: 26, offset: 0, dlc: false, glc: false, slc: false, seg: 2, addr: 8, data: 0, saddr: 0, sve: false, vdst: 0 }
|
||||
[255 0 0 ] [1 0 0 ] 1 DC6A0000
|
||||
[255 0 0 ] [2 0 0 ] 2 DC6A0000
|
||||
[255 0 0 ] [3 0 0 ] 3 DC6A0000
|
||||
[255 0 0 ] [4 0 0 ] 4 DC6A0000
|
||||
[255 0 0 ] [3 0 0 ] 4 DC6A0000
|
||||
```
|
||||
|
||||
#### Section 3: Decoded Instruction
|
||||
@@ -76,5 +76,5 @@ Remu output vs llvm-objdump:
|
||||
|
||||
```
|
||||
s_load_b64 s[0:1], s[0:1], 0x10 // 00000000160C: F4040000 F8000010
|
||||
SMEM sbase=0 sdata=0 op=1 offset=16 soffset=0
|
||||
SMEM { op: 1, sdata: 0, sbase: 0, offset: 16, soffset: 124, glc: false, dlc: false }
|
||||
```
|
||||
|
||||
@@ -140,19 +140,13 @@ mod tests {
|
||||
use std::sync::LazyLock;
|
||||
pub static DEBUG: LazyLock<bool> = LazyLock::new(|| std::env::var("DEBUG").map(|v| v.parse::<usize>().unwrap_or(0) >= 6).unwrap_or(false));
|
||||
|
||||
pub trait Colorize {
|
||||
fn color(self, color: &str) -> String;
|
||||
}
|
||||
impl<'a> Colorize for &'a str {
|
||||
fn color(self, color: &str) -> String {
|
||||
let ansi_code = match color {
|
||||
"blue" => format!("\x1b[{};2;112;184;255m", 38),
|
||||
"green" => format!("\x1b[{};2;39;176;139m", 38),
|
||||
"gray" => format!("\x1b[{};2;169;169;169m", 38),
|
||||
_ => format!("\x1b[{};2;255;255;255m", 38),
|
||||
};
|
||||
format!("{}{}{}", ansi_code, self, "\x1b[0m")
|
||||
}
|
||||
pub fn colored(st:&str, color:&str) -> String {
|
||||
let ansi_code = match color {
|
||||
"green" => format!("\x1b[{};2;39;176;139m", 38),
|
||||
"gray" => format!("\x1b[{};2;169;169;169m", 38),
|
||||
_ => format!("\x1b[{};2;255;255;255m", 38),
|
||||
};
|
||||
format!("{}{}{}", ansi_code, st, "\x1b[0m")
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
@@ -162,15 +156,3 @@ macro_rules! todo_instr {
|
||||
Err(1)
|
||||
}};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! print_instr {
|
||||
($name:expr $(, $arg:ident)* $(,)?) => {
|
||||
if *DEBUG {
|
||||
print!("{}", format!("{:<8}", $name).color("blue"));
|
||||
$(
|
||||
print!(" {:<16}", format!("{}={:?}", stringify!($arg), $arg));
|
||||
)*
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -4,22 +4,20 @@ pub trait Register {
|
||||
fn read64(&self, idx: usize) -> u64;
|
||||
fn write64(&mut self, idx: usize, addr: u64);
|
||||
}
|
||||
impl<T> Register for T
|
||||
where
|
||||
T: Index<usize, Output = u32> + IndexMut<usize>,
|
||||
{
|
||||
impl<T> Register for T where T: Index<usize, Output = u32> + IndexMut<usize> {
|
||||
fn read64(&self, idx: usize) -> u64 {
|
||||
let addr_lsb = self[idx];
|
||||
let addr_msb = self[idx + 1];
|
||||
((addr_msb as u64) << 32) | addr_lsb as u64
|
||||
let lsb = self[idx] as u64;
|
||||
let msb = self[idx + 1] as u64;
|
||||
(msb << 32) | lsb
|
||||
}
|
||||
fn write64(&mut self, idx: usize, addr: u64) {
|
||||
self[idx] = (addr & 0xffffffff) as u32;
|
||||
self[idx + 1] = ((addr & (0xffffffff << 32)) >> 32) as u32;
|
||||
|
||||
fn write64(&mut self, idx: usize, value: u64) {
|
||||
self[idx] = (value & 0xffffffff) as u32;
|
||||
self[idx + 1] = ((value & (0xffffffff << 32)) >> 32) as u32;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VGPR {
|
||||
values: [[u32; 256]; 32],
|
||||
pub default_lane: Option<usize>,
|
||||
|
||||
@@ -1,21 +1,25 @@
|
||||
use crate::helpers::{extract_mantissa, f16_hi, f16_lo, ldexp, nth, sign_ext, IEEEClass, VOPModifier};
|
||||
use crate::helpers::{Colorize, DEBUG};
|
||||
use crate::helpers::DEBUG;
|
||||
use crate::state::{Register, Value, VecDataStore, WaveValue, VGPR};
|
||||
use crate::{print_instr, todo_instr};
|
||||
use crate::todo_instr;
|
||||
use half::{bf16, f16};
|
||||
use crate::rdna3::{Instruction, decode};
|
||||
use num_traits::Float;
|
||||
|
||||
const SGPR_COUNT: usize = 105;
|
||||
pub const SGPR_COUNT: usize = 128;
|
||||
pub const VCC: usize = 106;
|
||||
pub const EXEC: usize = 126;
|
||||
pub const NULL_SRC: usize = 124;
|
||||
pub const SGPR_SRC: usize = 105;
|
||||
|
||||
const VGPR_COUNT: usize = 256;
|
||||
const NULL_SRC: u32 = 124;
|
||||
const SIMM_SRC: usize = 255;
|
||||
|
||||
pub const END_PRG: u32 = 0xbfb00000;
|
||||
|
||||
pub struct Thread<'a> {
|
||||
pub scalar_reg: &'a mut Vec<u32>,
|
||||
pub scc: &'a mut u32,
|
||||
pub scalar_reg: &'a mut [u32; SGPR_COUNT],
|
||||
pub scc: &'a mut u32, // SCC is physically an sgpr, unclear which one
|
||||
|
||||
pub vec_reg: &'a mut VGPR,
|
||||
pub vcc: &'a mut WaveValue,
|
||||
@@ -36,19 +40,17 @@ impl<'a> Thread<'a> {
|
||||
pub fn interpret(&mut self) -> Result<(), i32> {
|
||||
let instruction = self.stream[self.pc_offset];
|
||||
let decoded = decode(self.stream[self.pc_offset], self.stream.get(self.pc_offset+1));
|
||||
if *DEBUG {
|
||||
print!("{:?}", decoded);
|
||||
}
|
||||
if let Instruction::SMEM { sbase, sdata, op, offset, soffset, .. } = decoded {
|
||||
let _ = self.u64_instr();
|
||||
let soffset = match self.val(soffset as usize) {
|
||||
NULL_SRC => 0,
|
||||
val => val,
|
||||
};
|
||||
|
||||
print_instr!("SMEM", sbase, sdata, op, offset, soffset);
|
||||
let soffset: u32 = self.val(soffset as usize);
|
||||
|
||||
// TODO: refactor vcc_lo to store in scalar register 106
|
||||
let base_addr = match sbase {
|
||||
106 => ((self.scalar_reg[107] as u64) << 32) | self.vcc.value as u64,
|
||||
_ => self.scalar_reg.read64(sbase as usize),
|
||||
let base_addr = match sbase as usize {
|
||||
VCC => ((self.scalar_reg[107] as u64) << 32) | self.vcc.value as u64,
|
||||
s => self.scalar_reg.read64(s),
|
||||
};
|
||||
let addr = (base_addr as i64 + offset as i64 + soffset as i64) as u64;
|
||||
|
||||
@@ -65,8 +67,6 @@ impl<'a> Thread<'a> {
|
||||
let src = ssrc0 as usize;
|
||||
let sdst = sdst as usize;
|
||||
|
||||
print_instr!("SOP1", src, sdst, op);
|
||||
|
||||
match op {
|
||||
1 => {
|
||||
let s0 = self.val(src);
|
||||
@@ -121,8 +121,6 @@ impl<'a> Thread<'a> {
|
||||
let s0 = ssrc0 as usize;
|
||||
let s1 = ssrc1 as usize;
|
||||
|
||||
print_instr!("SOPC", s0, s1, op);
|
||||
|
||||
fn scmp<T>(s0: T, s1: T, offset: u8, op: u8) -> bool
|
||||
where
|
||||
T: PartialOrd + PartialEq,
|
||||
@@ -162,7 +160,6 @@ impl<'a> Thread<'a> {
|
||||
self.scalar = true;
|
||||
}
|
||||
else if let Instruction::SOPP { simm16, op } = decoded {
|
||||
print_instr!("SOPP", simm16, op);
|
||||
|
||||
match op {
|
||||
32..=42 => {
|
||||
@@ -189,8 +186,6 @@ impl<'a> Thread<'a> {
|
||||
let sdst = sdst as usize;
|
||||
let s0: u32 = self.val(sdst);
|
||||
|
||||
print_instr!("SOPK", simm, sdst, s0, op);
|
||||
|
||||
match op {
|
||||
0 => self.write_to_sdst(sdst, simm as i16 as i32 as u32),
|
||||
3..=8 => {
|
||||
@@ -240,8 +235,6 @@ impl<'a> Thread<'a> {
|
||||
let s1 = ssrc1 as usize;
|
||||
let sdst = sdst as usize;
|
||||
|
||||
print_instr!("SOP2", s0, s1, sdst, op);
|
||||
|
||||
match op {
|
||||
23 | 25 | 27 => {
|
||||
let (s0, s1): (u64, u64) = (self.val(s0), self.val(s1));
|
||||
@@ -418,8 +411,6 @@ impl<'a> Thread<'a> {
|
||||
let opsel = [b(11), b(12), b(13)];
|
||||
let opsel_hi = [b(59), b(60), b(14)];
|
||||
|
||||
print_instr!("VOPP", op, vdst, src_parts, opsel, opsel_hi, neg, neg_hi);
|
||||
|
||||
match op {
|
||||
0..=18 => {
|
||||
let fxn = |x, y, z| -> u16 {
|
||||
@@ -554,8 +545,6 @@ impl<'a> Thread<'a> {
|
||||
let s0 = src as usize;
|
||||
let vdst = vdst as usize;
|
||||
|
||||
print_instr!("VOP1", s0, op, vdst);
|
||||
|
||||
match op {
|
||||
3 | 15 | 21 | 23 | 25 | 26 | 60 | 61 | 47 | 49 => {
|
||||
let s0: u64 = self.val(s0);
|
||||
@@ -723,8 +712,6 @@ impl<'a> Thread<'a> {
|
||||
// LSB is the opposite of VDSTX[0]
|
||||
let vdsty = (((instr >> 49) & 0x7f) << 1 | ((vdstx as u64 & 1) ^ 1)) as usize;
|
||||
|
||||
print_instr!("VOPD", opx, vdstx, sx, srcx0, vx, vsrcx1, opy, vdsty, sy, srcy0, vy, vsrcy1);
|
||||
|
||||
for (op, s0, s1, dst) in ([(opx, srcx0, vsrcx1, vdstx), (opy, srcy0, vsrcy1, vdsty)]).iter() {
|
||||
let ret = match *op {
|
||||
0 | 1 | 2 | 3 | 4 | 5 | 6 | 10 | 11 => {
|
||||
@@ -765,8 +752,6 @@ impl<'a> Thread<'a> {
|
||||
let s1 = vsrc as usize;
|
||||
let op = op as u32;
|
||||
|
||||
print_instr!("VOPC", s0, s1, op);
|
||||
|
||||
let dest_offset = if op >= 128 { 128 } else { 0 };
|
||||
let ret = match op {
|
||||
(0..=15) | 125 | (128..=143) => {
|
||||
@@ -835,8 +820,6 @@ impl<'a> Thread<'a> {
|
||||
let s1 = self.vec_reg[vsrc as usize];
|
||||
let vdst = vdst as usize;
|
||||
|
||||
print_instr!("VOP2", s0, s1, vdst, op);
|
||||
|
||||
match op {
|
||||
(50..=60) => {
|
||||
let (s0, s1) = (f16::from_bits(self.val(s0)), f16::from_bits(s1 as u16));
|
||||
@@ -951,8 +934,6 @@ impl<'a> Thread<'a> {
|
||||
assert_eq!(omod, 0);
|
||||
assert_eq!(clmp, 0);
|
||||
|
||||
print_instr!("VOPSD", vdst, sdst, op, s0, s1, s2);
|
||||
|
||||
let vcc = match op {
|
||||
766 => {
|
||||
let (s0, s1, s2): (u32, u32, u64) = (self.val(s0), self.val(s1), self.val(s2));
|
||||
@@ -1006,8 +987,8 @@ impl<'a> Thread<'a> {
|
||||
};
|
||||
|
||||
match sdst {
|
||||
106 => self.vcc.set_lane(vcc),
|
||||
124 => {}
|
||||
VCC => self.vcc.set_lane(vcc),
|
||||
NULL_SRC => {}
|
||||
_ => self.set_sgpr_co(sdst, vcc),
|
||||
}
|
||||
}
|
||||
@@ -1026,8 +1007,6 @@ impl<'a> Thread<'a> {
|
||||
assert_eq!(cm, 0);
|
||||
assert_eq!(opsel, 0);
|
||||
|
||||
print_instr!("VOP3", vdst, abs, opsel, op, src, neg);
|
||||
|
||||
match op {
|
||||
// VOPC using VOP3 encoding
|
||||
0..=255 => {
|
||||
@@ -1094,9 +1073,9 @@ impl<'a> Thread<'a> {
|
||||
};
|
||||
|
||||
match vdst {
|
||||
0..=SGPR_COUNT | 107 => self.set_sgpr_co(vdst, ret),
|
||||
106 => self.vcc.set_lane(ret),
|
||||
126 => self.exec.set_lane(ret),
|
||||
0..=SGPR_SRC | 107 => self.set_sgpr_co(vdst, ret),
|
||||
VCC => self.vcc.set_lane(ret),
|
||||
EXEC => self.exec.set_lane(ret),
|
||||
_ => todo_instr!(instruction)?,
|
||||
}
|
||||
}
|
||||
@@ -1403,8 +1382,6 @@ impl<'a> Thread<'a> {
|
||||
let data1 = ((instr >> 48) & 0xff) as usize;
|
||||
let vdst = ((instr >> 56) & 0xff) as usize;
|
||||
|
||||
print_instr!("LDS", op, addr, data0, data1, vdst);
|
||||
|
||||
let lds_base = self.vec_reg[addr];
|
||||
let single_addr = || (lds_base + (instr & 0xffff) as u32) as usize;
|
||||
let double_addr = |adj: u32| {
|
||||
@@ -1494,14 +1471,12 @@ impl<'a> Thread<'a> {
|
||||
let vdst = ((instr >> 56) & 0xff) as usize;
|
||||
|
||||
let saddr_val: u32 = self.val(saddr);
|
||||
let saddr_off = saddr_val == 0x7F || saddr as u32 == NULL_SRC;
|
||||
let saddr_off = saddr_val == 0x7F || saddr == NULL_SRC;
|
||||
|
||||
match seg {
|
||||
1 => {
|
||||
let sve = ((instr >> 50) & 0x1) != 0;
|
||||
|
||||
print_instr!("SCRATCH", offset, op, addr, data, saddr, vdst, sve);
|
||||
|
||||
let addr = match (sve, saddr_off) {
|
||||
(true, true) => offset as u64 as usize,
|
||||
(false, false) => saddr_val as usize,
|
||||
@@ -1520,8 +1495,6 @@ impl<'a> Thread<'a> {
|
||||
}
|
||||
}
|
||||
2 => {
|
||||
print_instr!("GLOBAL", offset, op, addr, data, saddr, vdst);
|
||||
|
||||
let addr = match saddr_off {
|
||||
true => self.vec_reg.read64(addr) as i64 + (offset as i64),
|
||||
false => {
|
||||
@@ -1562,7 +1535,6 @@ impl<'a> Thread<'a> {
|
||||
else if instruction >> 26 == 0b111000 {
|
||||
let instr = self.u64_instr();
|
||||
let op = ((instr >> 18) & 0x7f) as usize;
|
||||
print_instr!("MUBUF", op);
|
||||
match op {
|
||||
43 => {} // NOTE: remu doesn't have an l0 cache, it just has the software managed lds
|
||||
_ => todo_instr!(instruction)?,
|
||||
@@ -1704,11 +1676,10 @@ impl<'a> Thread<'a> {
|
||||
/* ALU utils */
|
||||
fn _common_srcs(&mut self, code: usize) -> u32 {
|
||||
match code {
|
||||
106 => self.vcc.value,
|
||||
VCC => self.vcc.value,
|
||||
107 => self.scalar_reg[code as usize],
|
||||
126 => self.exec.value,
|
||||
128 => 0,
|
||||
124 => NULL_SRC,
|
||||
EXEC => self.exec.value,
|
||||
NULL_SRC | 128 => 0,
|
||||
255 => match self.simm {
|
||||
None => {
|
||||
let val = self.stream[self.pc_offset + 1];
|
||||
@@ -1724,8 +1695,8 @@ impl<'a> Thread<'a> {
|
||||
fn write_to_sdst(&mut self, sdst_bf: usize, val: u32) {
|
||||
match sdst_bf {
|
||||
// NOTE: remu is only wave32, vcc_hi is treated as a regular SGPR
|
||||
0..=SGPR_COUNT | 107 => self.scalar_reg[sdst_bf] = val,
|
||||
106 => self.vcc.value = val,
|
||||
0..=SGPR_SRC | 107 => self.scalar_reg[sdst_bf] = val,
|
||||
VCC => self.vcc.value = val,
|
||||
126 => self.exec.value = val,
|
||||
_ => todo!("write to sdst {}", sdst_bf),
|
||||
}
|
||||
@@ -1782,7 +1753,7 @@ pub trait ALUSrc<T> {
|
||||
impl ALUSrc<u16> for Thread<'_> {
|
||||
fn val(&mut self, code: usize) -> u16 {
|
||||
match code {
|
||||
0..=SGPR_COUNT => self.scalar_reg[code] as u16,
|
||||
0..=SGPR_SRC => self.scalar_reg[code] as u16,
|
||||
VGPR_COUNT..=511 => self.vec_reg[code - VGPR_COUNT] as u16,
|
||||
129..=192 => (code - 128) as u16,
|
||||
193..=208 => ((code - 192) as i16 * -1) as u16,
|
||||
@@ -1810,7 +1781,7 @@ impl ALUSrc<u16> for Thread<'_> {
|
||||
impl ALUSrc<u32> for Thread<'_> {
|
||||
fn val(&mut self, code: usize) -> u32 {
|
||||
match code {
|
||||
0..=SGPR_COUNT => self.scalar_reg[code],
|
||||
0..=SGPR_SRC => self.scalar_reg[code],
|
||||
VGPR_COUNT..=511 => self.vec_reg[code - VGPR_COUNT],
|
||||
129..=192 => (code - 128) as u32,
|
||||
193..=208 => ((code - 192) as i32 * -1) as u32,
|
||||
@@ -1836,7 +1807,7 @@ impl ALUSrc<u32> for Thread<'_> {
|
||||
impl ALUSrc<u64> for Thread<'_> {
|
||||
fn val(&mut self, code: usize) -> u64 {
|
||||
match code {
|
||||
0..=SGPR_COUNT => self.scalar_reg.read64(code),
|
||||
0..=SGPR_SRC => self.scalar_reg.read64(code),
|
||||
VGPR_COUNT..=511 => self.vec_reg.read64(code - VGPR_COUNT),
|
||||
129..=192 => (code - 128) as u64,
|
||||
193..=208 => ((code - 192) as i64 * -1) as u64,
|
||||
@@ -1884,7 +1855,7 @@ mod test_alu_utils {
|
||||
fn test_write_to_sdst_vcc_val() {
|
||||
let mut thread = _helper_test_thread();
|
||||
let val = 0b1011101011011011111011101111;
|
||||
thread.write_to_sdst(106, val);
|
||||
thread.write_to_sdst(VCC, val);
|
||||
assert_eq!(thread.vcc.value, 195935983);
|
||||
}
|
||||
|
||||
@@ -1968,7 +1939,7 @@ mod test_smem {
|
||||
}
|
||||
let addr = buf.as_ptr() as u64;
|
||||
// NOTE: vcc is an alias for s[106:107]
|
||||
thread.scalar_reg.write64(106, addr);
|
||||
thread.scalar_reg.write64(VCC, addr);
|
||||
// TODO: vcc_lo should just read from s106
|
||||
thread.vcc.value = (addr & 0xffffffff) as u32;
|
||||
r(&vec![0xF4000035, 0xF8000000, END_PRG], &mut thread);
|
||||
@@ -3666,7 +3637,7 @@ fn r(prg: &Vec<u32>, thread: &mut Thread) {
|
||||
}
|
||||
fn _helper_test_thread() -> Thread<'static> {
|
||||
let static_lds: &'static mut VecDataStore = Box::leak(Box::new(VecDataStore::new()));
|
||||
let static_sgpr: &'static mut Vec<u32> = Box::leak(Box::new(vec![0; 256]));
|
||||
let static_sgpr: &'static mut [u32; SGPR_COUNT] = Box::leak(Box::new([0; SGPR_COUNT]));
|
||||
let static_vgpr: &'static mut VGPR = Box::leak(Box::new(VGPR::new()));
|
||||
let static_scc: &'static mut u32 = Box::leak(Box::new(0));
|
||||
let static_exec: &'static mut WaveValue = Box::leak(Box::new(WaveValue::new(u32::MAX, 32)));
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
use crate::helpers::{Colorize, DEBUG};
|
||||
use crate::helpers::{colored, DEBUG};
|
||||
use crate::state::{Register, VecDataStore, WaveValue, VGPR};
|
||||
use crate::thread::{Thread, END_PRG};
|
||||
use crate::thread::{Thread, END_PRG, SGPR_COUNT};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub const WAVE_SIZE: usize = 32;
|
||||
|
||||
pub struct WorkGroup<'a> {
|
||||
dispatch_dim: u32,
|
||||
id: [u32; 3],
|
||||
@@ -13,12 +15,13 @@ pub struct WorkGroup<'a> {
|
||||
wave_state: HashMap<usize, WaveState>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct WaveState {
|
||||
scalar_reg: Vec<u32>,
|
||||
scalar_reg: [u32; SGPR_COUNT],
|
||||
scc: u32,
|
||||
vec_reg: VGPR,
|
||||
vcc: WaveValue,
|
||||
exec: WaveValue,
|
||||
vec_reg: VGPR,
|
||||
pc: usize,
|
||||
sds: HashMap<usize, VecDataStore>,
|
||||
}
|
||||
@@ -33,27 +36,19 @@ const BARRIERS: [[u32; 2]; 5] = [
|
||||
];
|
||||
impl<'a> WorkGroup<'a> {
|
||||
pub fn new(dispatch_dim: u32, id: [u32; 3], launch_bounds: [u32; 3], kernel: &'a Vec<u32>, kernel_args: *const u64) -> Self {
|
||||
return Self {
|
||||
dispatch_dim,
|
||||
id,
|
||||
kernel,
|
||||
launch_bounds,
|
||||
kernel_args,
|
||||
lds: VecDataStore::new(),
|
||||
wave_state: HashMap::new(),
|
||||
};
|
||||
Self { dispatch_dim, id, kernel, launch_bounds, kernel_args, lds: VecDataStore::new(), wave_state: HashMap::new() }
|
||||
}
|
||||
|
||||
pub fn exec_waves(&mut self) -> Result<(), i32> {
|
||||
let mut blocks = vec![];
|
||||
let mut threads = vec![];
|
||||
for z in 0..self.launch_bounds[2] {
|
||||
for y in 0..self.launch_bounds[1] {
|
||||
for x in 0..self.launch_bounds[0] {
|
||||
blocks.push([x, y, z])
|
||||
threads.push([x, y, z])
|
||||
}
|
||||
}
|
||||
}
|
||||
let waves = blocks.chunks(32).map(|w| w.to_vec()).collect::<Vec<_>>();
|
||||
let waves = threads.chunks(WAVE_SIZE).collect::<Vec<_>>();
|
||||
|
||||
let mut sync = false;
|
||||
for (i, x) in self.kernel.iter().enumerate() {
|
||||
@@ -62,6 +57,7 @@ impl<'a> WorkGroup<'a> {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for _ in 0..=(sync as usize) {
|
||||
for w in waves.iter().enumerate() {
|
||||
self.exec_wave(w)?
|
||||
@@ -70,65 +66,47 @@ impl<'a> WorkGroup<'a> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn exec_wave(&mut self, (wave_id, threads): (usize, &Vec<[u32; 3]>)) -> Result<(), i32> {
|
||||
let wave_state = self.wave_state.get(&wave_id);
|
||||
let mut sds = match wave_state {
|
||||
Some(val) => val.sds.clone(),
|
||||
None => {
|
||||
let mut sds = HashMap::new();
|
||||
for i in 0..=31 {
|
||||
sds.insert(i, VecDataStore::new());
|
||||
}
|
||||
sds
|
||||
fn exec_wave(&mut self, (wave_id, threads): (usize, &&[[u32; 3]])) -> Result<(), i32> {
|
||||
let (mut scalar_reg, mut scc, mut pc, mut vec_reg, mut vcc, mut exec, mut sds) = match self.wave_state.get(&wave_id) {
|
||||
None => {
|
||||
let mut scalar_reg = [0; SGPR_COUNT];
|
||||
scalar_reg.write64(0, self.kernel_args as u64);
|
||||
|
||||
let [gx, gy, gz] = self.id;
|
||||
match self.dispatch_dim {
|
||||
3 => (scalar_reg[13], scalar_reg[14], scalar_reg[15]) = (gx, gy, gz),
|
||||
2 => (scalar_reg[14], scalar_reg[15]) = (gx, gy),
|
||||
_ => scalar_reg[15] = gx,
|
||||
}
|
||||
};
|
||||
let (mut scalar_reg, mut scc, mut pc) = match wave_state {
|
||||
Some(val) => (val.scalar_reg.to_vec(), val.scc, val.pc),
|
||||
None => {
|
||||
let mut scalar_reg = vec![0; 256];
|
||||
scalar_reg.write64(0, self.kernel_args as u64);
|
||||
let [gx, gy, gz] = self.id;
|
||||
match self.dispatch_dim {
|
||||
3 => (scalar_reg[13], scalar_reg[14], scalar_reg[15]) = (gx, gy, gz),
|
||||
2 => (scalar_reg[14], scalar_reg[15]) = (gx, gy),
|
||||
_ => scalar_reg[15] = gx,
|
||||
}
|
||||
(scalar_reg, 0, 0)
|
||||
}
|
||||
};
|
||||
let (mut vec_reg, mut vcc) = match wave_state {
|
||||
Some(val) => (val.vec_reg.clone(), val.vcc.clone()),
|
||||
None => (VGPR::new(), WaveValue::new(0, threads.len())),
|
||||
};
|
||||
let mut exec = match wave_state {
|
||||
Some(val) => val.exec.clone(),
|
||||
None => {
|
||||
let active = match threads.len() == 32 {
|
||||
true => u32::MAX,
|
||||
false => (1 << threads.len()) - 1,
|
||||
};
|
||||
WaveValue::new(active, threads.len())
|
||||
|
||||
let mut vec_reg = VGPR::new();
|
||||
for (t, [x, y, z]) in threads.iter().enumerate() {
|
||||
vec_reg.get_lane_mut(t)[0] = match &self.launch_bounds {
|
||||
[_, 1, 1] => *x,
|
||||
_ => (z << 20) | (y << 10) | x,
|
||||
}
|
||||
}
|
||||
|
||||
let vcc = WaveValue::new(0, threads.len());
|
||||
let active = (!0u32).wrapping_shr(32 - (threads.len() as u32));
|
||||
let exec = WaveValue::new(active, threads.len());
|
||||
|
||||
let sds = (0..=31).map(|i| (i, VecDataStore::new())).collect();
|
||||
(scalar_reg, 0, 0, vec_reg, vcc, exec, sds)
|
||||
}
|
||||
|
||||
Some(val) => {
|
||||
let val = val.clone();
|
||||
(val.scalar_reg, val.scc, val.pc, val.vec_reg, val.vcc, val.exec, val.sds)
|
||||
}
|
||||
};
|
||||
|
||||
let mut seeded_lanes = vec![];
|
||||
loop {
|
||||
if self.kernel[pc] == END_PRG {
|
||||
break Ok(());
|
||||
}
|
||||
if BARRIERS.contains(&[self.kernel[pc], self.kernel[pc + 1]]) && wave_state.is_none() {
|
||||
self.wave_state.insert(
|
||||
wave_id,
|
||||
WaveState {
|
||||
scalar_reg,
|
||||
scc,
|
||||
vec_reg,
|
||||
vcc,
|
||||
exec,
|
||||
pc,
|
||||
sds,
|
||||
},
|
||||
);
|
||||
if BARRIERS.contains(&[self.kernel[pc], self.kernel[pc + 1]]) && self.wave_state.get(&wave_id).is_none() {
|
||||
self.wave_state.insert(wave_id, WaveState { scalar_reg, scc, vec_reg, vcc, exec, pc, sds });
|
||||
break Ok(());
|
||||
}
|
||||
if SYNCS.contains(&self.kernel[pc]) || self.kernel[pc] >> 20 == 0xbf8 || self.kernel[pc] == 0x7E000000 {
|
||||
@@ -148,14 +126,7 @@ impl<'a> WorkGroup<'a> {
|
||||
false => "gray",
|
||||
};
|
||||
let [id0, id1, id2] = self.id;
|
||||
print!("[{id0:<3} {id1:<3} {id2:<3}] [{x:<3} {y:<3} {z:<3}] {}", lane.color(state));
|
||||
}
|
||||
if !seeded_lanes.contains(&lane_id) && self.wave_state.get(&wave_id).is_none() {
|
||||
match (self.launch_bounds[1] != 1, self.launch_bounds[2] != 1) {
|
||||
(false, false) => vec_reg[0] = *x,
|
||||
_ => vec_reg[0] = (z << 20) | (y << 10) | x,
|
||||
}
|
||||
seeded_lanes.push(lane_id);
|
||||
print!("[{id0:<3} {id1:<3} {id2:<3}] [{x:<3} {y:<3} {z:<3}] {}", colored(&lane, state));
|
||||
}
|
||||
let mut thread = Thread {
|
||||
scalar_reg: &mut scalar_reg,
|
||||
|
||||
Reference in New Issue
Block a user