num_batches_tracked long only if supported [pr] (#7582)

This commit is contained in:
Ahmed Harmouche
2024-11-08 12:28:21 +01:00
committed by GitHub
parent a1dfd288bb
commit d4e91b0de7

View File

@@ -1,7 +1,8 @@
from __future__ import annotations
import math
from typing import Optional, Union, Tuple, List
from tinygrad.tensor import Tensor
from tinygrad.tensor import Tensor, dtypes
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import prod, make_tuple, flatten
from tinygrad.nn import optim, state, datasets # noqa: F401
@@ -36,7 +37,7 @@ class BatchNorm:
self.weight: Optional[Tensor] = Tensor.ones(sz) if affine else None
self.bias: Optional[Tensor] = Tensor.zeros(sz) if affine else None
self.num_batches_tracked = Tensor.zeros(1, dtype='long', requires_grad=False)
self.num_batches_tracked = Tensor.zeros(1, dtype='long' if is_dtype_supported(dtypes.long) else 'int', requires_grad=False)
if track_running_stats: self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
def calc_stats(self, x:Tensor) -> Tuple[Tensor, Tensor]: