pass _buf to program

This commit is contained in:
Szymon Ożóg
2023-08-16 06:44:42 +02:00
parent 207cd697bf
commit 034273726c

View File

@@ -38,7 +38,8 @@ class TritonProgram:
self.program = cuda.module_from_buffer(self.program.encode('utf-8')).get_function(self.program.split(".visible .entry ")[1].split("(")[0])
def __call__(self, global_size, local_size, *args, wait=False) -> Any:
self.program(*[x for x in args], block = tuple(local_size), grid = tuple(global_size))
self.program(*[x._buf for x in args], block = tuple(local_size), grid = tuple(global_size))
def uops_to_triton(function_name:str, uops:List[UOp]):
kernel = []