From d65bd669f89eb8a486da1873a8b17432961fa4d7 Mon Sep 17 00:00:00 2001 From: Daniel <81985269+0xbeedee@users.noreply.github.com> Date: Wed, 15 Oct 2025 20:02:33 +0200 Subject: [PATCH] update tiny torch backend hook (#12575) * update the backend to fix torch deprecation warning * use param_hook to avoid full backward hook needlessly firing on inputs which do not require gradients * fix indentation --------- Co-authored-by: chenyu --- extra/torch_backend/backend.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index 04668ee3ef..37a58efc5a 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -642,10 +642,11 @@ def get_real_tinygrad_buffers(): torch.nn.modules.module.register_module_buffer_registration_hook(register_torch_buffer) from torch.nn.modules import Module -def backward_hook(model:Module, _grad_input, _grad_out): - grads_to_realize = [unwrap(p.grad) for p in model.parameters() if p.grad is not None] - if len(grads_to_realize): Tensor.realize(*grads_to_realize) -def module_hook(module:Module, _name, _submodule): module.register_backward_hook(backward_hook) +def param_hook(_grad): + if _grad is not None and _grad.is_tiny: Tensor.realize(unwrap(_grad)) +def module_hook(module:Module, _name, _submodule): + for param in _submodule.parameters(recurse=False): + if param.requires_grad: param.register_hook(param_hook) torch.nn.modules.module.register_module_module_registration_hook(module_hook) def realize_optimizer_step(optimizer: torch.optim.Optimizer, *args, **kwargs):