full fix for as_strided in torch backend (#9257)

* fixes from chargpt for torch backend

* shrink support

* add stride support

* comment cleanup

* a few more

* work

* import the stream hack

* llvm multi auto
This commit is contained in:
George Hotz
2025-02-26 22:34:05 +08:00
committed by GitHub
parent f60f997bf7
commit 2158dc4849
5 changed files with 189 additions and 50 deletions

View File

@@ -155,7 +155,8 @@ jobs:
with:
key: torch-backend-pillow-torchvision-et-pt
deps: testing_minimal
pydeps: "pillow torchvision expecttest pytest"
pydeps: "pillow torchvision expecttest"
llvm: 'true'
- name: Install ninja
run: |
sudo apt update || true
@@ -169,11 +170,11 @@ jobs:
- name: Test one op in torch tests
run: PYTHONPATH=. DEBUG=2 python3 extra/torch_backend/torch_tests.py TestTinyBackendPRIVATEUSE1.test_unary_log_tiny_float32
- name: Test Ops with TINY_BACKEND (expect failure)
run: PYTHONPATH=. TINY_BACKEND=1 python3 -m pytest test/test_ops.py || true
run: PYTHONPATH=. LLVM=1 LLVMOPT=0 TINY_BACKEND=1 python3 -m pytest -n auto test/test_ops.py || true
- name: Test beautiful_mnist in torch with TINY_BACKEND (expect failure)
run: PYTHONPATH=. TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py || true
- name: Test some torch tests (expect failure)
run: PYTHONPATH=. pytest extra/torch_backend/torch_tests.py -v --tb=no || true
run: PYTHONPATH=. python3 -m pytest extra/torch_backend/torch_tests.py -v --tb=no || true
tc:
name: Tensor Core tests

View File

@@ -6,15 +6,16 @@ from extra.optimization.helpers import load_worlds, ast_str_to_ast
from tinygrad.helpers import prod, tqdm
from tinygrad.ops import UOp, Ops
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.ops import sym_infer, Node
from tinygrad.ops import sym_infer
from tinygrad.tensor import Tensor
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
def apply_mop(st: ShapeTracker, mop_arg: Tuple[MovementOps, Tuple]) -> ShapeTracker:
def apply_mop(st: Tensor|ShapeTracker, mop_arg: Tuple[MovementOps, Tuple]) -> ShapeTracker:
mop, arg = mop_arg
if mop == MovementOps.RESHAPE:
# shapetracker doesn't allow flattening with -1 but required for MovementOps.RESHAPE
if arg == (-1,): return st.reshape((prod(st.views[-1].shape),))
if arg == (-1,): return st.reshape((prod(st.shape),))
return st.reshape(arg)
if mop == MovementOps.PERMUTE: return st.permute(arg)
if mop == MovementOps.EXPAND:
@@ -22,7 +23,9 @@ def apply_mop(st: ShapeTracker, mop_arg: Tuple[MovementOps, Tuple]) -> ShapeTrac
return st.expand(arg)
if mop == MovementOps.PAD: return st.pad(arg)
if mop == MovementOps.SHRINK: return st.shrink(arg)
if mop == MovementOps.STRIDE: return st.stride(arg)
if mop == MovementOps.STRIDE:
assert all(x in [-1, 1] for x in arg)
return st.flip(tuple(i for i,x in enumerate(arg) if x == -1))
raise ValueError("invalid mop")
def make_scratch_st(st: ShapeTracker) -> ShapeTracker:
@@ -36,7 +39,7 @@ def to_movement_ops(st: ShapeTracker) -> List[Tuple[MovementOps, Tuple]]:
offset = v.offset + sum(st*(s-1) for s,st in zip(real_shape, v.strides) if st<0)
real_offset = offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
real_real_shape = [s for s,st in zip(real_shape, v.strides) if st]
strides: List[Node|int] = [abs(st) if isinstance(st,int) else st for st in v.strides if st]
strides: List[int] = [abs(st) if isinstance(st,int) else st for st in v.strides if st]
buffer_size = sum((s-1)*st for s,st in zip(real_real_shape,strides)) + 1
if i: buffer_size = prod(st.views[i-1].shape) - real_offset
def sort_by_strides(shape, strides): return sorted(zip(shape, strides), key=lambda k: (k[1],-k[0]), reverse=True), sorted(range(len(strides)), key=lambda k: (strides[k],-real_real_shape[k]), reverse=True)
@@ -80,9 +83,12 @@ def to_movement_ops(st: ShapeTracker) -> List[Tuple[MovementOps, Tuple]]:
if scratch_st in seen:
ret = seen[scratch_st][:]
else:
ret.append(mop_arg)
if len(ret) and ret[-1][0] == MovementOps.RESHAPE and mop_arg[0] == MovementOps.RESHAPE:
ret[-1] = mop_arg
else:
if mop_arg == (MovementOps.RESHAPE, -1): mop_arg = (MovementOps.RESHAPE, (prod(st.shape),))
ret.append(mop_arg)
seen[scratch_st] = ret[:]
return ret
def get_real_view(shape, strides, offset, mask):

View File

@@ -1,5 +1,6 @@
from tinygrad import Tensor, dtypes
from tinygrad.helpers import DEBUG, getenv, prod
import torch.lib
TORCH_DEBUG = getenv("TORCH_DEBUG")
import torch, pathlib, math, operator
torch.autograd.grad_mode.set_multithreading_enabled(False)
@@ -38,33 +39,19 @@ def masked_select(self, mask):
# err, bad
return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()]))
from tinygrad.shape.shapetracker import ShapeTracker, View
from extra.to_movement_ops import to_movement_ops, apply_mop, MovementOps
@torch.library.impl("aten::as_strided", "privateuseone")
def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None):
#return tensor.cpu().as_strided(size, stride).tiny()
if TORCH_DEBUG >= 1: print("** NOTE: this as_strided might be wrong", tensor.shape, size, stride, storage_offset)
nz_strides = [st for s,st in zip(size, stride) if s != 1]
decending_strides = all(x>=y for x,y in zip(nz_strides[:-1], nz_strides[1:]))
# this is reshape (squeeze/unsqueeze), strides must be in decending order
if tuple(x for x in tensor.shape if x != 1) == tuple(x for x in size if x != 1) and decending_strides:
return tensor.reshape(size)
# this is also expand, hit?
if tensor.numel() == 1:
assert all(x == 0 for x in stride)
return wrap(unwrap(tensor).reshape([1]*len(size)).expand(size))
# this is expand
if len(tensor.shape) == len(size) and all(x == y or x == 1 for x,y in zip(tensor.shape, size)) and decending_strides:
return wrap(unwrap(tensor).expand(size))
# this is permute because we are flipping strides
if len(tensor.shape) == 2 and tuple(tensor.shape)[::-1] == tuple(size) and stride == [0, 1]:
return wrap(unwrap(tensor).permute(1,0))
#print(tensor.cpu().numpy())
raise NotImplementedError(f"fix as_strided {tensor.shape} -> {size} {stride} {storage_offset}")
# TODO: this is heavyweight
st = ShapeTracker([View.create(tuple(tensor.shape)), View.create(tuple(size), tuple(stride), 0 if storage_offset is None else storage_offset)])
ret = unwrap(tensor)
if prod(size) == 1: return wrap(ret.flatten()[storage_offset].reshape(size))
if TORCH_DEBUG >= 1: print("**** as_strided", tensor.shape, size, stride, st)
mops = to_movement_ops(st)
if mops[0] == (MovementOps.RESHAPE, tuple(tensor.shape)): mops = mops[1:]
for mo in mops: ret = apply_mop(ret, mo)
return wrap(ret)
@torch.library.impl("aten::empty_strided", "privateuseone")
def empty_strided(size, stride, dtype, layout=None, device=None, pin_memory=False):
@@ -85,6 +72,33 @@ def max_pool2d_with_indices(self:Tensor, kernel_size, stride=None, padding=0, di
# TODO: this is wrong
return (wrap(ret), wrap(Tensor.zeros_like(ret, dtype=dtypes.int64)))
@torch.library.impl("aten::arange", "privateuseone")
def arange(end, dtype=None, device=None, pin_memory=None):
return wrap(Tensor.arange(0, end, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())))
@torch.library.impl("aten::arange.start", "privateuseone")
def arange_start(start, end, dtype=None, device=None, pin_memory=None):
return wrap(Tensor.arange(start, end, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())))
@torch.library.impl("aten::arange.start_step", "privateuseone")
def arange_start_step(start, end, step, dtype=None, device=None, pin_memory=None):
return wrap(Tensor.arange(start, end, step, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())))
@torch.library.impl("aten::topk", "privateuseone")
def topk(self, k, dim=-1, largest=True, sorted=True):
# TODO: move to tinygrad
t1, t2 = torch.topk(self.cpu(), k, dim, largest, sorted)
return torch.return_types.topk((t1.tiny(), t2.tiny()))
@torch.library.impl("aten::_index_put_impl_", "privateuseone")
def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False):
# TODO: move to tinygrad
return aten._index_put_impl_(self.cpu(), [x.cpu() for x in indices], values.cpu(), accumulate, unsafe).tiny()
@torch.library.impl("aten::index.Tensor", "privateuseone")
def index_tensor(x, y):
return aten.index(x.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in y]).tiny()
@torch.library.impl("aten::convolution_overrideable", "privateuseone")
def convolution_overrideable(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups):
if TORCH_DEBUG >= 1:
@@ -103,14 +117,15 @@ def _copy_from(src, dest):
dest.copy_(torch.from_numpy(unwrap(src).numpy()))
elif str(src.device) == "cpu" and str(dest.device) == "tiny":
unwrap(dest).assign(Tensor(src.numpy()))
#if 0 in dest.stride():
# print(dest.shape, dest.stride())
# exit(0)
else:
raise NotImplementedError(f"can't copy from {src.device} -> {dest.device}")
@torch.library.impl("aten::cat.out", "privateuseone")
def cat_out(tensors, out, dim=0): unwrap(out).replace(Tensor.cat(*[unwrap(x) for x in tensors], dim=dim), allow_shape_mismatch=True)
@torch.library.impl("aten::index.Tensor", "privateuseone")
def index_tensor(x, y): return wrap(unwrap(x)[y[0].tolist()])
def cat_out(tensors, dim=0, out=None):
unwrap(out).replace(Tensor.cat(*[unwrap(x) for x in tensors], dim=dim), allow_shape_mismatch=True)
# register some decompositions
from torch._decomp import get_decompositions
@@ -118,6 +133,7 @@ aten = torch.ops.aten
decomps = {
"post_autograd": [
aten.native_batch_norm, aten.native_batch_norm_backward,
aten.native_layer_norm_backward,
aten.addmm,
aten.addcmul,
aten.addcdiv,
@@ -142,6 +158,10 @@ decomps = {
aten.nan_to_num,
aten.logit,
aten.rsub,
aten.index_select,
aten.native_dropout, aten.native_dropout_backward,
aten._softmax_backward_data, aten.embedding_dense_backward,
aten.linalg_vector_norm,
# activations
aten.hardswish, aten.hardswish_backward,
aten.hardtanh, aten.hardtanh_backward,
@@ -194,11 +214,13 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
"aten.add.out": lambda input,other,alpha=1: input+alpha*other,
"aten.sub.out": lambda input,other,alpha=1: input-alpha*other, # NOTE: this is also needed to handle reverse
"aten.mul.out": operator.mul,
"aten.bmm.out": operator.matmul,
"aten.leaky_relu.out": Tensor.leakyrelu, # TODO: this should be renamed in tinygrad
# NOTE: because these methods have a name with "Tensor" in them, they can't go in simple tensor methods
"aten.remainder.Tensor_out": Tensor.mod,
"aten.pow.Tensor_Tensor_out": Tensor.pow,
"aten.pow.Tensor_Scalar_out": Tensor.pow,
"aten.pow.Scalar_out": lambda x,y: x**y,
"aten.bitwise_and.Tensor_out": Tensor.bitwise_and,
"aten.bitwise_or.Tensor_out": Tensor.bitwise_or,
"aten.bitwise_xor.Tensor_out": lambda x,y: x^y, # TODO: tinygrad lacks bitwise_xor, add it
@@ -229,10 +251,15 @@ def wrap_out(f):
tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
"aten.view": Tensor.reshape,
"aten._unsafe_view": Tensor.reshape, # when are views unsafe, and do we care?
"aten.remainder.Scalar_Tensor": lambda x,y: x%y,
"aten.floor_divide": lambda x,y: x//y,
"aten.floor_divide_.Tensor": lambda x,y: x.assign(x//y),
# TODO: use tinygrad methods, but they require x to be unsigned
"aten.__lshift__.Scalar": lambda x,y: x*(2**y),
"aten.__ilshift__.Scalar": lambda x,y: x.assign(x*(2**y)),
"aten.__rshift__.Scalar": lambda x,y: x//(2**y),
"aten.__irshift__.Scalar": lambda x,y: x.assign(x//(2**y)),
# relu doesn't have an out form?
"aten.relu": Tensor.relu,
"aten.relu_": lambda x: x.assign(x.relu()),
@@ -251,7 +278,7 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
"aten.var_mean.correction": lambda self, dims, keepdim=False, correction=1: (self.var(dims, keepdim, correction), self.mean(dims, keepdim)),
# NOTE: axis=[] in torch means all, change tinygrad?
"aten.sum.IntList_out": lambda self,axis,keepdim=False,out=None:
out.replace(Tensor.sum(self, axis if len(axis) else None, keepdim), allow_shape_mismatch=True),
out.replace(Tensor.sum(self, axis if axis is None or len(axis) else None, keepdim), allow_shape_mismatch=True),
"aten.scatter.value": Tensor.scatter,
"aten.gather": Tensor.gather,
"aten.where.self": Tensor.where,
@@ -266,11 +293,14 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
# these don't work in out form, they have size 0
"aten.abs": Tensor.abs,
"aten.logical_not": Tensor.logical_not,
"aten.masked_fill_.Scalar": lambda self,mask,value: self.assign(mask.where(self, value)),
"aten.multinomial": Tensor.multinomial,
}}
def wrap_fxn(k,f):
def nf(*args, **kwargs):
if TORCH_DEBUG: print(k, len(args), [x.shape if isinstance(x, torch.Tensor) else x for x in args],
if TORCH_DEBUG:
print(k, len(args), [x.shape if isinstance(x, torch.Tensor) else x for x in args],
{k:v.shape if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()})
args = [unwrap(x) if isinstance(x, torch.Tensor) else x for x in args]
kwargs = {k:unwrap(v) if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()}
@@ -300,8 +330,9 @@ def realize_optimizer_step(optimizer: torch.optim.Optimizer, *args, **kwargs):
tinygrad_tensors.append(param.data)
for state_dict in optimizer.state.values():
for key, value in state_dict.items():
if torch.is_tensor(value) and str(value.device) == "tiny": tinygrad_tensors.append(value)
Tensor.realize(*[unwrap(x) for x in tinygrad_tensors])
if torch.is_tensor(value): tinygrad_tensors.append(value)
real_tinygrad_tensors = [unwrap(x) for x in tinygrad_tensors if str(x.device) == "tiny"]
if len(real_tinygrad_tensors): Tensor.realize(*real_tinygrad_tensors)
_optimizer_init = torch.optim.Optimizer.__init__
def _optimizer_patched_init(self, *args, **kwargs):

View File

@@ -50,6 +50,11 @@ class TestTorchBackend(unittest.TestCase):
np.testing.assert_equal(perm.cpu().numpy(), [[1,3],[2,4]])
np.testing.assert_equal(back.cpu().numpy(), [[1,2],[3,4]])
def test_shrink(self):
a = torch.Tensor([1,2,3,4]).to(device)
np.testing.assert_equal(a[:3].cpu().numpy(), [1,2,3])
np.testing.assert_equal(a[1:].cpu().numpy(), [2,3,4])
def test_plus_inplace(self):
a = torch.ones(4, device=device)
b = torch.ones(4, device=device)
@@ -66,15 +71,13 @@ class TestTorchBackend(unittest.TestCase):
a = torch.ones(4, device=device)
np.testing.assert_equal(torch.isfinite(a).cpu().numpy(), [True, True, True, True])
@unittest.skip("broken")
def test_eq(self):
a = torch.ones(4, device=device)
b = torch.ones(4, device=device)
c = a == b
print(c.cpu().numpy())
# TODO: why
@unittest.skip("broken")
@unittest.skip("meh")
def test_str(self):
a = torch.ones(4, device=device)
print(str(a))

View File

@@ -7,8 +7,95 @@
// register guard
namespace at {
namespace detail {
C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl<DeviceType::PrivateUse1>);
//C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl<DeviceType::PrivateUse1>);
// NOTE: pytorch's no-op class throws error on backwards with events/streams
// TODO: why are there events in autograd?
struct CustomNoOpDeviceGuardImpl : public c10::impl::DeviceGuardImplInterface
{
static const DeviceType D = DeviceType::PrivateUse1;
CustomNoOpDeviceGuardImpl() = default;
DeviceType type() const override {
return D;
}
Device exchangeDevice(Device) const override {
return Device(D, 0); // no-op
}
Device getDevice() const override {
return Device(D, 0);
}
void setDevice(Device) const override {
// no-op
}
void uncheckedSetDevice(Device) const noexcept override {
// no-op
}
Stream getStream(Device) const noexcept override {
// no-op
return Stream(Stream::DEFAULT, Device(D, 0));
}
Stream getDefaultStream(Device) const override {
// no-op
return Stream(Stream::DEFAULT, Device(D, 0));
}
Stream getStreamFromGlobalPool(Device, bool isHighPriority = false)
const override {
// no-op
(void)isHighPriority;
return Stream(Stream::DEFAULT, Device(D, 0));
}
Stream getNewStream(Device, int priority = 0) const override {
// no-op
(void)priority;
return Stream(Stream::DEFAULT, Device(D, 0));
}
// NB: These do NOT set the current device
Stream exchangeStream(Stream) const noexcept override {
// no-op
return Stream(Stream::DEFAULT, Device(D, 0));
}
DeviceIndex deviceCount() const noexcept override {
return 1;
}
// Event-related functions
void record(
void** /*event*/,
const Stream& /*stream*/,
const DeviceIndex /*device_index*/,
const EventFlag /*flag*/) const override {
//TORCH_CHECK(false, D, " backend doesn't support events.");
}
void block(void* /*event*/, const Stream& /*stream*/) const override {
//TORCH_CHECK(false, D, " backend doesn't support events.")
}
bool queryEvent(void* /*event*/) const override {
//TORCH_CHECK(false, D, " backend doesn't support events.")
return true;
}
void destroyEvent(void* /*event*/, const DeviceIndex /*device_index*/)
const noexcept override {}
// Stream-related functions
bool queryStream(const Stream& /*stream*/) const override {
return true;
}
void synchronizeStream(const Stream& /*stream*/) const override {
// Don't wait for anything.
}
};
C10_REGISTER_GUARD_IMPL(PrivateUse1, CustomNoOpDeviceGuardImpl);
}
template <typename OpaqueHandle>
struct TinyOpaqueTensorImpl : public OpaqueTensorImpl<OpaqueHandle> {
TinyOpaqueTensorImpl(
at::DispatchKeySet key_set,
const caffe2::TypeMeta data_type,
c10::Device device,
OpaqueHandle opaque_handle,
c10::IntArrayRef sizes,
c10::IntArrayRef strides)
: OpaqueTensorImpl<OpaqueHandle>(key_set, data_type, device, opaque_handle, sizes)
{ this->sizes_and_strides_.set_strides(strides); }
};
}
struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
@@ -26,17 +113,28 @@ at::Tensor wrap_tensor(py::object &py_obj, c10::ScalarType dtype) {
// TODO: we have to get the dtype and the shape from the tinygrad Tensor
std::vector<int64_t> sizes = py_obj.attr("shape").cast<std::vector<int64_t>>();
return at::detail::make_tensor<at::OpaqueTensorImpl<std::shared_ptr<c10::SafePyObject>>>(
// Last dimension stride is 1 for contiguous row-major layout
std::vector<int64_t> strides(sizes.size());
if (sizes.size() >= 1) {
strides[sizes.size() - 1] = 1;
// Compute strides from right to left
for (int64_t i = sizes.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * sizes[i + 1];
}
}
return at::detail::make_tensor<at::TinyOpaqueTensorImpl<std::shared_ptr<c10::SafePyObject>>>(
at::DispatchKeySet(at::DispatchKey::PrivateUse1),
c10::scalarTypeToTypeMeta(dtype),
at::Device(at::kPrivateUse1),
std::make_shared<c10::SafePyObject>(py_obj.release().ptr(), getPyInterpreter()),
sizes);
sizes, strides);
}
py::object unwrap_tensor(const at::Tensor &tensor) {
auto* impl = tensor.unsafeGetTensorImpl();
auto* opaque_impl = static_cast<at::OpaqueTensorImpl<std::shared_ptr<c10::SafePyObject>>*>(impl);
auto* opaque_impl = static_cast<at::TinyOpaqueTensorImpl<std::shared_ptr<c10::SafePyObject>>*>(impl);
std::shared_ptr<c10::SafePyObject> tiny = opaque_impl->opaque_handle();
return py::reinterpret_borrow<py::object>(tiny->ptr(getPyInterpreter()));
}