remove duplicated dtype in DEFINE_GLOBAL args (#2768)

now DEFINE_GLOBAL uop.arg[1] is always the same as uop.dtype, we can remove the one in arg and just use uop.dtype
This commit is contained in:
chenyu
2023-12-14 15:42:36 -05:00
committed by GitHub
parent 5235cdee3d
commit 57017c87e9
6 changed files with 15 additions and 14 deletions

View File

@@ -97,7 +97,7 @@ def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
lang.ins.clear()
lang.tor.clear()
lang.cnts.clear()
buf_to_dtype = {args[0]:args[1] for uop,_,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL}
buf_to_dtype = {args:dtype for uop,dtype,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL}
global_size, local_size = [], []
skipload_branch = 0
lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]

View File

@@ -98,8 +98,8 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
kk(f"{'if '+r[vin[3]]+': ' if len(vin)>3 else ''}tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ")
elif uop == UOps.DEFINE_GLOBAL:
bufs.append(args)
signatures.append("*" if isinstance(args[1], PtrDType) else "" + signature_dtypes[args[1]])
r[u] = args[0]
signatures.append("*" if isinstance(dtype, PtrDType) else "" + signature_dtypes[dtype])
r[u] = args
elif uop == UOps.SPECIAL:
dims.append(args[1])
valid.append(f"{args[1]}<{get_max(args[2])}")
@@ -111,7 +111,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
elif uop == UOps.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype, isinstance(args, tuple) and args[1])
else: raise NotImplementedError(f"unimplemented: {uop}")
prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(f"{buf[0]}" for buf in bufs)+"):\n"
prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(bufs)+"):\n"
for i, line in enumerate(list(filter(lambda line: "tl.arange" in line, kernel))): kernel[kernel.index(line)] += f"[{', '.join([':' if i == j else 'None' for j in range(len(local_size))])}]"
prg += "\n".join(kernel)

View File

@@ -20,8 +20,8 @@ def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], ar
def _test_single_value(vals, op, dtype):
uops = []
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), ('data0', dtype))
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), (f'data{i+1}', dtype)) for i in range(len(vals))]
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), 'data0')
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), f'data{i+1}') for i in range(len(vals))]
loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i in range(len(vals)))
alu = uop(uops, UOps.ALU, dtype, loads, op)
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
@@ -33,7 +33,7 @@ def _test_single_value(vals, op, dtype):
def _test_single_value_const(vals, op, dtype):
uops = []
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), ('data0', dtype))
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), 'data0')
loads = (uop(uops, UOps.CONST, dtype, [], a) for a in vals)
alu = uop(uops, UOps.ALU, dtype, loads, op)
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))

View File

@@ -184,11 +184,11 @@ class Linearizer(Kernel):
# add global buffers
for i,buf in enumerate(self.bufs):
if isinstance(buf, MemBuffer):
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, dtype:=PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", dtype)) # noqa: E501
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), f"data{buf.idx}")
# add var vals
for var in vars_from_ast(self.ast):
assert var.expr is not None
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes.int32))
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), var.expr)
# define local buffers
for lb in self.local_alias.values():
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size()))

View File

@@ -199,8 +199,9 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
kk(lang.render_local(args[0], args[1]))
r[u] = args[0]
elif uop == UOps.DEFINE_GLOBAL:
bufs.append(args)
r[u] = args[0]
assert dtype is not None
bufs.append((args, dtype))
r[u] = args
elif uop == UOps.GEP:
if cast(DType, vin[0].dtype).sz > 4:
r[u] = f"({r[vin[0]]})[{args}]" # this is correct for HIP

View File

@@ -79,11 +79,11 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
module = ir.Module(name=__file__)
# extract global buffers
buf_to_dtype = {u.arg[0]:u.arg[1] for u in uops if u.uop == UOps.DEFINE_GLOBAL}
buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop == UOps.DEFINE_GLOBAL}
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
# create llvm function
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values()]
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values() if dtype is not None]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=function_name) # noqa: E501
for a in func.args:
if a.type.is_pointer: a.add_attribute("noalias")
@@ -125,7 +125,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), block._block, bb[-1]._block)
if uop == UOps.DEFINE_GLOBAL:
lvars[u] = func.args[buf_index[args[0]]]
lvars[u] = func.args[buf_index[args]]
if uop == UOps.DEFINE_ACC:
lvars[u] = const(args, dtype)
reduce_phis.append(u)