mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
@@ -1,5 +1,5 @@
|
||||
use half::f16;
|
||||
use num_traits::{float::FloatCore, PrimInt, Unsigned};
|
||||
use num_traits::{float::FloatCore, PrimInt, Unsigned, clamp};
|
||||
|
||||
pub fn bits<T>(word: T, hi: usize, lo: usize) -> T where T: PrimInt + Unsigned {
|
||||
assert!(hi >= lo);
|
||||
@@ -48,6 +48,7 @@ impl IEEEClass<u64> for f64 {
|
||||
pub trait VOPModifier<T> {
|
||||
fn negate(&self, pos: usize, modifier: usize) -> T;
|
||||
fn absolute(&self, pos: usize, modifier: usize) -> T;
|
||||
fn clmp(&self, cm: bool) -> T;
|
||||
}
|
||||
impl<T> VOPModifier<T> for T
|
||||
where
|
||||
@@ -65,6 +66,11 @@ where
|
||||
_ => *self,
|
||||
}
|
||||
}
|
||||
fn clmp(&self, cm:bool) -> T {
|
||||
if !cm { return *self }
|
||||
let r = clamp(*self, T::zero(), T::one());
|
||||
if r == T::zero() { T::zero() } else { r }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_mantissa(x: f64) -> f64 {
|
||||
|
||||
@@ -1024,7 +1024,7 @@ impl<'a> Thread<'a> {
|
||||
let vdst = (instr & 0xff) as usize;
|
||||
let abs = ((instr >> 8) & 0x7) as usize;
|
||||
let opsel = ((instr >> 11) & 0xf) as usize;
|
||||
let cm = (instr >> 15) & 0x1;
|
||||
let cm = ((instr >> 15) & 0x1) != 0;
|
||||
|
||||
let s = |n: usize| ((instr >> n) & 0x1ff) as usize;
|
||||
let src = (s(32), s(41), s(50));
|
||||
@@ -1032,7 +1032,9 @@ impl<'a> Thread<'a> {
|
||||
let omod = (instr >> 59) & 0x3;
|
||||
let neg = ((instr >> 61) & 0x7) as usize;
|
||||
assert_eq!(omod, 0);
|
||||
assert_eq!(cm, 0);
|
||||
if op != 272 && cm {
|
||||
return todo_instr!(op); // TODO: add VOP3 clamp for all ops
|
||||
}
|
||||
assert_eq!(opsel, 0);
|
||||
|
||||
match op {
|
||||
@@ -1266,7 +1268,7 @@ impl<'a> Thread<'a> {
|
||||
}
|
||||
|
||||
let ret = match op {
|
||||
257 | 259 | 299 | 260 | 261 | 264 | 272 | 392 | 426 | 430 | 531 | 537 | 540 | 551 | 567 | 796 => {
|
||||
257 | 259 | 299 | 260 | 261 | 264 | 272 | 392 | 426 | 430 | 531 | 537 | 540 | 543 | 551 | 567 | 606 | 796 => {
|
||||
let s0 = f32::from_bits(s0).negate(0, neg).absolute(0, abs);
|
||||
let s1 = f32::from_bits(s1).negate(1, neg).absolute(1, abs);
|
||||
let s2 = f32::from_bits(s2).negate(2, neg).absolute(2, abs);
|
||||
@@ -1275,12 +1277,26 @@ impl<'a> Thread<'a> {
|
||||
260 => s0 - s1,
|
||||
261 => s1 - s0,
|
||||
264 => s0 * s1,
|
||||
272 => f32::max(s0, s1),
|
||||
272 => f32::max(s0, s1).clmp(cm),
|
||||
299 => f32::mul_add(s0, s1, f32::from_bits(self.vec_reg[vdst])),
|
||||
426 => s0.recip(),
|
||||
430 => 1.0 / f32::sqrt(s0),
|
||||
531 => f32::mul_add(s0, s1, s2),
|
||||
537 => f32::min(f32::min(s0, s1), s2),
|
||||
543 => {
|
||||
if s0.is_nan() || s1.is_nan() || s2.is_nan() {
|
||||
f32::min(f32::min(s0, s1), s2)
|
||||
} else {
|
||||
let max = f32::max(f32::max(s0, s1), s2);
|
||||
if max == s0 {
|
||||
f32::max(s1, s2)
|
||||
} else if max == s1 {
|
||||
f32::max(s0, s2)
|
||||
} else {
|
||||
f32::max(s0, s1)
|
||||
}
|
||||
}
|
||||
},
|
||||
540 => f32::max(f32::max(s0, s1), s2),
|
||||
551 => s2 / s1,
|
||||
567 => {
|
||||
@@ -1290,6 +1306,7 @@ impl<'a> Thread<'a> {
|
||||
false => ret,
|
||||
}
|
||||
}
|
||||
606 => f32::min(f32::max(s0, s1), s2),
|
||||
796 => s0 * 2f32.powi(s1.to_bits() as i32),
|
||||
// cnd_mask isn't a float only ALU but supports neg
|
||||
257 => {
|
||||
|
||||
Reference in New Issue
Block a user