diff --git a/examples/other_mnist/beautiful_mnist_torch.py b/examples/other_mnist/beautiful_mnist_torch.py index bee27d2bc2..715cf983ea 100644 --- a/examples/other_mnist/beautiful_mnist_torch.py +++ b/examples/other_mnist/beautiful_mnist_torch.py @@ -26,7 +26,10 @@ class Model(nn.Module): return self.lin(torch.flatten(x, 1)) if __name__ == "__main__": - if getenv("TINY_BACKEND"): + if getenv("TINY_BACKEND2"): + import extra.torch_backend.backend2 + device = torch.device("cpu") + elif getenv("TINY_BACKEND"): import extra.torch_backend.backend device = torch.device("tiny") else: diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index 46d41f2d28..3c9d111e58 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -1,11 +1,16 @@ from tinygrad import Tensor, dtypes -from tinygrad.helpers import DEBUG +from tinygrad.helpers import DEBUG, getenv import torch, pathlib +torch.autograd.grad_mode.set_multithreading_enabled(False) + +# https://pytorch.org/docs/stable/torch.compiler_ir.html # TODO: don't replicate this in cpp torch_to_tiny_dtype = { torch.float32: dtypes.float32, torch.float64: dtypes.float64, + torch.uint8: dtypes.uint8, + torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.bool: dtypes.bool, @@ -13,24 +18,18 @@ torch_to_tiny_dtype = { import torch.utils.cpp_extension mod = torch.utils.cpp_extension.load(name="custom_device_extension", sources=[pathlib.Path(__file__).parent / "wrapped_tensor.cpp"]) -wrap, unwrap = mod.wrap, mod.unwrap +def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x) +def unwrap(x:torch.Tensor) -> Tensor: + assert isinstance(x, torch.Tensor), f"x isn't {type(x)}" + return mod.unwrap(x) class TinyBackend: pass torch.utils.rename_privateuse1_backend("tiny") torch._register_device_module("tiny", TinyBackend) torch.utils.generate_methods_for_privateuse1_backend() -@torch.library.impl("aten::view", "privateuseone") -def view(x, sz): return mod.wrap(mod.unwrap(x).reshape(sz)) - -@torch.library.impl("aten::min", "privateuseone") -def min(x): return mod.wrap(mod.unwrap(x).min()) - -@torch.library.impl("aten::max", "privateuseone") -def max(x): return mod.wrap(mod.unwrap(x).max()) - @torch.library.impl("aten::zero_", "privateuseone") def zero_(x): - tt = mod.unwrap(x) + tt = unwrap(x) tt.replace(tt.zeros_like()) @torch.library.impl("aten::fill_.Scalar", "privateuseone") @@ -51,11 +50,14 @@ def as_strided(tensor, size, stride, storage_offset=None): if size == [] and storage_offset is not None: # TODO: is this right? return wrap(unwrap(tensor).flatten()[storage_offset:storage_offset+1].reshape(())) - print(tensor.shape, size, stride, storage_offset) + # broadcast + if len(tensor.shape) == 0: return wrap(unwrap(tensor).reshape((1,)*len(size)).expand(size)) + print("******* NOTE: this as_strided is wrong ***********\n", tensor.shape, size, stride, storage_offset) + return wrap(Tensor.zeros(*size)) raise NotImplementedError("fix as_strided") @torch.library.impl("aten::empty_strided", "privateuseone") -def empty_strided(size, stride, dtype, layout, device, pin_memory): +def empty_strided(size, stride, dtype, layout, device, pin_memory=False): if DEBUG >= 2: print(f"empty_strided {size=} {stride=} {dtype=} {layout=} {device=} {pin_memory=}") ret = Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype]) return wrap(ret) @@ -68,49 +70,73 @@ def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=F @torch.library.impl("aten::convolution_overrideable", "privateuseone") def convolution_overrideable(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups): - print(input, weight, bias) - raise NotImplementedError + #print(f"{input.shape=} {weight.shape=} {bias.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}") + return wrap(unwrap(input).conv2d(unwrap(weight), unwrap(bias) if bias is not None else None, + groups=groups, stride=stride, dilation=dilation, padding=padding)) + #raise NotImplementedError("need convolution") @torch.library.impl("aten::_copy_from", "privateuseone") def _copy_from(src, dest): if str(src.device) == "tiny" and str(dest.device) == "tiny": unwrap(dest).replace(unwrap(src), allow_shape_mismatch=True) elif str(src.device) == "tiny" and str(dest.device) == "cpu": - dest[:] = torch.from_numpy(unwrap(src).numpy()) + # TODO: is there a better way? + dest.resize_(src.numel()).resize_(src.shape) + dest.copy_(torch.from_numpy(unwrap(src).numpy())) elif str(src.device) == "cpu" and str(dest.device) == "tiny": unwrap(dest).assign(Tensor(src.numpy())) else: raise NotImplementedError(f"can't copy from {src.device} -> {dest.device}") -@torch.library.impl("aten::exp2.out", "privateuseone") -def exp2_out(x, out): unwrap(out).replace(unwrap(x).exp2(), allow_shape_mismatch=True) +@torch.library.impl("aten::cat.out", "privateuseone") +def cat_out(tensors, out, dim=0): unwrap(out).replace(Tensor.cat(*[unwrap(x) for x in tensors], dim=dim), allow_shape_mismatch=True) -@torch.library.impl("aten::ceil.out", "privateuseone") -def ceil_out(x, out): unwrap(out).replace(unwrap(x).ceil(), allow_shape_mismatch=True) +@torch.library.impl("aten::index.Tensor", "privateuseone") +def index_tensor(x, y): return wrap(unwrap(x)[y[0].tolist()]) -@torch.library.impl("aten::abs.out", "privateuseone") -def abs_out(x, out): unwrap(out).replace(unwrap(x).abs(), allow_shape_mismatch=True) +tiny_backend = { + "aten.view": Tensor.reshape, + "aten.add.Tensor": Tensor.add, + "aten.sub.Tensor": Tensor.sub, + "aten.mul.Tensor": Tensor.mul, + "aten.div.Tensor": Tensor.div, + "aten.add_.Tensor": lambda x,y: x.assign(x.add(y)), + "aten.pow.Tensor_Scalar": Tensor.pow, + "aten.bitwise_and.Tensor": Tensor.bitwise_and, + "aten.eq.Tensor": Tensor.eq, "aten.eq.Scalar": Tensor.eq, + "aten.ne.Tensor": Tensor.ne, "aten.ne.Scalar": Tensor.ne, + "aten.gt.Tensor": Tensor.__gt__, "aten.gt.Scalar": Tensor.__gt__, + "aten.lt.Tensor": Tensor.__lt__, "aten.lt.Scalar": Tensor.__lt__, + "aten.exp2": Tensor.exp2, + "aten.min": Tensor.min, + "aten.max": Tensor.max, + "aten.relu": Tensor.relu, + "aten.mean": Tensor.mean, + "aten.neg": Tensor.neg, + "aten.mm": Tensor.matmul, +} -@torch.library.impl("aten::bitwise_and.Tensor", "privateuseone") -def bitwise_and_tensor(x, y): return wrap(unwrap(x) & unwrap(y)) +# there's earlier things to hook here +#"aten.add.out": lambda x,y,out: out.replace(x+y, allow_shape_mismatch=True), +#"aten.abs.out": lambda x,out: out.replace(x.abs(), allow_shape_mismatch=True), +#"aten.ceil.out": lambda x,out: out.replace(x.ceil(), allow_shape_mismatch=True), +#"aten.exp2.out": lambda x,out: out.replace(x.exp2(), allow_shape_mismatch=True), -@torch.library.impl("aten::add.Tensor", "privateuseone") -def add_tensor(x, y): return wrap(unwrap(x) + unwrap(y)) +def wrap_fxn(k,f): + def nf(*args, **kwargs): + #print(k, len(args), kwargs.keys()) + args = [unwrap(x) if isinstance(x, torch.Tensor) else x for x in args] + kwargs = {k:unwrap(v) if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()} + return wrap(f(*args, **kwargs)) + return nf -@torch.library.impl("aten::mul.Tensor", "privateuseone") -def mul_tensor(x, y): return wrap(unwrap(x) * unwrap(y)) +for k,v in tiny_backend.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_fxn(k,v)) -@torch.library.impl("aten::div.Tensor", "privateuseone") -def div_tensor(x, y): return wrap(unwrap(x) / unwrap(y)) - -@torch.library.impl("aten::eq.Tensor", "privateuseone") -def eq_tensor(x, y): return wrap(unwrap(x).eq(unwrap(y))) - -@torch.library.impl("aten::ne.Tensor", "privateuseone") -def ne_tensor(x, y): return wrap(unwrap(x).ne(unwrap(y))) - -@torch.library.impl("aten::ne.Scalar", "privateuseone") -def ne_scalar(x, y): return wrap(unwrap(x).ne(y)) - -@torch.library.impl("aten::gt.Scalar", "privateuseone") -def gt_scalar(x, y): return wrap(unwrap(x) > y) +if getenv("TORCH_DEBUG"): + from torch.utils._python_dispatch import TorchDispatchMode + class DispatchLog(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args, kwargs=None): + #print(f"Dispatch Log: {func}(*{args}, **{kwargs})") + print(f"Dispatch Log: {func}") + return func(*args, **(kwargs or {})) + DispatchLog().__enter__() diff --git a/extra/torch_backend/backend2.py b/extra/torch_backend/backend2.py new file mode 100644 index 0000000000..aa3c7838ba --- /dev/null +++ b/extra/torch_backend/backend2.py @@ -0,0 +1,49 @@ +from tinygrad import Tensor, dtypes +import torch, contextlib +from torch.utils._python_dispatch import TorchDispatchMode + +torch_to_tiny_dtype = { + torch.float32: dtypes.float32, + torch.float64: dtypes.float64, + torch.int32: dtypes.int32, + torch.int64: dtypes.int64, + torch.bool: dtypes.bool, +} + +def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None): + return TTensor(Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype])) + +tiny_backend = { + "aten.empty.memory_format": empty_memory_format, + "aten.view.default": lambda x,sz: TTensor(x.tiny.reshape(sz)), + "aten.abs.default": lambda x: TTensor(x.tiny.abs()), + "aten.eq.Tensor": lambda x,y: TTensor(x.tiny == y.tiny), + "aten.bitwise_and.Tensor": lambda x,y: TTensor(x.tiny & y.tiny), + "aten.ne.Scalar": lambda x,y: TTensor(x.tiny != y), + "aten.mul.Tensor": lambda x,y: TTensor(x.tiny * y.tiny), + "aten.masked_select.default": lambda x,y: TTensor(Tensor(x.tiny.numpy()[y.tiny.numpy()])), +} + +class TTensor(torch.Tensor): + tiny: Tensor + context = contextlib.nullcontext + + @staticmethod + def __new__(cls, tiny, *args, **kwargs): + out = torch.Tensor._make_wrapper_subclass(cls, tiny.shape) + torch._C._set_throw_on_mutable_data_ptr(out) + out.tiny = tiny + return out + def __repr__(self): return super().__repr__(tensor_contents=f"{self.tiny}") + def __torch_dispatch__(cls, func, types, args, kwargs=None): + print(f"Dispatch Log: {func}(*{[type(x) for x in args]}, **{kwargs.keys()})") + #print(f"Dispatch Log: {func}(*{args}, **{kwargs})") + new_func = tiny_backend.get(str(func), None) + if new_func is None: raise NotImplementedError(f"add support for {func}") + return new_func(*args, **(kwargs or {})) + +class Dispatcher(TorchDispatchMode): __torch_dispatch__ = TTensor.__torch_dispatch__ +Dispatcher().__enter__() + +if __name__ == "__main__": + a = torch.empty((4,), dtype=torch.int) diff --git a/extra/torch_backend/test.py b/extra/torch_backend/test.py index 72c168a45f..24227a57cf 100644 --- a/extra/torch_backend/test.py +++ b/extra/torch_backend/test.py @@ -2,37 +2,48 @@ import unittest import torch import numpy as np -import extra.torch_backend.backend # "tiny" backend is installed +from tinygrad.helpers import getenv +if getenv("TINY_BACKEND2"): + import extra.torch_backend.backend2 + device = "cpu" +else: + import extra.torch_backend.backend + device = "tiny" class TestTorchBackend(unittest.TestCase): def test_numpy_ones(self): - a = torch.ones(4, device="tiny") + a = torch.ones(4, device=device) np.testing.assert_equal(a.cpu().numpy(), [1,1,1,1]) def test_numpy_ones(self): - a = torch.ones(4, dtype=torch.int32, device="tiny") + a = torch.ones(4, dtype=torch.int32, device=device) assert a.dtype == torch.int32 np.testing.assert_equal(a.cpu().numpy(), [1,1,1,1]) def test_plus(self): - a = torch.ones(4, device="tiny") - b = torch.ones(4, device="tiny") + a = torch.ones(4, device=device) + b = torch.ones(4, device=device) c = a+b np.testing.assert_equal(c.cpu().numpy(), [2,2,2,2]) + def test_exp2(qself): + a = torch.ones(4, device=device) + b = a.exp2() + print(b) + def test_eq(self): - a = torch.ones(4, device="tiny") - b = torch.ones(4, device="tiny") + a = torch.ones(4, device=device) + b = torch.ones(4, device=device) c = a == b print(c.cpu().numpy()) def test_isfinite(self): - a = torch.ones(4, device="tiny") - np.testing.assert_equal(torch.isfinite(a), [True, True, True, True]) + a = torch.ones(4, device=device) + np.testing.assert_equal(torch.isfinite(a).cpu().numpy(), [True, True, True, True]) # TODO: why def test_str(self): - a = torch.ones(4, device="tiny") + a = torch.ones(4, device=device) print(str(a)) if __name__ == "__main__": diff --git a/extra/torch_backend/wrapped_tensor.cpp b/extra/torch_backend/wrapped_tensor.cpp index fd20913682..5380dda934 100644 --- a/extra/torch_backend/wrapped_tensor.cpp +++ b/extra/torch_backend/wrapped_tensor.cpp @@ -10,6 +10,17 @@ C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl(); } else if (dtype_name == "long") { return caffe2::TypeMeta::Make(); } else if (dtype_name == "bool") { return caffe2::TypeMeta::Make(); + } else if (dtype_name == "char") { return caffe2::TypeMeta::Make(); + } else if (dtype_name == "unsigned char") { return caffe2::TypeMeta::Make(); } throw std::runtime_error("Unsupported dtype: " + dtype_name); } diff --git a/test/test_ops.py b/test/test_ops.py index 64fff504e0..fc338dd7ac 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -8,6 +8,10 @@ from tinygrad import Tensor, Device, dtypes from tinygrad.tensor import _to_np_dtype from tinygrad.device import is_dtype_supported +if getenv("TINY_BACKEND"): + import extra.torch_backend.backend # noqa: F401 # pylint: disable=unused-import + torch.set_default_device("tiny") + if CI: warnings.filterwarnings("ignore", message="Non-empty compiler output encountered") @@ -46,8 +50,8 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra if DEBUG >= 6: np.set_printoptions(linewidth=200, suppress=True) print(ret.numpy()) - print(out.detach().numpy()) - compare("forward pass", ret.numpy(), out.detach().numpy(), atol=atol, rtol=rtol) + print(out.detach().cpu().numpy()) + compare("forward pass", ret.numpy(), out.detach().cpu().numpy(), atol=atol, rtol=rtol) torch_fbp, tinygrad_fbp = np.nan, np.nan if not forward_only and not FORWARD_ONLY: @@ -65,7 +69,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra tinygrad_fbp = time.monotonic() - st for i, (t, tt_grad) in enumerate(zip(ts, tst_grads)): - compare(f"backward pass tensor {i}", tt_grad.numpy(), t.grad.detach().numpy(), atol=grad_atol, rtol=grad_rtol) + compare(f"backward pass tensor {i}", tt_grad.numpy(), t.grad.detach().cpu().numpy(), atol=grad_atol, rtol=grad_rtol) """ (ret+1).square().mean().backward() @@ -90,7 +94,7 @@ def prepare_test_op(low, high, shps, vals, forward_only=False): for i in range(len(ts)): # NOTE: torch default int64 for python ints input if ts[i].dtype == torch.int64: ts[i] = ts[i].type(torch.int32) - tst = [Tensor(x.detach().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts] + tst = [Tensor(x.detach().cpu().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts] return ts, tst class TestOps(unittest.TestCase):