mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
cuda hooking (#9180)
* cuda hooking * progress * more hook cuda * fix params * compile + cuMemHostAlloc hook * work * revert that
This commit is contained in:
@@ -211,6 +211,7 @@ generate_libc() {
|
||||
clang2py -k cdefstum \
|
||||
$(dpkg -L libc6-dev | grep sys/mman.h) \
|
||||
$(dpkg -L libc6-dev | grep sys/syscall.h) \
|
||||
/usr/include/string.h \
|
||||
/usr/include/elf.h \
|
||||
/usr/include/unistd.h \
|
||||
/usr/include/asm-generic/mman-common.h \
|
||||
|
||||
241
extra/hook_cuda.py
Normal file
241
extra/hook_cuda.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import ctypes, struct, platform, pathlib, os, binascii
|
||||
from hexdump import hexdump
|
||||
from tinygrad.helpers import to_mv, DEBUG, getenv
|
||||
from tinygrad.runtime.autogen import libc, cuda
|
||||
from tinygrad.device import CPUProgram
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
from tinygrad.runtime.ops_cuda import cu_time_execution
|
||||
|
||||
def _hook(fxn_address_value, tramp):
|
||||
page_address = (fxn_address_value//0x1000)*0x1000
|
||||
ret = libc.mprotect(page_address, 0x2000, 7)
|
||||
assert ret == 0
|
||||
libc.memcpy(fxn_address_value, tramp, len(tramp))
|
||||
ret = libc.mprotect(page_address, 0x2000, 5)
|
||||
assert ret == 0
|
||||
CPUProgram.rt_lib["__clear_cache"](fxn_address_value, fxn_address_value + len(tramp))
|
||||
|
||||
def install_hook(c_function, python_function):
|
||||
python_function_addr = ctypes.cast(ctypes.byref(python_function), ctypes.POINTER(ctypes.c_ulong)).contents.value
|
||||
# AARCH64 trampoline to ioctl
|
||||
if (processor:=platform.processor()) == "aarch64":
|
||||
# 0x0000000000000000: 70 00 00 10 adr x16, #0xc
|
||||
# 0x0000000000000004: 10 02 40 F9 ldr x16, [x16]
|
||||
# 0x0000000000000008: 00 02 1F D6 br x16
|
||||
tramp = b"\x70\x00\x00\x10\x10\x02\x40\xf9\x00\x02\x1f\xd6"
|
||||
tramp += struct.pack("Q", python_function_addr)
|
||||
elif processor == "x86_64":
|
||||
# 0x0000000000000000: 49 BB aa aa aa aa aa aa aa aa movabs r11, <address>
|
||||
# 0x000000000000000a: 41 FF E3 jmp r11
|
||||
tramp = b"\x49\xBB" + struct.pack("Q", python_function_addr) + b"\x41\xFF\xE3"
|
||||
else:
|
||||
raise Exception(f"processor {processor} not supported")
|
||||
tramp = ctypes.create_string_buffer(tramp)
|
||||
|
||||
# get real function address
|
||||
fxn_address = ctypes.cast(ctypes.byref(c_function), ctypes.POINTER(ctypes.c_ulong))
|
||||
fxn_address_value = fxn_address.contents.value
|
||||
#print(f"** hooking function at 0x{fxn_address_value}")
|
||||
|
||||
orig_save = (ctypes.c_char*len(tramp))()
|
||||
libc.memcpy(orig_save, fxn_address_value, len(tramp))
|
||||
_hook(fxn_address_value, tramp)
|
||||
|
||||
def original(*args):
|
||||
_hook(fxn_address_value, orig_save)
|
||||
ret = c_function(*args)
|
||||
_hook(fxn_address_value, tramp)
|
||||
return ret
|
||||
return original
|
||||
|
||||
hooked = {}
|
||||
|
||||
allocated_memory = {}
|
||||
function_names = {}
|
||||
|
||||
seen_modules = set()
|
||||
|
||||
@ctypes.CFUNCTYPE(ctypes.c_int)
|
||||
def dummy():
|
||||
print("**** dummy function hook ****")
|
||||
return -1
|
||||
|
||||
@ctypes.CFUNCTYPE(*([cuda.cuInit.restype] + cuda.cuInit.argtypes))
|
||||
def cuInit(flags):
|
||||
print("call cuInit", flags)
|
||||
return hooked["cuInit"](flags)
|
||||
|
||||
@ctypes.CFUNCTYPE(*([cuda.cuMemHostAlloc.restype] + cuda.cuMemHostAlloc.argtypes))
|
||||
def cuMemHostAlloc(pp, bytesize, flags):
|
||||
print(f"cuMemHostAlloc {bytesize}")
|
||||
return hooked["cuMemHostAlloc"](pp, bytesize, flags)
|
||||
|
||||
@ctypes.CFUNCTYPE(*([cuda.cuModuleLoadData.restype] + cuda.cuModuleLoadData.argtypes))
|
||||
def cuModuleLoadData(module, image):
|
||||
ret = hooked["cuModuleLoadData"](module, image)
|
||||
module_address = ctypes.addressof(module.contents.contents)
|
||||
print(f"cuModuleLoadData 0x{image:x} -> 0x{module_address:X}")
|
||||
seen_modules.add(module_address)
|
||||
|
||||
#images, sections, relocs = elf_loader(bytes(to_mv(image, 0x100000)))
|
||||
#for s in sections: print(s)
|
||||
|
||||
#print('\n'.join([x for x in maps.split("\n") if 'libcuda' in x]))
|
||||
|
||||
#hexdump(to_mv(image, 0x1000))
|
||||
#image, sections, relocs = elf_loader(to_mv(image))
|
||||
#print(sections)
|
||||
return ret
|
||||
|
||||
@ctypes.CFUNCTYPE(*([cuda.cuModuleGetFunction.restype] + cuda.cuModuleGetFunction.argtypes))
|
||||
def cuModuleGetFunction(hfunc, hmod, name):
|
||||
ret = hooked["cuModuleGetFunction"](hfunc, hmod, name)
|
||||
python_name = ctypes.string_at(name).decode()
|
||||
|
||||
# pip install git+https://github.com/wbenny/pydemangler.git
|
||||
import pydemangler
|
||||
demangled_name = pydemangler.demangle(python_name)
|
||||
if demangled_name is not None: python_name = demangled_name
|
||||
|
||||
print(f"called cuModuleGetFunction 0x{ctypes.addressof(hmod.contents):X} {python_name}")
|
||||
function_names[ctypes.addressof(hfunc.contents.contents)] = python_name
|
||||
return ret
|
||||
|
||||
@ctypes.CFUNCTYPE(*([cuda.cuMemAlloc_v2.restype] + cuda.cuMemAlloc_v2.argtypes))
|
||||
def cuMemAlloc_v2(dptr, bytesize):
|
||||
ret = hooked["cuMemAlloc_v2"](dptr, bytesize)
|
||||
cuda_address = ctypes.addressof(dptr.contents)
|
||||
allocated_memory[cuda_address] = bytesize
|
||||
print(f"cuMemAlloc_v2 {bytesize} 0x{cuda_address:X}")
|
||||
return ret
|
||||
|
||||
@ctypes.CFUNCTYPE(*([cuda.cuLaunchKernel.restype] + cuda.cuLaunchKernel.argtypes))
|
||||
def cuLaunchKernel(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra):
|
||||
tm = cu_time_execution(lambda:
|
||||
hooked["cuLaunchKernel"](f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra), True)
|
||||
|
||||
name = function_names[ctypes.addressof(f.contents)]
|
||||
print(f"{tm*1e6:9.2f} us -- cuLaunchKernel <<{gridDimX:6d}, {gridDimY:5d}, {gridDimZ:5d}>>",
|
||||
f"<<{blockDimX:4d}, {blockDimY:4d}, {blockDimZ:4d}>> {sharedMemBytes} {name}")
|
||||
|
||||
if extra: hexdump(to_mv(extra, 0x100))
|
||||
|
||||
if getenv("PARAMS") and kernelParams:
|
||||
#print(f"params @ 0x{ctypes.addressof(kernelParams.contents):X}")
|
||||
params = []
|
||||
while True:
|
||||
ret = cuda.cuFuncGetParamInfo(f, len(params), ctypes.byref(paramOffset:=ctypes.c_size_t()), ctypes.byref(paramSize:=ctypes.c_size_t()))
|
||||
if ret != 0: break
|
||||
params.append((paramOffset.value, paramSize.value))
|
||||
#params_dat = to_mv(kernelParams.contents, params[-1][0] + params[-1][1])
|
||||
params_ptr = to_mv(kernelParams, len(params)*8).cast("Q")
|
||||
#params_dat = to_mv(kernelParams.contents, params[-1][0] + params[-1][1])
|
||||
for i,(off,sz) in enumerate(params):
|
||||
hexdump(to_mv(params_ptr[i], sz))
|
||||
|
||||
|
||||
#hexdump(params_dat)
|
||||
#for i,(off,sz) in enumerate(params):
|
||||
# print(f"{i}: offset:{off:3d} size:{sz:3d}") # --", binascii.hexlify(dat).decode())
|
||||
# hexdump(params_dat[off:off+sz])
|
||||
#if name == "exp2_kernel_vectorized4_kernel":
|
||||
# ptr_0 = struct.unpack("Q", params_dat[0x10:0x18])[0]
|
||||
# hexdump(to_mv(ptr_0, 0x80))
|
||||
#ptr_1 = struct.unpack("Q", to_mv(ptr_0, 8))[0]
|
||||
|
||||
#print(f"params 0x{ctypes.addressof(kernelParams):X}")
|
||||
#hexdump(to_mv(kernelParams, 0x100))
|
||||
#print(f"data 0x{to_mv(kernelParams, 8).cast('Q')[0]:X}")
|
||||
#hexdump(to_mv(kernelParams.contents, 0x80))
|
||||
#for i,addr in enumerate(to_mv(kernelParams.contents, 0x100).cast("Q")): print(f"{i*8:3d}: {addr:X}")
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
#out = cuda.CUmoduleLoadingMode()
|
||||
#print(cuda.cuModuleGetLoadingMode(ctypes.byref(out)))
|
||||
#print(out.value)
|
||||
|
||||
hooked['cuInit'] = install_hook(cuda.cuInit, cuInit)
|
||||
hooked['cuModuleGetFunction'] = install_hook(cuda.cuModuleGetFunction, cuModuleGetFunction)
|
||||
hooked['cuLaunchKernel'] = install_hook(cuda.cuLaunchKernel, cuLaunchKernel)
|
||||
|
||||
# memory stuff
|
||||
hooked['cuMemAlloc_v2'] = install_hook(cuda.cuMemAlloc_v2, cuMemAlloc_v2)
|
||||
hooked['cuMemHostAlloc'] = install_hook(cuda.cuMemHostAlloc, cuMemHostAlloc)
|
||||
|
||||
# module loading + not used module loading
|
||||
hooked['cuModuleLoadData'] = install_hook(cuda.cuModuleLoadData, cuModuleLoadData)
|
||||
install_hook(cuda.cuModuleLoad, dummy)
|
||||
install_hook(cuda.cuModuleLoadDataEx, dummy)
|
||||
install_hook(cuda.cuModuleLoadFatBinary, dummy)
|
||||
|
||||
# library stuff (doesn't seem used)
|
||||
#install_hook(cuda.cuLibraryLoadData, dummy)
|
||||
#install_hook(cuda.cuLibraryLoadFromFile, dummy)
|
||||
#install_hook(cuda.cuLibraryGetModule, dummy)
|
||||
|
||||
#install_hook(cuda.cuMemAllocManaged, dummy)
|
||||
|
||||
# unused
|
||||
#install_hook(cuda.cuFuncGetModule, dummy)
|
||||
#install_hook(cuda.cuModuleGetGlobal_v2, dummy)
|
||||
|
||||
# hook v1
|
||||
#install_hook(cuda._libraries['libcuda.so'].cuModuleGetGlobal, dummy)
|
||||
#install_hook(cuda._libraries['libcuda.so'].cuMemAlloc, dummy)
|
||||
#install_hook(cuda._libraries['libcuda.so'].cuLinkComplete, dummy)
|
||||
|
||||
#nvjitlink = ctypes.CDLL("/home/tiny/.local/lib/python3.10/site-packages/nvidia/nvjitlink/lib/libnvJitLink.so.12")
|
||||
#install_hook(nvjitlink.nvJitLinkCreate, dummy)
|
||||
#nvrtc = ctypes.CDLL("/home/tiny/.local/lib/python3.10/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc.so.11.2")
|
||||
#nvrtc = ctypes.CDLL("/usr/local/cuda-12.4/targets/x86_64-linux/lib/libnvrtc.so.12.4.127")
|
||||
#from tinygrad.runtime.autogen import nvrtc
|
||||
#install_hook(nvrtc.nvrtcCreateProgram, dummy)
|
||||
#install_hook(nvrtc.nvJitLinkCreate, dummy)
|
||||
|
||||
#import tinygrad.runtime.autogen.nvrtc as nvrtc
|
||||
#install_hook(nvrtc.nvJitLinkCreate, dummy)
|
||||
#install_hook(nvrtc.nvrtcCreateProgram, dummy)
|
||||
|
||||
#hooked['cuLinkCreate'] = install_hook(cuda.cuLinkCreate, dummy)
|
||||
|
||||
if getenv("TINYGRAD"):
|
||||
from tinygrad import Tensor
|
||||
(Tensor.zeros(6, device="CUDA").contiguous()*2).realize()
|
||||
exit(0)
|
||||
|
||||
print("importing torch...")
|
||||
import torch
|
||||
print("torch", torch.__version__, torch.__file__)
|
||||
|
||||
if getenv("RESNET"):
|
||||
import torchvision.models as models
|
||||
model = models.resnet18(pretrained=True)
|
||||
model = model.cuda()
|
||||
model.eval()
|
||||
|
||||
if getenv("COMPILE"): model = torch.compile(model)
|
||||
|
||||
X = torch.rand(getenv("BS", 1), 3, 288, 288, device='cuda')
|
||||
model(X)
|
||||
|
||||
print("\n\n\n****** second run ******\n")
|
||||
model(X)
|
||||
else:
|
||||
a = torch.zeros(4, 4).cuda()
|
||||
b = torch.zeros(4, 4).cuda()
|
||||
print("tensor created")
|
||||
print(f"a: 0x{a.data_ptr():X}")
|
||||
print(f"b: 0x{b.data_ptr():X}")
|
||||
a += 1
|
||||
b += 2
|
||||
a = a.exp2()
|
||||
b = b.exp2()
|
||||
a += b
|
||||
#c = a @ b
|
||||
print("tensor math done", a.cpu().numpy())
|
||||
|
||||
# confirm cuda library is right
|
||||
#maps = pathlib.Path("/proc/self/maps").read_text()
|
||||
#print('\n'.join([x for x in maps.split("\n") if 'cuda' in x or 'nv' in x]))
|
||||
@@ -245,6 +245,265 @@ try:
|
||||
except AttributeError:
|
||||
pass
|
||||
_SYSCALL_H = 1 # macro
|
||||
_STRING_H = 1 # macro
|
||||
__GLIBC_INTERNAL_STARTING_HEADER_IMPLEMENTATION = True # macro
|
||||
__need_NULL = True # macro
|
||||
try:
|
||||
memcpy = _libraries['libc'].memcpy
|
||||
memcpy.restype = ctypes.POINTER(None)
|
||||
memcpy.argtypes = [ctypes.POINTER(None), ctypes.POINTER(None), size_t]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
memmove = _libraries['libc'].memmove
|
||||
memmove.restype = ctypes.POINTER(None)
|
||||
memmove.argtypes = [ctypes.POINTER(None), ctypes.POINTER(None), size_t]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
memccpy = _libraries['libc'].memccpy
|
||||
memccpy.restype = ctypes.POINTER(None)
|
||||
memccpy.argtypes = [ctypes.POINTER(None), ctypes.POINTER(None), ctypes.c_int32, size_t]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
memset = _libraries['libc'].memset
|
||||
memset.restype = ctypes.POINTER(None)
|
||||
memset.argtypes = [ctypes.POINTER(None), ctypes.c_int32, size_t]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
memcmp = _libraries['libc'].memcmp
|
||||
memcmp.restype = ctypes.c_int32
|
||||
memcmp.argtypes = [ctypes.POINTER(None), ctypes.POINTER(None), size_t]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
__memcmpeq = _libraries['libc'].__memcmpeq
|
||||
__memcmpeq.restype = ctypes.c_int32
|
||||
__memcmpeq.argtypes = [ctypes.POINTER(None), ctypes.POINTER(None), size_t]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
memchr = _libraries['libc'].memchr
|
||||
memchr.restype = ctypes.POINTER(None)
|
||||
memchr.argtypes = [ctypes.POINTER(None), ctypes.c_int32, size_t]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strcpy = _libraries['libc'].strcpy
|
||||
strcpy.restype = ctypes.POINTER(ctypes.c_char)
|
||||
strcpy.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char)]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strncpy = _libraries['libc'].strncpy
|
||||
strncpy.restype = ctypes.POINTER(ctypes.c_char)
|
||||
strncpy.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char), size_t]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strcat = _libraries['libc'].strcat
|
||||
strcat.restype = ctypes.POINTER(ctypes.c_char)
|
||||
strcat.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char)]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strncat = _libraries['libc'].strncat
|
||||
strncat.restype = ctypes.POINTER(ctypes.c_char)
|
||||
strncat.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char), size_t]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strcmp = _libraries['libc'].strcmp
|
||||
strcmp.restype = ctypes.c_int32
|
||||
strcmp.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char)]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strncmp = _libraries['libc'].strncmp
|
||||
strncmp.restype = ctypes.c_int32
|
||||
strncmp.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char), size_t]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strcoll = _libraries['libc'].strcoll
|
||||
strcoll.restype = ctypes.c_int32
|
||||
strcoll.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char)]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strxfrm = _libraries['libc'].strxfrm
|
||||
strxfrm.restype = ctypes.c_uint64
|
||||
strxfrm.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char), size_t]
|
||||
except AttributeError:
|
||||
pass
|
||||
class struct___locale_struct(Structure):
|
||||
pass
|
||||
|
||||
class struct___locale_data(Structure):
|
||||
pass
|
||||
|
||||
struct___locale_struct._pack_ = 1 # source:False
|
||||
struct___locale_struct._fields_ = [
|
||||
('__locales', ctypes.POINTER(struct___locale_data) * 13),
|
||||
('__ctype_b', ctypes.POINTER(ctypes.c_uint16)),
|
||||
('__ctype_tolower', ctypes.POINTER(ctypes.c_int32)),
|
||||
('__ctype_toupper', ctypes.POINTER(ctypes.c_int32)),
|
||||
('__names', ctypes.POINTER(ctypes.c_char) * 13),
|
||||
]
|
||||
|
||||
locale_t = ctypes.POINTER(struct___locale_struct)
|
||||
try:
|
||||
strcoll_l = _libraries['libc'].strcoll_l
|
||||
strcoll_l.restype = ctypes.c_int32
|
||||
strcoll_l.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char), locale_t]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strxfrm_l = _libraries['libc'].strxfrm_l
|
||||
strxfrm_l.restype = size_t
|
||||
strxfrm_l.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char), size_t, locale_t]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strdup = _libraries['libc'].strdup
|
||||
strdup.restype = ctypes.POINTER(ctypes.c_char)
|
||||
strdup.argtypes = [ctypes.POINTER(ctypes.c_char)]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strndup = _libraries['libc'].strndup
|
||||
strndup.restype = ctypes.POINTER(ctypes.c_char)
|
||||
strndup.argtypes = [ctypes.POINTER(ctypes.c_char), size_t]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strchr = _libraries['libc'].strchr
|
||||
strchr.restype = ctypes.POINTER(ctypes.c_char)
|
||||
strchr.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.c_int32]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strrchr = _libraries['libc'].strrchr
|
||||
strrchr.restype = ctypes.POINTER(ctypes.c_char)
|
||||
strrchr.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.c_int32]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strcspn = _libraries['libc'].strcspn
|
||||
strcspn.restype = ctypes.c_uint64
|
||||
strcspn.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char)]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strspn = _libraries['libc'].strspn
|
||||
strspn.restype = ctypes.c_uint64
|
||||
strspn.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char)]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strpbrk = _libraries['libc'].strpbrk
|
||||
strpbrk.restype = ctypes.POINTER(ctypes.c_char)
|
||||
strpbrk.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char)]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strstr = _libraries['libc'].strstr
|
||||
strstr.restype = ctypes.POINTER(ctypes.c_char)
|
||||
strstr.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char)]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strtok = _libraries['libc'].strtok
|
||||
strtok.restype = ctypes.POINTER(ctypes.c_char)
|
||||
strtok.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char)]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
__strtok_r = _libraries['libc'].__strtok_r
|
||||
__strtok_r.restype = ctypes.POINTER(ctypes.c_char)
|
||||
__strtok_r.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.POINTER(ctypes.c_char))]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strtok_r = _libraries['libc'].strtok_r
|
||||
strtok_r.restype = ctypes.POINTER(ctypes.c_char)
|
||||
strtok_r.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.POINTER(ctypes.c_char))]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strlen = _libraries['libc'].strlen
|
||||
strlen.restype = ctypes.c_uint64
|
||||
strlen.argtypes = [ctypes.POINTER(ctypes.c_char)]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strnlen = _libraries['libc'].strnlen
|
||||
strnlen.restype = size_t
|
||||
strnlen.argtypes = [ctypes.POINTER(ctypes.c_char), size_t]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strerror = _libraries['libc'].strerror
|
||||
strerror.restype = ctypes.POINTER(ctypes.c_char)
|
||||
strerror.argtypes = [ctypes.c_int32]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strerror_r = _libraries['libc'].strerror_r
|
||||
strerror_r.restype = ctypes.c_int32
|
||||
strerror_r.argtypes = [ctypes.c_int32, ctypes.POINTER(ctypes.c_char), size_t]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strerror_l = _libraries['libc'].strerror_l
|
||||
strerror_l.restype = ctypes.POINTER(ctypes.c_char)
|
||||
strerror_l.argtypes = [ctypes.c_int32, locale_t]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
explicit_bzero = _libraries['libc'].explicit_bzero
|
||||
explicit_bzero.restype = None
|
||||
explicit_bzero.argtypes = [ctypes.POINTER(None), size_t]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strsep = _libraries['libc'].strsep
|
||||
strsep.restype = ctypes.POINTER(ctypes.c_char)
|
||||
strsep.argtypes = [ctypes.POINTER(ctypes.POINTER(ctypes.c_char)), ctypes.POINTER(ctypes.c_char)]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
strsignal = _libraries['libc'].strsignal
|
||||
strsignal.restype = ctypes.POINTER(ctypes.c_char)
|
||||
strsignal.argtypes = [ctypes.c_int32]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
__stpcpy = _libraries['libc'].__stpcpy
|
||||
__stpcpy.restype = ctypes.POINTER(ctypes.c_char)
|
||||
__stpcpy.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char)]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
stpcpy = _libraries['libc'].stpcpy
|
||||
stpcpy.restype = ctypes.POINTER(ctypes.c_char)
|
||||
stpcpy.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char)]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
__stpncpy = _libraries['libc'].__stpncpy
|
||||
__stpncpy.restype = ctypes.POINTER(ctypes.c_char)
|
||||
__stpncpy.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char), size_t]
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
stpncpy = _libraries['libc'].stpncpy
|
||||
stpncpy.restype = ctypes.POINTER(ctypes.c_char)
|
||||
stpncpy.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char), size_t]
|
||||
except AttributeError:
|
||||
pass
|
||||
_ELF_H = 1 # macro
|
||||
EI_NIDENT = (16) # macro
|
||||
EI_MAG0 = 0 # macro
|
||||
@@ -3726,7 +3985,6 @@ STDIN_FILENO = 0 # macro
|
||||
STDOUT_FILENO = 1 # macro
|
||||
STDERR_FILENO = 2 # macro
|
||||
__ssize_t_defined = True # macro
|
||||
__need_NULL = True # macro
|
||||
__gid_t_defined = True # macro
|
||||
__uid_t_defined = True # macro
|
||||
__useconds_t_defined = True # macro
|
||||
@@ -5467,59 +5725,71 @@ __all__ = \
|
||||
'Val_GNU_MIPS_ABI_FP_SOFT', 'Val_GNU_MIPS_ABI_FP_XX', 'W_OK',
|
||||
'X_OK', '_ELF_H', '_POSIX2_C_BIND', '_POSIX2_C_DEV',
|
||||
'_POSIX2_C_VERSION', '_POSIX2_LOCALEDEF', '_POSIX2_SW_DEV',
|
||||
'_POSIX2_VERSION', '_POSIX_VERSION', '_SYSCALL_H', '_SYS_MMAN_H',
|
||||
'_UNISTD_H', '_XOPEN_ENH_I18N', '_XOPEN_LEGACY', '_XOPEN_UNIX',
|
||||
'_XOPEN_VERSION', '_XOPEN_XCU_VERSION', '_XOPEN_XPG2',
|
||||
'_XOPEN_XPG3', '_XOPEN_XPG4', '__ASM_GENERIC_MMAN_COMMON_H',
|
||||
'_POSIX2_VERSION', '_POSIX_VERSION', '_STRING_H', '_SYSCALL_H',
|
||||
'_SYS_MMAN_H', '_UNISTD_H', '_XOPEN_ENH_I18N', '_XOPEN_LEGACY',
|
||||
'_XOPEN_UNIX', '_XOPEN_VERSION', '_XOPEN_XCU_VERSION',
|
||||
'_XOPEN_XPG2', '_XOPEN_XPG3', '_XOPEN_XPG4',
|
||||
'__ASM_GENERIC_MMAN_COMMON_H',
|
||||
'__GLIBC_INTERNAL_STARTING_HEADER_IMPLEMENTATION',
|
||||
'__POSIX2_THIS_VERSION', '__environ', '__getpgid', '__gid_t',
|
||||
'__gid_t_defined', '__intptr_t_defined', '__mode_t_defined',
|
||||
'__need_NULL', '__need_size_t', '__off_t', '__off_t_defined',
|
||||
'__pid_t', '__pid_t_defined', '__socklen_t_defined',
|
||||
'__ssize_t_defined', '__uid_t', '__uid_t_defined', '__useconds_t',
|
||||
'__useconds_t_defined', '_exit', 'access', 'acct', 'alarm', 'brk',
|
||||
'c__Ea_Val_GNU_MIPS_ABI_FP_ANY', 'chdir', 'chown', 'chroot',
|
||||
'close', 'closefrom', 'confstr', 'crypt', 'daemon', 'dup', 'dup2',
|
||||
'endusershell', 'execl', 'execle', 'execlp', 'execv', 'execve',
|
||||
'execvp', 'faccessat', 'fchdir', 'fchown', 'fchownat',
|
||||
'fdatasync', 'fexecve', 'fork', 'fpathconf', 'fsync', 'ftruncate',
|
||||
'getcwd', 'getdomainname', 'getdtablesize', 'getegid',
|
||||
'getentropy', 'geteuid', 'getgid', 'getgroups', 'gethostid',
|
||||
'gethostname', 'getlogin', 'getlogin_r', 'getpagesize', 'getpass',
|
||||
'getpgid', 'getpgrp', 'getpid', 'getppid', 'getsid', 'getuid',
|
||||
'getusershell', 'getwd', 'gid_t', 'intptr_t', 'isatty', 'lchown',
|
||||
'link', 'linkat', 'lockf', 'lseek', 'madvise', 'mincore', 'mlock',
|
||||
'mlockall', 'mmap', 'mode_t', 'mprotect', 'msync', 'munlock',
|
||||
'munlockall', 'munmap', 'nice', 'off_t', 'pathconf', 'pause',
|
||||
'pid_t', 'pipe', 'posix_madvise', 'pread', 'profil', 'pwrite',
|
||||
'read', 'readlink', 'readlinkat', 'revoke', 'rmdir', 'sbrk',
|
||||
'setdomainname', 'setegid', 'seteuid', 'setgid', 'sethostid',
|
||||
'sethostname', 'setlogin', 'setpgid', 'setpgrp', 'setregid',
|
||||
'setreuid', 'setsid', 'setuid', 'setusershell', 'shm_open',
|
||||
'shm_unlink', 'size_t', 'sleep', 'socklen_t', 'ssize_t',
|
||||
'struct_c__SA_Elf32_Chdr', 'struct_c__SA_Elf32_Dyn',
|
||||
'struct_c__SA_Elf32_Ehdr', 'struct_c__SA_Elf32_Lib',
|
||||
'struct_c__SA_Elf32_Move', 'struct_c__SA_Elf32_Nhdr',
|
||||
'struct_c__SA_Elf32_Phdr', 'struct_c__SA_Elf32_RegInfo',
|
||||
'struct_c__SA_Elf32_Rel', 'struct_c__SA_Elf32_Rela',
|
||||
'struct_c__SA_Elf32_Shdr', 'struct_c__SA_Elf32_Sym',
|
||||
'struct_c__SA_Elf32_Syminfo', 'struct_c__SA_Elf32_Verdaux',
|
||||
'struct_c__SA_Elf32_Verdef', 'struct_c__SA_Elf32_Vernaux',
|
||||
'struct_c__SA_Elf32_Verneed', 'struct_c__SA_Elf32_auxv_t',
|
||||
'struct_c__SA_Elf64_Chdr', 'struct_c__SA_Elf64_Dyn',
|
||||
'struct_c__SA_Elf64_Ehdr', 'struct_c__SA_Elf64_Lib',
|
||||
'struct_c__SA_Elf64_Move', 'struct_c__SA_Elf64_Nhdr',
|
||||
'struct_c__SA_Elf64_Phdr', 'struct_c__SA_Elf64_Rel',
|
||||
'struct_c__SA_Elf64_Rela', 'struct_c__SA_Elf64_Shdr',
|
||||
'struct_c__SA_Elf64_Sym', 'struct_c__SA_Elf64_Syminfo',
|
||||
'struct_c__SA_Elf64_Verdaux', 'struct_c__SA_Elf64_Verdef',
|
||||
'struct_c__SA_Elf64_Vernaux', 'struct_c__SA_Elf64_Verneed',
|
||||
'struct_c__SA_Elf64_auxv_t', 'struct_c__SA_Elf_MIPS_ABIFlags_v0',
|
||||
'struct_c__SA_Elf_Options', 'struct_c__SA_Elf_Options_Hw',
|
||||
'__gid_t_defined', '__intptr_t_defined', '__memcmpeq',
|
||||
'__mode_t_defined', '__need_NULL', '__need_size_t', '__off_t',
|
||||
'__off_t_defined', '__pid_t', '__pid_t_defined',
|
||||
'__socklen_t_defined', '__ssize_t_defined', '__stpcpy',
|
||||
'__stpncpy', '__strtok_r', '__uid_t', '__uid_t_defined',
|
||||
'__useconds_t', '__useconds_t_defined', '_exit', 'access', 'acct',
|
||||
'alarm', 'brk', 'c__Ea_Val_GNU_MIPS_ABI_FP_ANY', 'chdir', 'chown',
|
||||
'chroot', 'close', 'closefrom', 'confstr', 'crypt', 'daemon',
|
||||
'dup', 'dup2', 'endusershell', 'execl', 'execle', 'execlp',
|
||||
'execv', 'execve', 'execvp', 'explicit_bzero', 'faccessat',
|
||||
'fchdir', 'fchown', 'fchownat', 'fdatasync', 'fexecve', 'fork',
|
||||
'fpathconf', 'fsync', 'ftruncate', 'getcwd', 'getdomainname',
|
||||
'getdtablesize', 'getegid', 'getentropy', 'geteuid', 'getgid',
|
||||
'getgroups', 'gethostid', 'gethostname', 'getlogin', 'getlogin_r',
|
||||
'getpagesize', 'getpass', 'getpgid', 'getpgrp', 'getpid',
|
||||
'getppid', 'getsid', 'getuid', 'getusershell', 'getwd', 'gid_t',
|
||||
'intptr_t', 'isatty', 'lchown', 'link', 'linkat', 'locale_t',
|
||||
'lockf', 'lseek', 'madvise', 'memccpy', 'memchr', 'memcmp',
|
||||
'memcpy', 'memmove', 'memset', 'mincore', 'mlock', 'mlockall',
|
||||
'mmap', 'mode_t', 'mprotect', 'msync', 'munlock', 'munlockall',
|
||||
'munmap', 'nice', 'off_t', 'pathconf', 'pause', 'pid_t', 'pipe',
|
||||
'posix_madvise', 'pread', 'profil', 'pwrite', 'read', 'readlink',
|
||||
'readlinkat', 'revoke', 'rmdir', 'sbrk', 'setdomainname',
|
||||
'setegid', 'seteuid', 'setgid', 'sethostid', 'sethostname',
|
||||
'setlogin', 'setpgid', 'setpgrp', 'setregid', 'setreuid',
|
||||
'setsid', 'setuid', 'setusershell', 'shm_open', 'shm_unlink',
|
||||
'size_t', 'sleep', 'socklen_t', 'ssize_t', 'stpcpy', 'stpncpy',
|
||||
'strcat', 'strchr', 'strcmp', 'strcoll', 'strcoll_l', 'strcpy',
|
||||
'strcspn', 'strdup', 'strerror', 'strerror_l', 'strerror_r',
|
||||
'strlen', 'strncat', 'strncmp', 'strncpy', 'strndup', 'strnlen',
|
||||
'strpbrk', 'strrchr', 'strsep', 'strsignal', 'strspn', 'strstr',
|
||||
'strtok', 'strtok_r', 'struct___locale_data',
|
||||
'struct___locale_struct', 'struct_c__SA_Elf32_Chdr',
|
||||
'struct_c__SA_Elf32_Dyn', 'struct_c__SA_Elf32_Ehdr',
|
||||
'struct_c__SA_Elf32_Lib', 'struct_c__SA_Elf32_Move',
|
||||
'struct_c__SA_Elf32_Nhdr', 'struct_c__SA_Elf32_Phdr',
|
||||
'struct_c__SA_Elf32_RegInfo', 'struct_c__SA_Elf32_Rel',
|
||||
'struct_c__SA_Elf32_Rela', 'struct_c__SA_Elf32_Shdr',
|
||||
'struct_c__SA_Elf32_Sym', 'struct_c__SA_Elf32_Syminfo',
|
||||
'struct_c__SA_Elf32_Verdaux', 'struct_c__SA_Elf32_Verdef',
|
||||
'struct_c__SA_Elf32_Vernaux', 'struct_c__SA_Elf32_Verneed',
|
||||
'struct_c__SA_Elf32_auxv_t', 'struct_c__SA_Elf64_Chdr',
|
||||
'struct_c__SA_Elf64_Dyn', 'struct_c__SA_Elf64_Ehdr',
|
||||
'struct_c__SA_Elf64_Lib', 'struct_c__SA_Elf64_Move',
|
||||
'struct_c__SA_Elf64_Nhdr', 'struct_c__SA_Elf64_Phdr',
|
||||
'struct_c__SA_Elf64_Rel', 'struct_c__SA_Elf64_Rela',
|
||||
'struct_c__SA_Elf64_Shdr', 'struct_c__SA_Elf64_Sym',
|
||||
'struct_c__SA_Elf64_Syminfo', 'struct_c__SA_Elf64_Verdaux',
|
||||
'struct_c__SA_Elf64_Verdef', 'struct_c__SA_Elf64_Vernaux',
|
||||
'struct_c__SA_Elf64_Verneed', 'struct_c__SA_Elf64_auxv_t',
|
||||
'struct_c__SA_Elf_MIPS_ABIFlags_v0', 'struct_c__SA_Elf_Options',
|
||||
'struct_c__SA_Elf_Options_Hw',
|
||||
'struct_c__UA_Elf32_gptab_gt_entry',
|
||||
'struct_c__UA_Elf32_gptab_gt_header', 'symlink', 'symlinkat',
|
||||
'sync', 'syscall', 'sysconf', 'tcgetpgrp', 'tcsetpgrp',
|
||||
'truncate', 'ttyname', 'ttyname_r', 'ttyslot', 'ualarm', 'uid_t',
|
||||
'union_c__SA_Elf32_Dyn_d_un', 'union_c__SA_Elf32_auxv_t_a_un',
|
||||
'union_c__SA_Elf64_Dyn_d_un', 'union_c__SA_Elf64_auxv_t_a_un',
|
||||
'union_c__UA_Elf32_gptab', 'unlink', 'unlinkat', 'useconds_t',
|
||||
'usleep', 'vfork', 'vhangup', 'write']
|
||||
'struct_c__UA_Elf32_gptab_gt_header', 'strxfrm', 'strxfrm_l',
|
||||
'symlink', 'symlinkat', 'sync', 'syscall', 'sysconf', 'tcgetpgrp',
|
||||
'tcsetpgrp', 'truncate', 'ttyname', 'ttyname_r', 'ttyslot',
|
||||
'ualarm', 'uid_t', 'union_c__SA_Elf32_Dyn_d_un',
|
||||
'union_c__SA_Elf32_auxv_t_a_un', 'union_c__SA_Elf64_Dyn_d_un',
|
||||
'union_c__SA_Elf64_auxv_t_a_un', 'union_c__UA_Elf32_gptab',
|
||||
'unlink', 'unlinkat', 'useconds_t', 'usleep', 'vfork', 'vhangup',
|
||||
'write']
|
||||
|
||||
Reference in New Issue
Block a user