hotfix: group cpu functions in torch backend

This commit is contained in:
George Hotz
2025-02-28 10:38:13 +08:00
parent b32595dbbc
commit ac40316692

View File

@@ -22,6 +22,33 @@ torch.utils.rename_privateuse1_backend("tiny")
torch._register_device_module("tiny", TinyBackend())
torch.utils.generate_methods_for_privateuse1_backend()
# *** bad functions on CPU ***
@torch.library.impl("aten::masked_select", "privateuseone")
def masked_select(self, mask):
# err, bad
return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()]))
@torch.library.impl("aten::topk", "privateuseone")
def topk(self, k, dim=-1, largest=True, sorted=True):
# TODO: move to tinygrad
t1, t2 = torch.topk(self.cpu(), k, dim, largest, sorted)
return torch.return_types.topk((t1.tiny(), t2.tiny()))
@torch.library.impl("aten::_index_put_impl_", "privateuseone")
def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False):
# TODO: move to tinygrad
return aten._index_put_impl_(self.cpu(), [x.cpu() for x in indices], values.cpu(), accumulate, unsafe).tiny()
@torch.library.impl("aten::index.Tensor", "privateuseone")
def index_tensor(x, y):
return aten.index(x.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in y]).tiny()
@torch.library.impl("aten::randperm.generator_out", "privateuseone")
def randperm_generator(n, generator=None, out=None): out.copy_(torch.randperm(n, generator=generator, device="cpu").tiny())
# *** end bad functions on CPU ***
@torch.library.impl("aten::zero_", "privateuseone")
def zero_(x):
tt = unwrap(x)
@@ -35,11 +62,6 @@ def fill_scalar(x, y):
@torch.library.impl("aten::_local_scalar_dense", "privateuseone")
def _local_scalar_dense(tensor): return unwrap(tensor).item()
@torch.library.impl("aten::masked_select", "privateuseone")
def masked_select(self, mask):
# err, bad
return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()]))
@functools.lru_cache(None)
def cached_to_movement_ops(shape, st) -> list:
mops = to_movement_ops(st)
@@ -99,24 +121,6 @@ def arange_start(start, end, dtype=None, device=None, pin_memory=None):
def arange_start_step(start, end, step, dtype=None, device=None, pin_memory=None):
return wrap(Tensor.arange(start, end, step, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())))
@torch.library.impl("aten::topk", "privateuseone")
def topk(self, k, dim=-1, largest=True, sorted=True):
# TODO: move to tinygrad
t1, t2 = torch.topk(self.cpu(), k, dim, largest, sorted)
return torch.return_types.topk((t1.tiny(), t2.tiny()))
@torch.library.impl("aten::_index_put_impl_", "privateuseone")
def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False):
# TODO: move to tinygrad
return aten._index_put_impl_(self.cpu(), [x.cpu() for x in indices], values.cpu(), accumulate, unsafe).tiny()
@torch.library.impl("aten::randperm.generator_out", "privateuseone")
def randperm_generator(n, generator=None, out=None): out.copy_(torch.randperm(n, generator=generator).tiny())
@torch.library.impl("aten::index.Tensor", "privateuseone")
def index_tensor(x, y):
return aten.index(x.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in y]).tiny()
@torch.library.impl("aten::convolution_overrideable", "privateuseone")
def convolution_overrideable(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups):
if TORCH_DEBUG >= 1: