diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 816cfa38a4..eb7a1642cd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -155,7 +155,8 @@ jobs: with: key: torch-backend-pillow-torchvision-et-pt deps: testing_minimal - pydeps: "pillow torchvision expecttest pytest" + pydeps: "pillow torchvision expecttest" + llvm: 'true' - name: Install ninja run: | sudo apt update || true @@ -169,11 +170,11 @@ jobs: - name: Test one op in torch tests run: PYTHONPATH=. DEBUG=2 python3 extra/torch_backend/torch_tests.py TestTinyBackendPRIVATEUSE1.test_unary_log_tiny_float32 - name: Test Ops with TINY_BACKEND (expect failure) - run: PYTHONPATH=. TINY_BACKEND=1 python3 -m pytest test/test_ops.py || true + run: PYTHONPATH=. LLVM=1 LLVMOPT=0 TINY_BACKEND=1 python3 -m pytest -n auto test/test_ops.py || true - name: Test beautiful_mnist in torch with TINY_BACKEND (expect failure) run: PYTHONPATH=. TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py || true - name: Test some torch tests (expect failure) - run: PYTHONPATH=. pytest extra/torch_backend/torch_tests.py -v --tb=no || true + run: PYTHONPATH=. python3 -m pytest extra/torch_backend/torch_tests.py -v --tb=no || true tc: name: Tensor Core tests diff --git a/extra/to_movement_ops.py b/extra/to_movement_ops.py index 2528a2ca8f..fd306156f3 100644 --- a/extra/to_movement_ops.py +++ b/extra/to_movement_ops.py @@ -6,15 +6,16 @@ from extra.optimization.helpers import load_worlds, ast_str_to_ast from tinygrad.helpers import prod, tqdm from tinygrad.ops import UOp, Ops from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.ops import sym_infer, Node +from tinygrad.ops import sym_infer +from tinygrad.tensor import Tensor class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702 -def apply_mop(st: ShapeTracker, mop_arg: Tuple[MovementOps, Tuple]) -> ShapeTracker: +def apply_mop(st: Tensor|ShapeTracker, mop_arg: Tuple[MovementOps, Tuple]) -> ShapeTracker: mop, arg = mop_arg if mop == MovementOps.RESHAPE: # shapetracker doesn't allow flattening with -1 but required for MovementOps.RESHAPE - if arg == (-1,): return st.reshape((prod(st.views[-1].shape),)) + if arg == (-1,): return st.reshape((prod(st.shape),)) return st.reshape(arg) if mop == MovementOps.PERMUTE: return st.permute(arg) if mop == MovementOps.EXPAND: @@ -22,7 +23,9 @@ def apply_mop(st: ShapeTracker, mop_arg: Tuple[MovementOps, Tuple]) -> ShapeTrac return st.expand(arg) if mop == MovementOps.PAD: return st.pad(arg) if mop == MovementOps.SHRINK: return st.shrink(arg) - if mop == MovementOps.STRIDE: return st.stride(arg) + if mop == MovementOps.STRIDE: + assert all(x in [-1, 1] for x in arg) + return st.flip(tuple(i for i,x in enumerate(arg) if x == -1)) raise ValueError("invalid mop") def make_scratch_st(st: ShapeTracker) -> ShapeTracker: @@ -36,7 +39,7 @@ def to_movement_ops(st: ShapeTracker) -> List[Tuple[MovementOps, Tuple]]: offset = v.offset + sum(st*(s-1) for s,st in zip(real_shape, v.strides) if st<0) real_offset = offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0) real_real_shape = [s for s,st in zip(real_shape, v.strides) if st] - strides: List[Node|int] = [abs(st) if isinstance(st,int) else st for st in v.strides if st] + strides: List[int] = [abs(st) if isinstance(st,int) else st for st in v.strides if st] buffer_size = sum((s-1)*st for s,st in zip(real_real_shape,strides)) + 1 if i: buffer_size = prod(st.views[i-1].shape) - real_offset def sort_by_strides(shape, strides): return sorted(zip(shape, strides), key=lambda k: (k[1],-k[0]), reverse=True), sorted(range(len(strides)), key=lambda k: (strides[k],-real_real_shape[k]), reverse=True) @@ -80,9 +83,12 @@ def to_movement_ops(st: ShapeTracker) -> List[Tuple[MovementOps, Tuple]]: if scratch_st in seen: ret = seen[scratch_st][:] else: - ret.append(mop_arg) + if len(ret) and ret[-1][0] == MovementOps.RESHAPE and mop_arg[0] == MovementOps.RESHAPE: + ret[-1] = mop_arg + else: + if mop_arg == (MovementOps.RESHAPE, -1): mop_arg = (MovementOps.RESHAPE, (prod(st.shape),)) + ret.append(mop_arg) seen[scratch_st] = ret[:] - return ret def get_real_view(shape, strides, offset, mask): diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index 70b71a6c8d..dc49b24837 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -1,5 +1,6 @@ from tinygrad import Tensor, dtypes from tinygrad.helpers import DEBUG, getenv, prod +import torch.lib TORCH_DEBUG = getenv("TORCH_DEBUG") import torch, pathlib, math, operator torch.autograd.grad_mode.set_multithreading_enabled(False) @@ -38,33 +39,19 @@ def masked_select(self, mask): # err, bad return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()])) +from tinygrad.shape.shapetracker import ShapeTracker, View +from extra.to_movement_ops import to_movement_ops, apply_mop, MovementOps @torch.library.impl("aten::as_strided", "privateuseone") def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None): - #return tensor.cpu().as_strided(size, stride).tiny() - if TORCH_DEBUG >= 1: print("** NOTE: this as_strided might be wrong", tensor.shape, size, stride, storage_offset) - - nz_strides = [st for s,st in zip(size, stride) if s != 1] - decending_strides = all(x>=y for x,y in zip(nz_strides[:-1], nz_strides[1:])) - - # this is reshape (squeeze/unsqueeze), strides must be in decending order - if tuple(x for x in tensor.shape if x != 1) == tuple(x for x in size if x != 1) and decending_strides: - return tensor.reshape(size) - - # this is also expand, hit? - if tensor.numel() == 1: - assert all(x == 0 for x in stride) - return wrap(unwrap(tensor).reshape([1]*len(size)).expand(size)) - - # this is expand - if len(tensor.shape) == len(size) and all(x == y or x == 1 for x,y in zip(tensor.shape, size)) and decending_strides: - return wrap(unwrap(tensor).expand(size)) - - # this is permute because we are flipping strides - if len(tensor.shape) == 2 and tuple(tensor.shape)[::-1] == tuple(size) and stride == [0, 1]: - return wrap(unwrap(tensor).permute(1,0)) - - #print(tensor.cpu().numpy()) - raise NotImplementedError(f"fix as_strided {tensor.shape} -> {size} {stride} {storage_offset}") + # TODO: this is heavyweight + st = ShapeTracker([View.create(tuple(tensor.shape)), View.create(tuple(size), tuple(stride), 0 if storage_offset is None else storage_offset)]) + ret = unwrap(tensor) + if prod(size) == 1: return wrap(ret.flatten()[storage_offset].reshape(size)) + if TORCH_DEBUG >= 1: print("**** as_strided", tensor.shape, size, stride, st) + mops = to_movement_ops(st) + if mops[0] == (MovementOps.RESHAPE, tuple(tensor.shape)): mops = mops[1:] + for mo in mops: ret = apply_mop(ret, mo) + return wrap(ret) @torch.library.impl("aten::empty_strided", "privateuseone") def empty_strided(size, stride, dtype, layout=None, device=None, pin_memory=False): @@ -85,6 +72,33 @@ def max_pool2d_with_indices(self:Tensor, kernel_size, stride=None, padding=0, di # TODO: this is wrong return (wrap(ret), wrap(Tensor.zeros_like(ret, dtype=dtypes.int64))) +@torch.library.impl("aten::arange", "privateuseone") +def arange(end, dtype=None, device=None, pin_memory=None): + return wrap(Tensor.arange(0, end, dtype=_from_torch_dtype(dtype or torch.get_default_dtype()))) + +@torch.library.impl("aten::arange.start", "privateuseone") +def arange_start(start, end, dtype=None, device=None, pin_memory=None): + return wrap(Tensor.arange(start, end, dtype=_from_torch_dtype(dtype or torch.get_default_dtype()))) + +@torch.library.impl("aten::arange.start_step", "privateuseone") +def arange_start_step(start, end, step, dtype=None, device=None, pin_memory=None): + return wrap(Tensor.arange(start, end, step, dtype=_from_torch_dtype(dtype or torch.get_default_dtype()))) + +@torch.library.impl("aten::topk", "privateuseone") +def topk(self, k, dim=-1, largest=True, sorted=True): + # TODO: move to tinygrad + t1, t2 = torch.topk(self.cpu(), k, dim, largest, sorted) + return torch.return_types.topk((t1.tiny(), t2.tiny())) + +@torch.library.impl("aten::_index_put_impl_", "privateuseone") +def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False): + # TODO: move to tinygrad + return aten._index_put_impl_(self.cpu(), [x.cpu() for x in indices], values.cpu(), accumulate, unsafe).tiny() + +@torch.library.impl("aten::index.Tensor", "privateuseone") +def index_tensor(x, y): + return aten.index(x.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in y]).tiny() + @torch.library.impl("aten::convolution_overrideable", "privateuseone") def convolution_overrideable(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups): if TORCH_DEBUG >= 1: @@ -103,14 +117,15 @@ def _copy_from(src, dest): dest.copy_(torch.from_numpy(unwrap(src).numpy())) elif str(src.device) == "cpu" and str(dest.device) == "tiny": unwrap(dest).assign(Tensor(src.numpy())) + #if 0 in dest.stride(): + # print(dest.shape, dest.stride()) + # exit(0) else: raise NotImplementedError(f"can't copy from {src.device} -> {dest.device}") @torch.library.impl("aten::cat.out", "privateuseone") -def cat_out(tensors, out, dim=0): unwrap(out).replace(Tensor.cat(*[unwrap(x) for x in tensors], dim=dim), allow_shape_mismatch=True) - -@torch.library.impl("aten::index.Tensor", "privateuseone") -def index_tensor(x, y): return wrap(unwrap(x)[y[0].tolist()]) +def cat_out(tensors, dim=0, out=None): + unwrap(out).replace(Tensor.cat(*[unwrap(x) for x in tensors], dim=dim), allow_shape_mismatch=True) # register some decompositions from torch._decomp import get_decompositions @@ -118,6 +133,7 @@ aten = torch.ops.aten decomps = { "post_autograd": [ aten.native_batch_norm, aten.native_batch_norm_backward, + aten.native_layer_norm_backward, aten.addmm, aten.addcmul, aten.addcdiv, @@ -142,6 +158,10 @@ decomps = { aten.nan_to_num, aten.logit, aten.rsub, + aten.index_select, + aten.native_dropout, aten.native_dropout_backward, + aten._softmax_backward_data, aten.embedding_dense_backward, + aten.linalg_vector_norm, # activations aten.hardswish, aten.hardswish_backward, aten.hardtanh, aten.hardtanh_backward, @@ -194,11 +214,13 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_ "aten.add.out": lambda input,other,alpha=1: input+alpha*other, "aten.sub.out": lambda input,other,alpha=1: input-alpha*other, # NOTE: this is also needed to handle reverse "aten.mul.out": operator.mul, + "aten.bmm.out": operator.matmul, "aten.leaky_relu.out": Tensor.leakyrelu, # TODO: this should be renamed in tinygrad # NOTE: because these methods have a name with "Tensor" in them, they can't go in simple tensor methods "aten.remainder.Tensor_out": Tensor.mod, "aten.pow.Tensor_Tensor_out": Tensor.pow, "aten.pow.Tensor_Scalar_out": Tensor.pow, + "aten.pow.Scalar_out": lambda x,y: x**y, "aten.bitwise_and.Tensor_out": Tensor.bitwise_and, "aten.bitwise_or.Tensor_out": Tensor.bitwise_or, "aten.bitwise_xor.Tensor_out": lambda x,y: x^y, # TODO: tinygrad lacks bitwise_xor, add it @@ -229,10 +251,15 @@ def wrap_out(f): tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{ "aten.view": Tensor.reshape, + "aten._unsafe_view": Tensor.reshape, # when are views unsafe, and do we care? + "aten.remainder.Scalar_Tensor": lambda x,y: x%y, "aten.floor_divide": lambda x,y: x//y, + "aten.floor_divide_.Tensor": lambda x,y: x.assign(x//y), # TODO: use tinygrad methods, but they require x to be unsigned "aten.__lshift__.Scalar": lambda x,y: x*(2**y), + "aten.__ilshift__.Scalar": lambda x,y: x.assign(x*(2**y)), "aten.__rshift__.Scalar": lambda x,y: x//(2**y), + "aten.__irshift__.Scalar": lambda x,y: x.assign(x//(2**y)), # relu doesn't have an out form? "aten.relu": Tensor.relu, "aten.relu_": lambda x: x.assign(x.relu()), @@ -251,7 +278,7 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{ "aten.var_mean.correction": lambda self, dims, keepdim=False, correction=1: (self.var(dims, keepdim, correction), self.mean(dims, keepdim)), # NOTE: axis=[] in torch means all, change tinygrad? "aten.sum.IntList_out": lambda self,axis,keepdim=False,out=None: - out.replace(Tensor.sum(self, axis if len(axis) else None, keepdim), allow_shape_mismatch=True), + out.replace(Tensor.sum(self, axis if axis is None or len(axis) else None, keepdim), allow_shape_mismatch=True), "aten.scatter.value": Tensor.scatter, "aten.gather": Tensor.gather, "aten.where.self": Tensor.where, @@ -266,11 +293,14 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{ # these don't work in out form, they have size 0 "aten.abs": Tensor.abs, "aten.logical_not": Tensor.logical_not, + "aten.masked_fill_.Scalar": lambda self,mask,value: self.assign(mask.where(self, value)), + "aten.multinomial": Tensor.multinomial, }} def wrap_fxn(k,f): def nf(*args, **kwargs): - if TORCH_DEBUG: print(k, len(args), [x.shape if isinstance(x, torch.Tensor) else x for x in args], + if TORCH_DEBUG: + print(k, len(args), [x.shape if isinstance(x, torch.Tensor) else x for x in args], {k:v.shape if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()}) args = [unwrap(x) if isinstance(x, torch.Tensor) else x for x in args] kwargs = {k:unwrap(v) if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()} @@ -300,8 +330,9 @@ def realize_optimizer_step(optimizer: torch.optim.Optimizer, *args, **kwargs): tinygrad_tensors.append(param.data) for state_dict in optimizer.state.values(): for key, value in state_dict.items(): - if torch.is_tensor(value) and str(value.device) == "tiny": tinygrad_tensors.append(value) - Tensor.realize(*[unwrap(x) for x in tinygrad_tensors]) + if torch.is_tensor(value): tinygrad_tensors.append(value) + real_tinygrad_tensors = [unwrap(x) for x in tinygrad_tensors if str(x.device) == "tiny"] + if len(real_tinygrad_tensors): Tensor.realize(*real_tinygrad_tensors) _optimizer_init = torch.optim.Optimizer.__init__ def _optimizer_patched_init(self, *args, **kwargs): diff --git a/extra/torch_backend/test.py b/extra/torch_backend/test.py index 60ade1a2c6..88fa473741 100644 --- a/extra/torch_backend/test.py +++ b/extra/torch_backend/test.py @@ -50,6 +50,11 @@ class TestTorchBackend(unittest.TestCase): np.testing.assert_equal(perm.cpu().numpy(), [[1,3],[2,4]]) np.testing.assert_equal(back.cpu().numpy(), [[1,2],[3,4]]) + def test_shrink(self): + a = torch.Tensor([1,2,3,4]).to(device) + np.testing.assert_equal(a[:3].cpu().numpy(), [1,2,3]) + np.testing.assert_equal(a[1:].cpu().numpy(), [2,3,4]) + def test_plus_inplace(self): a = torch.ones(4, device=device) b = torch.ones(4, device=device) @@ -66,15 +71,13 @@ class TestTorchBackend(unittest.TestCase): a = torch.ones(4, device=device) np.testing.assert_equal(torch.isfinite(a).cpu().numpy(), [True, True, True, True]) - @unittest.skip("broken") def test_eq(self): a = torch.ones(4, device=device) b = torch.ones(4, device=device) c = a == b print(c.cpu().numpy()) - # TODO: why - @unittest.skip("broken") + @unittest.skip("meh") def test_str(self): a = torch.ones(4, device=device) print(str(a)) diff --git a/extra/torch_backend/wrapped_tensor.cpp b/extra/torch_backend/wrapped_tensor.cpp index 3acde01a30..658dc41597 100644 --- a/extra/torch_backend/wrapped_tensor.cpp +++ b/extra/torch_backend/wrapped_tensor.cpp @@ -7,8 +7,95 @@ // register guard namespace at { namespace detail { -C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl); +//C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl); +// NOTE: pytorch's no-op class throws error on backwards with events/streams +// TODO: why are there events in autograd? +struct CustomNoOpDeviceGuardImpl : public c10::impl::DeviceGuardImplInterface +{ + static const DeviceType D = DeviceType::PrivateUse1; + CustomNoOpDeviceGuardImpl() = default; + DeviceType type() const override { + return D; + } + Device exchangeDevice(Device) const override { + return Device(D, 0); // no-op + } + Device getDevice() const override { + return Device(D, 0); + } + void setDevice(Device) const override { + // no-op + } + void uncheckedSetDevice(Device) const noexcept override { + // no-op + } + Stream getStream(Device) const noexcept override { + // no-op + return Stream(Stream::DEFAULT, Device(D, 0)); + } + Stream getDefaultStream(Device) const override { + // no-op + return Stream(Stream::DEFAULT, Device(D, 0)); + } + Stream getStreamFromGlobalPool(Device, bool isHighPriority = false) + const override { + // no-op + (void)isHighPriority; + return Stream(Stream::DEFAULT, Device(D, 0)); + } + Stream getNewStream(Device, int priority = 0) const override { + // no-op + (void)priority; + return Stream(Stream::DEFAULT, Device(D, 0)); + } + // NB: These do NOT set the current device + Stream exchangeStream(Stream) const noexcept override { + // no-op + return Stream(Stream::DEFAULT, Device(D, 0)); + } + DeviceIndex deviceCount() const noexcept override { + return 1; + } + // Event-related functions + void record( + void** /*event*/, + const Stream& /*stream*/, + const DeviceIndex /*device_index*/, + const EventFlag /*flag*/) const override { + //TORCH_CHECK(false, D, " backend doesn't support events."); + } + void block(void* /*event*/, const Stream& /*stream*/) const override { + //TORCH_CHECK(false, D, " backend doesn't support events.") + } + bool queryEvent(void* /*event*/) const override { + //TORCH_CHECK(false, D, " backend doesn't support events.") + return true; + } + void destroyEvent(void* /*event*/, const DeviceIndex /*device_index*/) + const noexcept override {} + // Stream-related functions + bool queryStream(const Stream& /*stream*/) const override { + return true; + } + void synchronizeStream(const Stream& /*stream*/) const override { + // Don't wait for anything. + } +}; +C10_REGISTER_GUARD_IMPL(PrivateUse1, CustomNoOpDeviceGuardImpl); } + +template +struct TinyOpaqueTensorImpl : public OpaqueTensorImpl { + TinyOpaqueTensorImpl( + at::DispatchKeySet key_set, + const caffe2::TypeMeta data_type, + c10::Device device, + OpaqueHandle opaque_handle, + c10::IntArrayRef sizes, + c10::IntArrayRef strides) + : OpaqueTensorImpl(key_set, data_type, device, opaque_handle, sizes) + { this->sizes_and_strides_.set_strides(strides); } +}; } struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface { @@ -26,17 +113,28 @@ at::Tensor wrap_tensor(py::object &py_obj, c10::ScalarType dtype) { // TODO: we have to get the dtype and the shape from the tinygrad Tensor std::vector sizes = py_obj.attr("shape").cast>(); - return at::detail::make_tensor>>( + // Last dimension stride is 1 for contiguous row-major layout + std::vector strides(sizes.size()); + if (sizes.size() >= 1) { + strides[sizes.size() - 1] = 1; + + // Compute strides from right to left + for (int64_t i = sizes.size() - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * sizes[i + 1]; + } + } + + return at::detail::make_tensor>>( at::DispatchKeySet(at::DispatchKey::PrivateUse1), c10::scalarTypeToTypeMeta(dtype), at::Device(at::kPrivateUse1), std::make_shared(py_obj.release().ptr(), getPyInterpreter()), - sizes); + sizes, strides); } py::object unwrap_tensor(const at::Tensor &tensor) { auto* impl = tensor.unsafeGetTensorImpl(); - auto* opaque_impl = static_cast>*>(impl); + auto* opaque_impl = static_cast>*>(impl); std::shared_ptr tiny = opaque_impl->opaque_handle(); return py::reinterpret_borrow(tiny->ptr(getPyInterpreter())); }