mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
full fix for as_strided in torch backend (#9257)
* fixes from chargpt for torch backend * shrink support * add stride support * comment cleanup * a few more * work * import the stream hack * llvm multi auto
This commit is contained in:
7
.github/workflows/test.yml
vendored
7
.github/workflows/test.yml
vendored
@@ -155,7 +155,8 @@ jobs:
|
||||
with:
|
||||
key: torch-backend-pillow-torchvision-et-pt
|
||||
deps: testing_minimal
|
||||
pydeps: "pillow torchvision expecttest pytest"
|
||||
pydeps: "pillow torchvision expecttest"
|
||||
llvm: 'true'
|
||||
- name: Install ninja
|
||||
run: |
|
||||
sudo apt update || true
|
||||
@@ -169,11 +170,11 @@ jobs:
|
||||
- name: Test one op in torch tests
|
||||
run: PYTHONPATH=. DEBUG=2 python3 extra/torch_backend/torch_tests.py TestTinyBackendPRIVATEUSE1.test_unary_log_tiny_float32
|
||||
- name: Test Ops with TINY_BACKEND (expect failure)
|
||||
run: PYTHONPATH=. TINY_BACKEND=1 python3 -m pytest test/test_ops.py || true
|
||||
run: PYTHONPATH=. LLVM=1 LLVMOPT=0 TINY_BACKEND=1 python3 -m pytest -n auto test/test_ops.py || true
|
||||
- name: Test beautiful_mnist in torch with TINY_BACKEND (expect failure)
|
||||
run: PYTHONPATH=. TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py || true
|
||||
- name: Test some torch tests (expect failure)
|
||||
run: PYTHONPATH=. pytest extra/torch_backend/torch_tests.py -v --tb=no || true
|
||||
run: PYTHONPATH=. python3 -m pytest extra/torch_backend/torch_tests.py -v --tb=no || true
|
||||
|
||||
tc:
|
||||
name: Tensor Core tests
|
||||
|
||||
@@ -6,15 +6,16 @@ from extra.optimization.helpers import load_worlds, ast_str_to_ast
|
||||
from tinygrad.helpers import prod, tqdm
|
||||
from tinygrad.ops import UOp, Ops
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.ops import sym_infer, Node
|
||||
from tinygrad.ops import sym_infer
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
|
||||
|
||||
def apply_mop(st: ShapeTracker, mop_arg: Tuple[MovementOps, Tuple]) -> ShapeTracker:
|
||||
def apply_mop(st: Tensor|ShapeTracker, mop_arg: Tuple[MovementOps, Tuple]) -> ShapeTracker:
|
||||
mop, arg = mop_arg
|
||||
if mop == MovementOps.RESHAPE:
|
||||
# shapetracker doesn't allow flattening with -1 but required for MovementOps.RESHAPE
|
||||
if arg == (-1,): return st.reshape((prod(st.views[-1].shape),))
|
||||
if arg == (-1,): return st.reshape((prod(st.shape),))
|
||||
return st.reshape(arg)
|
||||
if mop == MovementOps.PERMUTE: return st.permute(arg)
|
||||
if mop == MovementOps.EXPAND:
|
||||
@@ -22,7 +23,9 @@ def apply_mop(st: ShapeTracker, mop_arg: Tuple[MovementOps, Tuple]) -> ShapeTrac
|
||||
return st.expand(arg)
|
||||
if mop == MovementOps.PAD: return st.pad(arg)
|
||||
if mop == MovementOps.SHRINK: return st.shrink(arg)
|
||||
if mop == MovementOps.STRIDE: return st.stride(arg)
|
||||
if mop == MovementOps.STRIDE:
|
||||
assert all(x in [-1, 1] for x in arg)
|
||||
return st.flip(tuple(i for i,x in enumerate(arg) if x == -1))
|
||||
raise ValueError("invalid mop")
|
||||
|
||||
def make_scratch_st(st: ShapeTracker) -> ShapeTracker:
|
||||
@@ -36,7 +39,7 @@ def to_movement_ops(st: ShapeTracker) -> List[Tuple[MovementOps, Tuple]]:
|
||||
offset = v.offset + sum(st*(s-1) for s,st in zip(real_shape, v.strides) if st<0)
|
||||
real_offset = offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
|
||||
real_real_shape = [s for s,st in zip(real_shape, v.strides) if st]
|
||||
strides: List[Node|int] = [abs(st) if isinstance(st,int) else st for st in v.strides if st]
|
||||
strides: List[int] = [abs(st) if isinstance(st,int) else st for st in v.strides if st]
|
||||
buffer_size = sum((s-1)*st for s,st in zip(real_real_shape,strides)) + 1
|
||||
if i: buffer_size = prod(st.views[i-1].shape) - real_offset
|
||||
def sort_by_strides(shape, strides): return sorted(zip(shape, strides), key=lambda k: (k[1],-k[0]), reverse=True), sorted(range(len(strides)), key=lambda k: (strides[k],-real_real_shape[k]), reverse=True)
|
||||
@@ -80,9 +83,12 @@ def to_movement_ops(st: ShapeTracker) -> List[Tuple[MovementOps, Tuple]]:
|
||||
if scratch_st in seen:
|
||||
ret = seen[scratch_st][:]
|
||||
else:
|
||||
ret.append(mop_arg)
|
||||
if len(ret) and ret[-1][0] == MovementOps.RESHAPE and mop_arg[0] == MovementOps.RESHAPE:
|
||||
ret[-1] = mop_arg
|
||||
else:
|
||||
if mop_arg == (MovementOps.RESHAPE, -1): mop_arg = (MovementOps.RESHAPE, (prod(st.shape),))
|
||||
ret.append(mop_arg)
|
||||
seen[scratch_st] = ret[:]
|
||||
|
||||
return ret
|
||||
|
||||
def get_real_view(shape, strides, offset, mask):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.helpers import DEBUG, getenv, prod
|
||||
import torch.lib
|
||||
TORCH_DEBUG = getenv("TORCH_DEBUG")
|
||||
import torch, pathlib, math, operator
|
||||
torch.autograd.grad_mode.set_multithreading_enabled(False)
|
||||
@@ -38,33 +39,19 @@ def masked_select(self, mask):
|
||||
# err, bad
|
||||
return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()]))
|
||||
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from extra.to_movement_ops import to_movement_ops, apply_mop, MovementOps
|
||||
@torch.library.impl("aten::as_strided", "privateuseone")
|
||||
def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None):
|
||||
#return tensor.cpu().as_strided(size, stride).tiny()
|
||||
if TORCH_DEBUG >= 1: print("** NOTE: this as_strided might be wrong", tensor.shape, size, stride, storage_offset)
|
||||
|
||||
nz_strides = [st for s,st in zip(size, stride) if s != 1]
|
||||
decending_strides = all(x>=y for x,y in zip(nz_strides[:-1], nz_strides[1:]))
|
||||
|
||||
# this is reshape (squeeze/unsqueeze), strides must be in decending order
|
||||
if tuple(x for x in tensor.shape if x != 1) == tuple(x for x in size if x != 1) and decending_strides:
|
||||
return tensor.reshape(size)
|
||||
|
||||
# this is also expand, hit?
|
||||
if tensor.numel() == 1:
|
||||
assert all(x == 0 for x in stride)
|
||||
return wrap(unwrap(tensor).reshape([1]*len(size)).expand(size))
|
||||
|
||||
# this is expand
|
||||
if len(tensor.shape) == len(size) and all(x == y or x == 1 for x,y in zip(tensor.shape, size)) and decending_strides:
|
||||
return wrap(unwrap(tensor).expand(size))
|
||||
|
||||
# this is permute because we are flipping strides
|
||||
if len(tensor.shape) == 2 and tuple(tensor.shape)[::-1] == tuple(size) and stride == [0, 1]:
|
||||
return wrap(unwrap(tensor).permute(1,0))
|
||||
|
||||
#print(tensor.cpu().numpy())
|
||||
raise NotImplementedError(f"fix as_strided {tensor.shape} -> {size} {stride} {storage_offset}")
|
||||
# TODO: this is heavyweight
|
||||
st = ShapeTracker([View.create(tuple(tensor.shape)), View.create(tuple(size), tuple(stride), 0 if storage_offset is None else storage_offset)])
|
||||
ret = unwrap(tensor)
|
||||
if prod(size) == 1: return wrap(ret.flatten()[storage_offset].reshape(size))
|
||||
if TORCH_DEBUG >= 1: print("**** as_strided", tensor.shape, size, stride, st)
|
||||
mops = to_movement_ops(st)
|
||||
if mops[0] == (MovementOps.RESHAPE, tuple(tensor.shape)): mops = mops[1:]
|
||||
for mo in mops: ret = apply_mop(ret, mo)
|
||||
return wrap(ret)
|
||||
|
||||
@torch.library.impl("aten::empty_strided", "privateuseone")
|
||||
def empty_strided(size, stride, dtype, layout=None, device=None, pin_memory=False):
|
||||
@@ -85,6 +72,33 @@ def max_pool2d_with_indices(self:Tensor, kernel_size, stride=None, padding=0, di
|
||||
# TODO: this is wrong
|
||||
return (wrap(ret), wrap(Tensor.zeros_like(ret, dtype=dtypes.int64)))
|
||||
|
||||
@torch.library.impl("aten::arange", "privateuseone")
|
||||
def arange(end, dtype=None, device=None, pin_memory=None):
|
||||
return wrap(Tensor.arange(0, end, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())))
|
||||
|
||||
@torch.library.impl("aten::arange.start", "privateuseone")
|
||||
def arange_start(start, end, dtype=None, device=None, pin_memory=None):
|
||||
return wrap(Tensor.arange(start, end, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())))
|
||||
|
||||
@torch.library.impl("aten::arange.start_step", "privateuseone")
|
||||
def arange_start_step(start, end, step, dtype=None, device=None, pin_memory=None):
|
||||
return wrap(Tensor.arange(start, end, step, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())))
|
||||
|
||||
@torch.library.impl("aten::topk", "privateuseone")
|
||||
def topk(self, k, dim=-1, largest=True, sorted=True):
|
||||
# TODO: move to tinygrad
|
||||
t1, t2 = torch.topk(self.cpu(), k, dim, largest, sorted)
|
||||
return torch.return_types.topk((t1.tiny(), t2.tiny()))
|
||||
|
||||
@torch.library.impl("aten::_index_put_impl_", "privateuseone")
|
||||
def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False):
|
||||
# TODO: move to tinygrad
|
||||
return aten._index_put_impl_(self.cpu(), [x.cpu() for x in indices], values.cpu(), accumulate, unsafe).tiny()
|
||||
|
||||
@torch.library.impl("aten::index.Tensor", "privateuseone")
|
||||
def index_tensor(x, y):
|
||||
return aten.index(x.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in y]).tiny()
|
||||
|
||||
@torch.library.impl("aten::convolution_overrideable", "privateuseone")
|
||||
def convolution_overrideable(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups):
|
||||
if TORCH_DEBUG >= 1:
|
||||
@@ -103,14 +117,15 @@ def _copy_from(src, dest):
|
||||
dest.copy_(torch.from_numpy(unwrap(src).numpy()))
|
||||
elif str(src.device) == "cpu" and str(dest.device) == "tiny":
|
||||
unwrap(dest).assign(Tensor(src.numpy()))
|
||||
#if 0 in dest.stride():
|
||||
# print(dest.shape, dest.stride())
|
||||
# exit(0)
|
||||
else:
|
||||
raise NotImplementedError(f"can't copy from {src.device} -> {dest.device}")
|
||||
|
||||
@torch.library.impl("aten::cat.out", "privateuseone")
|
||||
def cat_out(tensors, out, dim=0): unwrap(out).replace(Tensor.cat(*[unwrap(x) for x in tensors], dim=dim), allow_shape_mismatch=True)
|
||||
|
||||
@torch.library.impl("aten::index.Tensor", "privateuseone")
|
||||
def index_tensor(x, y): return wrap(unwrap(x)[y[0].tolist()])
|
||||
def cat_out(tensors, dim=0, out=None):
|
||||
unwrap(out).replace(Tensor.cat(*[unwrap(x) for x in tensors], dim=dim), allow_shape_mismatch=True)
|
||||
|
||||
# register some decompositions
|
||||
from torch._decomp import get_decompositions
|
||||
@@ -118,6 +133,7 @@ aten = torch.ops.aten
|
||||
decomps = {
|
||||
"post_autograd": [
|
||||
aten.native_batch_norm, aten.native_batch_norm_backward,
|
||||
aten.native_layer_norm_backward,
|
||||
aten.addmm,
|
||||
aten.addcmul,
|
||||
aten.addcdiv,
|
||||
@@ -142,6 +158,10 @@ decomps = {
|
||||
aten.nan_to_num,
|
||||
aten.logit,
|
||||
aten.rsub,
|
||||
aten.index_select,
|
||||
aten.native_dropout, aten.native_dropout_backward,
|
||||
aten._softmax_backward_data, aten.embedding_dense_backward,
|
||||
aten.linalg_vector_norm,
|
||||
# activations
|
||||
aten.hardswish, aten.hardswish_backward,
|
||||
aten.hardtanh, aten.hardtanh_backward,
|
||||
@@ -194,11 +214,13 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
|
||||
"aten.add.out": lambda input,other,alpha=1: input+alpha*other,
|
||||
"aten.sub.out": lambda input,other,alpha=1: input-alpha*other, # NOTE: this is also needed to handle reverse
|
||||
"aten.mul.out": operator.mul,
|
||||
"aten.bmm.out": operator.matmul,
|
||||
"aten.leaky_relu.out": Tensor.leakyrelu, # TODO: this should be renamed in tinygrad
|
||||
# NOTE: because these methods have a name with "Tensor" in them, they can't go in simple tensor methods
|
||||
"aten.remainder.Tensor_out": Tensor.mod,
|
||||
"aten.pow.Tensor_Tensor_out": Tensor.pow,
|
||||
"aten.pow.Tensor_Scalar_out": Tensor.pow,
|
||||
"aten.pow.Scalar_out": lambda x,y: x**y,
|
||||
"aten.bitwise_and.Tensor_out": Tensor.bitwise_and,
|
||||
"aten.bitwise_or.Tensor_out": Tensor.bitwise_or,
|
||||
"aten.bitwise_xor.Tensor_out": lambda x,y: x^y, # TODO: tinygrad lacks bitwise_xor, add it
|
||||
@@ -229,10 +251,15 @@ def wrap_out(f):
|
||||
|
||||
tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
||||
"aten.view": Tensor.reshape,
|
||||
"aten._unsafe_view": Tensor.reshape, # when are views unsafe, and do we care?
|
||||
"aten.remainder.Scalar_Tensor": lambda x,y: x%y,
|
||||
"aten.floor_divide": lambda x,y: x//y,
|
||||
"aten.floor_divide_.Tensor": lambda x,y: x.assign(x//y),
|
||||
# TODO: use tinygrad methods, but they require x to be unsigned
|
||||
"aten.__lshift__.Scalar": lambda x,y: x*(2**y),
|
||||
"aten.__ilshift__.Scalar": lambda x,y: x.assign(x*(2**y)),
|
||||
"aten.__rshift__.Scalar": lambda x,y: x//(2**y),
|
||||
"aten.__irshift__.Scalar": lambda x,y: x.assign(x//(2**y)),
|
||||
# relu doesn't have an out form?
|
||||
"aten.relu": Tensor.relu,
|
||||
"aten.relu_": lambda x: x.assign(x.relu()),
|
||||
@@ -251,7 +278,7 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
||||
"aten.var_mean.correction": lambda self, dims, keepdim=False, correction=1: (self.var(dims, keepdim, correction), self.mean(dims, keepdim)),
|
||||
# NOTE: axis=[] in torch means all, change tinygrad?
|
||||
"aten.sum.IntList_out": lambda self,axis,keepdim=False,out=None:
|
||||
out.replace(Tensor.sum(self, axis if len(axis) else None, keepdim), allow_shape_mismatch=True),
|
||||
out.replace(Tensor.sum(self, axis if axis is None or len(axis) else None, keepdim), allow_shape_mismatch=True),
|
||||
"aten.scatter.value": Tensor.scatter,
|
||||
"aten.gather": Tensor.gather,
|
||||
"aten.where.self": Tensor.where,
|
||||
@@ -266,11 +293,14 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
||||
# these don't work in out form, they have size 0
|
||||
"aten.abs": Tensor.abs,
|
||||
"aten.logical_not": Tensor.logical_not,
|
||||
"aten.masked_fill_.Scalar": lambda self,mask,value: self.assign(mask.where(self, value)),
|
||||
"aten.multinomial": Tensor.multinomial,
|
||||
}}
|
||||
|
||||
def wrap_fxn(k,f):
|
||||
def nf(*args, **kwargs):
|
||||
if TORCH_DEBUG: print(k, len(args), [x.shape if isinstance(x, torch.Tensor) else x for x in args],
|
||||
if TORCH_DEBUG:
|
||||
print(k, len(args), [x.shape if isinstance(x, torch.Tensor) else x for x in args],
|
||||
{k:v.shape if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()})
|
||||
args = [unwrap(x) if isinstance(x, torch.Tensor) else x for x in args]
|
||||
kwargs = {k:unwrap(v) if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()}
|
||||
@@ -300,8 +330,9 @@ def realize_optimizer_step(optimizer: torch.optim.Optimizer, *args, **kwargs):
|
||||
tinygrad_tensors.append(param.data)
|
||||
for state_dict in optimizer.state.values():
|
||||
for key, value in state_dict.items():
|
||||
if torch.is_tensor(value) and str(value.device) == "tiny": tinygrad_tensors.append(value)
|
||||
Tensor.realize(*[unwrap(x) for x in tinygrad_tensors])
|
||||
if torch.is_tensor(value): tinygrad_tensors.append(value)
|
||||
real_tinygrad_tensors = [unwrap(x) for x in tinygrad_tensors if str(x.device) == "tiny"]
|
||||
if len(real_tinygrad_tensors): Tensor.realize(*real_tinygrad_tensors)
|
||||
|
||||
_optimizer_init = torch.optim.Optimizer.__init__
|
||||
def _optimizer_patched_init(self, *args, **kwargs):
|
||||
|
||||
@@ -50,6 +50,11 @@ class TestTorchBackend(unittest.TestCase):
|
||||
np.testing.assert_equal(perm.cpu().numpy(), [[1,3],[2,4]])
|
||||
np.testing.assert_equal(back.cpu().numpy(), [[1,2],[3,4]])
|
||||
|
||||
def test_shrink(self):
|
||||
a = torch.Tensor([1,2,3,4]).to(device)
|
||||
np.testing.assert_equal(a[:3].cpu().numpy(), [1,2,3])
|
||||
np.testing.assert_equal(a[1:].cpu().numpy(), [2,3,4])
|
||||
|
||||
def test_plus_inplace(self):
|
||||
a = torch.ones(4, device=device)
|
||||
b = torch.ones(4, device=device)
|
||||
@@ -66,15 +71,13 @@ class TestTorchBackend(unittest.TestCase):
|
||||
a = torch.ones(4, device=device)
|
||||
np.testing.assert_equal(torch.isfinite(a).cpu().numpy(), [True, True, True, True])
|
||||
|
||||
@unittest.skip("broken")
|
||||
def test_eq(self):
|
||||
a = torch.ones(4, device=device)
|
||||
b = torch.ones(4, device=device)
|
||||
c = a == b
|
||||
print(c.cpu().numpy())
|
||||
|
||||
# TODO: why
|
||||
@unittest.skip("broken")
|
||||
@unittest.skip("meh")
|
||||
def test_str(self):
|
||||
a = torch.ones(4, device=device)
|
||||
print(str(a))
|
||||
|
||||
@@ -7,8 +7,95 @@
|
||||
// register guard
|
||||
namespace at {
|
||||
namespace detail {
|
||||
C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl<DeviceType::PrivateUse1>);
|
||||
//C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl<DeviceType::PrivateUse1>);
|
||||
// NOTE: pytorch's no-op class throws error on backwards with events/streams
|
||||
// TODO: why are there events in autograd?
|
||||
struct CustomNoOpDeviceGuardImpl : public c10::impl::DeviceGuardImplInterface
|
||||
{
|
||||
static const DeviceType D = DeviceType::PrivateUse1;
|
||||
CustomNoOpDeviceGuardImpl() = default;
|
||||
DeviceType type() const override {
|
||||
return D;
|
||||
}
|
||||
Device exchangeDevice(Device) const override {
|
||||
return Device(D, 0); // no-op
|
||||
}
|
||||
Device getDevice() const override {
|
||||
return Device(D, 0);
|
||||
}
|
||||
void setDevice(Device) const override {
|
||||
// no-op
|
||||
}
|
||||
void uncheckedSetDevice(Device) const noexcept override {
|
||||
// no-op
|
||||
}
|
||||
Stream getStream(Device) const noexcept override {
|
||||
// no-op
|
||||
return Stream(Stream::DEFAULT, Device(D, 0));
|
||||
}
|
||||
Stream getDefaultStream(Device) const override {
|
||||
// no-op
|
||||
return Stream(Stream::DEFAULT, Device(D, 0));
|
||||
}
|
||||
Stream getStreamFromGlobalPool(Device, bool isHighPriority = false)
|
||||
const override {
|
||||
// no-op
|
||||
(void)isHighPriority;
|
||||
return Stream(Stream::DEFAULT, Device(D, 0));
|
||||
}
|
||||
Stream getNewStream(Device, int priority = 0) const override {
|
||||
// no-op
|
||||
(void)priority;
|
||||
return Stream(Stream::DEFAULT, Device(D, 0));
|
||||
}
|
||||
// NB: These do NOT set the current device
|
||||
Stream exchangeStream(Stream) const noexcept override {
|
||||
// no-op
|
||||
return Stream(Stream::DEFAULT, Device(D, 0));
|
||||
}
|
||||
DeviceIndex deviceCount() const noexcept override {
|
||||
return 1;
|
||||
}
|
||||
// Event-related functions
|
||||
void record(
|
||||
void** /*event*/,
|
||||
const Stream& /*stream*/,
|
||||
const DeviceIndex /*device_index*/,
|
||||
const EventFlag /*flag*/) const override {
|
||||
//TORCH_CHECK(false, D, " backend doesn't support events.");
|
||||
}
|
||||
void block(void* /*event*/, const Stream& /*stream*/) const override {
|
||||
//TORCH_CHECK(false, D, " backend doesn't support events.")
|
||||
}
|
||||
bool queryEvent(void* /*event*/) const override {
|
||||
//TORCH_CHECK(false, D, " backend doesn't support events.")
|
||||
return true;
|
||||
}
|
||||
void destroyEvent(void* /*event*/, const DeviceIndex /*device_index*/)
|
||||
const noexcept override {}
|
||||
// Stream-related functions
|
||||
bool queryStream(const Stream& /*stream*/) const override {
|
||||
return true;
|
||||
}
|
||||
void synchronizeStream(const Stream& /*stream*/) const override {
|
||||
// Don't wait for anything.
|
||||
}
|
||||
};
|
||||
C10_REGISTER_GUARD_IMPL(PrivateUse1, CustomNoOpDeviceGuardImpl);
|
||||
}
|
||||
|
||||
template <typename OpaqueHandle>
|
||||
struct TinyOpaqueTensorImpl : public OpaqueTensorImpl<OpaqueHandle> {
|
||||
TinyOpaqueTensorImpl(
|
||||
at::DispatchKeySet key_set,
|
||||
const caffe2::TypeMeta data_type,
|
||||
c10::Device device,
|
||||
OpaqueHandle opaque_handle,
|
||||
c10::IntArrayRef sizes,
|
||||
c10::IntArrayRef strides)
|
||||
: OpaqueTensorImpl<OpaqueHandle>(key_set, data_type, device, opaque_handle, sizes)
|
||||
{ this->sizes_and_strides_.set_strides(strides); }
|
||||
};
|
||||
}
|
||||
|
||||
struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
|
||||
@@ -26,17 +113,28 @@ at::Tensor wrap_tensor(py::object &py_obj, c10::ScalarType dtype) {
|
||||
// TODO: we have to get the dtype and the shape from the tinygrad Tensor
|
||||
std::vector<int64_t> sizes = py_obj.attr("shape").cast<std::vector<int64_t>>();
|
||||
|
||||
return at::detail::make_tensor<at::OpaqueTensorImpl<std::shared_ptr<c10::SafePyObject>>>(
|
||||
// Last dimension stride is 1 for contiguous row-major layout
|
||||
std::vector<int64_t> strides(sizes.size());
|
||||
if (sizes.size() >= 1) {
|
||||
strides[sizes.size() - 1] = 1;
|
||||
|
||||
// Compute strides from right to left
|
||||
for (int64_t i = sizes.size() - 2; i >= 0; --i) {
|
||||
strides[i] = strides[i + 1] * sizes[i + 1];
|
||||
}
|
||||
}
|
||||
|
||||
return at::detail::make_tensor<at::TinyOpaqueTensorImpl<std::shared_ptr<c10::SafePyObject>>>(
|
||||
at::DispatchKeySet(at::DispatchKey::PrivateUse1),
|
||||
c10::scalarTypeToTypeMeta(dtype),
|
||||
at::Device(at::kPrivateUse1),
|
||||
std::make_shared<c10::SafePyObject>(py_obj.release().ptr(), getPyInterpreter()),
|
||||
sizes);
|
||||
sizes, strides);
|
||||
}
|
||||
|
||||
py::object unwrap_tensor(const at::Tensor &tensor) {
|
||||
auto* impl = tensor.unsafeGetTensorImpl();
|
||||
auto* opaque_impl = static_cast<at::OpaqueTensorImpl<std::shared_ptr<c10::SafePyObject>>*>(impl);
|
||||
auto* opaque_impl = static_cast<at::TinyOpaqueTensorImpl<std::shared_ptr<c10::SafePyObject>>*>(impl);
|
||||
std::shared_ptr<c10::SafePyObject> tiny = opaque_impl->opaque_handle();
|
||||
return py::reinterpret_borrow<py::object>(tiny->ptr(getPyInterpreter()));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user