mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
transformer kvcache and mask have same dtype as input (#2771)
* transformer kvcache and mask have same dtype as input * don't use `=0` in cstyle ternary where * (bool) * where float16 test
This commit is contained in:
@@ -7,6 +7,7 @@ from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.device import CompiledASTRunner, Compiled
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from test.test_dtype import is_dtype_supported
|
||||
|
||||
def _uops_to_prg(uops):
|
||||
src, runtime_args = Device[Device.DEFAULT].renderer("test", uops)
|
||||
@@ -95,8 +96,10 @@ class TestNonFloatUOps(TestUOps):
|
||||
def test_div_int32(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), PtrDType(dtypes.int32), no_b_zero=True)
|
||||
def test_mod_int32(self): self._test_bop_fxn(BinaryOps.MOD, lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], PtrDType(dtypes.int32), no_b_zero=True)
|
||||
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), PtrDType(dtypes.int32))
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "no bool storage buffer on webgpu")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bool), "dtype not supported")
|
||||
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), PtrDType(dtypes.bool))
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "dtype not supported")
|
||||
def test_where_float16(self): self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, PtrDType(dtypes.float16))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user