diff --git a/extra/assembly/assembly.py b/extra/assembly/assembly.py index f6f0289a36..560908d4af 100644 --- a/extra/assembly/assembly.py +++ b/extra/assembly/assembly.py @@ -97,7 +97,7 @@ def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]): lang.ins.clear() lang.tor.clear() lang.cnts.clear() - buf_to_dtype = {args[0]:args[1] for uop,_,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL} + buf_to_dtype = {args:dtype for uop,dtype,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL} global_size, local_size = [], [] skipload_branch = 0 lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype] diff --git a/extra/triton/triton.py b/extra/triton/triton.py index 9b48f3140c..21a66a2d96 100644 --- a/extra/triton/triton.py +++ b/extra/triton/triton.py @@ -98,8 +98,8 @@ def uops_to_triton(function_name:str, uops:List[UOp]): kk(f"{'if '+r[vin[3]]+': ' if len(vin)>3 else ''}tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ") elif uop == UOps.DEFINE_GLOBAL: bufs.append(args) - signatures.append("*" if isinstance(args[1], PtrDType) else "" + signature_dtypes[args[1]]) - r[u] = args[0] + signatures.append("*" if isinstance(dtype, PtrDType) else "" + signature_dtypes[dtype]) + r[u] = args elif uop == UOps.SPECIAL: dims.append(args[1]) valid.append(f"{args[1]}<{get_max(args[2])}") @@ -111,7 +111,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]): elif uop == UOps.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype, isinstance(args, tuple) and args[1]) else: raise NotImplementedError(f"unimplemented: {uop}") - prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(f"{buf[0]}" for buf in bufs)+"):\n" + prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(bufs)+"):\n" for i, line in enumerate(list(filter(lambda line: "tl.arange" in line, kernel))): kernel[kernel.index(line)] += f"[{', '.join([':' if i == j else 'None' for j in range(len(local_size))])}]" prg += "\n".join(kernel) diff --git a/test/test_uops.py b/test/test_uops.py index ed07469d2b..c863565bd7 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -20,8 +20,8 @@ def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], ar def _test_single_value(vals, op, dtype): uops = [] - buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), ('data0', dtype)) - buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), (f'data{i+1}', dtype)) for i in range(len(vals))] + buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), 'data0') + buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), f'data{i+1}') for i in range(len(vals))] loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i in range(len(vals))) alu = uop(uops, UOps.ALU, dtype, loads, op) uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) @@ -33,7 +33,7 @@ def _test_single_value(vals, op, dtype): def _test_single_value_const(vals, op, dtype): uops = [] - buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), ('data0', dtype)) + buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), 'data0') loads = (uop(uops, UOps.CONST, dtype, [], a) for a in vals) alu = uop(uops, UOps.ALU, dtype, loads, op) uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 75c10fe4ca..c900b6ab38 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -184,11 +184,11 @@ class Linearizer(Kernel): # add global buffers for i,buf in enumerate(self.bufs): if isinstance(buf, MemBuffer): - self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, dtype:=PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", dtype)) # noqa: E501 + self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), f"data{buf.idx}") # add var vals for var in vars_from_ast(self.ast): assert var.expr is not None - self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes.int32)) + self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), var.expr) # define local buffers for lb in self.local_alias.values(): self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size())) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 59e7a90bdb..deaff7b38c 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -199,8 +199,9 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu kk(lang.render_local(args[0], args[1])) r[u] = args[0] elif uop == UOps.DEFINE_GLOBAL: - bufs.append(args) - r[u] = args[0] + assert dtype is not None + bufs.append((args, dtype)) + r[u] = args elif uop == UOps.GEP: if cast(DType, vin[0].dtype).sz > 4: r[u] = f"({r[vin[0]]})[{args}]" # this is correct for HIP diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 490f83676c..271607d9e4 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -79,11 +79,11 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]: module = ir.Module(name=__file__) # extract global buffers - buf_to_dtype = {u.arg[0]:u.arg[1] for u in uops if u.uop == UOps.DEFINE_GLOBAL} + buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop == UOps.DEFINE_GLOBAL} buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())} # create llvm function - func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values()] + func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values() if dtype is not None] func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=function_name) # noqa: E501 for a in func.args: if a.type.is_pointer: a.add_attribute("noalias") @@ -125,7 +125,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]: bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}"))) bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), block._block, bb[-1]._block) if uop == UOps.DEFINE_GLOBAL: - lvars[u] = func.args[buf_index[args[0]]] + lvars[u] = func.args[buf_index[args]] if uop == UOps.DEFINE_ACC: lvars[u] = const(args, dtype) reduce_phis.append(u)