torch examples (#9290)

* torch, fix examples/mnist

* fix vae torch example

* where out
This commit is contained in:
George Hotz
2025-02-28 10:16:06 +08:00
committed by GitHub
parent c977781b3c
commit b32595dbbc

View File

@@ -72,6 +72,8 @@ def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=F
@torch.library.impl("aten::max_pool2d_with_indices", "privateuseone")
def max_pool2d_with_indices(self:Tensor, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False):
# TODO: supprt stride [] in tinygrad?
if stride is not None and len(stride) == 0: stride = None
# TODO: support return_indices in tinygrad
ret = unwrap(self).max_pool2d(kernel_size, stride, dilation, padding, ceil_mode)
# TODO: this is wrong
@@ -79,6 +81,7 @@ def max_pool2d_with_indices(self:Tensor, kernel_size, stride=None, padding=0, di
@torch.library.impl("aten::max_pool2d_with_indices_backward", "privateuseone")
def max_pool2d_with_indices_backward(grad_out:Tensor, self:Tensor, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, indices=None):
if stride is not None and len(stride) == 0: stride = None
# TODO: utilize input indices once they are correct
grad_out, self = unwrap(grad_out), unwrap(self)
out = Tensor.max_pool2d(self, kernel_size, stride, dilation, padding, ceil_mode)
@@ -181,6 +184,7 @@ decomps = [
aten.native_dropout, aten.native_dropout_backward,
aten._softmax_backward_data, aten.embedding_dense_backward,
aten.linalg_vector_norm,
aten.binary_cross_entropy, aten.binary_cross_entropy_backward,
# activations
aten.hardswish, aten.hardswish_backward,
aten.hardtanh, aten.hardtanh_backward,
@@ -253,6 +257,7 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
#"aten.arange.start_out": Tensor.arange,
"aten.lerp.Scalar_out": Tensor.lerp,
"aten.scatter.value_out": Tensor.scatter,
"aten.where.self_out": Tensor.where,
}}
# we add the "out" here
@@ -297,7 +302,7 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
out.replace(Tensor.sum(self, axis if axis is None or len(axis) else None, keepdim), allow_shape_mismatch=True),
"aten.scatter.value": Tensor.scatter,
"aten.gather": Tensor.gather,
"aten.where.self": Tensor.where,
"aten.where.self": Tensor.where, # NOTE: this is needed as well as the out type
"aten._softmax": lambda self,dim,half_to_float: self.softmax(dim),
"aten._log_softmax": lambda self,dim,half_to_float: self.log_softmax(dim),
"aten.random_": lambda self: