mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
hotfix: group cpu functions in torch backend
This commit is contained in:
@@ -22,6 +22,33 @@ torch.utils.rename_privateuse1_backend("tiny")
|
||||
torch._register_device_module("tiny", TinyBackend())
|
||||
torch.utils.generate_methods_for_privateuse1_backend()
|
||||
|
||||
# *** bad functions on CPU ***
|
||||
|
||||
@torch.library.impl("aten::masked_select", "privateuseone")
|
||||
def masked_select(self, mask):
|
||||
# err, bad
|
||||
return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()]))
|
||||
|
||||
@torch.library.impl("aten::topk", "privateuseone")
|
||||
def topk(self, k, dim=-1, largest=True, sorted=True):
|
||||
# TODO: move to tinygrad
|
||||
t1, t2 = torch.topk(self.cpu(), k, dim, largest, sorted)
|
||||
return torch.return_types.topk((t1.tiny(), t2.tiny()))
|
||||
|
||||
@torch.library.impl("aten::_index_put_impl_", "privateuseone")
|
||||
def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False):
|
||||
# TODO: move to tinygrad
|
||||
return aten._index_put_impl_(self.cpu(), [x.cpu() for x in indices], values.cpu(), accumulate, unsafe).tiny()
|
||||
|
||||
@torch.library.impl("aten::index.Tensor", "privateuseone")
|
||||
def index_tensor(x, y):
|
||||
return aten.index(x.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in y]).tiny()
|
||||
|
||||
@torch.library.impl("aten::randperm.generator_out", "privateuseone")
|
||||
def randperm_generator(n, generator=None, out=None): out.copy_(torch.randperm(n, generator=generator, device="cpu").tiny())
|
||||
|
||||
# *** end bad functions on CPU ***
|
||||
|
||||
@torch.library.impl("aten::zero_", "privateuseone")
|
||||
def zero_(x):
|
||||
tt = unwrap(x)
|
||||
@@ -35,11 +62,6 @@ def fill_scalar(x, y):
|
||||
@torch.library.impl("aten::_local_scalar_dense", "privateuseone")
|
||||
def _local_scalar_dense(tensor): return unwrap(tensor).item()
|
||||
|
||||
@torch.library.impl("aten::masked_select", "privateuseone")
|
||||
def masked_select(self, mask):
|
||||
# err, bad
|
||||
return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()]))
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def cached_to_movement_ops(shape, st) -> list:
|
||||
mops = to_movement_ops(st)
|
||||
@@ -99,24 +121,6 @@ def arange_start(start, end, dtype=None, device=None, pin_memory=None):
|
||||
def arange_start_step(start, end, step, dtype=None, device=None, pin_memory=None):
|
||||
return wrap(Tensor.arange(start, end, step, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())))
|
||||
|
||||
@torch.library.impl("aten::topk", "privateuseone")
|
||||
def topk(self, k, dim=-1, largest=True, sorted=True):
|
||||
# TODO: move to tinygrad
|
||||
t1, t2 = torch.topk(self.cpu(), k, dim, largest, sorted)
|
||||
return torch.return_types.topk((t1.tiny(), t2.tiny()))
|
||||
|
||||
@torch.library.impl("aten::_index_put_impl_", "privateuseone")
|
||||
def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False):
|
||||
# TODO: move to tinygrad
|
||||
return aten._index_put_impl_(self.cpu(), [x.cpu() for x in indices], values.cpu(), accumulate, unsafe).tiny()
|
||||
|
||||
@torch.library.impl("aten::randperm.generator_out", "privateuseone")
|
||||
def randperm_generator(n, generator=None, out=None): out.copy_(torch.randperm(n, generator=generator).tiny())
|
||||
|
||||
@torch.library.impl("aten::index.Tensor", "privateuseone")
|
||||
def index_tensor(x, y):
|
||||
return aten.index(x.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in y]).tiny()
|
||||
|
||||
@torch.library.impl("aten::convolution_overrideable", "privateuseone")
|
||||
def convolution_overrideable(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups):
|
||||
if TORCH_DEBUG >= 1:
|
||||
|
||||
Reference in New Issue
Block a user