mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
different way to write torch backend (#9197)
* different way to write torch backend * both backends * more work * simpler code * more work * test both * imply unwrap/wrap * FORWARD_ONLY=1 TINY_BACKEND=1 python3 test/test_ops.py TestOps.test_add works * ready to start making test_ops work in torch backend * backward pass, TINY_BACKEND=1 python3 test/test_ops.py TestOps.test_add works * FORWARD_ONLY=1 TINY_BACKEND=1 python3 test/test_ops.py TestOps.test_simple_conv2d works * matmul backward is broken with as_strided
This commit is contained in:
@@ -26,7 +26,10 @@ class Model(nn.Module):
|
||||
return self.lin(torch.flatten(x, 1))
|
||||
|
||||
if __name__ == "__main__":
|
||||
if getenv("TINY_BACKEND"):
|
||||
if getenv("TINY_BACKEND2"):
|
||||
import extra.torch_backend.backend2
|
||||
device = torch.device("cpu")
|
||||
elif getenv("TINY_BACKEND"):
|
||||
import extra.torch_backend.backend
|
||||
device = torch.device("tiny")
|
||||
else:
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.helpers import DEBUG, getenv
|
||||
import torch, pathlib
|
||||
torch.autograd.grad_mode.set_multithreading_enabled(False)
|
||||
|
||||
# https://pytorch.org/docs/stable/torch.compiler_ir.html
|
||||
|
||||
# TODO: don't replicate this in cpp
|
||||
torch_to_tiny_dtype = {
|
||||
torch.float32: dtypes.float32,
|
||||
torch.float64: dtypes.float64,
|
||||
torch.uint8: dtypes.uint8,
|
||||
torch.int8: dtypes.int8,
|
||||
torch.int32: dtypes.int32,
|
||||
torch.int64: dtypes.int64,
|
||||
torch.bool: dtypes.bool,
|
||||
@@ -13,24 +18,18 @@ torch_to_tiny_dtype = {
|
||||
|
||||
import torch.utils.cpp_extension
|
||||
mod = torch.utils.cpp_extension.load(name="custom_device_extension", sources=[pathlib.Path(__file__).parent / "wrapped_tensor.cpp"])
|
||||
wrap, unwrap = mod.wrap, mod.unwrap
|
||||
def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x)
|
||||
def unwrap(x:torch.Tensor) -> Tensor:
|
||||
assert isinstance(x, torch.Tensor), f"x isn't {type(x)}"
|
||||
return mod.unwrap(x)
|
||||
class TinyBackend: pass
|
||||
torch.utils.rename_privateuse1_backend("tiny")
|
||||
torch._register_device_module("tiny", TinyBackend)
|
||||
torch.utils.generate_methods_for_privateuse1_backend()
|
||||
|
||||
@torch.library.impl("aten::view", "privateuseone")
|
||||
def view(x, sz): return mod.wrap(mod.unwrap(x).reshape(sz))
|
||||
|
||||
@torch.library.impl("aten::min", "privateuseone")
|
||||
def min(x): return mod.wrap(mod.unwrap(x).min())
|
||||
|
||||
@torch.library.impl("aten::max", "privateuseone")
|
||||
def max(x): return mod.wrap(mod.unwrap(x).max())
|
||||
|
||||
@torch.library.impl("aten::zero_", "privateuseone")
|
||||
def zero_(x):
|
||||
tt = mod.unwrap(x)
|
||||
tt = unwrap(x)
|
||||
tt.replace(tt.zeros_like())
|
||||
|
||||
@torch.library.impl("aten::fill_.Scalar", "privateuseone")
|
||||
@@ -51,11 +50,14 @@ def as_strided(tensor, size, stride, storage_offset=None):
|
||||
if size == [] and storage_offset is not None:
|
||||
# TODO: is this right?
|
||||
return wrap(unwrap(tensor).flatten()[storage_offset:storage_offset+1].reshape(()))
|
||||
print(tensor.shape, size, stride, storage_offset)
|
||||
# broadcast
|
||||
if len(tensor.shape) == 0: return wrap(unwrap(tensor).reshape((1,)*len(size)).expand(size))
|
||||
print("******* NOTE: this as_strided is wrong ***********\n", tensor.shape, size, stride, storage_offset)
|
||||
return wrap(Tensor.zeros(*size))
|
||||
raise NotImplementedError("fix as_strided")
|
||||
|
||||
@torch.library.impl("aten::empty_strided", "privateuseone")
|
||||
def empty_strided(size, stride, dtype, layout, device, pin_memory):
|
||||
def empty_strided(size, stride, dtype, layout, device, pin_memory=False):
|
||||
if DEBUG >= 2: print(f"empty_strided {size=} {stride=} {dtype=} {layout=} {device=} {pin_memory=}")
|
||||
ret = Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype])
|
||||
return wrap(ret)
|
||||
@@ -68,49 +70,73 @@ def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=F
|
||||
|
||||
@torch.library.impl("aten::convolution_overrideable", "privateuseone")
|
||||
def convolution_overrideable(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups):
|
||||
print(input, weight, bias)
|
||||
raise NotImplementedError
|
||||
#print(f"{input.shape=} {weight.shape=} {bias.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}")
|
||||
return wrap(unwrap(input).conv2d(unwrap(weight), unwrap(bias) if bias is not None else None,
|
||||
groups=groups, stride=stride, dilation=dilation, padding=padding))
|
||||
#raise NotImplementedError("need convolution")
|
||||
|
||||
@torch.library.impl("aten::_copy_from", "privateuseone")
|
||||
def _copy_from(src, dest):
|
||||
if str(src.device) == "tiny" and str(dest.device) == "tiny":
|
||||
unwrap(dest).replace(unwrap(src), allow_shape_mismatch=True)
|
||||
elif str(src.device) == "tiny" and str(dest.device) == "cpu":
|
||||
dest[:] = torch.from_numpy(unwrap(src).numpy())
|
||||
# TODO: is there a better way?
|
||||
dest.resize_(src.numel()).resize_(src.shape)
|
||||
dest.copy_(torch.from_numpy(unwrap(src).numpy()))
|
||||
elif str(src.device) == "cpu" and str(dest.device) == "tiny":
|
||||
unwrap(dest).assign(Tensor(src.numpy()))
|
||||
else:
|
||||
raise NotImplementedError(f"can't copy from {src.device} -> {dest.device}")
|
||||
|
||||
@torch.library.impl("aten::exp2.out", "privateuseone")
|
||||
def exp2_out(x, out): unwrap(out).replace(unwrap(x).exp2(), allow_shape_mismatch=True)
|
||||
@torch.library.impl("aten::cat.out", "privateuseone")
|
||||
def cat_out(tensors, out, dim=0): unwrap(out).replace(Tensor.cat(*[unwrap(x) for x in tensors], dim=dim), allow_shape_mismatch=True)
|
||||
|
||||
@torch.library.impl("aten::ceil.out", "privateuseone")
|
||||
def ceil_out(x, out): unwrap(out).replace(unwrap(x).ceil(), allow_shape_mismatch=True)
|
||||
@torch.library.impl("aten::index.Tensor", "privateuseone")
|
||||
def index_tensor(x, y): return wrap(unwrap(x)[y[0].tolist()])
|
||||
|
||||
@torch.library.impl("aten::abs.out", "privateuseone")
|
||||
def abs_out(x, out): unwrap(out).replace(unwrap(x).abs(), allow_shape_mismatch=True)
|
||||
tiny_backend = {
|
||||
"aten.view": Tensor.reshape,
|
||||
"aten.add.Tensor": Tensor.add,
|
||||
"aten.sub.Tensor": Tensor.sub,
|
||||
"aten.mul.Tensor": Tensor.mul,
|
||||
"aten.div.Tensor": Tensor.div,
|
||||
"aten.add_.Tensor": lambda x,y: x.assign(x.add(y)),
|
||||
"aten.pow.Tensor_Scalar": Tensor.pow,
|
||||
"aten.bitwise_and.Tensor": Tensor.bitwise_and,
|
||||
"aten.eq.Tensor": Tensor.eq, "aten.eq.Scalar": Tensor.eq,
|
||||
"aten.ne.Tensor": Tensor.ne, "aten.ne.Scalar": Tensor.ne,
|
||||
"aten.gt.Tensor": Tensor.__gt__, "aten.gt.Scalar": Tensor.__gt__,
|
||||
"aten.lt.Tensor": Tensor.__lt__, "aten.lt.Scalar": Tensor.__lt__,
|
||||
"aten.exp2": Tensor.exp2,
|
||||
"aten.min": Tensor.min,
|
||||
"aten.max": Tensor.max,
|
||||
"aten.relu": Tensor.relu,
|
||||
"aten.mean": Tensor.mean,
|
||||
"aten.neg": Tensor.neg,
|
||||
"aten.mm": Tensor.matmul,
|
||||
}
|
||||
|
||||
@torch.library.impl("aten::bitwise_and.Tensor", "privateuseone")
|
||||
def bitwise_and_tensor(x, y): return wrap(unwrap(x) & unwrap(y))
|
||||
# there's earlier things to hook here
|
||||
#"aten.add.out": lambda x,y,out: out.replace(x+y, allow_shape_mismatch=True),
|
||||
#"aten.abs.out": lambda x,out: out.replace(x.abs(), allow_shape_mismatch=True),
|
||||
#"aten.ceil.out": lambda x,out: out.replace(x.ceil(), allow_shape_mismatch=True),
|
||||
#"aten.exp2.out": lambda x,out: out.replace(x.exp2(), allow_shape_mismatch=True),
|
||||
|
||||
@torch.library.impl("aten::add.Tensor", "privateuseone")
|
||||
def add_tensor(x, y): return wrap(unwrap(x) + unwrap(y))
|
||||
def wrap_fxn(k,f):
|
||||
def nf(*args, **kwargs):
|
||||
#print(k, len(args), kwargs.keys())
|
||||
args = [unwrap(x) if isinstance(x, torch.Tensor) else x for x in args]
|
||||
kwargs = {k:unwrap(v) if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()}
|
||||
return wrap(f(*args, **kwargs))
|
||||
return nf
|
||||
|
||||
@torch.library.impl("aten::mul.Tensor", "privateuseone")
|
||||
def mul_tensor(x, y): return wrap(unwrap(x) * unwrap(y))
|
||||
for k,v in tiny_backend.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_fxn(k,v))
|
||||
|
||||
@torch.library.impl("aten::div.Tensor", "privateuseone")
|
||||
def div_tensor(x, y): return wrap(unwrap(x) / unwrap(y))
|
||||
|
||||
@torch.library.impl("aten::eq.Tensor", "privateuseone")
|
||||
def eq_tensor(x, y): return wrap(unwrap(x).eq(unwrap(y)))
|
||||
|
||||
@torch.library.impl("aten::ne.Tensor", "privateuseone")
|
||||
def ne_tensor(x, y): return wrap(unwrap(x).ne(unwrap(y)))
|
||||
|
||||
@torch.library.impl("aten::ne.Scalar", "privateuseone")
|
||||
def ne_scalar(x, y): return wrap(unwrap(x).ne(y))
|
||||
|
||||
@torch.library.impl("aten::gt.Scalar", "privateuseone")
|
||||
def gt_scalar(x, y): return wrap(unwrap(x) > y)
|
||||
if getenv("TORCH_DEBUG"):
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
class DispatchLog(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args, kwargs=None):
|
||||
#print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
|
||||
print(f"Dispatch Log: {func}")
|
||||
return func(*args, **(kwargs or {}))
|
||||
DispatchLog().__enter__()
|
||||
|
||||
49
extra/torch_backend/backend2.py
Normal file
49
extra/torch_backend/backend2.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from tinygrad import Tensor, dtypes
|
||||
import torch, contextlib
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
torch_to_tiny_dtype = {
|
||||
torch.float32: dtypes.float32,
|
||||
torch.float64: dtypes.float64,
|
||||
torch.int32: dtypes.int32,
|
||||
torch.int64: dtypes.int64,
|
||||
torch.bool: dtypes.bool,
|
||||
}
|
||||
|
||||
def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
|
||||
return TTensor(Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype]))
|
||||
|
||||
tiny_backend = {
|
||||
"aten.empty.memory_format": empty_memory_format,
|
||||
"aten.view.default": lambda x,sz: TTensor(x.tiny.reshape(sz)),
|
||||
"aten.abs.default": lambda x: TTensor(x.tiny.abs()),
|
||||
"aten.eq.Tensor": lambda x,y: TTensor(x.tiny == y.tiny),
|
||||
"aten.bitwise_and.Tensor": lambda x,y: TTensor(x.tiny & y.tiny),
|
||||
"aten.ne.Scalar": lambda x,y: TTensor(x.tiny != y),
|
||||
"aten.mul.Tensor": lambda x,y: TTensor(x.tiny * y.tiny),
|
||||
"aten.masked_select.default": lambda x,y: TTensor(Tensor(x.tiny.numpy()[y.tiny.numpy()])),
|
||||
}
|
||||
|
||||
class TTensor(torch.Tensor):
|
||||
tiny: Tensor
|
||||
context = contextlib.nullcontext
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, tiny, *args, **kwargs):
|
||||
out = torch.Tensor._make_wrapper_subclass(cls, tiny.shape)
|
||||
torch._C._set_throw_on_mutable_data_ptr(out)
|
||||
out.tiny = tiny
|
||||
return out
|
||||
def __repr__(self): return super().__repr__(tensor_contents=f"{self.tiny}")
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs=None):
|
||||
print(f"Dispatch Log: {func}(*{[type(x) for x in args]}, **{kwargs.keys()})")
|
||||
#print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
|
||||
new_func = tiny_backend.get(str(func), None)
|
||||
if new_func is None: raise NotImplementedError(f"add support for {func}")
|
||||
return new_func(*args, **(kwargs or {}))
|
||||
|
||||
class Dispatcher(TorchDispatchMode): __torch_dispatch__ = TTensor.__torch_dispatch__
|
||||
Dispatcher().__enter__()
|
||||
|
||||
if __name__ == "__main__":
|
||||
a = torch.empty((4,), dtype=torch.int)
|
||||
@@ -2,37 +2,48 @@
|
||||
import unittest
|
||||
import torch
|
||||
import numpy as np
|
||||
import extra.torch_backend.backend # "tiny" backend is installed
|
||||
from tinygrad.helpers import getenv
|
||||
if getenv("TINY_BACKEND2"):
|
||||
import extra.torch_backend.backend2
|
||||
device = "cpu"
|
||||
else:
|
||||
import extra.torch_backend.backend
|
||||
device = "tiny"
|
||||
|
||||
class TestTorchBackend(unittest.TestCase):
|
||||
def test_numpy_ones(self):
|
||||
a = torch.ones(4, device="tiny")
|
||||
a = torch.ones(4, device=device)
|
||||
np.testing.assert_equal(a.cpu().numpy(), [1,1,1,1])
|
||||
|
||||
def test_numpy_ones(self):
|
||||
a = torch.ones(4, dtype=torch.int32, device="tiny")
|
||||
a = torch.ones(4, dtype=torch.int32, device=device)
|
||||
assert a.dtype == torch.int32
|
||||
np.testing.assert_equal(a.cpu().numpy(), [1,1,1,1])
|
||||
|
||||
def test_plus(self):
|
||||
a = torch.ones(4, device="tiny")
|
||||
b = torch.ones(4, device="tiny")
|
||||
a = torch.ones(4, device=device)
|
||||
b = torch.ones(4, device=device)
|
||||
c = a+b
|
||||
np.testing.assert_equal(c.cpu().numpy(), [2,2,2,2])
|
||||
|
||||
def test_exp2(qself):
|
||||
a = torch.ones(4, device=device)
|
||||
b = a.exp2()
|
||||
print(b)
|
||||
|
||||
def test_eq(self):
|
||||
a = torch.ones(4, device="tiny")
|
||||
b = torch.ones(4, device="tiny")
|
||||
a = torch.ones(4, device=device)
|
||||
b = torch.ones(4, device=device)
|
||||
c = a == b
|
||||
print(c.cpu().numpy())
|
||||
|
||||
def test_isfinite(self):
|
||||
a = torch.ones(4, device="tiny")
|
||||
np.testing.assert_equal(torch.isfinite(a), [True, True, True, True])
|
||||
a = torch.ones(4, device=device)
|
||||
np.testing.assert_equal(torch.isfinite(a).cpu().numpy(), [True, True, True, True])
|
||||
|
||||
# TODO: why
|
||||
def test_str(self):
|
||||
a = torch.ones(4, device="tiny")
|
||||
a = torch.ones(4, device=device)
|
||||
print(str(a))
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -10,6 +10,17 @@ C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl<DeviceType::
|
||||
}
|
||||
}
|
||||
|
||||
struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
|
||||
// NOTE: no idea what this is
|
||||
bool hasPrimaryContext(c10::DeviceIndex device_index) const override { return true; }
|
||||
};
|
||||
|
||||
int register_hook() {
|
||||
at::RegisterPrivateUse1HooksInterface(new OpenRegHooksInterface());
|
||||
return 0;
|
||||
}
|
||||
int temp_register_hook = register_hook();
|
||||
|
||||
// code from chatgpt
|
||||
struct GILSafeDeleter {
|
||||
void operator()(PyObject* ptr) const {
|
||||
@@ -56,6 +67,8 @@ static caffe2::TypeMeta dtypeFromName(const std::string &dtype_name) {
|
||||
} else if (dtype_name == "int") { return caffe2::TypeMeta::Make<int32_t>();
|
||||
} else if (dtype_name == "long") { return caffe2::TypeMeta::Make<int64_t>();
|
||||
} else if (dtype_name == "bool") { return caffe2::TypeMeta::Make<bool>();
|
||||
} else if (dtype_name == "char") { return caffe2::TypeMeta::Make<char>();
|
||||
} else if (dtype_name == "unsigned char") { return caffe2::TypeMeta::Make<unsigned char>();
|
||||
}
|
||||
throw std::runtime_error("Unsupported dtype: " + dtype_name);
|
||||
}
|
||||
|
||||
@@ -8,6 +8,10 @@ from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
if getenv("TINY_BACKEND"):
|
||||
import extra.torch_backend.backend # noqa: F401 # pylint: disable=unused-import
|
||||
torch.set_default_device("tiny")
|
||||
|
||||
if CI:
|
||||
warnings.filterwarnings("ignore", message="Non-empty compiler output encountered")
|
||||
|
||||
@@ -46,8 +50,8 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
|
||||
if DEBUG >= 6:
|
||||
np.set_printoptions(linewidth=200, suppress=True)
|
||||
print(ret.numpy())
|
||||
print(out.detach().numpy())
|
||||
compare("forward pass", ret.numpy(), out.detach().numpy(), atol=atol, rtol=rtol)
|
||||
print(out.detach().cpu().numpy())
|
||||
compare("forward pass", ret.numpy(), out.detach().cpu().numpy(), atol=atol, rtol=rtol)
|
||||
|
||||
torch_fbp, tinygrad_fbp = np.nan, np.nan
|
||||
if not forward_only and not FORWARD_ONLY:
|
||||
@@ -65,7 +69,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
|
||||
tinygrad_fbp = time.monotonic() - st
|
||||
|
||||
for i, (t, tt_grad) in enumerate(zip(ts, tst_grads)):
|
||||
compare(f"backward pass tensor {i}", tt_grad.numpy(), t.grad.detach().numpy(), atol=grad_atol, rtol=grad_rtol)
|
||||
compare(f"backward pass tensor {i}", tt_grad.numpy(), t.grad.detach().cpu().numpy(), atol=grad_atol, rtol=grad_rtol)
|
||||
|
||||
"""
|
||||
(ret+1).square().mean().backward()
|
||||
@@ -90,7 +94,7 @@ def prepare_test_op(low, high, shps, vals, forward_only=False):
|
||||
for i in range(len(ts)):
|
||||
# NOTE: torch default int64 for python ints input
|
||||
if ts[i].dtype == torch.int64: ts[i] = ts[i].type(torch.int32)
|
||||
tst = [Tensor(x.detach().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts]
|
||||
tst = [Tensor(x.detach().cpu().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts]
|
||||
return ts, tst
|
||||
|
||||
class TestOps(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user