add tvm example, formatting (#1813)

* add tvm example

* no realize
This commit is contained in:
George Hotz
2023-09-07 11:50:41 -07:00
committed by GitHub
parent 5b15a972b5
commit 4613c9e77c
7 changed files with 54 additions and 8 deletions

46
extra/gemm/tvm_gemm.py Normal file
View File

@@ -0,0 +1,46 @@
# https://tvm.apache.org/docs/tutorial/tensor_expr_get_started.html#example-2-manually-optimizing-matrix-multiplication-with-te
import tvm
from tvm import te
#print(tvm.target.Target.list_kinds())
M, N, K = 1024, 1024, 1024
# c, opencl
target = tvm.target.Target(target="c")
# TVM Matrix Multiplication using TE
k = te.reduce_axis((0, K), "k")
A = te.placeholder((M, K), name="A")
B = te.placeholder((K, N), name="B")
C = te.compute((M, N), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name="C")
# Default schedule
s = te.create_schedule(C.op)
#print(tvm.lower(s, [A, B, C], simple_mode=True))
# Output C code
func = tvm.build(s, [A, B, C], target=target, name="mmult")
print(func.get_source())
# tinygrad version
import os
from tinygrad.tensor import Tensor
# disable optimizations
os.environ["NOOPT"] = "1"
# define the compute
A = Tensor.rand(M, K, device="clang")
B = Tensor.rand(K, N, device="clang")
C = (A.reshape(M, 1, K) * B.permute(1,0).reshape(1, N, K)).sum(axis=2)
# capture the kernel. TODO: https://github.com/tinygrad/tinygrad/issues/1812
from tinygrad.jit import CacheCollector
CacheCollector.start()
C.realize()
result = CacheCollector.finish()
print(result[0][0].prg)

View File

@@ -23,10 +23,10 @@ class LinearizerOptions(NamedTuple):
local_max: Optional[List[int]] = None
class Kernel:
def __init__(self, ast:LazyOp, output_buffer:LazyBuffer, opts:LinearizerOptions):
def __init__(self, ast:LazyOp, output_buffer:LazyBuffer, opts:Optional[LinearizerOptions]=None):
# NOTE: if there's a RESHAPE, we skip it. the output shape is set from the reduce op or a latebuf
self.ast = ast.src[0] if ast.op == MovementOps.RESHAPE else ast
self.opts = opts
self.opts = opts if opts else LinearizerOptions()
# get the output buffers
self.bufs = [output_buffer] + dedup(ast.buffers)

View File

@@ -79,7 +79,7 @@ class CStyleLanguage(NamedTuple):
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
self.arg_int_prefix if dtype == dtypes._arg_int32 else
("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)]
prg = ''.join([f"{self.kernel_prefix} void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] +
prg = ''.join([f"{self.kernel_prefix}void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] +
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
if self.half_prekernel and any(dtype == dtypes.float16 for _,dtype in bufs): prg = ''.join([f"{self.half_prekernel}", "\n", prg])
@@ -101,7 +101,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
kernel,prekernel = [],[]
#pend_close = None
bufs = []
depth = 0
depth = 1
def kk(s): kernel.append(" "*depth+s)
c: DefaultDict[str, int] = defaultdict(int)

View File

@@ -12,7 +12,7 @@ ARM64 = getenv('ARM64', False)
if CI and ARM64: from unicorn import Uc, UC_ARCH_ARM64, UC_MODE_ARM, UC_HOOK_CODE, arm64_const # type: ignore
args = {
'Windows': {'cflags':'', 'ext':'dll', 'exp':'__declspec(dllexport)'},
'Windows': {'cflags':'', 'ext':'dll', 'exp':'__declspec(dllexport) '},
'Linux': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'so', 'exp':''},
'Darwin': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'dylib', 'exp':''}
}[platform.system()]

View File

@@ -86,7 +86,7 @@ class CUDAProgram:
return start.time_till(end)*1e-3
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
kernel_prefix = "__global__", smem_prefix = "__shared__ ", arg_int_prefix = "const int", barrier = "__syncthreads();", float4 = "make_float4",
kernel_prefix = "__global__ ", smem_prefix = "__shared__ ", arg_int_prefix = "const int", barrier = "__syncthreads();", float4 = "make_float4",
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)],
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)],
half_prekernel = """

View File

@@ -104,7 +104,7 @@ class CLProgram:
return None
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ", arg_int_prefix = "const int",
kernel_prefix = "__kernel ", buffer_prefix = "__global ", smem_prefix = "__local ", arg_int_prefix = "const int",
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable",
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
gid = [f'get_group_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True))

View File

@@ -81,7 +81,7 @@ class MetalProgram:
METAL.mtl_buffers_in_flight.append(command_buffer)
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ", arg_int_prefix = "constant int&",
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel ", buffer_prefix = "device ", smem_prefix = "threadgroup ", arg_int_prefix = "constant int&",
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);", float4 = "float4", uses_ptr_arithmetic=True,
gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)],
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']))