diff --git a/extra/triton/triton.py b/extra/triton/triton.py index d1ff8149d4..9b48f3140c 100644 --- a/extra/triton/triton.py +++ b/extra/triton/triton.py @@ -1,7 +1,7 @@ from typing import Dict, List, Final, Callable, DefaultDict from collections import defaultdict from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Op -from tinygrad.helpers import DType, dtypes, ImageDType, DEBUG, getenv +from tinygrad.helpers import DType, PtrDType, dtypes, ImageDType, DEBUG, getenv from tinygrad.codegen.linearizer import UOp, UOps from triton.compiler import compile as triton_compile import linecache @@ -9,7 +9,7 @@ import math import re triton_dtypes = {dtypes.double: "tl.float64", dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.bool: "tl.int1", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64", dtypes.uint32: "tl.uint32", dtypes.uint64: "tl.uint64", dtypes.int16: "tl.int16", dtypes.uint16: "tl.uint16"} -signature_dtypes = {dtypes.double: "*fp64",dtypes.float32: "*fp32", dtypes.float16: "*fp16", dtypes.bool: "*i8", dtypes.int8: "*i1", dtypes.uint8: "*u8", dtypes._arg_int32: "i32", dtypes.int32: "*i32", dtypes.int64: "*i64", dtypes.uint32: "*u32", dtypes.uint64: "*u64", dtypes.int16: "*i16", dtypes.uint16: "*u16"} +signature_dtypes = {dtypes.double: "fp64",dtypes.float32: "fp32", dtypes.float16: "fp16", dtypes.bool: "i8", dtypes.int8: "i1", dtypes.uint8: "u8", dtypes.int32: "i32", dtypes.int64: "i64", dtypes.uint32: "u32", dtypes.uint64: "u64", dtypes.int16: "i16", dtypes.uint16: "u16"} def next_power_of_2(x): return 1 << (x - 1).bit_length() @@ -98,7 +98,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]): kk(f"{'if '+r[vin[3]]+': ' if len(vin)>3 else ''}tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ") elif uop == UOps.DEFINE_GLOBAL: bufs.append(args) - signatures.append(signature_dtypes[args[1]]) + signatures.append("*" if isinstance(args[1], PtrDType) else "" + signature_dtypes[args[1]]) r[u] = args[0] elif uop == UOps.SPECIAL: dims.append(args[1]) diff --git a/test/test_dtype.py b/test/test_dtype.py index c9fe72da39..aaf364c21c 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -236,8 +236,7 @@ class TestTypeSpec(unittest.TestCase): assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type # assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int -# TODO: better way to write a set of core dtypes? -core_types = [d for d in DTYPES_DICT.values() if d not in [dtypes._arg_int32]] +core_types = list(DTYPES_DICT.values()) class TestTypePromotion(unittest.TestCase): @given(st.sampled_from(core_types)) def test_self_promo_to_self(self, dtype): diff --git a/test/test_uops.py b/test/test_uops.py index 1e0d275ff1..ed07469d2b 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -46,18 +46,18 @@ class TestUOps(unittest.TestCase): def _equal(self, v1, v2): if not (math.isnan(v1) and math.isnan(v2)): self.assertAlmostEqual(v1, v2, places=5) if v1.dtype != np.bool_ else self.assertEqual(v1, v2) - def _test_uop_fxn(self, bop, fxn, dt=dtypes.float32): + def _test_uop_fxn(self, bop, fxn, dt=PtrDType(dtypes.float32)): for f in [_test_single_value, _test_single_value_const]: for a in [-2.0, 0.0, 1.0]: self._equal(f([a], bop, dt), fxn(a)) - def _test_bop_fxn(self, bop, fxn, dt=dtypes.float32, no_b_zero=False): + def _test_bop_fxn(self, bop, fxn, dt=PtrDType(dtypes.float32), no_b_zero=False): for f in [_test_single_value, _test_single_value_const]: for a in [-2.0, 0.0, 1.0]: for b in [-3.0, 1.0] + ([] if no_b_zero else [0.0]): self._equal(f([a,b], bop, dt), fxn(a,b)) - def _test_top_fxn(self, bop, fxn, dt=dtypes.float32): + def _test_top_fxn(self, bop, fxn, dt=PtrDType(dtypes.float32)): for f in [_test_single_value, _test_single_value_const]: for a in [-2.0, 0, 1]: for b in [-3.0, 3.0]: @@ -88,15 +88,15 @@ class TestFloatUOps(TestUOps): # TODO: fix this on all the backends @unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or getenv('ARM64', False), "only test for compiled backends, broken on some") class TestNonFloatUOps(TestUOps): - def test_neg_int32(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a, dtypes.int32) - def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), dtypes.int32) - def test_sub_int32(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: int(a)-int(b), dtypes.int32) - def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), dtypes.int32) - def test_div_int32(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), dtypes.int32, no_b_zero=True) - def test_mod_int32(self): self._test_bop_fxn(BinaryOps.MOD, lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], dtypes.int32, no_b_zero=True) - def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a str: tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else "" # noqa: E501 buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else - self.arg_int_prefix if dtype == dtypes._arg_int32 else - ("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)] + ("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else + self.arg_int_prefix if dtype == dtypes.int else None) for i,(name,dtype) in enumerate(bufs)] prg = ''.join([f"{self.kernel_prefix}void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] + [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] + [") {\n" + tmp] + ['\n'.join(kernel), "\n}"]) @@ -351,7 +351,7 @@ class WGSLLanguage(CStyleLanguage): local_size = local_size[::-1] if local_size else [1] bind_it = iter(range(len(bufs))) prg = "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast(bits); }\n" - prg += "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) {'var' if dtype == dtypes._arg_int32 else 'var'} {name}: {'i32' if dtype == dtypes._arg_int32 else f'array<{self.type_map[dtype]}>'};" for name,dtype in bufs]) # noqa: E501 + prg += "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) {'var' if isinstance(dtype, PtrDType) else 'var'} {name}: {f'array<{self.type_map[dtype]}>' if isinstance(dtype, PtrDType) else 'i32'};" for name,dtype in bufs]) # noqa: E501 prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3, @builtin(local_invocation_id) lindex: vec3) {{\n" + "\n".join(kernel) + "\n}" # noqa: E501 return prg diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 1a523db043..490f83676c 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -1,7 +1,7 @@ from typing import Final, Dict, Callable, Any, List, Optional, Tuple from llvmlite import ir from tinygrad.codegen.linearizer import UOps, UOp -from tinygrad.helpers import DType, dtypes +from tinygrad.helpers import DType, PtrDType, dtypes from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf @@ -33,7 +33,7 @@ code_for_op: Final[Dict[Op, Callable]] = { dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), - dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16), + dtypes.int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16), dtypes.uint32:ir.IntType(32), dtypes.uint64:ir.IntType(64)} def cast(bb, val, input_type, output_type, bitcast=False): @@ -84,7 +84,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]: # create llvm function func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values()] - func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if dt!=dtypes._arg_int32 else x for x,dt in func_dtypes]), name=function_name) # noqa: E501 + func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=function_name) # noqa: E501 for a in func.args: if a.type.is_pointer: a.add_attribute("noalias") @@ -99,7 +99,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]: lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type for bufname,dtype in buf_to_dtype.items(): - if dtype == dtypes._arg_int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32)) + if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32)) for u in uops: uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg