torch hook: address comments (#9295)

* torch hook: address comments

* failed test
This commit is contained in:
nimlgen
2025-02-28 11:51:52 +03:00
committed by GitHub
parent d657d5f754
commit 052722a7bc
2 changed files with 28 additions and 6 deletions

View File

@@ -87,5 +87,12 @@ class TestTorchBackend(unittest.TestCase):
a = torch.ones(4, device=device)
print(str(a))
@unittest.skip("failed")
def test_floor_div(self):
a = torch.tensor([10., 7., 5.], device=device)
b = torch.tensor([3., 2., 2.], device=device)
result = a // b
np.testing.assert_equal(result.cpu().numpy(), [3., 3., 2.])
if __name__ == "__main__":
unittest.main()

View File

@@ -2,6 +2,7 @@ import ctypes, struct, platform, pathlib, os, binascii, itertools
from hexdump import hexdump
from tinygrad.device import Device
from tinygrad import Tensor
from tinygrad.dtype import _from_torch_dtype
from tinygrad.helpers import to_mv, DEBUG, getenv, colored, time_to_str
import extra.torch_hook.hook_cuda as hook_cuda
@@ -14,6 +15,7 @@ import extra.torch_hook.hook_cuda as hook_cuda
TINY_MIRROR = getenv("TINY_MIRROR", 1) # should mirror aten ops to tiny backend
RUN_ONLY = getenv("RUN_ONLY", -1) # run only a specific aten call
REALIZE = getenv("REALIZE", 1) # realize and wait each aten call
WRAP_TINY = getenv("WRAP_TINY", 1) # reuse cuda tensors
FULL_KERN_NAME = getenv("FULL_KERN_NAME", 0) # print full kernel name
print("importing torch...")
@@ -39,15 +41,24 @@ class DispatchLog(TorchDispatchMode):
def can_print_arg(arg):
return args is None or isinstance(arg, str) or isinstance(arg, int) or isinstance(arg, float) or isinstance(arg, bool)
def create_tiny_mapping(arg):
if WRAP_TINY:
tt = Tensor.from_blob(arg.data_ptr(), arg.shape, dtype=_from_torch_dtype(arg.dtype))
cuda_to_tiny_mappings[arg] = tiny_torch.wrap(tt)
for i,arg in enumerate(args):
if torch.is_tensor(arg):
if arg.device.type == "cuda": should_call_tiny = True
if arg.device.type == "cuda":
should_call_tiny = True
if WRAP_TINY: create_tiny_mapping(arg)
txt_args.append(f"tensor({arg.shape} {arg.device} {arg.dtype})")
elif can_print_arg(arg): txt_args.append(f'{arg}')
else: txt_args.append(f"{type(arg)}")
for k,v in (kwargs or {}).items():
if torch.is_tensor(v):
if arg.device.type == "cuda": should_call_tiny = True
if arg.device.type == "cuda":
should_call_tiny = True
if WRAP_TINY: create_tiny_mapping(arg)
txt_args.append(f"{k}:tensor({v.shape} {v.device} {v.dtype})")
elif can_print_arg(arg): txt_args.append(f'{k}:{arg}"')
else: txt_args.append(f"{type(arg)}")
@@ -68,7 +79,7 @@ class DispatchLog(TorchDispatchMode):
for param in ev.params:
if isinstance(param, hook_cuda.HookTensorParamEvent):
is_out = param.cuda_address == out_addr
txt_params += [f"{'out' if is_out else 'in'} tensor{param.enum}({param.cuda_address:#x}, off={param.offset:#x})"]
txt_params += [f"{'result ' if is_out else ''}Tensor{param.enum}({param.cuda_address:#x})"]
just_kern_name = ev.name
if not FULL_KERN_NAME:
@@ -98,11 +109,15 @@ class DispatchLog(TorchDispatchMode):
# TODO: this is a hack, any way to do this better?
if REALIZE:
tiny_x.cpu()
out_addr = 0x0
if torch.is_tensor(tiny_x):
tt = tiny_torch.unwrap(tiny_x).realize()
try: out_addr = tt.lazydata.buffer._buf.value
except Exception: pass
tiny_events = hook_cuda.collect_events(clear=True)
print_events(tiny_events, colored("tiny", "magenta"), 0x0)
print_events(tiny_events, colored("tiny", "magenta"), out_addr)
cuda_to_tiny_mappings[orig_x] = tiny_x
if not WRAP_TINY: cuda_to_tiny_mappings[orig_x] = tiny_x
hook_cuda.pop_ignore_dispatch()
return orig_x