From 796c3bbb235e5e36d512521d56b8a56db94d1781 Mon Sep 17 00:00:00 2001 From: Priyank Patel Date: Mon, 10 Mar 2025 08:29:00 -0700 Subject: [PATCH] torch: support in-place operations on views (#9371) * add torch inplace tests * first set of tests passing * wrap all inplace funcs, add more tests * fixes and wrap more functions * fix all uint8 tests to avoid slow tests * fix the one test * another test, another fix * and one more, works for ddp now * something on contiguous, cleanup --------- Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com> --- .github/workflows/test.yml | 2 + extra/torch_backend/backend.py | 103 +++++++++++++++++++++------- extra/torch_backend/test_inplace.py | 65 ++++++++++++++++++ 3 files changed, 144 insertions(+), 26 deletions(-) create mode 100644 extra/torch_backend/test_inplace.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 82120bd264..839c9f67fe 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -175,6 +175,8 @@ jobs: run: PYTHONPATH=. DEBUG=2 python3 extra/torch_backend/torch_tests.py TestTinyBackendPRIVATEUSE1.test_unary_log_tiny_float32 - name: Test Ops with TINY_BACKEND (expect failure) run: PYTHONPATH=. LLVM=1 LLVMOPT=0 TINY_BACKEND=1 python3 -m pytest -n auto test/test_ops.py --durations=20 || true + - name: Test in-place operations on views + run: PYTHONPATH=. TORCH_DEBUG=1 python3 extra/torch_backend/test_inplace.py torchbackendmore: name: Torch Backend Tests More diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index 1066a9f284..1afb3a10f9 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -6,7 +6,7 @@ from tinygrad import Tensor, dtypes from tinygrad.helpers import getenv, prod import torch.lib TORCH_DEBUG = getenv("TORCH_DEBUG") -import torch, pathlib, math, operator, functools +import torch, pathlib, math, operator, functools, inspect torch.autograd.grad_mode.set_multithreading_enabled(False) from tinygrad.dtype import _from_torch_dtype, _to_torch_dtype @@ -28,6 +28,36 @@ torch.utils.rename_privateuse1_backend("tiny") torch._register_device_module("tiny", TinyBackend()) torch.utils.generate_methods_for_privateuse1_backend() +# in place operations with views +def is_view(self: torch.Tensor) -> bool: return getattr(self, "_base", None) is not None +def realize_with_views(self: torch.Tensor, views: list[torch.Tensor]): + assert self.device.type == "tiny" + self = unwrap(self) + if not self.lazydata.st.contiguous: raise ValueError("base of view must be contiguous") # TODO: support? + self.replace(self.clone().realize()) + for v in views: + v = unwrap(v) + ret = self + st = ShapeTracker(self.lazydata.st.views + v.lazydata.st.views) # TODO: is this right? + for mo in cached_to_movement_ops(self.shape, st): ret = apply_mop(ret, mo) + v.replace(ret) +def maybe_realize_storage(self: torch.Tensor) -> bool: + if realize:=is_view(self): realize_with_views(self._base, [self]) # TODO: other views could exist + return realize +def inplace_fn(outvars: str|list[str]): + if type(outvars) is str: outvars = [outvars] + def decorator(fn): + sig = inspect.signature(fn) + def wrapper(*args, **kwargs): + bound = sig.bind(*args, **kwargs) + outs = [kwargs.get(v, bound.arguments.get(v)) for v in outvars] + realize = any(maybe_realize_storage(o) for o in outs) + ret = fn(*args, **kwargs) + if realize: Tensor.realize(*(unwrap(o) for o in outs)) + return ret + return wrapper + return decorator + # *** bad functions on CPU *** @torch.library.impl("aten::masked_select", "privateuseone") @@ -36,9 +66,11 @@ def masked_select(self, mask): return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()])) @torch.library.impl("aten::_index_put_impl_", "privateuseone") +@inplace_fn("self") def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False): # TODO: move to tinygrad - return aten._index_put_impl_(self.cpu(), [x.cpu() for x in indices], values.cpu(), accumulate, unsafe).tiny() + ret = aten._index_put_impl_(self.cpu(), [x.cpu() if isinstance(x, torch.Tensor) else None for x in indices], values.cpu(), accumulate, unsafe).tiny() + return wrap(unwrap(self).assign(unwrap(ret))) @torch.library.impl("aten::index.Tensor", "privateuseone") def index_tensor(x, y): @@ -50,14 +82,20 @@ def randperm_generator(n, generator=None, out=None): out.copy_(torch.randperm(n, # *** end bad functions on CPU *** @torch.library.impl("aten::zero_", "privateuseone") +@inplace_fn("x") def zero_(x): + if TORCH_DEBUG: print(f"zero_ {x.shape}") tt = unwrap(x) - tt.replace(tt.zeros_like()) + # NOTE: unconditional contiguous covers if x is contiguous (match it) or if x is view (realize for inplace) + # TODO: consolidate + tt.assign(tt.zeros_like().contiguous()) @torch.library.impl("aten::fill_.Scalar", "privateuseone") +@inplace_fn("x") def fill_scalar(x, y): + if TORCH_DEBUG: print(f"fill_.Scalar {x.shape} {y}") tt = unwrap(x) - tt.replace(tt.full_like(y)) + tt.assign(tt.full_like(y).contiguous()) @torch.library.impl("aten::_local_scalar_dense", "privateuseone") def _local_scalar_dense(tensor): return unwrap(tensor).item() @@ -89,7 +127,7 @@ def empty_strided(size, stride, dtype, layout=None, device=None, pin_memory=Fals @torch.library.impl("aten::empty.memory_format", "privateuseone") def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None): if TORCH_DEBUG: print(f"empty.memory_format {size=} {dtype=} {layout=} {device=} {pin_memory=} {memory_format=}") - ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())) + ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())).contiguous() return wrap(ret) @torch.library.impl("aten::max_pool2d_with_indices", "privateuseone") @@ -137,23 +175,33 @@ def convolution_backward_overrideable(grad_out, input, weight, stride, padding, grads = out.gradient(*[t for t,m in zip([input, weight, bias], output_mask) if m], gradient=grad_out) return tuple([wrap(grads.pop(0)) if m else None for m in output_mask]) +def upsample(self, size, align_corners=False, mode=None): return wrap(Tensor.interpolate(unwrap(self), size, mode=mode, align_corners=align_corners)) +for i,pre in enumerate(["", "bi", "tri"]): + torch.library.impl(f"aten::upsample_{pre}linear{i+1}d", "privateuseone")(functools.partial(upsample, mode="linear")) + torch.library.impl(f"aten::upsample_nearest{i+1}d", "privateuseone")(functools.partial(upsample, mode="nearest")) + torch.library.impl(f"aten::_upsample_nearest_exact{i+1}d", "privateuseone")(functools.partial(upsample, mode="nearest-exact")) + @torch.library.impl("aten::_copy_from", "privateuseone") def _copy_from(src: torch.Tensor, dest, non_blocking=False): + realize = str(dest.device) == "tiny" and maybe_realize_storage(dest) cast_dtype = _from_torch_dtype(dest.dtype) if str(src.device) == "tiny" and str(dest.device) == "tiny": - unwrap(dest).replace(unwrap(src).cast(cast_dtype), allow_shape_mismatch=True) + unwrap(dest).assign(unwrap(src).cast(cast_dtype)) + if realize: Tensor.realize(unwrap(dest)) elif str(src.device) == "tiny" and str(dest.device) == "cpu": # TODO: is there a better way? dest.resize_(src.numel()).resize_(src.shape) dest.copy_(torch.from_numpy(unwrap(src).cast(cast_dtype).numpy())) elif str(src.device) == "cpu" and str(dest.device) == "tiny": unwrap(dest).assign(Tensor(src.numpy()).cast(cast_dtype)) + if realize: Tensor.realize(unwrap(dest)) else: raise NotImplementedError(f"can't copy from {src.device} -> {dest.device}") @torch.library.impl("aten::cat.out", "privateuseone") +@inplace_fn("out") def cat_out(tensors, dim=0, out=None): - unwrap(out).replace(Tensor.cat(*[unwrap(x) for x in tensors], dim=dim), allow_shape_mismatch=True) + unwrap(out).assign(Tensor.cat(*[unwrap(x) for x in tensors], dim=dim)) # register some decompositions from torch._decomp import get_decompositions @@ -280,6 +328,10 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_ "aten.where.self_out": Tensor.where, "aten.prod.int_out": Tensor.prod, "aten.scatter_add.out": functools.partial(Tensor.scatter_reduce, reduce='sum'), + # NOTE: axis=[] in torch means all, change tinygrad? + "aten.sum.IntList_out": lambda self,axis,keepdim=False,dtype=None: + self.sum(axis if axis is None or len(axis) else None, keepdim, + acc_dtype = _from_torch_dtype(dtype) if dtype is not None else None), }} # we add the "out" here @@ -290,7 +342,8 @@ def wrap_out(f): if getenv("ALLOW_DTYPE_MISMATCH", 1): assigned = assigned.cast(out.dtype) assert out.shape == assigned.shape, f"shape mismatch: {assigned.shape} -> {out.shape}" assert out.dtype == assigned.dtype, f"dtype mismatch: {assigned.dtype} -> {out.dtype}" - return out.replace(assigned) + if out.lazydata.is_realized: assigned = assigned.contiguous() # TODO: how does this map to torch's semantics + return out.assign(assigned) return _wrap_out tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{ @@ -298,15 +351,15 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{ "aten._unsafe_view": Tensor.reshape, # when are views unsafe, and do we care? "aten.remainder.Scalar_Tensor": lambda x,y: x%y, "aten.floor_divide": lambda x,y: x//y, - "aten.floor_divide_.Tensor": lambda x,y: x.assign(x//y), + "aten.floor_divide_.Tensor": inplace_fn("x")(lambda x,y: x.assign(x//y)), # TODO: use tinygrad methods, but they require x to be unsigned "aten.__lshift__.Scalar": lambda x,y: x*(2**y), - "aten.__ilshift__.Scalar": lambda x,y: x.assign(x*(2**y)), + "aten.__ilshift__.Scalar": inplace_fn("x")(lambda x,y: x.assign(x*(2**y))), "aten.__rshift__.Scalar": lambda x,y: x//(2**y), - "aten.__irshift__.Scalar": lambda x,y: x.assign(x//(2**y)), + "aten.__irshift__.Scalar": inplace_fn("x")(lambda x,y: x.assign(x//(2**y))), # relu doesn't have an out form? "aten.relu": Tensor.relu, - "aten.relu_": lambda x: x.assign(x.relu()), + "aten.relu_": inplace_fn("x")(lambda x: x.assign(x.relu())), "aten.mean": Tensor.mean, "aten.mean.dim": Tensor.mean, "aten.min": Tensor.min, @@ -320,30 +373,26 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{ "aten.std_mean.correction": Tensor.std_mean, "aten.var.correction": Tensor.var, "aten.var_mean.correction": Tensor.var_mean, - # NOTE: axis=[] in torch means all, change tinygrad? - "aten.sum.IntList_out": lambda self,axis,keepdim=False,dtype=None,out=None: - out.replace(self.sum(axis if axis is None or len(axis) else None, keepdim, - acc_dtype = _from_torch_dtype(dtype) if dtype is not None else None), allow_shape_mismatch=True), "aten.scatter.value": Tensor.scatter, "aten.scatter.value_reduce": Tensor.scatter, "aten.gather": lambda self, dim, index: self.gather(dim, index.cast(dtypes.int)), "aten.where.self": Tensor.where, # NOTE: this is needed as well as the out type "aten._softmax": lambda self,dim,half_to_float: self.softmax(dim), "aten._log_softmax": lambda self,dim,half_to_float: self.log_softmax(dim), - "aten.random_": lambda self: - self.assign(Tensor.randint(*self.shape, low=dtypes.min(self.dtype), high=dtypes.max(self.dtype), device=self.device, dtype=self.dtype)), - "aten.random_.from": lambda self, from_, to: - self.assign(Tensor.randint(*self.shape, low=from_, high=to, device=self.device, dtype=self.dtype)), - "aten.uniform_": lambda self, low=0, high=1: self.assign(Tensor.uniform(*self.shape, low=low, high=high)), - "aten.normal_": lambda self, mean=0, std=1: self.assign(Tensor.normal(*self.shape, mean=mean, std=std)), + "aten.random_": inplace_fn("self")(lambda self: + self.assign(Tensor.randint(*self.shape, low=dtypes.min(self.dtype), high=dtypes.max(self.dtype), device=self.device, dtype=self.dtype))), + "aten.random_.from": inplace_fn("self")(lambda self, from_, to: + self.assign(Tensor.randint(*self.shape, low=from_, high=to, device=self.device, dtype=self.dtype))), + "aten.uniform_": inplace_fn("self")(lambda self, low=0, high=1: self.assign(Tensor.uniform(*self.shape, low=low, high=high))), + "aten.normal_": inplace_fn("self")(lambda self, mean=0, std=1: self.assign(Tensor.normal(*self.shape, mean=mean, std=std))), # these don't work in out form, they have size 0 "aten.abs": Tensor.abs, "aten.logical_not": Tensor.logical_not, - "aten.logical_or_": lambda x, y: x.assign(x | y), + "aten.logical_or_": inplace_fn("x")(lambda x, y: x.assign(x | y)), "aten.multinomial": Tensor.multinomial, "aten.pad": Tensor.pad, "aten.reflection_pad2d": functools.partial(Tensor.pad, mode="reflect"), - "aten.masked_fill_.Scalar": lambda self, mask, value: self.assign(self.masked_fill(mask, value)), + "aten.masked_fill_.Scalar": inplace_fn("self")(lambda self, mask, value: self.assign(self.masked_fill(mask, value))), "aten.masked_fill.Scalar": Tensor.masked_fill, "aten.masked_fill.Tensor": Tensor.masked_fill, "aten.all": Tensor.all, @@ -356,7 +405,7 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{ "aten.asinh": Tensor.asinh, "aten.mul": Tensor.mul, "aten.atanh": Tensor.atanh, - "aten.fill_.Tensor": Tensor.full, + "aten.fill_.Tensor": Tensor.full, # TODO: looks wrong "aten.flip": Tensor.flip, "aten.scatter_reduce.two": Tensor.scatter_reduce, "aten.squeeze_.dim": lambda self, dim: self.replace(self.squeeze(dim), allow_shape_mismatch=True), @@ -390,7 +439,9 @@ def wrap_fxn(k,f): if isinstance(out, Tensor): return wrap(out) elif isinstance(out, tuple): return tuple(wrap(x) for x in out) else: raise RuntimeError(f"unknown output type {type(out)}") - return nf + def nf2(*args, **kwargs): + return inplace_fn("out")(nf)(*args, **kwargs) if "out" in kwargs else nf(*args, **kwargs) + return nf2 for k,v in tiny_backend.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_fxn(k,v)) diff --git a/extra/torch_backend/test_inplace.py b/extra/torch_backend/test_inplace.py new file mode 100644 index 0000000000..99dfdef4a8 --- /dev/null +++ b/extra/torch_backend/test_inplace.py @@ -0,0 +1,65 @@ +import unittest +import torch +import tinygrad.frontend.torch +torch.set_default_device("tiny") +import numpy as np + +class TestTorchBackendInplace(unittest.TestCase): + def test_zero(self): + a = torch.ones(4) + a.zero_() + np.testing.assert_equal(a.cpu().numpy(), [0,0,0,0]) + + def test_view_zero(self): + a = torch.ones(4) + a.view((2, 2)).zero_() + np.testing.assert_equal(a.cpu().numpy(), [0,0,0,0]) + + def test_slice_zero(self): + a = torch.ones(4) + a[2:].zero_() + np.testing.assert_equal(a.cpu().numpy(), [1,1,0,0]) + + def test_slice_permute_zero(self): + a = torch.ones((3,2)) + a.permute(1,0)[1:].zero_() + np.testing.assert_equal(a.cpu().numpy(), [[1,0],[1,0],[1,0]]) + + def test_slice_fill(self): + a = torch.zeros(4) + a[2:].fill_(2) + np.testing.assert_equal(a.cpu().numpy(), [0,0,2,2]) + + def test_slice_mul(self): + a = torch.ones(4) + a[:2] *= 3 + a[2:] *= 2 + np.testing.assert_equal(a.cpu().numpy(), [3,3,2,2]) + + def test_stacked_mul(self): + a = torch.ones((3,3)) + b = a[1:,1:].permute(1,0) + c = b[1:,:] + b *= 2 + c *= 3 + np.testing.assert_equal(a.cpu().numpy(), [[1,1,1],[1,2,6],[1,2,6]]) + + def test_flatten_reshape_add(self): + a = torch.zeros((2,2,12,32)) + b = a.flatten() + c = b.reshape((48,32)) + a += 1 + b += 1 + c += 1 + np.testing.assert_equal(c.cpu().numpy(), torch.full((48,32),3).cpu().numpy()) + + def test_noncontig(self): + a = torch.empty_strided((4,4),(1,4), dtype=torch.int64) + # self.assertFalse(a.is_contiguous()) # TODO: we are contiguous when it's not required + a.zero_() + b = a.view((4,4)) + b[1:3,:] += 1 + np.testing.assert_equal(a.cpu().numpy(), [[0]*4,[1]*4,[1]*4,[0]*4]) + +if __name__ == "__main__": + unittest.main()