mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
torch hook: address comments (#9295)
* torch hook: address comments * failed test
This commit is contained in:
@@ -87,5 +87,12 @@ class TestTorchBackend(unittest.TestCase):
|
||||
a = torch.ones(4, device=device)
|
||||
print(str(a))
|
||||
|
||||
@unittest.skip("failed")
|
||||
def test_floor_div(self):
|
||||
a = torch.tensor([10., 7., 5.], device=device)
|
||||
b = torch.tensor([3., 2., 2.], device=device)
|
||||
result = a // b
|
||||
np.testing.assert_equal(result.cpu().numpy(), [3., 3., 2.])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -2,6 +2,7 @@ import ctypes, struct, platform, pathlib, os, binascii, itertools
|
||||
from hexdump import hexdump
|
||||
from tinygrad.device import Device
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.dtype import _from_torch_dtype
|
||||
from tinygrad.helpers import to_mv, DEBUG, getenv, colored, time_to_str
|
||||
|
||||
import extra.torch_hook.hook_cuda as hook_cuda
|
||||
@@ -14,6 +15,7 @@ import extra.torch_hook.hook_cuda as hook_cuda
|
||||
TINY_MIRROR = getenv("TINY_MIRROR", 1) # should mirror aten ops to tiny backend
|
||||
RUN_ONLY = getenv("RUN_ONLY", -1) # run only a specific aten call
|
||||
REALIZE = getenv("REALIZE", 1) # realize and wait each aten call
|
||||
WRAP_TINY = getenv("WRAP_TINY", 1) # reuse cuda tensors
|
||||
FULL_KERN_NAME = getenv("FULL_KERN_NAME", 0) # print full kernel name
|
||||
|
||||
print("importing torch...")
|
||||
@@ -39,15 +41,24 @@ class DispatchLog(TorchDispatchMode):
|
||||
def can_print_arg(arg):
|
||||
return args is None or isinstance(arg, str) or isinstance(arg, int) or isinstance(arg, float) or isinstance(arg, bool)
|
||||
|
||||
def create_tiny_mapping(arg):
|
||||
if WRAP_TINY:
|
||||
tt = Tensor.from_blob(arg.data_ptr(), arg.shape, dtype=_from_torch_dtype(arg.dtype))
|
||||
cuda_to_tiny_mappings[arg] = tiny_torch.wrap(tt)
|
||||
|
||||
for i,arg in enumerate(args):
|
||||
if torch.is_tensor(arg):
|
||||
if arg.device.type == "cuda": should_call_tiny = True
|
||||
if arg.device.type == "cuda":
|
||||
should_call_tiny = True
|
||||
if WRAP_TINY: create_tiny_mapping(arg)
|
||||
txt_args.append(f"tensor({arg.shape} {arg.device} {arg.dtype})")
|
||||
elif can_print_arg(arg): txt_args.append(f'{arg}')
|
||||
else: txt_args.append(f"{type(arg)}")
|
||||
for k,v in (kwargs or {}).items():
|
||||
if torch.is_tensor(v):
|
||||
if arg.device.type == "cuda": should_call_tiny = True
|
||||
if arg.device.type == "cuda":
|
||||
should_call_tiny = True
|
||||
if WRAP_TINY: create_tiny_mapping(arg)
|
||||
txt_args.append(f"{k}:tensor({v.shape} {v.device} {v.dtype})")
|
||||
elif can_print_arg(arg): txt_args.append(f'{k}:{arg}"')
|
||||
else: txt_args.append(f"{type(arg)}")
|
||||
@@ -68,7 +79,7 @@ class DispatchLog(TorchDispatchMode):
|
||||
for param in ev.params:
|
||||
if isinstance(param, hook_cuda.HookTensorParamEvent):
|
||||
is_out = param.cuda_address == out_addr
|
||||
txt_params += [f"{'out' if is_out else 'in'} tensor{param.enum}({param.cuda_address:#x}, off={param.offset:#x})"]
|
||||
txt_params += [f"{'result ' if is_out else ''}Tensor{param.enum}({param.cuda_address:#x})"]
|
||||
|
||||
just_kern_name = ev.name
|
||||
if not FULL_KERN_NAME:
|
||||
@@ -98,11 +109,15 @@ class DispatchLog(TorchDispatchMode):
|
||||
|
||||
# TODO: this is a hack, any way to do this better?
|
||||
if REALIZE:
|
||||
tiny_x.cpu()
|
||||
out_addr = 0x0
|
||||
if torch.is_tensor(tiny_x):
|
||||
tt = tiny_torch.unwrap(tiny_x).realize()
|
||||
try: out_addr = tt.lazydata.buffer._buf.value
|
||||
except Exception: pass
|
||||
tiny_events = hook_cuda.collect_events(clear=True)
|
||||
print_events(tiny_events, colored("tiny", "magenta"), 0x0)
|
||||
print_events(tiny_events, colored("tiny", "magenta"), out_addr)
|
||||
|
||||
cuda_to_tiny_mappings[orig_x] = tiny_x
|
||||
if not WRAP_TINY: cuda_to_tiny_mappings[orig_x] = tiny_x
|
||||
|
||||
hook_cuda.pop_ignore_dispatch()
|
||||
return orig_x
|
||||
|
||||
Reference in New Issue
Block a user