diff --git a/examples/llm.c/export.py b/examples/llm.c/export.py index 4312997d57..e4a90bf443 100755 --- a/examples/llm.c/export.py +++ b/examples/llm.c/export.py @@ -48,7 +48,7 @@ if __name__ == "__main__": for ast in ast_dedup: k = Device["CLANG"].get_linearizer(*ast) k.linearize() - src = Device["CLANG"].compiler.render(to_function_name(k.name), k.uops) + src = Device["CLANG"].compiler.compiler_opts.renderer(to_function_name(k.name), k.uops) srcs[ast] = (k.name, src) print("functions:", len(srcs)) used_buffers = dedup(flatten([si.bufs for si in sched])) diff --git a/test/external/external_test_hip_compile.py b/test/external/external_test_hip_compile.py index cb7a7a19ac..e4ed054a1e 100644 --- a/test/external/external_test_hip_compile.py +++ b/test/external/external_test_hip_compile.py @@ -32,7 +32,7 @@ class TestHIPCompileSpeed(unittest.TestCase): compile_hip(code) return (time.perf_counter() - st) * 1000 - tinygrad_tm = min([time_compile(Device[Device.DEFAULT].compiler.render(f"test{i}", lin.uops)) for i in range(10)]) + tinygrad_tm = min([time_compile(Device[Device.DEFAULT].compiler.compiler_opts.renderer(f"test{i}", lin.uops)) for i in range(10)]) ref_tm = min([time_compile(reference.format(name=f"test{i}")) for i in range(10)]) print(f"tinygrad {tinygrad_tm:6.2f} ms") print(f"reference {ref_tm:6.2f} ms") diff --git a/test/test_device_speed.py b/test/test_device_speed.py index cbb64f7228..d77b1316dd 100644 --- a/test/test_device_speed.py +++ b/test/test_device_speed.py @@ -7,7 +7,7 @@ class TestDeviceSpeed(unittest.TestCase): @classmethod def setUpClass(cls): cls.dev = Device[Device.DEFAULT] - cls.empty = Device[Device.DEFAULT].compiler.render("test", UOpGraph()) + cls.empty = Device[Device.DEFAULT].compiler.compiler_opts.renderer("test", UOpGraph()) def test_empty_compile(self): with Timing("compiler "): diff --git a/test/test_uops.py b/test/test_uops.py index b8c82740ed..55efe1ad61 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -12,7 +12,7 @@ from tinygrad.codegen.uops import exec_alu, UOpGraph from test.helpers import is_dtype_supported def _uops_to_prg(uops): - src = Device[Device.DEFAULT].compiler.render("test", uops) + src = Device[Device.DEFAULT].compiler.compiler_opts.renderer("test", uops) has_local = Device[Device.DEFAULT].compiler.compiler_opts.has_local return CompiledRunner(Program("test", src, Device.DEFAULT, [1,1,1] if has_local else None, [1,1,1] if has_local else None, uops=uops)) diff --git a/tinygrad/device.py b/tinygrad/device.py index f1137a381d..cc355199ed 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -1,7 +1,7 @@ from __future__ import annotations import multiprocessing from dataclasses import dataclass, replace -from typing import TYPE_CHECKING, List, Optional, Dict, Tuple, ClassVar +from typing import TYPE_CHECKING, List, Optional, Dict, Tuple, ClassVar, Callable import importlib, inspect, functools, pathlib, os from tinygrad.helpers import prod, getenv, all_int, to_function_name, diskcache_get, diskcache_put, DEBUG, BEAM, NOOPT from tinygrad.shape.symbolic import Variable, sym_infer, sint @@ -54,6 +54,8 @@ class Runner: # **************** for Compiled Devices **************** +def fake_renderer(name, uops): raise NotImplementedError("needs a renderer") + @dataclass(frozen=True) class CompilerOptions: device: str = "" @@ -67,11 +69,11 @@ class CompilerOptions: global_max: Optional[List[int]] = None local_max: Optional[List[int]] = None shared_max: int = 32768 + renderer: Callable = fake_renderer class Compiler: compiler_opts: ClassVar[CompilerOptions] def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey - def render(self, name:str, uops:UOpGraph) -> str: raise NotImplementedError("need a render function") def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function") def compile_cached(self, src:str) -> bytes: if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None: @@ -85,7 +87,8 @@ class Compiler: ops, mem = k.uops.flops_mem() run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else [])) # NOTE: we use min here to ignore the indexing FLOPS - return Program(k.name, self.render(to_function_name(k.name), k.uops), override_device if override_device else self.compiler_opts.device, + return Program(k.name, self.compiler_opts.renderer(to_function_name(k.name), k.uops), + override_device if override_device else self.compiler_opts.device, k.global_size, k.local_size, k.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count)) @dataclass(frozen=True) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 71f13fc9a8..b25e61c5e5 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -1,5 +1,5 @@ from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict, cast, Literal, Callable -import math, functools +import math from collections import defaultdict, Counter from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps @@ -179,7 +179,7 @@ class ClangLanguage(CStyleLanguage): buffer_suffix = " restrict" type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"} code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"(({a}>{b})?{a}:{b})"} -ClangRenderer = functools.partial(uops_to_cstyle, ClangLanguage()) +def ClangRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(ClangLanguage(), name, uops) class OpenCLLanguage(CStyleLanguage): kernel_prefix = "__kernel " @@ -197,7 +197,7 @@ class OpenCLLanguage(CStyleLanguage): def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str: if any(uop.dtype == dtypes.half for uop in uops): prefix = ["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] return super().render_kernel(function_name, kernel, bufs, uops, prefix) -OpenCLRenderer = functools.partial(uops_to_cstyle, OpenCLLanguage()) +def OpenCLRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(OpenCLLanguage(), name, uops) class MetalLanguage(CStyleLanguage): kernel_prefix = "kernel " @@ -227,7 +227,7 @@ class MetalLanguage(CStyleLanguage): b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c); return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""") return super().render_kernel(function_name, kernel, bufs, uops, prefix) -MetalRenderer = functools.partial(uops_to_cstyle, MetalLanguage()) +def MetalRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(MetalLanguage(), name, uops) code_for_op_half = {BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype in (dtypes.half, dtypes.bfloat16) else f"max({a},{b})", UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})", @@ -271,7 +271,7 @@ asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }} return c;}}""") return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix) -CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage()) +def CUDARenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(CUDALanguage(), name, uops) code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", @@ -358,4 +358,4 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { # NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))" -HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage()) +def HIPRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(HIPLanguage(), name, uops) diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index f533f72e0b..7e29853dc9 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -69,11 +69,10 @@ def create_sdma_packets(): sdma_pkts = create_sdma_packets() class AMDCompiler(Compiler): - compiler_opts = CompilerOptions("AMD", has_tensor_cores=True, shared_max=65536) + compiler_opts = CompilerOptions("AMD", has_tensor_cores=True, shared_max=65536, renderer=HIPRenderer) def __init__(self, arch:str): self.arch = arch super().__init__(f"compile_hip_{self.arch}") - def render(self, name:str, uops) -> str: return HIPRenderer(name, uops) def compile(self, src:str) -> bytes: return compile_hip(src, self.arch) SDMA_MAX_COPY_SIZE = 0x400000 diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index 1f6b3f88f6..a25f5ab7ee 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -5,8 +5,7 @@ from tinygrad.helpers import cpu_time_execution from tinygrad.renderer.cstyle import ClangRenderer class ClangCompiler(Compiler): - compiler_opts = CompilerOptions("CLANG", supports_float4=False, has_local=False) - def render(self, name:str, uops) -> str: return ClangRenderer(name, uops) + compiler_opts = CompilerOptions("CLANG", supports_float4=False, has_local=False, renderer=ClangRenderer) def compile(self, src:str) -> bytes: # TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here with tempfile.NamedTemporaryFile(delete=True) as output_file: diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 5a4bf6bc6d..c48f58b752 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -54,17 +54,18 @@ def _get_bytes(arg, get_str, get_sz, check) -> bytes: return ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value) class PTXCompiler(Compiler): - compiler_opts = CompilerOptions("CUDA", suffix="PTX", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], shared_max=49152) + compiler_opts = CompilerOptions("CUDA", suffix="PTX", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], + shared_max=49152, renderer=PTXRenderer) def __init__(self, arch:str): self.arch = arch self.version = "7.8" if arch >= "sm_89" else "7.5" PTXCompiler.compiler_opts = replace(PTXCompiler.compiler_opts, has_tensor_cores=int(arch[3:]) >= 80) super().__init__(f"compile_ptx_{self.arch}") - def render(self, name:str, uops) -> str: return PTXRenderer(name, uops).replace("TARGET", self.arch).replace("VERSION", self.version) - def compile(self, src:str) -> bytes: return src.encode() + def compile(self, src:str) -> bytes: return src.replace("TARGET", self.arch).replace("VERSION", self.version).encode() class CUDACompiler(Compiler): - compiler_opts = CompilerOptions("CUDA", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], shared_max=49152) + compiler_opts = CompilerOptions("CUDA", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], + shared_max=49152, renderer=CUDARenderer) def __init__(self, arch:str): self.arch = arch CUDACompiler.compiler_opts = replace(CUDACompiler.compiler_opts, has_tensor_cores=int(arch[3:]) >= 80) @@ -72,7 +73,6 @@ class CUDACompiler(Compiler): self.compile_options = [f'--gpu-architecture={arch}', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"] if (nvrtcMajor.value, nvrtcMinor.value) >= (12, 4): self.compile_options.append("--minimal") super().__init__(f"compile_cuda_{self.arch}") - def render(self, name:str, uops) -> str: return CUDARenderer(name, uops) def compile(self, src:str) -> bytes: check(cuda.nvrtcCreateProgram(ctypes.byref(prog := cuda.nvrtcProgram()), src.encode(), "".encode(), 0, None, None)) status = cuda.nvrtcCompileProgram(prog, len(self.compile_options), to_char_p_p([o.encode() for o in self.compile_options])) diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 8de6c98f58..534bfe9938 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -15,11 +15,10 @@ def check(status): def checked(ret, status): return (check(status.value), ret)[1] class CLCompiler(Compiler): - compiler_opts = CompilerOptions("GPU") + compiler_opts = CompilerOptions("GPU", renderer=OpenCLRenderer) def __init__(self, device:CLDevice, compile_key:str): self.device = device super().__init__(f"compile_cl_{compile_key}") - def render(self, name:str, uops) -> str: return OpenCLRenderer(name, uops) def compile(self, src:str) -> bytes: program = checked(cl.clCreateProgramWithSource(self.device.context, 1, to_char_p_p([src.encode()]), None, status := ctypes.c_int32()), status) build_status: int = cl.clBuildProgram(program, 1, self.device.device_id, None, cl.clBuildProgram.argtypes[4](), None) diff --git a/tinygrad/runtime/ops_hsa.py b/tinygrad/runtime/ops_hsa.py index ed22db4129..1425fec92e 100644 --- a/tinygrad/runtime/ops_hsa.py +++ b/tinygrad/runtime/ops_hsa.py @@ -43,11 +43,10 @@ class HSAProfiler: Profiler = HSAProfiler() class HSACompiler(Compiler): - compiler_opts = CompilerOptions("HSA", has_tensor_cores=True, shared_max=65536) + compiler_opts = CompilerOptions("HSA", has_tensor_cores=True, shared_max=65536, renderer=HIPRenderer) def __init__(self, arch:str): self.arch = arch super().__init__(f"compile_hip_{self.arch}") - def render(self, name:str, uops) -> str: return HIPRenderer(name, uops) def compile(self, src:str) -> bytes: return compile_hip(src, self.arch) class HSAProgram: diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index f52cf05287..df2d69b856 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -8,11 +8,10 @@ from tinygrad.renderer.llvmir import uops_to_llvm_ir import llvmlite.binding as llvm class LLVMCompiler(Compiler): - compiler_opts = CompilerOptions("LLVM", supports_float4=False, has_local=False, has_shared=False) + compiler_opts = CompilerOptions("LLVM", supports_float4=False, has_local=False, has_shared=False, renderer=uops_to_llvm_ir) def __init__(self, device:LLVMDevice): self.device = device super().__init__("compile_llvm") - def render(self, name:str, uops) -> str: return uops_to_llvm_ir(name, uops) def compile(self, src:str) -> bytes: mod = llvm.parse_assembly(src) mod.verify() diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index c35c40ccd0..13a588817e 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -13,11 +13,10 @@ def wait_check(cbuf: Any): raise RuntimeError(error) class MetalCompiler(Compiler): - compiler_opts = CompilerOptions("METAL", has_tensor_cores=os.uname().machine == "arm64", shared_max=32768) + compiler_opts = CompilerOptions("METAL", has_tensor_cores=os.uname().machine == "arm64", shared_max=32768, renderer=MetalRenderer) def __init__(self, device:Optional[MetalDevice]): self.device = device super().__init__("compile_metal") - def render(self, name:str, uops) -> str: return MetalRenderer(name, uops) def compile(self, src:str) -> bytes: if self.device is None: # NOTE: if you run llvm-dis on "air" you can see the llvm bytecode diff --git a/tinygrad/runtime/ops_nv.py b/tinygrad/runtime/ops_nv.py index a4259cbe55..ee041feb31 100644 --- a/tinygrad/runtime/ops_nv.py +++ b/tinygrad/runtime/ops_nv.py @@ -66,7 +66,7 @@ def nvdata64(data): return (data >> 32, data & 0xFFFFFFFF) def nvdata64_le(data): return (data & 0xFFFFFFFF, data >> 32) class NVCompiler(Compiler): - compiler_opts = CompilerOptions("NV", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], shared_max=49152) + compiler_opts = CompilerOptions("NV", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], shared_max=49152, renderer=CUDARenderer) def __init__(self, arch:str): self.arch = arch NVCompiler.compiler_opts = replace(NVCompiler.compiler_opts, has_tensor_cores=int(arch[3:]) >= 80) @@ -74,7 +74,6 @@ class NVCompiler(Compiler): self.compile_options = [f'--gpu-architecture={arch}', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"] if (nvrtcMajor.value, nvrtcMinor.value) >= (12, 4): self.compile_options.append("--minimal") super().__init__(f"compile_nv_{self.arch}") - def render(self, name:str, uops) -> str: return CUDARenderer(name, uops) def compile(self, src:str) -> bytes: cuda_check(cuda.nvrtcCreateProgram(ctypes.byref(prog := cuda.nvrtcProgram()), src.encode(), "".encode(), 0, None, None)) status = cuda.nvrtcCompileProgram(prog, len(self.compile_options), to_char_p_p([o.encode() for o in self.compile_options])) diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 393cb8e528..e94fd22c5a 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -178,13 +178,15 @@ class PythonProgram: i += 1 return time.perf_counter() - st +def PythonRenderer(name:str, uops:UOpGraph) -> str: + lops = [(u.uop, u.dtype, [uops.uops.index(v) for v in u.vin], u.arg) for u in uops] + return base64.b64encode(pickle.dumps(lops)).decode() + class PythonCompiler(Compiler): - compiler_opts = CompilerOptions("METAL", has_tensor_cores=True) if getenv("EMULATE_METAL") else \ - (CompilerOptions("HSA", has_tensor_cores=True) if getenv("EMULATE_HSA") else \ - (CompilerOptions("CUDA", has_tensor_cores=True) if getenv("EMULATE_CUDA") else CompilerOptions("PYTHON"))) - def render(self, name:str, uops:UOpGraph) -> str: - lops = [(u.uop, u.dtype, [uops.uops.index(v) for v in u.vin], u.arg) for u in uops] - return base64.b64encode(pickle.dumps(lops)).decode() + compiler_opts = CompilerOptions("METAL", has_tensor_cores=True, renderer=PythonRenderer) if getenv("EMULATE_METAL") else \ + (CompilerOptions("HSA", has_tensor_cores=True, renderer=PythonRenderer) if getenv("EMULATE_HSA") else \ + (CompilerOptions("CUDA", has_tensor_cores=True, renderer=PythonRenderer) if getenv("EMULATE_CUDA") else \ + CompilerOptions("PYTHON", renderer=PythonRenderer))) def compile(self, src:str) -> bytes: return base64.b64decode(src) class PythonAllocator(Allocator):