different way to write torch backend (#9197)

* different way to write torch backend

* both backends

* more work

* simpler code

* more work

* test both

* imply unwrap/wrap

* FORWARD_ONLY=1 TINY_BACKEND=1 python3 test/test_ops.py TestOps.test_add works

* ready to start making test_ops work in torch backend

* backward pass, TINY_BACKEND=1 python3 test/test_ops.py TestOps.test_add works

* FORWARD_ONLY=1 TINY_BACKEND=1 python3 test/test_ops.py TestOps.test_simple_conv2d works

* matmul backward is broken with as_strided
This commit is contained in:
George Hotz
2025-02-22 14:42:26 +08:00
committed by GitHub
parent 041b6d5678
commit 4e6665bda5
6 changed files with 164 additions and 58 deletions

View File

@@ -26,7 +26,10 @@ class Model(nn.Module):
return self.lin(torch.flatten(x, 1))
if __name__ == "__main__":
if getenv("TINY_BACKEND"):
if getenv("TINY_BACKEND2"):
import extra.torch_backend.backend2
device = torch.device("cpu")
elif getenv("TINY_BACKEND"):
import extra.torch_backend.backend
device = torch.device("tiny")
else:

View File

@@ -1,11 +1,16 @@
from tinygrad import Tensor, dtypes
from tinygrad.helpers import DEBUG
from tinygrad.helpers import DEBUG, getenv
import torch, pathlib
torch.autograd.grad_mode.set_multithreading_enabled(False)
# https://pytorch.org/docs/stable/torch.compiler_ir.html
# TODO: don't replicate this in cpp
torch_to_tiny_dtype = {
torch.float32: dtypes.float32,
torch.float64: dtypes.float64,
torch.uint8: dtypes.uint8,
torch.int8: dtypes.int8,
torch.int32: dtypes.int32,
torch.int64: dtypes.int64,
torch.bool: dtypes.bool,
@@ -13,24 +18,18 @@ torch_to_tiny_dtype = {
import torch.utils.cpp_extension
mod = torch.utils.cpp_extension.load(name="custom_device_extension", sources=[pathlib.Path(__file__).parent / "wrapped_tensor.cpp"])
wrap, unwrap = mod.wrap, mod.unwrap
def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x)
def unwrap(x:torch.Tensor) -> Tensor:
assert isinstance(x, torch.Tensor), f"x isn't {type(x)}"
return mod.unwrap(x)
class TinyBackend: pass
torch.utils.rename_privateuse1_backend("tiny")
torch._register_device_module("tiny", TinyBackend)
torch.utils.generate_methods_for_privateuse1_backend()
@torch.library.impl("aten::view", "privateuseone")
def view(x, sz): return mod.wrap(mod.unwrap(x).reshape(sz))
@torch.library.impl("aten::min", "privateuseone")
def min(x): return mod.wrap(mod.unwrap(x).min())
@torch.library.impl("aten::max", "privateuseone")
def max(x): return mod.wrap(mod.unwrap(x).max())
@torch.library.impl("aten::zero_", "privateuseone")
def zero_(x):
tt = mod.unwrap(x)
tt = unwrap(x)
tt.replace(tt.zeros_like())
@torch.library.impl("aten::fill_.Scalar", "privateuseone")
@@ -51,11 +50,14 @@ def as_strided(tensor, size, stride, storage_offset=None):
if size == [] and storage_offset is not None:
# TODO: is this right?
return wrap(unwrap(tensor).flatten()[storage_offset:storage_offset+1].reshape(()))
print(tensor.shape, size, stride, storage_offset)
# broadcast
if len(tensor.shape) == 0: return wrap(unwrap(tensor).reshape((1,)*len(size)).expand(size))
print("******* NOTE: this as_strided is wrong ***********\n", tensor.shape, size, stride, storage_offset)
return wrap(Tensor.zeros(*size))
raise NotImplementedError("fix as_strided")
@torch.library.impl("aten::empty_strided", "privateuseone")
def empty_strided(size, stride, dtype, layout, device, pin_memory):
def empty_strided(size, stride, dtype, layout, device, pin_memory=False):
if DEBUG >= 2: print(f"empty_strided {size=} {stride=} {dtype=} {layout=} {device=} {pin_memory=}")
ret = Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype])
return wrap(ret)
@@ -68,49 +70,73 @@ def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=F
@torch.library.impl("aten::convolution_overrideable", "privateuseone")
def convolution_overrideable(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups):
print(input, weight, bias)
raise NotImplementedError
#print(f"{input.shape=} {weight.shape=} {bias.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}")
return wrap(unwrap(input).conv2d(unwrap(weight), unwrap(bias) if bias is not None else None,
groups=groups, stride=stride, dilation=dilation, padding=padding))
#raise NotImplementedError("need convolution")
@torch.library.impl("aten::_copy_from", "privateuseone")
def _copy_from(src, dest):
if str(src.device) == "tiny" and str(dest.device) == "tiny":
unwrap(dest).replace(unwrap(src), allow_shape_mismatch=True)
elif str(src.device) == "tiny" and str(dest.device) == "cpu":
dest[:] = torch.from_numpy(unwrap(src).numpy())
# TODO: is there a better way?
dest.resize_(src.numel()).resize_(src.shape)
dest.copy_(torch.from_numpy(unwrap(src).numpy()))
elif str(src.device) == "cpu" and str(dest.device) == "tiny":
unwrap(dest).assign(Tensor(src.numpy()))
else:
raise NotImplementedError(f"can't copy from {src.device} -> {dest.device}")
@torch.library.impl("aten::exp2.out", "privateuseone")
def exp2_out(x, out): unwrap(out).replace(unwrap(x).exp2(), allow_shape_mismatch=True)
@torch.library.impl("aten::cat.out", "privateuseone")
def cat_out(tensors, out, dim=0): unwrap(out).replace(Tensor.cat(*[unwrap(x) for x in tensors], dim=dim), allow_shape_mismatch=True)
@torch.library.impl("aten::ceil.out", "privateuseone")
def ceil_out(x, out): unwrap(out).replace(unwrap(x).ceil(), allow_shape_mismatch=True)
@torch.library.impl("aten::index.Tensor", "privateuseone")
def index_tensor(x, y): return wrap(unwrap(x)[y[0].tolist()])
@torch.library.impl("aten::abs.out", "privateuseone")
def abs_out(x, out): unwrap(out).replace(unwrap(x).abs(), allow_shape_mismatch=True)
tiny_backend = {
"aten.view": Tensor.reshape,
"aten.add.Tensor": Tensor.add,
"aten.sub.Tensor": Tensor.sub,
"aten.mul.Tensor": Tensor.mul,
"aten.div.Tensor": Tensor.div,
"aten.add_.Tensor": lambda x,y: x.assign(x.add(y)),
"aten.pow.Tensor_Scalar": Tensor.pow,
"aten.bitwise_and.Tensor": Tensor.bitwise_and,
"aten.eq.Tensor": Tensor.eq, "aten.eq.Scalar": Tensor.eq,
"aten.ne.Tensor": Tensor.ne, "aten.ne.Scalar": Tensor.ne,
"aten.gt.Tensor": Tensor.__gt__, "aten.gt.Scalar": Tensor.__gt__,
"aten.lt.Tensor": Tensor.__lt__, "aten.lt.Scalar": Tensor.__lt__,
"aten.exp2": Tensor.exp2,
"aten.min": Tensor.min,
"aten.max": Tensor.max,
"aten.relu": Tensor.relu,
"aten.mean": Tensor.mean,
"aten.neg": Tensor.neg,
"aten.mm": Tensor.matmul,
}
@torch.library.impl("aten::bitwise_and.Tensor", "privateuseone")
def bitwise_and_tensor(x, y): return wrap(unwrap(x) & unwrap(y))
# there's earlier things to hook here
#"aten.add.out": lambda x,y,out: out.replace(x+y, allow_shape_mismatch=True),
#"aten.abs.out": lambda x,out: out.replace(x.abs(), allow_shape_mismatch=True),
#"aten.ceil.out": lambda x,out: out.replace(x.ceil(), allow_shape_mismatch=True),
#"aten.exp2.out": lambda x,out: out.replace(x.exp2(), allow_shape_mismatch=True),
@torch.library.impl("aten::add.Tensor", "privateuseone")
def add_tensor(x, y): return wrap(unwrap(x) + unwrap(y))
def wrap_fxn(k,f):
def nf(*args, **kwargs):
#print(k, len(args), kwargs.keys())
args = [unwrap(x) if isinstance(x, torch.Tensor) else x for x in args]
kwargs = {k:unwrap(v) if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()}
return wrap(f(*args, **kwargs))
return nf
@torch.library.impl("aten::mul.Tensor", "privateuseone")
def mul_tensor(x, y): return wrap(unwrap(x) * unwrap(y))
for k,v in tiny_backend.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_fxn(k,v))
@torch.library.impl("aten::div.Tensor", "privateuseone")
def div_tensor(x, y): return wrap(unwrap(x) / unwrap(y))
@torch.library.impl("aten::eq.Tensor", "privateuseone")
def eq_tensor(x, y): return wrap(unwrap(x).eq(unwrap(y)))
@torch.library.impl("aten::ne.Tensor", "privateuseone")
def ne_tensor(x, y): return wrap(unwrap(x).ne(unwrap(y)))
@torch.library.impl("aten::ne.Scalar", "privateuseone")
def ne_scalar(x, y): return wrap(unwrap(x).ne(y))
@torch.library.impl("aten::gt.Scalar", "privateuseone")
def gt_scalar(x, y): return wrap(unwrap(x) > y)
if getenv("TORCH_DEBUG"):
from torch.utils._python_dispatch import TorchDispatchMode
class DispatchLog(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args, kwargs=None):
#print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
print(f"Dispatch Log: {func}")
return func(*args, **(kwargs or {}))
DispatchLog().__enter__()

View File

@@ -0,0 +1,49 @@
from tinygrad import Tensor, dtypes
import torch, contextlib
from torch.utils._python_dispatch import TorchDispatchMode
torch_to_tiny_dtype = {
torch.float32: dtypes.float32,
torch.float64: dtypes.float64,
torch.int32: dtypes.int32,
torch.int64: dtypes.int64,
torch.bool: dtypes.bool,
}
def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
return TTensor(Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype]))
tiny_backend = {
"aten.empty.memory_format": empty_memory_format,
"aten.view.default": lambda x,sz: TTensor(x.tiny.reshape(sz)),
"aten.abs.default": lambda x: TTensor(x.tiny.abs()),
"aten.eq.Tensor": lambda x,y: TTensor(x.tiny == y.tiny),
"aten.bitwise_and.Tensor": lambda x,y: TTensor(x.tiny & y.tiny),
"aten.ne.Scalar": lambda x,y: TTensor(x.tiny != y),
"aten.mul.Tensor": lambda x,y: TTensor(x.tiny * y.tiny),
"aten.masked_select.default": lambda x,y: TTensor(Tensor(x.tiny.numpy()[y.tiny.numpy()])),
}
class TTensor(torch.Tensor):
tiny: Tensor
context = contextlib.nullcontext
@staticmethod
def __new__(cls, tiny, *args, **kwargs):
out = torch.Tensor._make_wrapper_subclass(cls, tiny.shape)
torch._C._set_throw_on_mutable_data_ptr(out)
out.tiny = tiny
return out
def __repr__(self): return super().__repr__(tensor_contents=f"{self.tiny}")
def __torch_dispatch__(cls, func, types, args, kwargs=None):
print(f"Dispatch Log: {func}(*{[type(x) for x in args]}, **{kwargs.keys()})")
#print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
new_func = tiny_backend.get(str(func), None)
if new_func is None: raise NotImplementedError(f"add support for {func}")
return new_func(*args, **(kwargs or {}))
class Dispatcher(TorchDispatchMode): __torch_dispatch__ = TTensor.__torch_dispatch__
Dispatcher().__enter__()
if __name__ == "__main__":
a = torch.empty((4,), dtype=torch.int)

View File

@@ -2,37 +2,48 @@
import unittest
import torch
import numpy as np
import extra.torch_backend.backend # "tiny" backend is installed
from tinygrad.helpers import getenv
if getenv("TINY_BACKEND2"):
import extra.torch_backend.backend2
device = "cpu"
else:
import extra.torch_backend.backend
device = "tiny"
class TestTorchBackend(unittest.TestCase):
def test_numpy_ones(self):
a = torch.ones(4, device="tiny")
a = torch.ones(4, device=device)
np.testing.assert_equal(a.cpu().numpy(), [1,1,1,1])
def test_numpy_ones(self):
a = torch.ones(4, dtype=torch.int32, device="tiny")
a = torch.ones(4, dtype=torch.int32, device=device)
assert a.dtype == torch.int32
np.testing.assert_equal(a.cpu().numpy(), [1,1,1,1])
def test_plus(self):
a = torch.ones(4, device="tiny")
b = torch.ones(4, device="tiny")
a = torch.ones(4, device=device)
b = torch.ones(4, device=device)
c = a+b
np.testing.assert_equal(c.cpu().numpy(), [2,2,2,2])
def test_exp2(qself):
a = torch.ones(4, device=device)
b = a.exp2()
print(b)
def test_eq(self):
a = torch.ones(4, device="tiny")
b = torch.ones(4, device="tiny")
a = torch.ones(4, device=device)
b = torch.ones(4, device=device)
c = a == b
print(c.cpu().numpy())
def test_isfinite(self):
a = torch.ones(4, device="tiny")
np.testing.assert_equal(torch.isfinite(a), [True, True, True, True])
a = torch.ones(4, device=device)
np.testing.assert_equal(torch.isfinite(a).cpu().numpy(), [True, True, True, True])
# TODO: why
def test_str(self):
a = torch.ones(4, device="tiny")
a = torch.ones(4, device=device)
print(str(a))
if __name__ == "__main__":

View File

@@ -10,6 +10,17 @@ C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl<DeviceType::
}
}
struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
// NOTE: no idea what this is
bool hasPrimaryContext(c10::DeviceIndex device_index) const override { return true; }
};
int register_hook() {
at::RegisterPrivateUse1HooksInterface(new OpenRegHooksInterface());
return 0;
}
int temp_register_hook = register_hook();
// code from chatgpt
struct GILSafeDeleter {
void operator()(PyObject* ptr) const {
@@ -56,6 +67,8 @@ static caffe2::TypeMeta dtypeFromName(const std::string &dtype_name) {
} else if (dtype_name == "int") { return caffe2::TypeMeta::Make<int32_t>();
} else if (dtype_name == "long") { return caffe2::TypeMeta::Make<int64_t>();
} else if (dtype_name == "bool") { return caffe2::TypeMeta::Make<bool>();
} else if (dtype_name == "char") { return caffe2::TypeMeta::Make<char>();
} else if (dtype_name == "unsigned char") { return caffe2::TypeMeta::Make<unsigned char>();
}
throw std::runtime_error("Unsupported dtype: " + dtype_name);
}

View File

@@ -8,6 +8,10 @@ from tinygrad import Tensor, Device, dtypes
from tinygrad.tensor import _to_np_dtype
from tinygrad.device import is_dtype_supported
if getenv("TINY_BACKEND"):
import extra.torch_backend.backend # noqa: F401 # pylint: disable=unused-import
torch.set_default_device("tiny")
if CI:
warnings.filterwarnings("ignore", message="Non-empty compiler output encountered")
@@ -46,8 +50,8 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
if DEBUG >= 6:
np.set_printoptions(linewidth=200, suppress=True)
print(ret.numpy())
print(out.detach().numpy())
compare("forward pass", ret.numpy(), out.detach().numpy(), atol=atol, rtol=rtol)
print(out.detach().cpu().numpy())
compare("forward pass", ret.numpy(), out.detach().cpu().numpy(), atol=atol, rtol=rtol)
torch_fbp, tinygrad_fbp = np.nan, np.nan
if not forward_only and not FORWARD_ONLY:
@@ -65,7 +69,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
tinygrad_fbp = time.monotonic() - st
for i, (t, tt_grad) in enumerate(zip(ts, tst_grads)):
compare(f"backward pass tensor {i}", tt_grad.numpy(), t.grad.detach().numpy(), atol=grad_atol, rtol=grad_rtol)
compare(f"backward pass tensor {i}", tt_grad.numpy(), t.grad.detach().cpu().numpy(), atol=grad_atol, rtol=grad_rtol)
"""
(ret+1).square().mean().backward()
@@ -90,7 +94,7 @@ def prepare_test_op(low, high, shps, vals, forward_only=False):
for i in range(len(ts)):
# NOTE: torch default int64 for python ints input
if ts[i].dtype == torch.int64: ts[i] = ts[i].type(torch.int32)
tst = [Tensor(x.detach().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts]
tst = [Tensor(x.detach().cpu().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts]
return ts, tst
class TestOps(unittest.TestCase):