increase speed of torch mnist: use gradient api (#9282)

This commit is contained in:
George Hotz
2025-02-27 11:57:41 +08:00
committed by GitHub
parent a0764f0dc0
commit 387ea41e99
2 changed files with 17 additions and 21 deletions

View File

@@ -170,7 +170,7 @@ jobs:
- name: Test one op in torch tests
run: PYTHONPATH=. DEBUG=2 python3 extra/torch_backend/torch_tests.py TestTinyBackendPRIVATEUSE1.test_unary_log_tiny_float32
- name: Test beautiful_mnist in torch with TINY_BACKEND
run: PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py
run: PYTHONPATH=. LLVM=1 TARGET_EVAL_ACC_PCT=96.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py
- name: Test Ops with TINY_BACKEND (expect failure)
run: PYTHONPATH=. LLVM=1 LLVMOPT=0 TINY_BACKEND=1 python3 -m pytest -n auto test/test_ops.py || true
- name: Test some torch tests (expect failure)

View File

@@ -2,7 +2,7 @@ from tinygrad import Tensor, dtypes
from tinygrad.helpers import DEBUG, getenv, prod
import torch.lib
TORCH_DEBUG = getenv("TORCH_DEBUG")
import torch, pathlib, math, operator
import torch, pathlib, math, operator, functools
torch.autograd.grad_mode.set_multithreading_enabled(False)
from tinygrad.dtype import _from_torch_dtype, _to_torch_dtype
@@ -40,18 +40,22 @@ def masked_select(self, mask):
# err, bad
return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()]))
@functools.lru_cache(None)
def cached_to_movement_ops(shape, st) -> list:
mops = to_movement_ops(st)
if mops[0] == (MovementOps.RESHAPE, shape): mops = mops[1:]
return mops
from tinygrad.shape.shapetracker import ShapeTracker, View
from extra.to_movement_ops import to_movement_ops, apply_mop, MovementOps
@torch.library.impl("aten::as_strided", "privateuseone")
def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None):
# TODO: this is heavyweight
st = ShapeTracker([View.create(tuple(tensor.shape)), View.create(tuple(size), tuple(stride), 0 if storage_offset is None else storage_offset)])
st = ShapeTracker((View.create(tuple(tensor.shape)), View.create(tuple(size), tuple(stride), 0 if storage_offset is None else storage_offset)))
ret = unwrap(tensor)
if prod(size) == 1: return wrap(ret.flatten()[storage_offset].reshape(size))
if TORCH_DEBUG >= 1: print("**** as_strided", tensor.shape, size, stride, st)
mops = to_movement_ops(st)
if mops[0] == (MovementOps.RESHAPE, tuple(tensor.shape)): mops = mops[1:]
for mo in mops: ret = apply_mop(ret, mo)
for mo in cached_to_movement_ops(tuple(tensor.shape), st): ret = apply_mop(ret, mo)
return wrap(ret)
@torch.library.impl("aten::empty_strided", "privateuseone")
@@ -75,11 +79,10 @@ def max_pool2d_with_indices(self:Tensor, kernel_size, stride=None, padding=0, di
@torch.library.impl("aten::max_pool2d_with_indices_backward", "privateuseone")
def max_pool2d_with_indices_backward(grad_out:Tensor, self:Tensor, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, indices=None):
grad_out, self, indices = unwrap(grad_out), unwrap(self), unwrap(indices)
# TODO: utilize input indices once they are correct
self = self.detach().clone().requires_grad_(True)
Tensor.max_pool2d(self, kernel_size, stride, dilation, padding, ceil_mode).backward(grad_out)
return wrap(self.grad)
grad_out, self = unwrap(grad_out), unwrap(self)
out = Tensor.max_pool2d(self, kernel_size, stride, dilation, padding, ceil_mode)
return wrap(out.gradient(self, gradient=grad_out)[0])
@torch.library.impl("aten::arange", "privateuseone")
def arange(end, dtype=None, device=None, pin_memory=None):
@@ -117,19 +120,15 @@ def convolution_overrideable(input, weight, bias, stride, padding, dilation, tra
print(f"convolution {input.shape=} {weight.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}")
return wrap(unwrap(input).conv2d(unwrap(weight), unwrap(bias) if bias is not None else None,
groups=groups, stride=stride, dilation=dilation, padding=padding))
#raise NotImplementedError("need convolution")
@torch.library.impl("aten::convolution_backward_overrideable", "privateuseone")
def convolution_backward_overrideable(grad_out, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask):
if TORCH_DEBUG >= 1:
print(f"convolution_backward {input.shape=} {weight.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}")
grad_out, input, weight = unwrap(grad_out), unwrap(input), unwrap(weight)
input = input.detach().clone().requires_grad_(output_mask[0])
weight = weight.detach().clone().requires_grad_(output_mask[1])
bias = Tensor.zeros(weight.shape[0]).requires_grad_(output_mask[2])
Tensor.conv2d(input, weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding).backward(grad_out)
return tuple(wrap(x.grad) if x.grad is not None else None for x in [input, weight, bias])
#raise NotImplementedError("need convolution")
grad_out, input, weight, bias = unwrap(grad_out), unwrap(input), unwrap(weight), Tensor.zeros(weight.shape[0])
out = Tensor.conv2d(input, weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding)
grads = out.gradient(*[t for t,m in zip([input, weight, bias], output_mask) if m], gradient=grad_out)
return tuple([wrap(grads.pop(0)) if m else None for m in output_mask])
@torch.library.impl("aten::_copy_from", "privateuseone")
def _copy_from(src, dest, non_blocking=False):
@@ -141,9 +140,6 @@ def _copy_from(src, dest, non_blocking=False):
dest.copy_(torch.from_numpy(unwrap(src).numpy()))
elif str(src.device) == "cpu" and str(dest.device) == "tiny":
unwrap(dest).assign(Tensor(src.numpy()))
#if 0 in dest.stride():
# print(dest.shape, dest.stride())
# exit(0)
else:
raise NotImplementedError(f"can't copy from {src.device} -> {dest.device}")