transformer kvcache and mask have same dtype as input (#2771)

* transformer kvcache and mask have same dtype as input

* don't use `=0` in cstyle ternary where

* (bool)

* where float16 test
This commit is contained in:
chenyu
2023-12-14 22:41:51 -05:00
committed by GitHub
parent 2dd0dd4ae0
commit c0f76ed4ea
4 changed files with 13 additions and 13 deletions

View File

@@ -7,6 +7,7 @@ from tinygrad.device import Buffer, Device
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.device import CompiledASTRunner, Compiled
from tinygrad.codegen.linearizer import UOps, UOp
from test.test_dtype import is_dtype_supported
def _uops_to_prg(uops):
src, runtime_args = Device[Device.DEFAULT].renderer("test", uops)
@@ -95,8 +96,10 @@ class TestNonFloatUOps(TestUOps):
def test_div_int32(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), PtrDType(dtypes.int32), no_b_zero=True)
def test_mod_int32(self): self._test_bop_fxn(BinaryOps.MOD, lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], PtrDType(dtypes.int32), no_b_zero=True)
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), PtrDType(dtypes.int32))
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "no bool storage buffer on webgpu")
@unittest.skipUnless(is_dtype_supported(dtypes.bool), "dtype not supported")
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), PtrDType(dtypes.bool))
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "dtype not supported")
def test_where_float16(self): self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, PtrDType(dtypes.float16))
if __name__ == '__main__':
unittest.main(verbosity=2)