mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Bitshift (#4728)
* WIP * Cleanup * Cleanup * Fix variable, refactor to use set * right shift should be signed/unsigned * Test for bitshifts * Allow a neg
This commit is contained in:
@@ -2,7 +2,7 @@ from typing import Optional, Tuple, Any, List
|
||||
import unittest, math
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.helpers import CI, getenv
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
|
||||
@@ -81,12 +81,12 @@ class TestUOps(unittest.TestCase):
|
||||
a = dtypes.as_const(a, dts[0])
|
||||
self._equal(f([a], op, dts), fxn(a))
|
||||
|
||||
def _test_bop_fxn(self, op, fxn, dts=(dtypes.float32, )*2, no_b_zero=False):
|
||||
def _test_bop_fxn(self, op, fxn, dts=(dtypes.float32, )*2, no_b_zero=False, no_b_neg=False):
|
||||
for f in [_test_single_value, _test_single_value_const]:
|
||||
for a in [-2.0, 0.0, 1.0]:
|
||||
for b in [-3.0, 1.0] + ([] if no_b_zero else [0.0]):
|
||||
a = dtypes.as_const(a, dts[0])
|
||||
b = dtypes.as_const(b, dts[1])
|
||||
b = dtypes.as_const(abs(b) if no_b_neg else b, dts[1])
|
||||
self._equal(f([a,b], op, dts), fxn(a,b))
|
||||
|
||||
def _test_top_fxn(self, op, fxn, dts=(dtypes.float32, )*3):
|
||||
@@ -122,6 +122,10 @@ class TestNonFloatUOps(TestUOps):
|
||||
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_sub_int32(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: int(a)-int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), (dtypes.int32, dtypes.int32))
|
||||
@unittest.skipUnless(getenv("PTX"), "only ptx uses bitshifts")
|
||||
def test_shr_int32(self): self._test_bop_fxn(BinaryOps.SHR, lambda a,b: int(a)>>int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
|
||||
@unittest.skipUnless(getenv("PTX"), "only ptx uses bitshifts")
|
||||
def test_shl_int32(self): self._test_bop_fxn(BinaryOps.SHL, lambda a,b: int(a)<<int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
|
||||
def test_div_int32(self):
|
||||
self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), (dtypes.int32, dtypes.int32), no_b_zero=True)
|
||||
def test_mod_int32(self):
|
||||
|
||||
Reference in New Issue
Block a user