mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
46
extra/gemm/tvm_gemm.py
Normal file
46
extra/gemm/tvm_gemm.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# https://tvm.apache.org/docs/tutorial/tensor_expr_get_started.html#example-2-manually-optimizing-matrix-multiplication-with-te
|
||||
import tvm
|
||||
from tvm import te
|
||||
#print(tvm.target.Target.list_kinds())
|
||||
|
||||
M, N, K = 1024, 1024, 1024
|
||||
|
||||
# c, opencl
|
||||
target = tvm.target.Target(target="c")
|
||||
|
||||
# TVM Matrix Multiplication using TE
|
||||
k = te.reduce_axis((0, K), "k")
|
||||
A = te.placeholder((M, K), name="A")
|
||||
B = te.placeholder((K, N), name="B")
|
||||
C = te.compute((M, N), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name="C")
|
||||
|
||||
# Default schedule
|
||||
s = te.create_schedule(C.op)
|
||||
#print(tvm.lower(s, [A, B, C], simple_mode=True))
|
||||
|
||||
# Output C code
|
||||
func = tvm.build(s, [A, B, C], target=target, name="mmult")
|
||||
print(func.get_source())
|
||||
|
||||
# tinygrad version
|
||||
|
||||
import os
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
# disable optimizations
|
||||
os.environ["NOOPT"] = "1"
|
||||
|
||||
# define the compute
|
||||
A = Tensor.rand(M, K, device="clang")
|
||||
B = Tensor.rand(K, N, device="clang")
|
||||
C = (A.reshape(M, 1, K) * B.permute(1,0).reshape(1, N, K)).sum(axis=2)
|
||||
|
||||
# capture the kernel. TODO: https://github.com/tinygrad/tinygrad/issues/1812
|
||||
from tinygrad.jit import CacheCollector
|
||||
CacheCollector.start()
|
||||
C.realize()
|
||||
result = CacheCollector.finish()
|
||||
|
||||
print(result[0][0].prg)
|
||||
|
||||
|
||||
@@ -23,10 +23,10 @@ class LinearizerOptions(NamedTuple):
|
||||
local_max: Optional[List[int]] = None
|
||||
|
||||
class Kernel:
|
||||
def __init__(self, ast:LazyOp, output_buffer:LazyBuffer, opts:LinearizerOptions):
|
||||
def __init__(self, ast:LazyOp, output_buffer:LazyBuffer, opts:Optional[LinearizerOptions]=None):
|
||||
# NOTE: if there's a RESHAPE, we skip it. the output shape is set from the reduce op or a latebuf
|
||||
self.ast = ast.src[0] if ast.op == MovementOps.RESHAPE else ast
|
||||
self.opts = opts
|
||||
self.opts = opts if opts else LinearizerOptions()
|
||||
|
||||
# get the output buffers
|
||||
self.bufs = [output_buffer] + dedup(ast.buffers)
|
||||
|
||||
@@ -79,7 +79,7 @@ class CStyleLanguage(NamedTuple):
|
||||
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
|
||||
self.arg_int_prefix if dtype == dtypes._arg_int32 else
|
||||
("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)]
|
||||
prg = ''.join([f"{self.kernel_prefix} void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] +
|
||||
prg = ''.join([f"{self.kernel_prefix}void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] +
|
||||
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
|
||||
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
|
||||
if self.half_prekernel and any(dtype == dtypes.float16 for _,dtype in bufs): prg = ''.join([f"{self.half_prekernel}", "\n", prg])
|
||||
@@ -101,7 +101,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
|
||||
kernel,prekernel = [],[]
|
||||
#pend_close = None
|
||||
bufs = []
|
||||
depth = 0
|
||||
depth = 1
|
||||
def kk(s): kernel.append(" "*depth+s)
|
||||
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
|
||||
@@ -12,7 +12,7 @@ ARM64 = getenv('ARM64', False)
|
||||
if CI and ARM64: from unicorn import Uc, UC_ARCH_ARM64, UC_MODE_ARM, UC_HOOK_CODE, arm64_const # type: ignore
|
||||
|
||||
args = {
|
||||
'Windows': {'cflags':'', 'ext':'dll', 'exp':'__declspec(dllexport)'},
|
||||
'Windows': {'cflags':'', 'ext':'dll', 'exp':'__declspec(dllexport) '},
|
||||
'Linux': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'so', 'exp':''},
|
||||
'Darwin': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'dylib', 'exp':''}
|
||||
}[platform.system()]
|
||||
|
||||
@@ -86,7 +86,7 @@ class CUDAProgram:
|
||||
return start.time_till(end)*1e-3
|
||||
|
||||
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
|
||||
kernel_prefix = "__global__", smem_prefix = "__shared__ ", arg_int_prefix = "const int", barrier = "__syncthreads();", float4 = "make_float4",
|
||||
kernel_prefix = "__global__ ", smem_prefix = "__shared__ ", arg_int_prefix = "const int", barrier = "__syncthreads();", float4 = "make_float4",
|
||||
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)],
|
||||
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)],
|
||||
half_prekernel = """
|
||||
|
||||
@@ -104,7 +104,7 @@ class CLProgram:
|
||||
return None
|
||||
|
||||
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
|
||||
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ", arg_int_prefix = "const int",
|
||||
kernel_prefix = "__kernel ", buffer_prefix = "__global ", smem_prefix = "__local ", arg_int_prefix = "const int",
|
||||
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable",
|
||||
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
|
||||
gid = [f'get_group_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True))
|
||||
|
||||
@@ -81,7 +81,7 @@ class MetalProgram:
|
||||
METAL.mtl_buffers_in_flight.append(command_buffer)
|
||||
|
||||
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
|
||||
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ", arg_int_prefix = "constant int&",
|
||||
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel ", buffer_prefix = "device ", smem_prefix = "threadgroup ", arg_int_prefix = "constant int&",
|
||||
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);", float4 = "float4", uses_ptr_arithmetic=True,
|
||||
gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)],
|
||||
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']))
|
||||
|
||||
Reference in New Issue
Block a user