From 43e60914f3bf12f8715ddcab5395ec0ded26704f Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Thu, 27 Feb 2025 19:36:55 +0300 Subject: [PATCH] init torch hooking (#9284) * smth * mv * prof wk * revert and move * fix * nvprof * fix and no print much --- extra/torch_hook/hook_cuda.py | 214 +++++++++++++++++++++++++++++++++ extra/torch_hook/hook_torch.py | 134 +++++++++++++++++++++ 2 files changed, 348 insertions(+) create mode 100644 extra/torch_hook/hook_cuda.py create mode 100644 extra/torch_hook/hook_torch.py diff --git a/extra/torch_hook/hook_cuda.py b/extra/torch_hook/hook_cuda.py new file mode 100644 index 0000000000..2e5ad907a0 --- /dev/null +++ b/extra/torch_hook/hook_cuda.py @@ -0,0 +1,214 @@ +import ctypes, struct, platform, pathlib, os, binascii, itertools +from hexdump import hexdump +from tinygrad.helpers import to_mv, DEBUG, getenv, colored, time_to_str +from tinygrad.runtime.autogen import libc, cuda +from tinygrad.device import CPUProgram, Device +from tinygrad.runtime.support.elf import elf_loader +from tinygrad.runtime.ops_cuda import cu_time_execution + +print(f"hooking CUDA runtime, running with {Device.DEFAULT}") + +# TODO: regen and make cuda 12 default? +cuda.cuFuncGetParamInfo = cuda._libraries['libcuda.so'].cuFuncGetParamInfo +cuda.cuFuncGetParamInfo.restype = cuda.CUresult +cuda.cuFuncGetParamInfo.argtypes = [cuda.CUfunction, cuda.size_t, ctypes.POINTER(ctypes.c_uint64), ctypes.POINTER(ctypes.c_uint64)] + +ignore_dispatch = [False] # default valus is False +def push_ignore_dispatch(val): + global ignore_dispatch + ignore_dispatch.append(val) + +def pop_ignore_dispatch(): + global ignore_dispatch + ignore_dispatch.pop() + +hooked = {} +def _hook(fxn_address_value, tramp): + page_address = (fxn_address_value//0x1000)*0x1000 + ret = libc.mprotect(page_address, 0x2000, 7) + assert ret == 0 + libc.memcpy(fxn_address_value, tramp, len(tramp)) + ret = libc.mprotect(page_address, 0x2000, 5) + assert ret == 0 + CPUProgram.rt_lib["__clear_cache"](fxn_address_value, fxn_address_value + len(tramp)) + +def install_hook(c_function, python_function): + python_function_addr = ctypes.cast(ctypes.byref(python_function), ctypes.POINTER(ctypes.c_ulong)).contents.value + # AARCH64 trampoline to ioctl + if (processor:=platform.processor()) == "aarch64": + # 0x0000000000000000: 70 00 00 10 adr x16, #0xc + # 0x0000000000000004: 10 02 40 F9 ldr x16, [x16] + # 0x0000000000000008: 00 02 1F D6 br x16 + tramp = b"\x70\x00\x00\x10\x10\x02\x40\xf9\x00\x02\x1f\xd6" + tramp += struct.pack("Q", python_function_addr) + elif processor == "x86_64": + # 0x0000000000000000: 49 BB aa aa aa aa aa aa aa aa movabs r11,
+ # 0x000000000000000a: 41 FF E3 jmp r11 + tramp = b"\x49\xBB" + struct.pack("Q", python_function_addr) + b"\x41\xFF\xE3" + else: + raise Exception(f"processor {processor} not supported") + tramp = ctypes.create_string_buffer(tramp) + + # get real function address + fxn_address = ctypes.cast(ctypes.byref(c_function), ctypes.POINTER(ctypes.c_ulong)) + fxn_address_value = fxn_address.contents.value + #print(f"** hooking function at 0x{fxn_address_value}") + + orig_save = (ctypes.c_char*len(tramp))() + libc.memcpy(orig_save, fxn_address_value, len(tramp)) + _hook(fxn_address_value, tramp) + + def original(*args): + _hook(fxn_address_value, orig_save) + ret = c_function(*args) + _hook(fxn_address_value, tramp) + return ret + return original + +allocated_memory_enum = 0 +allocated_memory = {} +function_names = {} +tiny_devs = {} + +seen_modules = set() + +global_events = [] +class HookEvent: pass +class HookMemAllocEvent(HookEvent): + def __init__(self, cuda_address, bytesize, enum): self.cuda_address, self.bytesize, self.enum = cuda_address, bytesize, enum + def __repr__(self): return f"tensor alloc: {self.enum}: {self.cuda_address:#x} - {self.bytesize:#x} bytes" +class HookConstParamEvent(HookEvent): + def __init__(self, value): self.value = value + def __repr__(self): return f"const({self.value:#x})" +class HookTensorParamEvent(HookEvent): + def __init__(self, cuda_address, offset, enum): self.cuda_address, self.offset, self.enum = cuda_address, offset, enum + def __repr__(self): return f"tensor{self.enum}({self.cuda_address:#x}, {self.offset=:#x})" +class HookKernelCallEvent(HookEvent): + def __init__(self, grid, block, tm, ptm, name, params): self.grid, self.block, self.tm, self.ptm, self.name, self.params = grid, block, tm, ptm, name, params + def __repr__(self): return f"kernel call <<{self.grid}>> <<{self.block}>> {self.ptm}\n | {self.params}\n | {self.name}" + +def collect_events(clear=False): + global global_events + x = global_events + if clear: global_events = [] + return x + +@ctypes.CFUNCTYPE(*([cuda.cuDeviceGet.restype] + cuda.cuDeviceGet.argtypes)) +def cuDeviceGet(device, ordinal): + tiny_devs[ordinal] = Device[f"{Device.DEFAULT}:{ordinal}"] + device.contents.value = ordinal + return cuda.CUDA_SUCCESS + +@ctypes.CFUNCTYPE(*([cuda.cuMemHostAlloc.restype] + cuda.cuMemHostAlloc.argtypes)) +def cuMemHostAlloc(pp, bytesize, flags): + print(f"cuMemHostAlloc {bytesize}") + return hooked["cuMemHostAlloc"](pp, bytesize, flags) + +@ctypes.CFUNCTYPE(*([cuda.cuModuleLoadData.restype] + cuda.cuModuleLoadData.argtypes)) +def cuModuleLoadData(module, image): + ret = hooked["cuModuleLoadData"](module, image) + module_address = ctypes.addressof(module.contents.contents) + seen_modules.add(module_address) + return ret + +@ctypes.CFUNCTYPE(*([cuda.cuModuleGetFunction.restype] + cuda.cuModuleGetFunction.argtypes)) +def cuModuleGetFunction(hfunc, hmod, name): + ret = hooked["cuModuleGetFunction"](hfunc, hmod, name) + python_name = ctypes.string_at(name).decode() + + # pip install git+https://github.com/wbenny/pydemangler.git + import pydemangler + demangled_name = pydemangler.demangle(python_name) + if demangled_name is not None: python_name = demangled_name + + # print(f"called cuModuleGetFunction 0x{ctypes.addressof(hmod.contents):X} {python_name}") + function_names[ctypes.addressof(hfunc.contents.contents)] = python_name + return ret + +@ctypes.CFUNCTYPE(*([cuda.cuMemAlloc_v2.restype] + cuda.cuMemAlloc_v2.argtypes)) +def cuMemAlloc_v2(dptr, bytesize): + global allocated_memory_enum, text_prefix + + ret = hooked["cuMemAlloc_v2"](dptr, bytesize) + cuda_address = dptr.contents.value + allocated_memory[cuda_address] = (bytesize, allocated_memory_enum) + + global_events.append(HookMemAllocEvent(cuda_address, bytesize, allocated_memory_enum)) + if DEBUG >= 3: print(global_events[-1]) + + allocated_memory_enum += 1 + return ret + +@ctypes.CFUNCTYPE(*([cuda.cuLaunchKernel.restype] + cuda.cuLaunchKernel.argtypes)) +def cuLaunchKernel(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra): + global ignore_dispatch + + name = function_names[ctypes.addressof(f.contents)] + if ignore_dispatch[-1]: + if DEBUG >= 4: print(f"ignoring dispatch {name}") + return 0 + + tm = cu_time_execution(lambda: + hooked["cuLaunchKernel"](f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra), True) + + ptm = colored(time_to_str(tm, w=9), "yellow" if tm > 0.01 else "green") + + params = [] + while True: + ret = cuda.cuFuncGetParamInfo(f, len(params), ctypes.byref(paramOffset:=ctypes.c_size_t()), ctypes.byref(paramSize:=ctypes.c_size_t())) + if ret != 0: break + params.append((paramOffset.value, paramSize.value)) + + ev_params = [] + if extra: params_ptr = to_mv(extra, 5*8).cast("Q") + else: params_ptr = to_mv(kernelParams, len(params)*8).cast("Q") + + for i,(off,sz) in enumerate(params): + sz_to_let = {1: 'B', 2: 'H', 4: 'I', 8: 'Q'} + if sz >= 8: + for j in range(sz//8): + if extra: value = to_mv(params_ptr[1] + off, sz).cast("Q")[0] + else: value = to_mv(params_ptr[i] + j*8, 8).cast('Q')[0] + + has_in_allocated_mem, lcoff, alnum = False, 0, -1 + for taddr, (tsz, talnum) in allocated_memory.items(): + if taddr <= value < taddr + tsz: + has_in_allocated_mem = True + lcoff = value - taddr + alnum = talnum + break + + if has_in_allocated_mem: ev_params.append(HookTensorParamEvent(value, lcoff, alnum)) + else: ev_params.append(HookConstParamEvent(value)) + else: + if extra: value = to_mv(params_ptr[1] + off, sz).cast(sz_to_let[sz])[0] + else: value = to_mv(params_ptr[i], sz).cast(sz_to_let[sz])[0] + ev_params.append(HookConstParamEvent(value)) + + global_events.append(HookKernelCallEvent((gridDimX, gridDimY, gridDimZ), (blockDimX, blockDimY, blockDimZ), tm, ptm, name, ev_params)) + if DEBUG >= 3: print(global_events[-1]) + + return 0 + +def create_hook(func_name, restype, argtypes): + def hook_template(*args): + # print(func_name, flush=True) + return hooked[func_name](*args) + return ctypes.CFUNCTYPE(restype, *argtypes)(hook_template) + +def install_hooks(): + hooked['cuModuleGetFunction'] = install_hook(cuda.cuModuleGetFunction, cuModuleGetFunction) + hooked['cuLaunchKernel'] = install_hook(cuda.cuLaunchKernel, cuLaunchKernel) + + # memory stuff + hooked['cuMemAlloc_v2'] = install_hook(cuda.cuMemAlloc_v2, cuMemAlloc_v2) + hooked['cuMemHostAlloc'] = install_hook(cuda.cuMemHostAlloc, cuMemHostAlloc) + + # module loading + not used module loading + hooked['cuModuleLoadData'] = install_hook(cuda.cuModuleLoadData, cuModuleLoadData) + +NVPROFILER = os.environ.get("NV_COMPUTE_PROFILER_PERFWORKS_DIR", None) # realize and wait each aten call +if NVPROFILER is None: install_hooks() +else: + print("Detected NSIGHT Profiled, hooking not avail.") + cuda._libraries['libcuda.so'] = ctypes.CDLL(NVPROFILER + "/libcuda-injection.so") diff --git a/extra/torch_hook/hook_torch.py b/extra/torch_hook/hook_torch.py new file mode 100644 index 0000000000..e1ca596c4b --- /dev/null +++ b/extra/torch_hook/hook_torch.py @@ -0,0 +1,134 @@ +import ctypes, struct, platform, pathlib, os, binascii, itertools +from hexdump import hexdump +from tinygrad.device import Device +from tinygrad import Tensor +from tinygrad.helpers import to_mv, DEBUG, getenv, colored, time_to_str + +import extra.torch_hook.hook_cuda as hook_cuda + +# settings to profile gemm in the __main__ example: TINY_MIRROR=1;CUDA=1;RUN_ONLY=9 +# nvprof sample command (this will sample all kernels): +# ncu --export ~/nvprof_data --force-overwrite --rule AchievedOccupancy --rule Compute --rule LaunchConfiguration --rule Memory --rule PMSamplingData --rule SOLBottleneck --rule TheoreticalOccupancy --rule WorkloadImbalance python3 extra/torch_hook/hook_torch.py +# or just run nsight compute from the host to the machine. + +TINY_MIRROR = getenv("TINY_MIRROR", 1) # should mirror aten ops to tiny backend +RUN_ONLY = getenv("RUN_ONLY", -1) # run only a specific aten call +REALIZE = getenv("REALIZE", 1) # realize and wait each aten call +FULL_KERN_NAME = getenv("FULL_KERN_NAME", 0) # print full kernel name + +print("importing torch...") +import torch +print("importing torch done:", torch.__version__, torch.__file__) + +if TINY_MIRROR: + print("importing tiny torch") + import extra.torch_backend.backend as tiny_torch + print("importing tiny torch done") + +torch.set_default_device("cuda") + +cuda_to_tiny_mappings = {} + +enumerator_aten_calls = itertools.count(0) +from torch.utils._python_dispatch import TorchDispatchMode +class DispatchLog(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args, kwargs=None): + txt_args = [] + should_call_tiny = kwargs.get('device') is not None and kwargs['device'].type == "cuda" + + def can_print_arg(arg): + return args is None or isinstance(arg, str) or isinstance(arg, int) or isinstance(arg, float) or isinstance(arg, bool) + + for i,arg in enumerate(args): + if torch.is_tensor(arg): + if arg.device.type == "cuda": should_call_tiny = True + txt_args.append(f"tensor({arg.shape} {arg.device} {arg.dtype})") + elif can_print_arg(arg): txt_args.append(f'{arg}') + else: txt_args.append(f"{type(arg)}") + for k,v in (kwargs or {}).items(): + if torch.is_tensor(v): + if arg.device.type == "cuda": should_call_tiny = True + txt_args.append(f"{k}:tensor({v.shape} {v.device} {v.dtype})") + elif can_print_arg(arg): txt_args.append(f'{k}:{arg}"') + else: txt_args.append(f"{type(arg)}") + + # magenta-colored kerenls mirrored to tiny backend. + aten_id = next(enumerator_aten_calls) + should_call_tiny = TINY_MIRROR and should_call_tiny + print(colored(f"#{aten_id} {func}", "magenta" if should_call_tiny else "cyan") + "("+", ".join(txt_args)+")", flush=True) + + # ignore dispatches if needed + hook_cuda.push_ignore_dispatch(RUN_ONLY >= 0 and RUN_ONLY != aten_id) + orig_x = func(*args, **(kwargs or {})) + + def print_events(evs, name, out_addr): + for ev in evs: + if isinstance(ev, hook_cuda.HookKernelCallEvent): + txt_params = [] + for param in ev.params: + if isinstance(param, hook_cuda.HookTensorParamEvent): + is_out = param.cuda_address == out_addr + txt_params += [f"{'out' if is_out else 'in'} tensor{param.enum}({param.cuda_address:#x}, off={param.offset:#x})"] + + just_kern_name = ev.name + if not FULL_KERN_NAME: + just_kern_name = ev.name.replace("(anonymous namespace)", "").replace("void ", "").split("<")[0].split("(")[0].split("::")[-1] + print(f"\t {name} kernel {just_kern_name} {ev.grid} {ev.block} {ev.ptm}\n\t\t({', '.join(txt_params)})") + else: print("\t", name, ev) + + if REALIZE: + torch.cuda.synchronize() + cuda_events = hook_cuda.collect_events(clear=True) + print_events(cuda_events, colored("cuda", "cyan"), orig_x.data_ptr() if torch.is_tensor(orig_x) else 0x0) + + if should_call_tiny: + # replace with tiny tensor + tiny_args, tiny_kwargs = [], {} + for arg in args: + if torch.is_tensor(arg): tiny_args.append(cuda_to_tiny_mappings[arg]) + else: tiny_args.append(arg) + + for k,v in (kwargs or {}).items(): + if torch.is_tensor(v): tiny_kwargs[k] = cuda_to_tiny_mappings[v] + else: tiny_kwargs[k] = v + if 'device' in tiny_kwargs and kwargs['device'].type == "cuda": + tiny_kwargs['device'] = torch.device("tiny") + + tiny_x = func(*tiny_args, **tiny_kwargs) + + # TODO: this is a hack, any way to do this better? + if REALIZE: + tiny_x.cpu() + tiny_events = hook_cuda.collect_events(clear=True) + print_events(tiny_events, colored("tiny", "magenta"), 0x0) + + cuda_to_tiny_mappings[orig_x] = tiny_x + + hook_cuda.pop_ignore_dispatch() + return orig_x +DispatchLog().__enter__() + +if __name__ == "__main__": + if getenv("RESNET"): + import torchvision.models as models + model = models.resnet18(pretrained=True) + model = model.cuda() + model.eval() + + if getenv("COMPILE"): model = torch.compile(model) + + X = torch.rand(getenv("BS", 1), 3, 288, 288, device='cuda') + model(X) + + print("\n\n\n****** second run ******\n") + model(X) + else: + a = torch.randn(64, 64) + b = torch.randn(64, 64) + a += 1 + b += 2 + a = a.exp2() + b = b.exp2() + a += b + c = a @ b + print("tensor math done", c.cpu().numpy())