diff --git a/test/test_uops.py b/test/test_uops.py index b939f80cf9..7cfffe85b4 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -22,8 +22,8 @@ def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], ar def _test_single_value(vals, op, dts): uops = [] output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0] - buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0')) - buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), (i+1, f'data{i+1}')) for i,dtype in enumerate(dts)] + buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0',True)) + buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), (i+1, f'data{i+1}',False)) for i,dtype in enumerate(dts)] loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts)) alu = uop(uops, UOps.ALU, output_dtype, loads, op) uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) @@ -38,7 +38,7 @@ def _test_single_value(vals, op, dts): def _test_single_value_const(vals, op, dts): uops = [] output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0] - buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0')) + buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0',True)) loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts)) alu = uop(uops, UOps.ALU, output_dtype, loads, op) uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 3930206ce4..3cc08a012b 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -195,7 +195,7 @@ class Linearizer(Kernel): if isinstance(buf, MemBuffer): self.buf_uops[i] = self.uops.add(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), - (buf.idx, f"data{buf.idx}")) + (buf.idx, f"data{buf.idx}", i == 0)) # add var vals for i,var in enumerate(self.ast.vars()): assert var.expr is not None diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 6f286b6882..e4bd5b3d7f 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -62,11 +62,12 @@ class CStyleLanguage(NamedTuple): out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]" return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val - def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], uops:List[UOp], prefix=None) -> str: - tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else "" # noqa: E501 - buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else - ("const " if i > 0 else "")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else - self.arg_int_prefix if dtype == dtypes.int else None) for i,(name,dtype) in enumerate(bufs)] + def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], + local_size:List[int], uops:List[UOp], prefix=None) -> str: + tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501 + buftypes = [(name,f"{'write_only' if mutable else 'read_only'} image2d_t" if dtype.name.startswith('image') else + ("" if mutable else "const ")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else + self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs] prg = ''.join([f"{self.kernel_prefix}void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] + [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] + [") {\n" + tmp] + ['\n'.join(kernel), "\n}"]) @@ -90,7 +91,7 @@ class CStyleLanguage(NamedTuple): def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> str: local_size: List[int] = [] kernel = [] - bufs: List[Tuple[str, DType]] = [] + bufs: List[Tuple[str, Tuple[DType, bool]]] = [] #pend_close = None depth = 1 def kk(s): kernel.append(" "*depth+s) @@ -162,11 +163,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st kk(lang.render_local(args[0], dtype, args[1])) r[u] = args[0] elif uop is UOps.DEFINE_VAR: - bufs.append((args.expr, dtype)) + bufs.append((args.expr, (dtype,False))) r[u] = args.expr elif uop is UOps.DEFINE_GLOBAL: assert len(bufs) == args[0], f"missed a global buffer {len(bufs)} {args}" - bufs.append((args[1], dtype)) + bufs.append((args[1], (dtype,args[2]))) r[u] = args[1] elif uop is UOps.WMMA: kk(f"{dtype.name} {ssa(u, 'wmma')} = {args}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});") elif uop is UOps.DEFINE_ACC: kk(f"{dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};")