diff --git a/extra/remu/src/helpers.rs b/extra/remu/src/helpers.rs index 0534ea217e..09deca329f 100644 --- a/extra/remu/src/helpers.rs +++ b/extra/remu/src/helpers.rs @@ -55,10 +55,7 @@ where { fn negate(&self, pos: usize, modifier: usize) -> T { match (modifier >> pos) & 1 { - 1 => match self.is_zero() { - true => T::zero(), - false => -*self, - }, + 1 => -*self, _ => *self, } } @@ -122,7 +119,7 @@ mod tests { assert_eq!(0.3_f32.negate(2, 0b100), -0.3_f32); assert_eq!(0.3_f32.negate(0, 0b110), 0.3_f32); assert_eq!(0.3_f32.negate(1, 0b010), -0.3_f32); - assert_eq!(0.0_f32.negate(0, 0b001).to_bits(), 0); + assert_eq!(0.0_f32.negate(0, 0b001).to_bits(), (-0.0f32).to_bits()); assert_eq!((-0.0_f32).negate(0, 0b001).to_bits(), 0); } diff --git a/extra/remu/src/thread.rs b/extra/remu/src/thread.rs index e1a3e552b5..4cb84a561d 100644 --- a/extra/remu/src/thread.rs +++ b/extra/remu/src/thread.rs @@ -1246,7 +1246,7 @@ impl<'a> Thread<'a> { } let ret = match op { - 257 | 259 | 299 | 260 | 261 | 264 | 272 | 392 | 531 | 537 | 540 | 551 | 567 | 796 => { + 257 | 259 | 299 | 260 | 261 | 264 | 272 | 392 | 426 | 531 | 537 | 540 | 551 | 567 | 796 => { let s0 = f32::from_bits(s0).negate(0, neg).absolute(0, abs); let s1 = f32::from_bits(s1).negate(1, neg).absolute(1, abs); let s2 = f32::from_bits(s2).negate(2, neg).absolute(2, abs); @@ -1257,6 +1257,7 @@ impl<'a> Thread<'a> { 264 => s0 * s1, 272 => f32::max(s0, s1), 299 => f32::mul_add(s0, s1, f32::from_bits(self.vec_reg[vdst])), + 426 => s0.recip(), 531 => f32::mul_add(s0, s1, s2), 537 => f32::min(f32::min(s0, s1), s2), 540 => f32::max(f32::max(s0, s1), s2), @@ -3032,7 +3033,7 @@ mod test_vop3 { #[test] fn test_v_cndmask_b32_e64_neg() { - [[0.0f32, 0.0], [-0.0f32, 0.0], [1.0f32, -1.0], [-1.0f32, 1.0]].iter().for_each(|[input, ret]| { + [[0.0f32, -0.0], [-0.0f32, 0.0], [1.0f32, -1.0], [-1.0f32, 1.0]].iter().for_each(|[input, ret]| { let mut thread = _helper_test_thread(); thread.scalar_reg[0] = false as u32; thread.vec_reg[3] = input.to_bits(); diff --git a/extra/remu/test/hwtest.py b/extra/remu/test/hwtest.py index 9fe4196275..0cc2ba0fc2 100644 --- a/extra/remu/test/hwtest.py +++ b/extra/remu/test/hwtest.py @@ -1,6 +1,6 @@ import numpy as np import unittest -import subprocess, struct +import subprocess, struct, math from typing import cast from tinygrad.runtime.ops_amd import AMDProgram, AMDDevice from tinygrad import Tensor, dtypes, Device @@ -95,6 +95,8 @@ def get_output(s:str, n_threads:int=1): return test.numpy() def f16_to_bits(x:float) -> int: return struct.unpack(' float: return struct.unpack(' int: return struct.unpack('