From f17bc16f46ade23579535b678129c400d3fce01c Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 3 Nov 2023 12:31:29 -0700 Subject: [PATCH] simple runtime args (#2211) * simple runtime args * fix some tests * fix abstractions and triton * fix search --- docs/abstractions.py | 2 +- extra/thneed.py | 2 +- .../external/external_test_allocator_on_models.py | 2 +- test/external/external_test_speed_llama.py | 2 +- test/test_search.py | 3 +++ test/test_uops.py | 4 +++- tinygrad/features/search.py | 15 ++++++++++----- tinygrad/ops.py | 7 +++++-- tinygrad/renderer/triton.py | 2 +- tinygrad/runtime/ops_clang.py | 2 +- tinygrad/runtime/ops_cuda.py | 8 ++++---- tinygrad/runtime/ops_gpu.py | 4 ++-- tinygrad/runtime/ops_hip.py | 2 +- tinygrad/runtime/ops_llvm.py | 2 +- tinygrad/runtime/ops_metal.py | 2 +- tinygrad/runtime/ops_webgpu.py | 2 +- 16 files changed, 37 insertions(+), 24 deletions(-) diff --git a/docs/abstractions.py b/docs/abstractions.py index ca2371eec5..a533f567f4 100644 --- a/docs/abstractions.py +++ b/docs/abstractions.py @@ -230,7 +230,7 @@ output = RawMallocBuffer(1, dtypes.float32) # compile the program, run it, and 2+3 does indeed equal 5 program = ClangProgram("add", compile_clang(f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}")) -program(None, None, output, input_a, input_b) # NOTE: the None are for global_size and local_size +program(output, input_a, input_b) print(output.toCPU()) assert output.toCPU()[0] == 5, "it's still 5" np.testing.assert_allclose(output.toCPU(), numpy_a+numpy_b) diff --git a/extra/thneed.py b/extra/thneed.py index b880b58fdc..a202e796cf 100644 --- a/extra/thneed.py +++ b/extra/thneed.py @@ -206,7 +206,7 @@ class Thneed: l.x = get_global_id(0); out[l.y*row_pitch + l.x] = read_imagef(in, smp, l); } - """), argdtypes=(None, None, np.int32))(a.shape, None, a, buf, row_pitch//(4*(2 if FLOAT16 else 4))) + """), argdtypes=(None, None, np.int32))(a, buf, row_pitch//(4*(2 if FLOAT16 else 4)), global_size=a.shape) # multiple of 32 isn't enough jdat['objects'].append({ diff --git a/test/external/external_test_allocator_on_models.py b/test/external/external_test_allocator_on_models.py index 4ac99078e5..01cd022d9a 100644 --- a/test/external/external_test_allocator_on_models.py +++ b/test/external/external_test_allocator_on_models.py @@ -41,7 +41,7 @@ class FakeBuffer(RawBuffer): def toCPU(self): return np.empty(self.size, dtype=self.dtype.np) class FakeProgram: def __init__(self, name:str, prg:str): pass - def __call__(self, global_size, local_size, *bufs, wait=False): pass + def __call__(self, *bufs, global_size, local_size, wait=False): pass def helper_test_correctness(gen, train): from tinygrad.runtime.ops_gpu import CL, CLAllocator diff --git a/test/external/external_test_speed_llama.py b/test/external/external_test_speed_llama.py index 1b1ea3e728..7635fc8cb0 100644 --- a/test/external/external_test_speed_llama.py +++ b/test/external/external_test_speed_llama.py @@ -12,7 +12,7 @@ from tinygrad.runtime.lib import RawBuffer class FakeProgram: def __init__(self, name:str, prg:str): pass - def __call__(self, global_size, local_size, *bufs, wait=False): pass + def __call__(self, *bufs, global_size, local_size, wait=False): pass class RawFakeBuffer(RawBuffer): @classmethod diff --git a/test/test_search.py b/test/test_search.py index 10b94b9551..400c04a575 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -14,3 +14,6 @@ class TestTimeLinearizer(unittest.TestCase): rawbufs = [Device[Device.DEFAULT].buffer(si.out.st.size(), si.out.dtype)] + [Device[Device.DEFAULT].buffer(x.st.size(), x.dtype) for x in si.inputs] tm = time_linearizer(Linearizer(si.ast), rawbufs, allow_test_size=False, cnt=10) assert tm > 0 and tm != float('inf') + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_uops.py b/test/test_uops.py index 764e2efe9d..a18523b45f 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -8,7 +8,9 @@ from tinygrad.codegen.linearizer import UOps, UOp def _uops_to_prg(uops): src, runtime_args = Device[Device.DEFAULT].renderer("test", uops) - return ASTRunner("test", src, [1], [1], runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime) + return ASTRunner("test", src, + [1] if Device[Device.DEFAULT].linearizer_opts.has_local else None, [1] if Device[Device.DEFAULT].linearizer_opts.has_local else None, + runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime) def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp: uops.append(UOp(uop, dtype, tuple(vin), arg, len(uops))) diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index 72767a7fdb..724bf602a1 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -51,13 +51,18 @@ def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=Tru if clear_l2: # TODO: this is too small for many L2 caches with Context(DEBUG=0): Tensor.rand(1024,1024).realize() - tms.append(prg.clprg(global_size, local_size, *rawbufs, *var_vals.values(), wait=True)*factor) + lra = prg.runtime_args.copy() + if global_size: lra['global_size'] = global_size + if local_size: lra['local_size'] = local_size + tms.append(prg.clprg(*rawbufs, *var_vals.values(), **lra, wait=True)*factor) prg.global_size = real_global_size except Exception: - #import traceback; traceback.print_exc() - #print("FAILED") - #print(lin.ast) - #print(lin.applied_opts) + if DEBUG >= 4: + import traceback + traceback.print_exc() + print("FAILED") + print(lin.ast) + print(lin.applied_opts) tms = [float('inf')] if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms) return min(tms) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index bcf3dc7006..1e2575f640 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -234,7 +234,7 @@ class ASTRunner: def build(self, compiler, runtime, batch_exec=BasicBatchExecutor): self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg) - self.clprg, self.batch_exec = runtime(self.name, self.lib, **self.runtime_args), batch_exec + self.clprg, self.batch_exec = runtime(self.name, self.lib), batch_exec return self def exec(self, rawbufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False, optimizing=False) -> Optional[float]: @@ -254,7 +254,10 @@ class ASTRunner: # TODO: this is copied from get_program local_size = self.local_size = self.optimize_local_size(global_size, rawbufs) global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)] - if et := self.clprg(global_size, local_size, *rawbufs, *var_vals.values(), wait=force_wait or DEBUG>=2): GlobalCounters.time_sum_s += et + lra = self.runtime_args.copy() + if global_size: lra['global_size'] = global_size + if local_size and 'local_size' not in lra: lra['local_size'] = local_size + if et := self.clprg(*rawbufs, *var_vals.values(), **lra, wait=force_wait or DEBUG>=2): GlobalCounters.time_sum_s += et op_estimate = sym_infer(self.op_estimate, var_vals) if DEBUG >= 2: print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(37-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(global_size):18s} {str(local_size):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + diff --git a/tinygrad/renderer/triton.py b/tinygrad/renderer/triton.py index af6a67b218..944626b6b9 100644 --- a/tinygrad/renderer/triton.py +++ b/tinygrad/renderer/triton.py @@ -127,4 +127,4 @@ def uops_to_triton(function_name:str, uops:List[UOp]): max_local_size = [int(x) for x in prg.split(".maxntid ")[1].split("\n")[0].split(", ")] for i in range(len(local_size)): local_size[i] = min(local_size[i], max_local_size[i]) - return prg, {"shared":compiled.metadata["shared"], "local_size_override":local_size + [1]*(3-len(local_size))} + return prg, {"shared":compiled.metadata["shared"], "local_size":local_size + [1]*(3-len(local_size))} diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index bc83f7c1b9..59116a9333 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -27,7 +27,7 @@ class ClangProgram: pathlib.Path(cached_file_path.name).write_bytes(prg) self.fxn: Any = ctypes.CDLL(str(cached_file_path.name))[name] - def __call__(self, unused_global_size, unused_local_size, *args, wait=False): + def __call__(self, *args, wait=False): if wait: st = time.perf_counter() self.fxn(*[x._buf if isinstance(x, RawMallocBuffer) else x for x in args]) if wait: return time.perf_counter()-st diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 178d899cca..f8a5e38295 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -92,7 +92,7 @@ class CUDAGraph(GraphBatchExecutor): def compile_cuda(prg) -> bytes: return cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets']) class CUDAProgram: - def __init__(self, name:str, _prg:bytes, shared=0, local_size_override=None): + def __init__(self, name:str, _prg:bytes): prg = _prg.decode('utf-8') if DEBUG >= 5: print(pretty_ptx(prg)) if DEBUG >= 6: @@ -103,13 +103,13 @@ class CUDAProgram: print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8')) except Exception as e: print("failed to generate SASS", str(e)) # TODO: name is wrong, so we get it from the ptx using hacks - self.prg, self.shared, self.local_size_override = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0]), shared, local_size_override + self.prg = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0]) - def __call__(self, global_size, local_size, *args, wait=False): + def __call__(self, *args, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], shared:int=0, wait=False): if wait: start, end = cuda.Event(), cuda.Event() start.record() - self.prg(*[x._buf if isinstance(x, RawCUDABuffer) else np.int32(x) if (isinstance(x, int) and not getenv("CUDACPU")) else x for x in args], block=tuple(local_size if self.local_size_override is None else self.local_size_override), grid=tuple(global_size), shared=self.shared) + self.prg(*[x._buf if isinstance(x, RawCUDABuffer) else np.int32(x) if (isinstance(x, int) and not getenv("CUDACPU")) else x for x in args], block=tuple(local_size), grid=tuple(global_size), shared=shared) if wait: end.record() end.synchronize() diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 6ad6066d27..b8d252838c 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -4,7 +4,7 @@ os.environ['PYOPENCL_NO_CACHE'] = '1' import pathlib import numpy as np import pyopencl as cl # type: ignore -from typing import Optional, List +from typing import Optional, List, Tuple from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport, diskcache from tinygrad.ops import Compiled from tinygrad.renderer.opencl import OpenCLRenderer @@ -90,7 +90,7 @@ class CLProgram: @staticmethod def max_work_group_size(): return CL.cl_ctxs[0].devices[0].max_work_group_size - def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]: + def __call__(self, *bufs, global_size:Tuple[int,int,int], local_size:Optional[Tuple[int,int,int]]=None, wait=False) -> Optional[float]: if not hasattr(self, 'argdtypes'): self.set_argdtypes(tuple(None if x.__class__ is CLBuffer else np.int32 for x in bufs)) cl_bufs, wait_for = [], [] for x in bufs: diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py index e99f1ea43c..208b61f962 100644 --- a/tinygrad/runtime/ops_hip.py +++ b/tinygrad/runtime/ops_hip.py @@ -97,7 +97,7 @@ class HIPProgram: self.modules.append(hip.hipModuleLoadData(prg)) self.prgs.append(hip.hipModuleGetFunction(self.modules[-1], name)) - def __call__(self, global_size, local_size, *args, wait=False): + def __call__(self, *args, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], wait=False): hip.hipSetDevice(args[0]._device) if wait: start, end = hip.hipEventCreate(), hip.hipEventCreate() diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index 16c2900611..9cfaa45a28 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -59,7 +59,7 @@ class LLVMProgram: LLVM().engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(lib)) self.fxn = LLVM.engine.get_function_address(name) - def __call__(self, unused_global_size, unused_local_size, *bufs, wait=False): + def __call__(self, *bufs, wait=False): cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.c_void_p for _ in bufs])(self.fxn) if wait: st = time.perf_counter() cfunc(*[x._buf if not isinstance(x, int) else x for x in bufs]) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 179e64c22a..ca045e5686 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -83,7 +83,7 @@ class MetalProgram: os.system(f"cd {pathlib.Path(__file__).parents[2]}/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}") self.pipeline_state = unwrap(METAL.device.newComputePipelineStateWithFunction_error_(self.fxn, None)) - def __call__(self, global_size, local_size, *bufs, wait=False): + def __call__(self, *bufs, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], wait=False): assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}" command_buffer = METAL.mtl_queue.commandBuffer() encoder = command_buffer.computeCommandEncoder() diff --git a/tinygrad/runtime/ops_webgpu.py b/tinygrad/runtime/ops_webgpu.py index 2fa9bbdd72..e189301ee9 100644 --- a/tinygrad/runtime/ops_webgpu.py +++ b/tinygrad/runtime/ops_webgpu.py @@ -13,7 +13,7 @@ wgpu_device = get_default_device() class WebGPUProgram: def __init__(self, name: str, prg: str): self.name,self.prg = name,wgpu_device.create_shader_module(code=prg) - def __call__(self, global_size, local_size, *bufs, wait=False): + def __call__(self, *bufs, global_size, local_size, wait=False): assert len(bufs) <= 8, "WEBGPU only supports 8 buffers" binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.storage}} for i in range(len(bufs))] bindings = [{"binding": i, "resource": {"buffer": x._buf, "offset": 0, "size": x._buf.size}} for i, x in enumerate(bufs)]