diff --git a/extra/gemm/tvm_gemm.py b/extra/gemm/tvm_gemm.py new file mode 100644 index 0000000000..92c5b16b3d --- /dev/null +++ b/extra/gemm/tvm_gemm.py @@ -0,0 +1,46 @@ +# https://tvm.apache.org/docs/tutorial/tensor_expr_get_started.html#example-2-manually-optimizing-matrix-multiplication-with-te +import tvm +from tvm import te +#print(tvm.target.Target.list_kinds()) + +M, N, K = 1024, 1024, 1024 + +# c, opencl +target = tvm.target.Target(target="c") + +# TVM Matrix Multiplication using TE +k = te.reduce_axis((0, K), "k") +A = te.placeholder((M, K), name="A") +B = te.placeholder((K, N), name="B") +C = te.compute((M, N), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name="C") + +# Default schedule +s = te.create_schedule(C.op) +#print(tvm.lower(s, [A, B, C], simple_mode=True)) + +# Output C code +func = tvm.build(s, [A, B, C], target=target, name="mmult") +print(func.get_source()) + +# tinygrad version + +import os +from tinygrad.tensor import Tensor + +# disable optimizations +os.environ["NOOPT"] = "1" + +# define the compute +A = Tensor.rand(M, K, device="clang") +B = Tensor.rand(K, N, device="clang") +C = (A.reshape(M, 1, K) * B.permute(1,0).reshape(1, N, K)).sum(axis=2) + +# capture the kernel. TODO: https://github.com/tinygrad/tinygrad/issues/1812 +from tinygrad.jit import CacheCollector +CacheCollector.start() +C.realize() +result = CacheCollector.finish() + +print(result[0][0].prg) + + diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 216630277b..49fb5eb1a4 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -23,10 +23,10 @@ class LinearizerOptions(NamedTuple): local_max: Optional[List[int]] = None class Kernel: - def __init__(self, ast:LazyOp, output_buffer:LazyBuffer, opts:LinearizerOptions): + def __init__(self, ast:LazyOp, output_buffer:LazyBuffer, opts:Optional[LinearizerOptions]=None): # NOTE: if there's a RESHAPE, we skip it. the output shape is set from the reduce op or a latebuf self.ast = ast.src[0] if ast.op == MovementOps.RESHAPE else ast - self.opts = opts + self.opts = opts if opts else LinearizerOptions() # get the output buffers self.bufs = [output_buffer] + dedup(ast.buffers) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 19a481127e..1645d7bc7e 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -79,7 +79,7 @@ class CStyleLanguage(NamedTuple): buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else self.arg_int_prefix if dtype == dtypes._arg_int32 else ("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)] - prg = ''.join([f"{self.kernel_prefix} void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] + + prg = ''.join([f"{self.kernel_prefix}void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] + [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] + [") {\n" + tmp] + ['\n'.join(kernel), "\n}"]) if self.half_prekernel and any(dtype == dtypes.float16 for _,dtype in bufs): prg = ''.join([f"{self.half_prekernel}", "\n", prg]) @@ -101,7 +101,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st kernel,prekernel = [],[] #pend_close = None bufs = [] - depth = 0 + depth = 1 def kk(s): kernel.append(" "*depth+s) c: DefaultDict[str, int] = defaultdict(int) diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index bb53ab7a94..2959f75634 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -12,7 +12,7 @@ ARM64 = getenv('ARM64', False) if CI and ARM64: from unicorn import Uc, UC_ARCH_ARM64, UC_MODE_ARM, UC_HOOK_CODE, arm64_const # type: ignore args = { - 'Windows': {'cflags':'', 'ext':'dll', 'exp':'__declspec(dllexport)'}, + 'Windows': {'cflags':'', 'ext':'dll', 'exp':'__declspec(dllexport) '}, 'Linux': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'so', 'exp':''}, 'Darwin': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'dylib', 'exp':''} }[platform.system()] diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 88eac03ccf..68d2a1d8ba 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -86,7 +86,7 @@ class CUDAProgram: return start.time_till(end)*1e-3 renderer = functools.partial(uops_to_cstyle, CStyleLanguage( - kernel_prefix = "__global__", smem_prefix = "__shared__ ", arg_int_prefix = "const int", barrier = "__syncthreads();", float4 = "make_float4", + kernel_prefix = "__global__ ", smem_prefix = "__shared__ ", arg_int_prefix = "const int", barrier = "__syncthreads();", float4 = "make_float4", gid = [f'blockIdx.{chr(120+i)}' for i in range(3)], lid = [f'threadIdx.{chr(120+i)}' for i in range(3)], half_prekernel = """ diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 656defc468..23310dd7e3 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -104,7 +104,7 @@ class CLProgram: return None renderer = functools.partial(uops_to_cstyle, CStyleLanguage( - kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ", arg_int_prefix = "const int", + kernel_prefix = "__kernel ", buffer_prefix = "__global ", smem_prefix = "__local ", arg_int_prefix = "const int", half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable", barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)", gid = [f'get_group_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True)) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index fffae4ba34..efc9e56d9f 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -81,7 +81,7 @@ class MetalProgram: METAL.mtl_buffers_in_flight.append(command_buffer) renderer = functools.partial(uops_to_cstyle, CStyleLanguage( - kernel_prefix = "#include \nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ", arg_int_prefix = "constant int&", + kernel_prefix = "#include \nusing namespace metal;\nkernel ", buffer_prefix = "device ", smem_prefix = "threadgroup ", arg_int_prefix = "constant int&", barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);", float4 = "float4", uses_ptr_arithmetic=True, gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)], extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']))