mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
torch examples (#9290)
* torch, fix examples/mnist * fix vae torch example * where out
This commit is contained in:
@@ -72,6 +72,8 @@ def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=F
|
||||
|
||||
@torch.library.impl("aten::max_pool2d_with_indices", "privateuseone")
|
||||
def max_pool2d_with_indices(self:Tensor, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False):
|
||||
# TODO: supprt stride [] in tinygrad?
|
||||
if stride is not None and len(stride) == 0: stride = None
|
||||
# TODO: support return_indices in tinygrad
|
||||
ret = unwrap(self).max_pool2d(kernel_size, stride, dilation, padding, ceil_mode)
|
||||
# TODO: this is wrong
|
||||
@@ -79,6 +81,7 @@ def max_pool2d_with_indices(self:Tensor, kernel_size, stride=None, padding=0, di
|
||||
|
||||
@torch.library.impl("aten::max_pool2d_with_indices_backward", "privateuseone")
|
||||
def max_pool2d_with_indices_backward(grad_out:Tensor, self:Tensor, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, indices=None):
|
||||
if stride is not None and len(stride) == 0: stride = None
|
||||
# TODO: utilize input indices once they are correct
|
||||
grad_out, self = unwrap(grad_out), unwrap(self)
|
||||
out = Tensor.max_pool2d(self, kernel_size, stride, dilation, padding, ceil_mode)
|
||||
@@ -181,6 +184,7 @@ decomps = [
|
||||
aten.native_dropout, aten.native_dropout_backward,
|
||||
aten._softmax_backward_data, aten.embedding_dense_backward,
|
||||
aten.linalg_vector_norm,
|
||||
aten.binary_cross_entropy, aten.binary_cross_entropy_backward,
|
||||
# activations
|
||||
aten.hardswish, aten.hardswish_backward,
|
||||
aten.hardtanh, aten.hardtanh_backward,
|
||||
@@ -253,6 +257,7 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
|
||||
#"aten.arange.start_out": Tensor.arange,
|
||||
"aten.lerp.Scalar_out": Tensor.lerp,
|
||||
"aten.scatter.value_out": Tensor.scatter,
|
||||
"aten.where.self_out": Tensor.where,
|
||||
}}
|
||||
|
||||
# we add the "out" here
|
||||
@@ -297,7 +302,7 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
||||
out.replace(Tensor.sum(self, axis if axis is None or len(axis) else None, keepdim), allow_shape_mismatch=True),
|
||||
"aten.scatter.value": Tensor.scatter,
|
||||
"aten.gather": Tensor.gather,
|
||||
"aten.where.self": Tensor.where,
|
||||
"aten.where.self": Tensor.where, # NOTE: this is needed as well as the out type
|
||||
"aten._softmax": lambda self,dim,half_to_float: self.softmax(dim),
|
||||
"aten._log_softmax": lambda self,dim,half_to_float: self.log_softmax(dim),
|
||||
"aten.random_": lambda self:
|
||||
|
||||
Reference in New Issue
Block a user