mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
init torch hooking (#9284)
* smth * mv * prof wk * revert and move * fix * nvprof * fix and no print much
This commit is contained in:
214
extra/torch_hook/hook_cuda.py
Normal file
214
extra/torch_hook/hook_cuda.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import ctypes, struct, platform, pathlib, os, binascii, itertools
|
||||
from hexdump import hexdump
|
||||
from tinygrad.helpers import to_mv, DEBUG, getenv, colored, time_to_str
|
||||
from tinygrad.runtime.autogen import libc, cuda
|
||||
from tinygrad.device import CPUProgram, Device
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
from tinygrad.runtime.ops_cuda import cu_time_execution
|
||||
|
||||
print(f"hooking CUDA runtime, running with {Device.DEFAULT}")
|
||||
|
||||
# TODO: regen and make cuda 12 default?
|
||||
cuda.cuFuncGetParamInfo = cuda._libraries['libcuda.so'].cuFuncGetParamInfo
|
||||
cuda.cuFuncGetParamInfo.restype = cuda.CUresult
|
||||
cuda.cuFuncGetParamInfo.argtypes = [cuda.CUfunction, cuda.size_t, ctypes.POINTER(ctypes.c_uint64), ctypes.POINTER(ctypes.c_uint64)]
|
||||
|
||||
ignore_dispatch = [False] # default valus is False
|
||||
def push_ignore_dispatch(val):
|
||||
global ignore_dispatch
|
||||
ignore_dispatch.append(val)
|
||||
|
||||
def pop_ignore_dispatch():
|
||||
global ignore_dispatch
|
||||
ignore_dispatch.pop()
|
||||
|
||||
hooked = {}
|
||||
def _hook(fxn_address_value, tramp):
|
||||
page_address = (fxn_address_value//0x1000)*0x1000
|
||||
ret = libc.mprotect(page_address, 0x2000, 7)
|
||||
assert ret == 0
|
||||
libc.memcpy(fxn_address_value, tramp, len(tramp))
|
||||
ret = libc.mprotect(page_address, 0x2000, 5)
|
||||
assert ret == 0
|
||||
CPUProgram.rt_lib["__clear_cache"](fxn_address_value, fxn_address_value + len(tramp))
|
||||
|
||||
def install_hook(c_function, python_function):
|
||||
python_function_addr = ctypes.cast(ctypes.byref(python_function), ctypes.POINTER(ctypes.c_ulong)).contents.value
|
||||
# AARCH64 trampoline to ioctl
|
||||
if (processor:=platform.processor()) == "aarch64":
|
||||
# 0x0000000000000000: 70 00 00 10 adr x16, #0xc
|
||||
# 0x0000000000000004: 10 02 40 F9 ldr x16, [x16]
|
||||
# 0x0000000000000008: 00 02 1F D6 br x16
|
||||
tramp = b"\x70\x00\x00\x10\x10\x02\x40\xf9\x00\x02\x1f\xd6"
|
||||
tramp += struct.pack("Q", python_function_addr)
|
||||
elif processor == "x86_64":
|
||||
# 0x0000000000000000: 49 BB aa aa aa aa aa aa aa aa movabs r11, <address>
|
||||
# 0x000000000000000a: 41 FF E3 jmp r11
|
||||
tramp = b"\x49\xBB" + struct.pack("Q", python_function_addr) + b"\x41\xFF\xE3"
|
||||
else:
|
||||
raise Exception(f"processor {processor} not supported")
|
||||
tramp = ctypes.create_string_buffer(tramp)
|
||||
|
||||
# get real function address
|
||||
fxn_address = ctypes.cast(ctypes.byref(c_function), ctypes.POINTER(ctypes.c_ulong))
|
||||
fxn_address_value = fxn_address.contents.value
|
||||
#print(f"** hooking function at 0x{fxn_address_value}")
|
||||
|
||||
orig_save = (ctypes.c_char*len(tramp))()
|
||||
libc.memcpy(orig_save, fxn_address_value, len(tramp))
|
||||
_hook(fxn_address_value, tramp)
|
||||
|
||||
def original(*args):
|
||||
_hook(fxn_address_value, orig_save)
|
||||
ret = c_function(*args)
|
||||
_hook(fxn_address_value, tramp)
|
||||
return ret
|
||||
return original
|
||||
|
||||
allocated_memory_enum = 0
|
||||
allocated_memory = {}
|
||||
function_names = {}
|
||||
tiny_devs = {}
|
||||
|
||||
seen_modules = set()
|
||||
|
||||
global_events = []
|
||||
class HookEvent: pass
|
||||
class HookMemAllocEvent(HookEvent):
|
||||
def __init__(self, cuda_address, bytesize, enum): self.cuda_address, self.bytesize, self.enum = cuda_address, bytesize, enum
|
||||
def __repr__(self): return f"tensor alloc: {self.enum}: {self.cuda_address:#x} - {self.bytesize:#x} bytes"
|
||||
class HookConstParamEvent(HookEvent):
|
||||
def __init__(self, value): self.value = value
|
||||
def __repr__(self): return f"const({self.value:#x})"
|
||||
class HookTensorParamEvent(HookEvent):
|
||||
def __init__(self, cuda_address, offset, enum): self.cuda_address, self.offset, self.enum = cuda_address, offset, enum
|
||||
def __repr__(self): return f"tensor{self.enum}({self.cuda_address:#x}, {self.offset=:#x})"
|
||||
class HookKernelCallEvent(HookEvent):
|
||||
def __init__(self, grid, block, tm, ptm, name, params): self.grid, self.block, self.tm, self.ptm, self.name, self.params = grid, block, tm, ptm, name, params
|
||||
def __repr__(self): return f"kernel call <<{self.grid}>> <<{self.block}>> {self.ptm}\n | {self.params}\n | {self.name}"
|
||||
|
||||
def collect_events(clear=False):
|
||||
global global_events
|
||||
x = global_events
|
||||
if clear: global_events = []
|
||||
return x
|
||||
|
||||
@ctypes.CFUNCTYPE(*([cuda.cuDeviceGet.restype] + cuda.cuDeviceGet.argtypes))
|
||||
def cuDeviceGet(device, ordinal):
|
||||
tiny_devs[ordinal] = Device[f"{Device.DEFAULT}:{ordinal}"]
|
||||
device.contents.value = ordinal
|
||||
return cuda.CUDA_SUCCESS
|
||||
|
||||
@ctypes.CFUNCTYPE(*([cuda.cuMemHostAlloc.restype] + cuda.cuMemHostAlloc.argtypes))
|
||||
def cuMemHostAlloc(pp, bytesize, flags):
|
||||
print(f"cuMemHostAlloc {bytesize}")
|
||||
return hooked["cuMemHostAlloc"](pp, bytesize, flags)
|
||||
|
||||
@ctypes.CFUNCTYPE(*([cuda.cuModuleLoadData.restype] + cuda.cuModuleLoadData.argtypes))
|
||||
def cuModuleLoadData(module, image):
|
||||
ret = hooked["cuModuleLoadData"](module, image)
|
||||
module_address = ctypes.addressof(module.contents.contents)
|
||||
seen_modules.add(module_address)
|
||||
return ret
|
||||
|
||||
@ctypes.CFUNCTYPE(*([cuda.cuModuleGetFunction.restype] + cuda.cuModuleGetFunction.argtypes))
|
||||
def cuModuleGetFunction(hfunc, hmod, name):
|
||||
ret = hooked["cuModuleGetFunction"](hfunc, hmod, name)
|
||||
python_name = ctypes.string_at(name).decode()
|
||||
|
||||
# pip install git+https://github.com/wbenny/pydemangler.git
|
||||
import pydemangler
|
||||
demangled_name = pydemangler.demangle(python_name)
|
||||
if demangled_name is not None: python_name = demangled_name
|
||||
|
||||
# print(f"called cuModuleGetFunction 0x{ctypes.addressof(hmod.contents):X} {python_name}")
|
||||
function_names[ctypes.addressof(hfunc.contents.contents)] = python_name
|
||||
return ret
|
||||
|
||||
@ctypes.CFUNCTYPE(*([cuda.cuMemAlloc_v2.restype] + cuda.cuMemAlloc_v2.argtypes))
|
||||
def cuMemAlloc_v2(dptr, bytesize):
|
||||
global allocated_memory_enum, text_prefix
|
||||
|
||||
ret = hooked["cuMemAlloc_v2"](dptr, bytesize)
|
||||
cuda_address = dptr.contents.value
|
||||
allocated_memory[cuda_address] = (bytesize, allocated_memory_enum)
|
||||
|
||||
global_events.append(HookMemAllocEvent(cuda_address, bytesize, allocated_memory_enum))
|
||||
if DEBUG >= 3: print(global_events[-1])
|
||||
|
||||
allocated_memory_enum += 1
|
||||
return ret
|
||||
|
||||
@ctypes.CFUNCTYPE(*([cuda.cuLaunchKernel.restype] + cuda.cuLaunchKernel.argtypes))
|
||||
def cuLaunchKernel(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra):
|
||||
global ignore_dispatch
|
||||
|
||||
name = function_names[ctypes.addressof(f.contents)]
|
||||
if ignore_dispatch[-1]:
|
||||
if DEBUG >= 4: print(f"ignoring dispatch {name}")
|
||||
return 0
|
||||
|
||||
tm = cu_time_execution(lambda:
|
||||
hooked["cuLaunchKernel"](f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra), True)
|
||||
|
||||
ptm = colored(time_to_str(tm, w=9), "yellow" if tm > 0.01 else "green")
|
||||
|
||||
params = []
|
||||
while True:
|
||||
ret = cuda.cuFuncGetParamInfo(f, len(params), ctypes.byref(paramOffset:=ctypes.c_size_t()), ctypes.byref(paramSize:=ctypes.c_size_t()))
|
||||
if ret != 0: break
|
||||
params.append((paramOffset.value, paramSize.value))
|
||||
|
||||
ev_params = []
|
||||
if extra: params_ptr = to_mv(extra, 5*8).cast("Q")
|
||||
else: params_ptr = to_mv(kernelParams, len(params)*8).cast("Q")
|
||||
|
||||
for i,(off,sz) in enumerate(params):
|
||||
sz_to_let = {1: 'B', 2: 'H', 4: 'I', 8: 'Q'}
|
||||
if sz >= 8:
|
||||
for j in range(sz//8):
|
||||
if extra: value = to_mv(params_ptr[1] + off, sz).cast("Q")[0]
|
||||
else: value = to_mv(params_ptr[i] + j*8, 8).cast('Q')[0]
|
||||
|
||||
has_in_allocated_mem, lcoff, alnum = False, 0, -1
|
||||
for taddr, (tsz, talnum) in allocated_memory.items():
|
||||
if taddr <= value < taddr + tsz:
|
||||
has_in_allocated_mem = True
|
||||
lcoff = value - taddr
|
||||
alnum = talnum
|
||||
break
|
||||
|
||||
if has_in_allocated_mem: ev_params.append(HookTensorParamEvent(value, lcoff, alnum))
|
||||
else: ev_params.append(HookConstParamEvent(value))
|
||||
else:
|
||||
if extra: value = to_mv(params_ptr[1] + off, sz).cast(sz_to_let[sz])[0]
|
||||
else: value = to_mv(params_ptr[i], sz).cast(sz_to_let[sz])[0]
|
||||
ev_params.append(HookConstParamEvent(value))
|
||||
|
||||
global_events.append(HookKernelCallEvent((gridDimX, gridDimY, gridDimZ), (blockDimX, blockDimY, blockDimZ), tm, ptm, name, ev_params))
|
||||
if DEBUG >= 3: print(global_events[-1])
|
||||
|
||||
return 0
|
||||
|
||||
def create_hook(func_name, restype, argtypes):
|
||||
def hook_template(*args):
|
||||
# print(func_name, flush=True)
|
||||
return hooked[func_name](*args)
|
||||
return ctypes.CFUNCTYPE(restype, *argtypes)(hook_template)
|
||||
|
||||
def install_hooks():
|
||||
hooked['cuModuleGetFunction'] = install_hook(cuda.cuModuleGetFunction, cuModuleGetFunction)
|
||||
hooked['cuLaunchKernel'] = install_hook(cuda.cuLaunchKernel, cuLaunchKernel)
|
||||
|
||||
# memory stuff
|
||||
hooked['cuMemAlloc_v2'] = install_hook(cuda.cuMemAlloc_v2, cuMemAlloc_v2)
|
||||
hooked['cuMemHostAlloc'] = install_hook(cuda.cuMemHostAlloc, cuMemHostAlloc)
|
||||
|
||||
# module loading + not used module loading
|
||||
hooked['cuModuleLoadData'] = install_hook(cuda.cuModuleLoadData, cuModuleLoadData)
|
||||
|
||||
NVPROFILER = os.environ.get("NV_COMPUTE_PROFILER_PERFWORKS_DIR", None) # realize and wait each aten call
|
||||
if NVPROFILER is None: install_hooks()
|
||||
else:
|
||||
print("Detected NSIGHT Profiled, hooking not avail.")
|
||||
cuda._libraries['libcuda.so'] = ctypes.CDLL(NVPROFILER + "/libcuda-injection.so")
|
||||
134
extra/torch_hook/hook_torch.py
Normal file
134
extra/torch_hook/hook_torch.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import ctypes, struct, platform, pathlib, os, binascii, itertools
|
||||
from hexdump import hexdump
|
||||
from tinygrad.device import Device
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.helpers import to_mv, DEBUG, getenv, colored, time_to_str
|
||||
|
||||
import extra.torch_hook.hook_cuda as hook_cuda
|
||||
|
||||
# settings to profile gemm in the __main__ example: TINY_MIRROR=1;CUDA=1;RUN_ONLY=9
|
||||
# nvprof sample command (this will sample all kernels):
|
||||
# ncu --export ~/nvprof_data --force-overwrite --rule AchievedOccupancy --rule Compute --rule LaunchConfiguration --rule Memory --rule PMSamplingData --rule SOLBottleneck --rule TheoreticalOccupancy --rule WorkloadImbalance python3 extra/torch_hook/hook_torch.py
|
||||
# or just run nsight compute from the host to the machine.
|
||||
|
||||
TINY_MIRROR = getenv("TINY_MIRROR", 1) # should mirror aten ops to tiny backend
|
||||
RUN_ONLY = getenv("RUN_ONLY", -1) # run only a specific aten call
|
||||
REALIZE = getenv("REALIZE", 1) # realize and wait each aten call
|
||||
FULL_KERN_NAME = getenv("FULL_KERN_NAME", 0) # print full kernel name
|
||||
|
||||
print("importing torch...")
|
||||
import torch
|
||||
print("importing torch done:", torch.__version__, torch.__file__)
|
||||
|
||||
if TINY_MIRROR:
|
||||
print("importing tiny torch")
|
||||
import extra.torch_backend.backend as tiny_torch
|
||||
print("importing tiny torch done")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
cuda_to_tiny_mappings = {}
|
||||
|
||||
enumerator_aten_calls = itertools.count(0)
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
class DispatchLog(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args, kwargs=None):
|
||||
txt_args = []
|
||||
should_call_tiny = kwargs.get('device') is not None and kwargs['device'].type == "cuda"
|
||||
|
||||
def can_print_arg(arg):
|
||||
return args is None or isinstance(arg, str) or isinstance(arg, int) or isinstance(arg, float) or isinstance(arg, bool)
|
||||
|
||||
for i,arg in enumerate(args):
|
||||
if torch.is_tensor(arg):
|
||||
if arg.device.type == "cuda": should_call_tiny = True
|
||||
txt_args.append(f"tensor({arg.shape} {arg.device} {arg.dtype})")
|
||||
elif can_print_arg(arg): txt_args.append(f'{arg}')
|
||||
else: txt_args.append(f"{type(arg)}")
|
||||
for k,v in (kwargs or {}).items():
|
||||
if torch.is_tensor(v):
|
||||
if arg.device.type == "cuda": should_call_tiny = True
|
||||
txt_args.append(f"{k}:tensor({v.shape} {v.device} {v.dtype})")
|
||||
elif can_print_arg(arg): txt_args.append(f'{k}:{arg}"')
|
||||
else: txt_args.append(f"{type(arg)}")
|
||||
|
||||
# magenta-colored kerenls mirrored to tiny backend.
|
||||
aten_id = next(enumerator_aten_calls)
|
||||
should_call_tiny = TINY_MIRROR and should_call_tiny
|
||||
print(colored(f"#{aten_id} {func}", "magenta" if should_call_tiny else "cyan") + "("+", ".join(txt_args)+")", flush=True)
|
||||
|
||||
# ignore dispatches if needed
|
||||
hook_cuda.push_ignore_dispatch(RUN_ONLY >= 0 and RUN_ONLY != aten_id)
|
||||
orig_x = func(*args, **(kwargs or {}))
|
||||
|
||||
def print_events(evs, name, out_addr):
|
||||
for ev in evs:
|
||||
if isinstance(ev, hook_cuda.HookKernelCallEvent):
|
||||
txt_params = []
|
||||
for param in ev.params:
|
||||
if isinstance(param, hook_cuda.HookTensorParamEvent):
|
||||
is_out = param.cuda_address == out_addr
|
||||
txt_params += [f"{'out' if is_out else 'in'} tensor{param.enum}({param.cuda_address:#x}, off={param.offset:#x})"]
|
||||
|
||||
just_kern_name = ev.name
|
||||
if not FULL_KERN_NAME:
|
||||
just_kern_name = ev.name.replace("(anonymous namespace)", "").replace("void ", "").split("<")[0].split("(")[0].split("::")[-1]
|
||||
print(f"\t {name} kernel {just_kern_name} {ev.grid} {ev.block} {ev.ptm}\n\t\t({', '.join(txt_params)})")
|
||||
else: print("\t", name, ev)
|
||||
|
||||
if REALIZE:
|
||||
torch.cuda.synchronize()
|
||||
cuda_events = hook_cuda.collect_events(clear=True)
|
||||
print_events(cuda_events, colored("cuda", "cyan"), orig_x.data_ptr() if torch.is_tensor(orig_x) else 0x0)
|
||||
|
||||
if should_call_tiny:
|
||||
# replace with tiny tensor
|
||||
tiny_args, tiny_kwargs = [], {}
|
||||
for arg in args:
|
||||
if torch.is_tensor(arg): tiny_args.append(cuda_to_tiny_mappings[arg])
|
||||
else: tiny_args.append(arg)
|
||||
|
||||
for k,v in (kwargs or {}).items():
|
||||
if torch.is_tensor(v): tiny_kwargs[k] = cuda_to_tiny_mappings[v]
|
||||
else: tiny_kwargs[k] = v
|
||||
if 'device' in tiny_kwargs and kwargs['device'].type == "cuda":
|
||||
tiny_kwargs['device'] = torch.device("tiny")
|
||||
|
||||
tiny_x = func(*tiny_args, **tiny_kwargs)
|
||||
|
||||
# TODO: this is a hack, any way to do this better?
|
||||
if REALIZE:
|
||||
tiny_x.cpu()
|
||||
tiny_events = hook_cuda.collect_events(clear=True)
|
||||
print_events(tiny_events, colored("tiny", "magenta"), 0x0)
|
||||
|
||||
cuda_to_tiny_mappings[orig_x] = tiny_x
|
||||
|
||||
hook_cuda.pop_ignore_dispatch()
|
||||
return orig_x
|
||||
DispatchLog().__enter__()
|
||||
|
||||
if __name__ == "__main__":
|
||||
if getenv("RESNET"):
|
||||
import torchvision.models as models
|
||||
model = models.resnet18(pretrained=True)
|
||||
model = model.cuda()
|
||||
model.eval()
|
||||
|
||||
if getenv("COMPILE"): model = torch.compile(model)
|
||||
|
||||
X = torch.rand(getenv("BS", 1), 3, 288, 288, device='cuda')
|
||||
model(X)
|
||||
|
||||
print("\n\n\n****** second run ******\n")
|
||||
model(X)
|
||||
else:
|
||||
a = torch.randn(64, 64)
|
||||
b = torch.randn(64, 64)
|
||||
a += 1
|
||||
b += 2
|
||||
a = a.exp2()
|
||||
b = b.exp2()
|
||||
a += b
|
||||
c = a @ b
|
||||
print("tensor math done", c.cpu().numpy())
|
||||
Reference in New Issue
Block a user